Java tutorial
/* * Copyright (C) 2015 SoftIndex LLC. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.wpm.codegen; import org.objectweb.asm.Type; import org.objectweb.asm.commons.GeneratorAdapter; import org.objectweb.asm.commons.Method; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import net.wpm.codegen.ClassBuilder; import net.wpm.codegen.ClassScope; import net.wpm.codegen.Context; import net.wpm.codegen.Expression; import net.wpm.codegen.utils.DefiningClassLoader; import net.wpm.codegen.utils.DefiningClassWriter; import net.wpm.codegen.utils.Preconditions; import java.io.FileOutputStream; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.nio.file.Path; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import static net.wpm.codegen.Expressions.sequence; import static java.util.Arrays.asList; import static net.wpm.codegen.Utils.loadAndCast; import static org.objectweb.asm.Opcodes.*; import static org.objectweb.asm.Type.getInternalName; import static org.objectweb.asm.Type.getType; import static org.objectweb.asm.commons.Method.getMethod; /** * Intends for dynamic description of the behaviour of the object in runtime * * @param <T> type of item */ @SuppressWarnings("unchecked") public class ClassBuilder<T> { private final Logger logger = LoggerFactory.getLogger(this.getClass()); public static final String DEFAULT_CLASS_NAME = ClassBuilder.class.getPackage().getName() + ".Class"; private static final AtomicInteger COUNTER = new AtomicInteger(); private final DefiningClassLoader classLoader; private Path bytecodeSaveDir; private final ClassScope<T> scope; private Expression staticConstructor = sequence(Collections.EMPTY_LIST); private Expression constructor = null; private final Map<String, Class<?>> fields = new LinkedHashMap<String, Class<?>>(); private final Map<String, Class<?>> staticFields = new LinkedHashMap<String, Class<?>>(); private final Map<String, Class<?>> staticConstants = new LinkedHashMap<String, Class<?>>(); private final Map<Method, Expression> methods = new LinkedHashMap<Method, Expression>(); private final Map<Method, Expression> staticMethods = new LinkedHashMap<Method, Expression>(); public ClassBuilder<T> setBytecodeSaveDir(Path bytecodeSaveDir) { this.bytecodeSaveDir = bytecodeSaveDir; return this; } public static class AsmClassKey<T> { private final Set<Class<?>> parentClasses; private final Map<String, Class<?>> fields; private final Map<String, Class<?>> staticFields; private final Map<Method, Expression> expressionMap; private final Map<Method, Expression> expressionStaticMap; public AsmClassKey(Set<Class<?>> parentClasses, Map<String, Class<?>> fields, Map<String, Class<?>> staticFields, Map<Method, Expression> expressionMap, Map<Method, Expression> expressionStaticMap) { this.parentClasses = parentClasses; this.fields = fields; this.staticFields = staticFields; this.expressionMap = expressionMap; this.expressionStaticMap = expressionStaticMap; } public Set<Class<?>> getParentClasses() { return parentClasses; } @Override public String toString() { return "AsmClassKey{" + "parentClasses=" + parentClasses + ", fields=" + fields + ", staticFields=" + staticFields + ", expressionMap=" + expressionMap + ", expressionStaticMap=" + expressionStaticMap + '}'; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; AsmClassKey<?> that = (AsmClassKey<?>) o; return Objects.equals(parentClasses, that.parentClasses) && Objects.equals(fields, that.fields) && Objects.equals(staticFields, that.staticFields) && Objects.equals(expressionMap, that.expressionMap) && Objects.equals(expressionStaticMap, that.expressionStaticMap); } @Override public int hashCode() { return Objects.hash(parentClasses, fields, expressionMap, expressionStaticMap); } } /** * Creates a new instance of AsmFunctionFactory * * @param classLoader class loader * @param type type of dynamic class */ public ClassBuilder(DefiningClassLoader classLoader, Class<T> type) { this(classLoader, type, Collections.EMPTY_LIST); } public ClassBuilder(DefiningClassLoader classLoader, Class<T> mainType, List<Class<?>> types) { this.classLoader = classLoader; this.scope = new ClassScope<T>(mainType, types); } public ClassBuilder<T> constructor(Expression expression) { constructor = expression; return this; } /** * Creates a new field for a dynamic class * * @param field name of field * @param fieldClass type of field * @return changed AsmFunctionFactory */ public ClassBuilder<T> field(String field, Class<?> fieldClass) { fields.put(field, fieldClass); scope.addField(field, fieldClass); return this; } /** * Creates a new static field for a dynamic class * * @param field * @param fieldClass * @return changed AsmBuilder */ public ClassBuilder<T> staticField(String field, Class<?> fieldClass) { staticFields.put(field, fieldClass); scope.addStaticField(field, fieldClass); return this; } /** * Creates a new static field for a dynamic class * * @param field * @param fieldClass * @return changed AsmBuilder */ public ClassBuilder<T> staticConstant(String field, Class<?> fieldClass) { staticConstants.put(field, fieldClass); scope.addStaticField(field, fieldClass); return this; } /** * Creates a new method for a dynamic class * * @param method new method for class * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> method(Method method, Expression expression) { methods.put(method, expression); scope.addMethod(method); return this; } public ClassBuilder<T> staticMethod(Method method, Expression expression) { staticMethods.put(method, expression); scope.addStaticMethod(method); return this; } /** * Creates a new method for a dynamic class * * @param methodName name of method * @param returnClass type which returns this method * @param argumentTypes list of types of arguments * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> method(String methodName, Class<?> returnClass, List<? extends Class<?>> argumentTypes, Expression expression) { Type[] types = new Type[argumentTypes.size()]; for (int i = 0; i < argumentTypes.size(); i++) { types[i] = getType(argumentTypes.get(i)); } return method(new Method(methodName, getType(returnClass), types), expression); } /** * Create a new static method for a dynamic class * * @param methodName name of method * @param returnClass type which returns this method * @param argumentTypes list of types of arguments * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> staticMethod(String methodName, Class<?> returnClass, List<? extends Class<?>> argumentTypes, Expression expression) { Type[] types = new Type[argumentTypes.size()]; for (int i = 0; i < argumentTypes.size(); i++) { types[i] = getType(argumentTypes.get(i)); } return staticMethod(new Method(methodName, getType(returnClass), types), expression); } /** * Create a new static initialization block for a dynamic class. Overwrites the existing one. * * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> staticInitializationBlock(Expression expression) { staticConstructor = sequence(staticConstructor, expression); return staticMethod("<clinit>", void.class, Collections.EMPTY_LIST, staticConstructor); } /** * Creates a new method for a dynamic class. The method must be part of the provided interfaces or abstract class. * * @param methodName name of method * @param expression function which will be processed * @return changed AsmFunctionFactory */ public ClassBuilder<T> method(String methodName, Expression expression) { if (methodName.contains("(")) { Method method = Method.getMethod(methodName); return method(method, expression); } Method foundMethod = null; List<List<java.lang.reflect.Method>> listOfMethods = new ArrayList<List<java.lang.reflect.Method>>(); listOfMethods.add(asList(Object.class.getMethods())); for (Class<?> type : scope.getParentClasses()) { listOfMethods.add(asList(type.getMethods())); listOfMethods.add(asList(type.getDeclaredMethods())); } for (List<java.lang.reflect.Method> list : listOfMethods) { for (java.lang.reflect.Method m : list) { if (m.getName().equals(methodName)) { Method method = getMethod(m); if (foundMethod != null && !method.equals(foundMethod)) throw new IllegalArgumentException("Method " + method + " collides with " + foundMethod); foundMethod = method; } } } Preconditions.check(foundMethod != null, "Could not find method '" + methodName + "'"); return method(foundMethod, expression); } /** * Returns a new class which is created in a dynamic way * * @return completed class */ public Class<T> build() { return build(null); } public Class<T> build(String className) { synchronized (classLoader) { AsmClassKey<T> key = new AsmClassKey<T>(scope.getParentClasses(), fields, staticFields, methods, staticMethods); Class<?> cachedClass = classLoader.getClassByKey(key); if (cachedClass != null) { logger.trace("Fetching {} for key {} from cache", cachedClass, key); return (Class<T>) cachedClass; } return defineNewClass(key, className); } } /** * Returns a new class which is created in a dynamic way * * @param key key * @return completed class */ private Class<T> defineNewClass(AsmClassKey<T> key, String newClassName) { DefiningClassWriter cw = new DefiningClassWriter(classLoader); String className; if (newClassName == null) { className = DEFAULT_CLASS_NAME + COUNTER.incrementAndGet(); } else { className = newClassName; } Type classType = getType('L' + className.replace('.', '/') + ';'); // contains all classes (abstract and interfaces) final Set<Class<?>> parentClasses = scope.getParentClasses(); final String[] internalNames = new String[parentClasses.size()]; int pos = 0; for (Class<?> clazz : parentClasses) internalNames[pos++] = getInternalName(clazz); if (scope.getMainType().isInterface()) { cw.visit(V1_6, ACC_PUBLIC + ACC_FINAL + ACC_SUPER, classType.getInternalName(), null, "java/lang/Object", internalNames); } else { cw.visit(V1_6, ACC_PUBLIC + ACC_FINAL + ACC_SUPER, classType.getInternalName(), null, internalNames[0], Arrays.copyOfRange(internalNames, 1, internalNames.length)); } { Method m = getMethod("void <init> ()"); GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC, m, null, null, cw); if (constructor != null) { Context ctx = new Context(classLoader, g, classType, scope.getParentClasses(), scope.getFields(), scope.getStaticFields(), m.getArgumentTypes(), m, scope.getMethods(), scope.getStaticMethods()); loadAndCast(ctx, constructor, m.getReturnType()); } g.loadThis(); if (scope.getMainType().isInterface()) { g.invokeConstructor(getType(Object.class), m); } else { g.invokeConstructor(getType(scope.getMainType()), m); } g.returnValue(); g.endMethod(); } for (String field : fields.keySet()) { Class<?> fieldClass = fields.get(field); cw.visitField(ACC_PUBLIC, field, getType(fieldClass).getDescriptor(), null, null); } for (String field : staticFields.keySet()) { Class<?> fieldClass = staticFields.get(field); cw.visitField(ACC_PUBLIC + ACC_STATIC, field, getType(fieldClass).getDescriptor(), null, null); } for (String field : staticConstants.keySet()) { Class<?> fieldClass = staticConstants.get(field); cw.visitField(ACC_PUBLIC + ACC_STATIC + ACC_FINAL, field, getType(fieldClass).getDescriptor(), null, null); } for (Method m : staticMethods.keySet()) { try { GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC + ACC_STATIC + ACC_FINAL, m, null, null, cw); Context ctx = new Context(classLoader, g, classType, scope.getParentClasses(), Collections.EMPTY_MAP, scope.getStaticFields(), m.getArgumentTypes(), m, scope.getMethods(), scope.getStaticMethods()); Expression expression = staticMethods.get(m); loadAndCast(ctx, expression, m.getReturnType()); g.returnValue(); g.endMethod(); } catch (Exception e) { throw new RuntimeException("Unable to implement " + m.getName() + m.getDescriptor(), e); } } for (Method m : methods.keySet()) { try { GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC + ACC_FINAL, m, null, null, cw); Context ctx = new Context(classLoader, g, classType, scope.getParentClasses(), scope.getFields(), scope.getStaticFields(), m.getArgumentTypes(), m, scope.getMethods(), scope.getStaticMethods()); Expression expression = methods.get(m); loadAndCast(ctx, expression, m.getReturnType()); g.returnValue(); g.endMethod(); } catch (Exception e) { throw new RuntimeException("Unable to implement " + m.getName() + m.getDescriptor(), e); } } if (bytecodeSaveDir != null) { FileOutputStream fos = null; try { fos = new FileOutputStream(bytecodeSaveDir.resolve(className + ".class").toFile()); fos.write(cw.toByteArray()); fos.close(); } catch (IOException e) { try { if (fos != null) fos.close(); } catch (IOException ioe) { throw new RuntimeException(ioe); } throw new RuntimeException(e); } } cw.visitEnd(); Class<?> definedClass = classLoader.defineClass(className, key, cw.toByteArray()); logger.trace("Defined new {} for key {}", definedClass, key); return (Class<T>) definedClass; } /** * Returns a new instance of a dynamic class * * @return new instance of the class which was created before in a dynamic way */ public T buildClassAndCreateNewInstance() { try { return build().newInstance(); } catch (InstantiationException e) { throw new RuntimeException(e); } catch (IllegalAccessException e) { throw new RuntimeException(e); } } @SuppressWarnings("rawtypes") public T buildClassAndCreateNewInstance(Object... constructorParameters) { Class[] constructorParameterTypes = new Class[constructorParameters.length]; for (int i = 0; i < constructorParameters.length; i++) { constructorParameterTypes[i] = constructorParameters[i].getClass(); } return buildClassAndCreateNewInstance(constructorParameterTypes, constructorParameters); } @SuppressWarnings("rawtypes") public T buildClassAndCreateNewInstance(Class[] constructorParameterTypes, Object[] constructorParameters) { try { return build().getConstructor(constructorParameterTypes).newInstance(constructorParameters); } catch (InstantiationException e) { throw new RuntimeException(e); } catch (IllegalAccessException e) { throw new RuntimeException(e); } catch (NoSuchMethodException e) { throw new RuntimeException(e); } catch (InvocationTargetException e) { throw new RuntimeException(e); } } }