com.feedzai.fos.impl.r.RandomForestPMMLProducerConsumerTest.java Source code

Java tutorial

Introduction

Here is the source code for com.feedzai.fos.impl.r.RandomForestPMMLProducerConsumerTest.java

Source

/*
 * $#
 * FOS R implementation
 * 
 * Copyright (C) 2013 Feedzai SA
 * 
 * This software is licensed under the Apache License, Version 2.0 (the "Apache License") or the GNU
 * Lesser General Public License version 3 (the "GPL License"). You may choose either license to govern
 * your use of this software only upon the condition that you accept all of the terms of either the Apache
 * License or the LGPL License.
 *
 * You may obtain a copy of the Apache License and the LGPL License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0.txt
 * http://www.gnu.org/licenses/lgpl-3.0.txt
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the Apache License
 * or the LGPL License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied. See the Apache License and the LGPL License for the specific language governing
 * permissions and limitations under the Apache License and the LGPL License.
 * #$
 */
package com.feedzai.fos.impl.r;

import com.feedzai.fos.api.Attribute;
import com.feedzai.fos.api.FOSException;
import com.feedzai.fos.api.ModelConfig;
import com.feedzai.fos.api.config.FosConfig;
import com.feedzai.fos.impl.r.config.RManagerConfig;
import com.feedzai.fos.impl.r.config.RModelConfig;
import com.feedzai.fos.impl.r.rserve.FosRserve;
import org.apache.commons.configuration.BaseConfiguration;
import org.jpmml.model.JAXBUtil;
import org.junit.BeforeClass;
import org.junit.Test;

import javax.xml.transform.stream.StreamSource;
import java.io.File;
import java.io.FileInputStream;
import java.nio.file.Files;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.zip.GZIPInputStream;

/**
 * @author Ricardo Ferreira (ricardo.ferreira@feedzai.com)
 */
public class RandomForestPMMLProducerConsumerTest {

    static FosRserve rserve;

    @BeforeClass
    public static void init() throws FOSException {
        rserve = new FosRserve();
    }

    private ModelConfig setupConfig() {
        Map<String, String> properties = new HashMap<>();
        List<Attribute> attributes = RIntegrationTest.getAttributes();
        ModelConfig modelConfig = new ModelConfig(attributes, properties);
        modelConfig.setProperty(RModelConfig.LIBRARIES, "randomForest, foreign");
        modelConfig.setProperty(RModelConfig.TRAIN_FUNCTION, "randomForest");
        modelConfig.setProperty(RModelConfig.TRAIN_FUNCTION_ARGUMENTS, "na.action = na.roughfix");

        modelConfig.setProperty(RModelConfig.CLASS_INDEX, Integer.valueOf(attributes.size() - 1).toString());

        return modelConfig;
    }

    private RManager setupManager() throws FOSException {
        BaseConfiguration configuration = new BaseConfiguration();

        configuration.setProperty(FosConfig.FACTORY_NAME, RManagerFactory.class.getName());

        FosConfig config = new FosConfig(configuration);

        RManagerConfig rManagerConfig = new RManagerConfig(config);

        return new RManager(rManagerConfig);
    }

    @Test
    public void testUncompressed() throws Exception {
        ModelConfig modelConfig = setupConfig();
        RManager rManager = setupManager();

        UUID uuid = rManager.trainAndAdd(modelConfig, RIntegrationTest.getTrainingInstances());

        File targetFile = Files.createTempFile("targetPMML", ".xml").toFile();

        // Save the model as PMML and load it.
        rManager.saveAsPMML(uuid, targetFile.getAbsolutePath(), false);

        try (FileInputStream fis = new FileInputStream(targetFile)) {
            JAXBUtil.unmarshalPMML(new StreamSource(fis));
        }

        targetFile.delete();
    }

    @Test
    public void testCompressed() throws Exception {
        ModelConfig modelConfig = setupConfig();
        RManager rManager = setupManager();

        UUID uuid = rManager.trainAndAdd(modelConfig, RIntegrationTest.getTrainingInstances());

        File targetFile = Files.createTempFile("targetPMML", ".xml").toFile();

        // Save the model as PMML and load it.
        rManager.saveAsPMML(uuid, targetFile.getAbsolutePath(), true);

        try (FileInputStream fis = new FileInputStream(targetFile);
                GZIPInputStream gis = new GZIPInputStream(fis)) {

            JAXBUtil.unmarshalPMML(new StreamSource(gis));
        }

        targetFile.delete();
    }

}