com.cloudera.oryx.rdf.computation.RDFDistributedGenerationRunner.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.rdf.computation.RDFDistributedGenerationRunner.java

Source

/*
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.rdf.computation;

import com.google.common.collect.Maps;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.dmg.pmml.IOUtil;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.StringReader;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.cloudera.oryx.common.io.IOUtils;
import com.cloudera.oryx.common.iterator.FileLineIterable;
import com.cloudera.oryx.common.servcomp.Namespaces;
import com.cloudera.oryx.common.servcomp.Store;
import com.cloudera.oryx.computation.common.DependsOn;
import com.cloudera.oryx.computation.common.DistributedGenerationRunner;
import com.cloudera.oryx.computation.common.JobStep;
import com.cloudera.oryx.computation.common.JobStepConfig;
import com.cloudera.oryx.rdf.computation.build.BuildTreesStep;
import com.cloudera.oryx.rdf.computation.build.MergeNewOldStep;

/**
 * @author Sean Owen
 */
public final class RDFDistributedGenerationRunner extends DistributedGenerationRunner {

    private static final Logger log = LoggerFactory.getLogger(RDFDistributedGenerationRunner.class);

    @Override
    protected List<DependsOn<Class<? extends JobStep>>> getPreDependencies() {
        return Collections.emptyList();
    }

    @Override
    protected List<DependsOn<Class<? extends JobStep>>> getIterationDependencies() {
        return Collections.singletonList(
                DependsOn.<Class<? extends JobStep>>nextAfterFirst(BuildTreesStep.class, MergeNewOldStep.class));
    }

    @Override
    protected List<DependsOn<Class<? extends JobStep>>> getPostDependencies() {
        return Collections.emptyList();
    }

    @Override
    protected JobStepConfig buildConfig(int iteration) {
        return new RDFJobStepConfig(getInstanceDir(), getGenerationID(), getLastGenerationID(), iteration);
    }

    @Override
    protected void doPost() throws IOException {

        String instanceGenerationPrefix = Namespaces.getInstanceGenerationPrefix(getInstanceDir(),
                getGenerationID());
        String outputPathKey = instanceGenerationPrefix + "trees/";
        Store store = Store.get();
        PMML joinedForest = null;

        // TODO This is still loading all trees into memory, which can be quite large.
        // To do better we would have to manage XML output more directly.

        Map<String, Mean> columnNameToMeanImportance = Maps.newHashMap();

        for (String treePrefix : store.list(outputPathKey, true)) {
            log.info("Reading trees from file {}", treePrefix);
            for (String treePMMLAsLine : new FileLineIterable(store.readFrom(treePrefix))) {
                PMML treePMML;
                try {
                    treePMML = IOUtil.unmarshal(new InputSource(new StringReader(treePMMLAsLine)));
                } catch (SAXException e) {
                    throw new IOException(e);
                } catch (JAXBException e) {
                    throw new IOException(e);
                }

                if (joinedForest == null) {
                    joinedForest = treePMML;
                    updateMeanImportances(columnNameToMeanImportance, treePMML.getModels().get(0));
                } else {
                    MiningModel existingModel = (MiningModel) joinedForest.getModels().get(0);
                    MiningModel nextModel = (MiningModel) treePMML.getModels().get(0);
                    updateMeanImportances(columnNameToMeanImportance, nextModel);
                    existingModel.getSegmentation().getSegments().addAll(nextModel.getSegmentation().getSegments());
                }
            }
        }

        // Stitch together feature importances
        for (MiningField field : joinedForest.getModels().get(0).getMiningSchema().getMiningFields()) {
            String name = field.getName().getValue();
            Mean importance = columnNameToMeanImportance.get(name);
            if (importance == null) {
                field.setImportance(null);
            } else {
                field.setImportance(importance.getResult());
            }
        }

        log.info("Writing combined model file");
        File tempJoinedForestFile = File.createTempFile("model-", ".pmml.gz");
        tempJoinedForestFile.deleteOnExit();
        OutputStream out = IOUtils.buildGZIPOutputStream(new FileOutputStream(tempJoinedForestFile));
        try {
            IOUtil.marshal(joinedForest, out);
        } catch (JAXBException e) {
            throw new IOException(e);
        } finally {
            out.close();
        }

        log.info("Uploading combined model file");
        store.upload(instanceGenerationPrefix + "model.pmml.gz", tempJoinedForestFile, false);
        IOUtils.delete(tempJoinedForestFile);
    }

    private static void updateMeanImportances(Map<String, Mean> columnNameToMeanImportance, Model model) {
        for (MiningField field : model.getMiningSchema().getMiningFields()) {
            Double importance = field.getImportance();
            if (importance != null) {
                String fieldName = field.getName().getValue();
                Mean mean = columnNameToMeanImportance.get(fieldName);
                if (mean == null) {
                    mean = new Mean();
                    columnNameToMeanImportance.put(fieldName, mean);
                }
                mean.increment(importance);
            }
        }
    }

}