com.feedzai.fos.impl.r.RManager.java Source code

Java tutorial

Introduction

Here is the source code for com.feedzai.fos.impl.r.RManager.java

Source

/*
 * $#
 * FOS R implementation
 * 
 * Copyright (C) 2013 Feedzai SA
 * 
 * This software is licensed under the Apache License, Version 2.0 (the "Apache License") or the GNU
 * Lesser General Public License version 3 (the "GPL License"). You may choose either license to govern
 * your use of this software only upon the condition that you accept all of the terms of either the Apache
 * License or the LGPL License.
 *
 * You may obtain a copy of the Apache License and the LGPL License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0.txt
 * http://www.gnu.org/licenses/lgpl-3.0.txt
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the Apache License
 * or the LGPL License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied. See the Apache License and the LGPL License for the specific language governing
 * permissions and limitations under the Apache License and the LGPL License.
 * #$
 */

package com.feedzai.fos.impl.r;

import com.feedzai.fos.api.*;
import com.feedzai.fos.common.validation.NotBlank;
import com.feedzai.fos.common.validation.NotNull;
import com.feedzai.fos.impl.r.config.RManagerConfig;
import com.feedzai.fos.impl.r.config.RModelConfig;
import com.feedzai.fos.impl.r.rserve.FosRserve;
import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.Files;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.*;
import java.util.zip.GZIPOutputStream;

import static com.feedzai.fos.impl.r.RScorer.rVariableName;

import static com.feedzai.fos.api.util.ManagerUtils.createModelFile;
import static com.feedzai.fos.api.util.ManagerUtils.getUuid;
import static com.google.common.base.Preconditions.checkNotNull;

/**
 * This class provides a R implementation of a FOS Manager
 *
 * @since 1.0.2
 * @author miguel.duarte
 */
public class RManager implements Manager {
    /** R Manager logger */
    private final static Logger logger = LoggerFactory.getLogger(RManager.class);

    /** Handle for the RServer daemon */
    private final FosRserve rserve;
    /** Map that stores RModel configurations for each configured model */
    private Map<UUID, RModelConfig> modelConfigs = new HashMap<>();

    /** Manager configuration */
    private RManagerConfig rManagerConfig;

    /** Reference for an R scorer */
    private RScorer rScorer;

    /**
     * Default libraries for the R server.
     */
    private final Set<String> defaultLibraries = ImmutableSet.of("pmml");

    /**
     * Create a new manager from the given configuration.
     * <p/> Will lookup any headers files and to to instantiate the model.
     * <p/> If a model fails, a log is produced but loading other models will continue (no exception is thrown).
     *
     * @param rManagerConfig the manager configuration
     */
    public RManager(RManagerConfig rManagerConfig) throws FOSException {
        checkNotNull(rManagerConfig, "Manager config cannot be null");

        this.rManagerConfig = rManagerConfig;
        this.rserve = new FosRserve();

        this.rScorer = new RScorer(rserve, defaultLibraries.toArray(new String[] {}));
    }

    @Override
    public synchronized UUID addModel(ModelConfig config, Model model) throws FOSException {
        if (!(model instanceof ModelBinary)) {
            throw new FOSException("Currently FOS-R only supports binary models.");
        }

        try {
            UUID uuid = getUuid(config);

            File file = createModelFile(modelConfigs.get(uuid).getModel(), uuid, model);

            RModelConfig rModelConfig = new RModelConfig(config, rManagerConfig);
            rModelConfig.setId(uuid);
            rModelConfig.setModel(file);

            modelConfigs.put(uuid, rModelConfig);
            rScorer.addOrUpdate(rModelConfig);

            return uuid;
        } catch (IOException e) {
            throw new FOSException(e);
        }
    }

    @Override
    public synchronized UUID addModel(ModelConfig config, @NotBlank ModelDescriptor descriptor)
            throws FOSException {
        if (descriptor.getFormat() != ModelDescriptor.Format.BINARY) {
            throw new FOSException("Currently FOS-R only supports binary models.");
        }

        UUID uuid = getUuid(config);

        RModelConfig rModelConfig = new RModelConfig(config, rManagerConfig);
        rModelConfig.setId(uuid);
        rModelConfig.setModel(new File(descriptor.getModelFilePath()));

        modelConfigs.put(uuid, rModelConfig);
        rScorer.addOrUpdate(rModelConfig);

        return uuid;
    }

    @Override
    public synchronized void removeModel(UUID modelId) throws FOSException {
        RModelConfig rModelConfig = modelConfigs.remove(modelId);
        rScorer.removeModel(modelId);

        // delete the header & model  file (or else it will be picked up on the next restart)
        rModelConfig.getHeader().delete();
        rModelConfig.getModel().delete();
        rModelConfig.getPMMLModel().delete();
    }

    @Override
    public synchronized void reconfigureModel(UUID modelId, ModelConfig modelConfig) throws FOSException {
        RModelConfig rModelConfig = this.modelConfigs.get(modelId);
        rModelConfig.update(modelConfig);

        rScorer.addOrUpdate(rModelConfig);
    }

    @Override
    public synchronized void reconfigureModel(UUID modelId, ModelConfig modelConfig, Model model)
            throws FOSException {
        throw new FOSException("Model reconfiguration not yet supported for R");
    }

    @Override
    public synchronized void reconfigureModel(UUID modelId, ModelConfig modelConfig,
            @NotBlank ModelDescriptor descriptor) throws FOSException {
        if (descriptor.getFormat() != ModelDescriptor.Format.BINARY) {
            throw new FOSException("Currently FOS-R only supports binary models.");
        }

        File file = new File(descriptor.getModelFilePath());

        RModelConfig rModelConfig = this.modelConfigs.get(modelId);
        rModelConfig.update(modelConfig);
        rModelConfig.setModel(file);

        rScorer.addOrUpdate(rModelConfig);
    }

    @Override
    @NotNull
    public synchronized Map<UUID, ModelConfig> listModels() {
        Map<UUID, ModelConfig> result = new HashMap<>(modelConfigs.size());
        for (Map.Entry<UUID, RModelConfig> entry : modelConfigs.entrySet()) {
            result.put(entry.getKey(), entry.getValue().getModelConfig());
        }

        return result;
    }

    @Override
    @NotNull
    public RScorer getScorer() {
        return rScorer;
    }

    @Override
    public synchronized UUID trainAndAdd(ModelConfig config, List<Object[]> instances) throws FOSException {
        try {
            File instanceFile = writeInstancesToTempFile(instances, config.getAttributes());
            config.setProperty(RModelConfig.MODEL_SAVE_PATH, instanceFile.getParent());
            trainFile(config, instanceFile.getAbsolutePath());

            String trainedModelPath = new File(instanceFile.getParent(),
                    instanceFile.getName() + "." + RModelConfig.MODEL_FILE_EXTENSION).getAbsolutePath();
            ModelDescriptor trainedModelDescriptor = new ModelDescriptor(ModelDescriptor.Format.BINARY,
                    trainedModelPath);

            return addModel(config, trainedModelDescriptor);
        } catch (IOException e) {
            throw new FOSException(e);
        }
    }

    /**
     * Dump a training instances lists into a temporary file
     *
     * @param instances training instances list
     * @return Temporary file with the dumped training instances
     * @throws IOException
     */
    private File writeInstancesToTempFile(List<Object[]> instances, List<Attribute> attributeList)
            throws IOException {
        File instanceFile = File.createTempFile("fosrtraining", ".arff");

        PrintWriter pw = new PrintWriter(new FileOutputStream(instanceFile));

        pw.println("% FOS generated ARFF file");
        pw.println("@relation fosrelation");
        pw.println("");

        for (Attribute attribute : attributeList) {
            pw.print("@attribute " + rVariableName(attribute.getName()) + " ");
            if (attribute instanceof NumericAttribute) {
                pw.println("REAL");
            } else if (attribute instanceof CategoricalAttribute) {
                CategoricalAttribute cat = (CategoricalAttribute) attribute;
                pw.print("{ '");
                pw.print(Joiner.on("', '").join(cat.getCategoricalInstances()));
                pw.println("'}");
            }
        }
        pw.println();
        pw.println("@data");
        // Dump instances to file
        for (Object[] instance : instances) {
            /* ? is the missing value constant in ARFF files */
            pw.println(Joiner.on(',').useForNull("?").join(instance));
        }
        pw.close();
        return instanceFile;
    }

    @Override
    public synchronized UUID trainAndAddFile(ModelConfig config, String path) throws FOSException {
        trainFile(config, path);

        ModelDescriptor trainedModelDescriptor = new ModelDescriptor(ModelDescriptor.Format.BINARY,
                path + "." + RModelConfig.MODEL_FILE_EXTENSION);

        return addModel(config, trainedModelDescriptor);
    }

    @Override
    public Model train(ModelConfig config, List<Object[]> instances) throws FOSException {
        try {
            File instanceFile = writeInstancesToTempFile(instances, config.getAttributes());
            return trainFile(config, instanceFile.getAbsolutePath());
        } catch (IOException e) {
            throw new FOSException(e);
        }
    }

    /**
     * Generate R boilerplate code to train a model. By default it will use a build in implementation using random
     * randomForest. Another algorithm can be used by overriding <code>RModelConfig.TRAIN_FILE</code> and
     * <code>RModelConfig.TRAIN_FUNCTION</code>.
     *
     * Sample generated code
     * <pre>
     *    headersfile <- '/tmp/fosrtraining8499205938185291252.instances.header'
     *    instancesfile <- '/tmp/fosrtraining8499205938185291252.instances'
     *    class.name <- 'class'
     *
     *    categorical.features <- c(
     *    'A1',
     *    'A4',
     *    'A5',
     *    'A6',
     *    'A7',
     *    'A9',
     *    'A10',
     *    'A12',
     *    'A13',
     *    'class')
     *    modelsavepath <- '/tmp/fosrtraining8499205938185291252.instances.model'
     *    trainRmodel()
     * </pre>
     *
     *
     * @param config    the model configuration
     * @param path File with the training instances
     * @return
     * @throws FOSException
     */
    @Override
    public Model trainFile(ModelConfig config, String path) throws FOSException {
        String trainFile = config.getProperty(RModelConfig.TRAIN_FILE);
        String trainFunction = config.getProperty(RModelConfig.TRAIN_FUNCTION);
        if (trainFunction == null) {
            trainFunction = RModelConfig.BUILT_IN_TRAIN_FUNCTION;
        }

        String trainArguments = config.getProperty(RModelConfig.TRAIN_FUNCTION_ARGUMENTS);
        String trainScript = null;

        try {
            if (trainFile != null) {
                trainScript = Files.toString(new File(trainFile), Charsets.UTF_8);
            }

            String libraries = config.getProperty(RModelConfig.LIBRARIES);
            for (String library : libraries.trim().split(",")) {
                library = library.trim();
                if (library.length() > 0) {
                    rserve.eval(String.format("require(%1s)", library));
                }
            }

            // eval optional train script
            if (trainScript != null) {
                rserve.eval(trainScript);
            }

            List<Attribute> attributes = config.getAttributes();

            File modelSaveFile = new File(config.getProperty(RModelConfig.MODEL_SAVE_PATH),
                    (new File(path).getName()) + "." + RModelConfig.MODEL_FILE_EXTENSION);

            config.setProperty(RModelConfig.MODEL_FILE, modelSaveFile.getAbsolutePath());

            Attribute modelClass = attributes.get(config.getIntProperty(RModelConfig.CLASS_INDEX));

            // load training data
            rserve.eval(String.format("train.data <- read.arff('%s')", path));
            rserve.eval(String.format("classfn <- as.formula('%s ~ .')", rVariableName(modelClass.getName())));
            rserve.eval(String.format("model <- %s(formula = classfn, data = train.data%s)", trainFunction,
                    trainArguments != null ? ", " + trainArguments : ""));

            rserve.eval(String.format("save(model, file = '%1s')", modelSaveFile.getAbsolutePath()));

            return new ModelBinary(Files.toByteArray(modelSaveFile));

        } catch (Throwable e) {
            throw new FOSException(e);
        }

    }

    /**
     * Deletes the temporary PMML file of a model if it exists.
     *
     * @throws FOSException When there are IO problems reading a model's configuration.
     */
    @Override
    public synchronized void close() throws FOSException {
        for (UUID uuid : modelConfigs.keySet()) {
            File tempPMMLFile = new File(
                    modelConfigs.get(uuid).getModelConfig().getProperty(RModelConfig.PMML_FILE));
            if (tempPMMLFile.exists()) {
                logger.debug("Deleting temporary R PMML file '{}'.", tempPMMLFile.getAbsolutePath());
                tempPMMLFile.delete();
            }
        }
    }

    @Override
    public void save(UUID uuid, String savepath) throws FOSException {
        try {
            File source = modelConfigs.get(uuid).getModel();
            File destination = new File(savepath);
            Files.copy(source, destination);
        } catch (Exception e) {
            String msg = "Unable to save model " + uuid + " to " + savepath;
            logger.error(msg, e);
            throw new FOSException(msg);
        }
    }

    @Override
    public void saveAsPMML(UUID uuid, String filePath, boolean compress) throws FOSException {
        if (modelConfigs.containsKey(uuid)) {
            File source = new File(modelConfigs.get(uuid).getModelConfig().getProperty(RModelConfig.PMML_FILE));

            File target = new File(filePath);

            // If the PMML hasn't already been exported, generate it first.
            if (!source.exists()) {
                rserve.eval(rScorer.getSaveAsPMMLFunctionCall(uuid));
            }

            try {
                if (compress) {
                    logger.debug("Copying R PMML file to compressed file '{}'.", target.getAbsolutePath());

                    try (FileOutputStream fos = new FileOutputStream(target);
                            GZIPOutputStream gos = new GZIPOutputStream(fos)) {
                        Files.copy(source, gos);
                    }
                } else {
                    logger.debug("Copying R PMML to compressed file '{}'.", target.getAbsolutePath());

                    Files.copy(source, target);
                }
            } catch (IOException e) {
                throw new FOSException("Failed to copy PMML file to destination file '" + filePath + "'.");
            }

        } else {
            throw new FOSException("Unknown model with UUID " + uuid);
        }
    }
}