com.ejisto.core.classloading.ClassTransformerImpl.java Source code

Java tutorial

Introduction

Here is the source code for com.ejisto.core.classloading.ClassTransformerImpl.java

Source

/*
 * Ejisto, a powerful developer assistant
 *
 * Copyright (C) 2010-2013 Celestino Bellone
 *
 * Ejisto is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Ejisto is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package com.ejisto.core.classloading;

import com.ejisto.core.classloading.javassist.EjistoMethodFilter;
import com.ejisto.core.classloading.javassist.ObjectEditor;
import com.ejisto.modules.dao.entities.MockedField;
import com.ejisto.modules.repository.MockedFieldsRepository;
import com.ejisto.sl.ClassTransformer;
import javassist.*;
import javassist.bytecode.AccessFlag;
import org.apache.commons.lang3.StringUtils;
import org.apache.log4j.Logger;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.lang.instrument.IllegalClassFormatException;
import java.security.ProtectionDomain;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static com.ejisto.constants.StringConstants.EJISTO_CLASS_TRANSFORMER_CATEGORY;
import static com.ejisto.modules.web.MockedFieldRequest.requestAllClasses;
import static java.lang.String.format;
import static java.lang.Thread.currentThread;

public class ClassTransformerImpl implements ClassTransformer {

    private static final Logger LOGGER = Logger.getLogger(EJISTO_CLASS_TRANSFORMER_CATEGORY.getValue());
    private final String contextPath;
    private final Collection<String> registeredClassNames;
    private final MockedFieldsRepository mockedFieldsRepository;
    private final AtomicReference<ClassPool> classPoolContainer = new AtomicReference<>();

    public ClassTransformerImpl(String contextPath, MockedFieldsRepository mockedFieldsRepository) {
        this.contextPath = contextPath;
        this.mockedFieldsRepository = mockedFieldsRepository;
        this.registeredClassNames = loadAllRegisteredClassNames(contextPath, mockedFieldsRepository);
    }

    private static Collection<String> loadAllRegisteredClassNames(String contextPath,
            MockedFieldsRepository mockedFieldsRepository) {
        Collection<MockedField> fields = mockedFieldsRepository.load(requestAllClasses(contextPath));
        Set<String> classes = fields.stream().map(MockedField::getClassName).collect(Collectors.toSet());
        trace(format("filtered classes for %s: %s of %s", contextPath, classes, fields));
        return classes;
    }

    private static void trace(String s) {
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace(s);
        }
    }

    private ClassPool getClassPool() {
        if (classPoolContainer.get() == null) {
            ClassPool classPool = new ClassPool();
            classPool.appendClassPath(new LoaderClassPath(currentThread().getContextClassLoader()));
            classPoolContainer.compareAndSet(null, classPool);
        }
        return classPoolContainer.get();
    }

    public Class<?> transform(String className) throws CannotCompileException, NotFoundException {
        CtClass clazz = instrument(className);
        Class<?> transformedClass = clazz.toClass();
        clazz.detach();
        return transformedClass;
    }

    private CtClass instrument(String className) throws NotFoundException, CannotCompileException {
        CtClass clazz = load(className);
        removeFinalModifier(clazz);
        addDefaultConstructor(clazz);
        List<MockedField> configuredFields = getFieldsFor(className);
        addMissingProperties(clazz, configuredFields);
        ObjectEditor editor = new ObjectEditor(new EjistoMethodFilter(contextPath, configuredFields));
        clazz.instrument(editor);
        return clazz;
    }

    private List<MockedField> getFieldsFor(String className) {
        return mockedFieldsRepository.load(contextPath, className);
    }

    private void addDefaultConstructor(CtClass clazz) throws NotFoundException, CannotCompileException {
        boolean found = false;
        for (CtConstructor constructor : clazz.getConstructors()) {
            if (constructor.getParameterTypes().length == 0) {
                found = true;
                break;
            }
        }
        if (!found) {
            //before adding a new empty constructor, we must scan all the fields in order to find
            //final ones. Final modifier must be removed.
            removeFinalModifierFromFields(clazz.getFields());
            CtConstructor defaultConstructor = new CtConstructor(new CtClass[0], clazz);
            defaultConstructor.setBody(null);//default constructor only calls "super()"
            clazz.addConstructor(defaultConstructor);
        }
    }

    private void removeFinalModifierFromFields(CtField[] fields) {
        Arrays.stream(fields).filter(field -> Modifier.isFinal(field.getModifiers()))
                .forEach(field -> field.setModifiers(Modifier.clear(field.getModifiers(), Modifier.FINAL)));
    }

    private void removeFinalModifier(CtClass clazz) {
        int modifiers = clazz.getModifiers();
        if (Modifier.isFinal(clazz.getModifiers())) {
            int cleanModifiers = Modifier.clear(modifiers, Modifier.FINAL);
            clazz.setModifiers(cleanModifiers);
        }
    }

    private CtClass load(String className) throws NotFoundException {
        CtClass clazz = getClassPool().get(className.replaceAll("/", "."));
        if (clazz.isFrozen()) {
            clazz.defrost();
        }
        return clazz;
    }

    public boolean addMissingProperties(CtClass clazz, List<MockedField> configuredFields)
            throws CannotCompileException, NotFoundException {
        boolean added = false;
        trace("trying to add missing properties");
        for (MockedField field : configuredFields) {
            added |= createPropertyIfNotFound(clazz, field);
        }
        return added;
    }

    byte[] addMissingProperties(byte[] original, String className, List<MockedField> fields)
            throws CannotCompileException, NotFoundException, IOException {
        return addMissingProperties(original, className, fields, x -> {
        });
    }

    byte[] addMissingProperties(byte[] original, String className, List<MockedField> fields,
            Consumer<CtClass> consumer) throws CannotCompileException, NotFoundException, IOException {
        CtClass clazz = null;
        try {
            if (original != null) {
                clazz = getClassPool().makeClass(new ByteArrayInputStream(original));
            } else {
                clazz = load(className);
            }
            if (addMissingProperties(clazz, fields)) {
                clazz.rebuildClassFile();
                consumer.accept(clazz);
                return clazz.toBytecode();
            }
            return original;
        } finally {
            if (clazz != null) {
                clazz.detach();
            }
        }
    }

    public void resetClassPool() {
        classPoolContainer.set(null);
    }

    private boolean createPropertyIfNotFound(CtClass clazz, MockedField field)
            throws CannotCompileException, NotFoundException {
        try {
            clazz.getField(field.getFieldName());
        } catch (NotFoundException e) {
            createMissingProperty(clazz, field);
            return true;
        }
        return false;
    }

    private void createMissingProperty(CtClass clazz, MockedField mockedField)
            throws CannotCompileException, NotFoundException {
        trace("creating property " + mockedField.getFieldName());
        CtField ctField = new CtField(load(mockedField.getFieldType()), mockedField.getFieldName(), clazz);
        ctField.setModifiers(AccessFlag.PRIVATE);
        clazz.addField(ctField);
        String methodSuffix = StringUtils.capitalize(mockedField.getFieldName());
        trace("creating getter: get" + methodSuffix);
        CtMethod getter = CtNewMethod.getter("get" + methodSuffix, ctField);
        trace(format("created [%s]", getter.getSignature()));
        clazz.addMethod(getter);
        trace("creating setter...");
        CtMethod setter = CtNewMethod.setter("set" + methodSuffix, ctField);
        trace(format("created [%s]", setter.getSignature()));
        clazz.addMethod(setter);
        trace("done.");
    }

    @Override
    public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined,
            ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {
        if (!isInstrumentableClass(className)) {
            return null;
        }
        trace(className + " is instrumentable. Loading fields...");
        List<MockedField> fields = getFieldsFor(getCanonicalClassName(className));
        boolean hasFields = fields != null && !fields.isEmpty();
        trace(className + " has registered fields: " + hasFields);
        if (!hasFields) {
            return null;
        } else {
            return transform(className, classfileBuffer, fields);
        }
    }

    @Override
    public boolean isInstrumentableClass(String name) {
        return registeredClassNames.contains(getCanonicalClassName(name));
    }

    private String getCanonicalClassName(String path) {
        return path.replaceAll("/", ".");
    }

    private byte[] transform(String className, byte[] classFileBuffer, List<MockedField> mockedFields)
            throws IllegalClassFormatException {
        try {
            trace("retrieving " + className + " from pool");
            ClassPool classPool = getClassPool();
            CtClass clazz;
            if (classFileBuffer == null) {
                clazz = classPool.get(getCanonicalClassName(className));
            } else {
                trace("class file bytes already loaded. Reusing it...");
                clazz = classPool.makeClass(new ByteArrayInputStream(classFileBuffer));
            }
            trace("instrumenting " + className);
            clazz.instrument(new ObjectEditor(new EjistoMethodFilter(contextPath, mockedFields)));
            trace("removing final modifier (if present)");
            removeFinalModifier(clazz);
            trace("adding default constructor, if none present ");
            addDefaultConstructor(clazz);
            trace("done. Returning bytecode");
            clazz.rebuildClassFile();
            byte[] result = clazz.toBytecode();
            clazz.detach();
            return result;
        } catch (Exception e) {
            LOGGER.error("error during transformation of class " + className, e);
            throw new IllegalClassFormatException(e.getMessage());
        }
    }
}