1 | |
|
2 | |
|
3 | |
|
4 | |
|
5 | |
|
6 | |
|
7 | |
|
8 | |
|
9 | |
|
10 | |
|
11 | |
|
12 | |
|
13 | |
|
14 | |
|
15 | |
|
16 | |
|
17 | |
|
18 | |
|
19 | |
package org.boretti.drools.integration.drools5.implementation; |
20 | |
|
21 | |
import java.lang.reflect.Array; |
22 | |
import java.util.ArrayList; |
23 | |
import java.util.HashMap; |
24 | |
import java.util.Iterator; |
25 | |
import java.util.List; |
26 | |
import java.util.Map; |
27 | |
import java.util.Map.Entry; |
28 | |
|
29 | |
import org.apache.log4j.Logger; |
30 | |
import org.boretti.drools.integration.drools5.DroolsInterface; |
31 | |
import org.drools.RuleBase; |
32 | |
import org.drools.StatefulSession; |
33 | |
import org.drools.StatelessSession; |
34 | |
|
35 | |
|
36 | |
|
37 | |
|
38 | |
|
39 | |
|
40 | |
abstract class DroolsInterfaceImplementation<T> implements DroolsInterface{ |
41 | |
|
42 | 114 | private final Logger logger = Logger.getLogger(this.getClass()); |
43 | |
|
44 | |
private Class<T> clazz; |
45 | |
|
46 | |
private RuleBase ruleBase; |
47 | |
|
48 | |
private StatefulSession session; |
49 | |
|
50 | |
private static final String FIELD_NAME_LOGGER = "droolsLogger"; |
51 | |
|
52 | 114 | private Map<String,DroolsVariableImplementation> fields = new HashMap<String,DroolsVariableImplementation>(); |
53 | |
|
54 | |
@Override |
55 | |
public RuleBase getRuleBase() { |
56 | 92 | return ruleBase; |
57 | |
} |
58 | |
|
59 | 114 | protected DroolsInterfaceImplementation(Class<T> clazz,RuleBase ruleBase) { |
60 | 114 | this.clazz=clazz; |
61 | 114 | this.ruleBase=ruleBase; |
62 | 114 | fields.put(FIELD_NAME_LOGGER,new DroolsVariableImplementation(true,Logger.getLogger(clazz))); |
63 | 114 | } |
64 | |
|
65 | |
public Class<T> getClazz() { |
66 | 0 | return clazz; |
67 | |
} |
68 | |
|
69 | |
@Override |
70 | |
public StatefulSession getCurrentSession() { |
71 | 42 | return session; |
72 | |
} |
73 | |
|
74 | |
void setCurrentSession(StatefulSession session) { |
75 | 10 | this.session=session; |
76 | 10 | } |
77 | |
|
78 | |
Iterator<?> runStatelessSession(List<Object> fact) { |
79 | 86 | StatelessSession dss = getRuleBase().newStatelessSession(); |
80 | 86 | for(Entry<String,DroolsVariableImplementation> es:fields.entrySet()) { |
81 | 98 | String field = es.getKey(); |
82 | 98 | DroolsVariableImplementation data = es.getValue(); |
83 | 98 | if (data.isGlobal()) dss.setGlobal(field, data.getData()); |
84 | 8 | else fact.add(data.getData()); |
85 | 98 | } |
86 | 86 | new DroolsWorkingMemoryLogger(dss); |
87 | 86 | return dss.executeWithResults(fact).iterateObjects(); |
88 | |
} |
89 | |
|
90 | |
Iterator<?> runStatefulSession(List<Object> fact,boolean requiredExisting) { |
91 | 58 | if (requiredExisting && session==null) { |
92 | 6 | String msg = "No session found when required"; |
93 | 6 | logger.error(msg); |
94 | 6 | throw new IllegalArgumentException(msg); |
95 | |
} |
96 | 52 | if (session==null) { |
97 | 16 | session=ruleBase.newStatefulSession(); |
98 | 16 | for(Entry<String,DroolsVariableImplementation> es:fields.entrySet()) { |
99 | 20 | String field = es.getKey(); |
100 | 20 | DroolsVariableImplementation data = es.getValue(); |
101 | 20 | if (data.isGlobal()) session.setGlobal(field, data.getData()); |
102 | 2 | else session.insert(data.getData()); |
103 | 20 | } |
104 | 16 | new DroolsWorkingMemoryLogger(session); |
105 | |
} |
106 | 52 | for(Object o:fact) session.insert(o); |
107 | 52 | session.fireAllRules(); |
108 | 52 | return session.iterateObjects(); |
109 | |
} |
110 | |
|
111 | |
Object getResult(Class<?> returnType,Iterator<?> result) { |
112 | 138 | return getInternalResult(returnType,result); |
113 | |
} |
114 | |
|
115 | |
Object getInternalResult(Class<?> returnType,Iterator<?> result) { |
116 | 138 | if (returnType.equals(Void.TYPE)) return null; |
117 | 134 | if (logger.isDebugEnabled()) logger.debug("Type is "+returnType); |
118 | 134 | if (result!=null) { |
119 | 134 | if (returnType.isArray()) { |
120 | 4 | Class<?> real = returnType.getComponentType(); |
121 | 4 | if (logger.isDebugEnabled()) logger.debug("Type is array of "+real); |
122 | 4 | List<Object> lst = new ArrayList<Object>(); |
123 | 4 | Iterator<?> i = result; |
124 | 16 | while(i.hasNext()) { |
125 | 12 | Object o = i.next(); |
126 | 12 | if (logger.isDebugEnabled()) logger.debug("In evaluation object is "+o); |
127 | 12 | if (isTypeCompatible(real,o.getClass())) lst.add(o); |
128 | 12 | } |
129 | 4 | return lst.toArray((Object[])Array.newInstance(real,lst.size())); |
130 | |
} else { |
131 | 130 | Iterator<?> i = result; |
132 | 150 | while(i.hasNext()) { |
133 | 132 | Object o = i.next(); |
134 | 132 | if (logger.isDebugEnabled()) logger.debug("In evaluation object is "+o); |
135 | 132 | if (isTypeCompatible(returnType,o.getClass())) return o; |
136 | 20 | } |
137 | |
} |
138 | |
} |
139 | 18 | return null; |
140 | |
} |
141 | |
|
142 | |
|
143 | |
|
144 | |
|
145 | |
|
146 | |
|
147 | |
|
148 | |
|
149 | |
private boolean isTypeCompatible(Class<?> type,Class<?> src) { |
150 | 144 | if (logger.isDebugEnabled()) logger.debug("type is "+type+" ; src is "+src); |
151 | 144 | if (type.equals(src)) return true; |
152 | 44 | if (type.isAssignableFrom(src)) return true; |
153 | 44 | if (type.isPrimitive()) { |
154 | 20 | if (type.equals(Boolean.TYPE) && src.equals(Boolean.class)) return true; |
155 | 16 | if (src.equals(Boolean.TYPE) && type.equals(Boolean.class)) return true; |
156 | 16 | if (type.equals(Integer.TYPE) && src.equals(Integer.class)) return true; |
157 | 12 | if (src.equals(Integer.TYPE) && type.equals(Integer.class)) return true; |
158 | 12 | if (type.equals(Long.TYPE) && src.equals(Long.class)) return true; |
159 | 8 | if (src.equals(Long.TYPE) && type.equals(Long.class)) return true; |
160 | 8 | if (type.equals(Short.TYPE) && src.equals(Short.class)) return true; |
161 | 8 | if (src.equals(Short.TYPE) && type.equals(Short.class)) return true; |
162 | 8 | if (type.equals(Byte.TYPE) && src.equals(Byte.class)) return true; |
163 | 8 | if (src.equals(Byte.TYPE) && type.equals(Byte.class)) return true; |
164 | 8 | if (type.equals(Float.TYPE) && src.equals(Float.class)) return true; |
165 | 4 | if (src.equals(Float.TYPE) && type.equals(Float.class)) return true; |
166 | 4 | if (type.equals(Double.TYPE) && src.equals(Double.class)) return true; |
167 | 0 | if (src.equals(Double.TYPE) && type.equals(Double.class)) return true; |
168 | 0 | if (type.equals(Character.TYPE) && src.equals(Character.class)) return true; |
169 | 0 | if (src.equals(Character.TYPE) && type.equals(Character.class)) return true; |
170 | |
} |
171 | 24 | return false; |
172 | |
} |
173 | |
|
174 | |
|
175 | |
|
176 | |
|
177 | |
Map<String, DroolsVariableImplementation> getFields() { |
178 | 20 | return fields; |
179 | |
} |
180 | |
|
181 | |
|
182 | |
|
183 | |
|
184 | |
@SuppressWarnings("unchecked") |
185 | |
@Override |
186 | |
public <T1> T1 getFieldByName(String name, Class<T1> clazz) { |
187 | 0 | DroolsVariableImplementation dvi = fields.get(name); |
188 | 0 | if (dvi==null) return null; |
189 | 0 | if (isTypeCompatible(clazz,dvi.getData().getClass())) return (T1)dvi.getData(); |
190 | 0 | String msg = "Expected "+clazz+"; found "+dvi.getClass().getClass(); |
191 | 0 | logger.error(msg); |
192 | 0 | throw new IllegalArgumentException(msg); |
193 | |
} |
194 | |
|
195 | |
} |