eu.amidst.huginlink.learning.ParallelTAN.java Source code

Java tutorial

Introduction

Here is the source code for eu.amidst.huginlink.learning.ParallelTAN.java

Source

/*
 *
 *
 *    Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements.
 *    See the NOTICE file distributed with this work for additional information regarding copyright ownership.
 *    The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use
 *    this file except in compliance with the License.  You may obtain a copy of the License at
 *
 *            http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software distributed under the License is
 *    distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and limitations under the License.
 *
 *
 */

package eu.amidst.huginlink.learning;

import COM.hugin.HAPI.*;
import com.google.common.base.Stopwatch;
import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.distribution.Multinomial;
import eu.amidst.core.learning.parametric.ParallelMLMissingData;
import eu.amidst.core.learning.parametric.ParallelMaximumLikelihood;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.models.DAG;
import eu.amidst.core.utils.*;
import eu.amidst.core.variables.Variable;
import eu.amidst.core.variables.Variables;
import eu.amidst.huginlink.converters.BNConverterToAMIDST;
import eu.amidst.huginlink.converters.BNConverterToHugin;
import eu.amidst.huginlink.inference.HuginInference;

import java.io.IOException;

/**
 * This class provides a link to the <a href="https://www.hugin.com">Hugin</a>'s functionality to learn in parallel a TAN model.
 * An important remark is that Hugin only allows to learn the TAN model for a data set completely loaded into RAM
 * memory. The case where our data set does not fit into memory, it solved in AMIDST in the following way. We learn
 * the structure using a smaller data set produced by <a href="https://en.wikipedia.org/wiki/Reservoir_sampling">Reservoir sampling</a>
 * and, then, we use AMIDST's {@link ParallelMaximumLikelihood} to learn the parameters of the TAN over the whole data set.
 *
 * <p> For further details about the implementation of the parallel TAN algorithm look at the following paper: </p>
 *
 * <i> Madsen, A.L. et al. A New Method for Vertical Parallelisation of TAN Learning Based on Balanced Incomplete
 * Block Designs. Probabilistic Graphical Models. Lecture Notes in Computer Science Volume 8754, 2014, pp 302-317. </i>
 */
public class ParallelTAN implements AmidstOptionsHandler {

    /** Represents the number of samples on memory. */
    private int numSamplesOnMemory;

    /** Represents the number of cores to be exploited during the parallel learning. */
    private int numCores;

    /** Represents the batch size. */
    private int batchSize;

    /** Represents the name of the variable acting as a root of the tree in the TAN model. */
    String nameRoot;

    /** Represents the name of the class variable in the TAN model */
    String nameTarget;

    /** Indicates if the learning is performed in parallel or not. */
    boolean parallelMode;

    /** Represents the learned BN**/
    BayesianNetwork learnedBN;

    /** Represents the inference engine for the predections*/
    HuginInference inference = null;

    /** Represents the class varible of the model*/
    private Variable targetVar;

    /**
     * Class constructor.
     */
    public ParallelTAN() {
    }

    /**
     * Sets the running mode, i.e., either parallel or sequential.
     * @param parallelMode true if the running mode is parallel, false otherwise.
     */
    public void setParallelMode(boolean parallelMode) {
        this.parallelMode = parallelMode;
        if (parallelMode) {
            //this.numCores = Runtime.getRuntime().availableProcessors();
        } else
            this.numCores = 1;

    }

    /**
     * Returns the batch size to be used during the learning process.
     * @return the batch size.
     */
    public int getBatchSize() {
        return batchSize;
    }

    /**
     * Sets the batch size.
     * @param batchSize the batch size.
     */
    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    /**
     * Returns the number of samples on memory.
     * @return the number of samples on memory.
     */
    public int getNumSamplesOnMemory() {
        return numSamplesOnMemory;
    }

    /**
     * Sets the number of samples on memory
     * @param numSamplesOnMemory_ the number of samples on memory
     */
    public void setNumSamplesOnMemory(int numSamplesOnMemory_) {
        this.numSamplesOnMemory = numSamplesOnMemory_;
    }

    /**
     * Returns the number of cores to be used during the learning process.
     * @return the number of cores.
     */
    public int getNumCores() {
        return numCores;
    }

    /**
     * Sets the number of cores to be used during the learning process.
     * @param numCores_ the number of cores.
     */
    public void setNumCores(int numCores_) {
        if (numCores_ > Runtime.getRuntime().availableProcessors()) {
            this.numCores = Runtime.getRuntime().availableProcessors();
        } else {
            this.numCores = numCores_;
        }
    }

    /**
     * Sets the name of the variable acting as a root of the tree in the TAN model.
     * @param nameRoot the name of the root variable.
     */
    public void setNameRoot(String nameRoot) {
        this.nameRoot = nameRoot;
    }

    /**
     * Sets the name of the class variable in the TAN model.
     * @param nameTarget the name of the class variable.
     */
    public void setNameTarget(String nameTarget) {
        this.nameTarget = nameTarget;
    }

    /**
     * Gets the name of the class variable in the TAN model.
     */
    public String getNameTarget() {
        return this.nameTarget;
    }

    /**
     * Predicts the class membership probabilities for a given instance.
     * @param instance the data instance to be classified.
     * @return an array of doubles containing the estimated membership probabilities of the data instance for each class label.
     */
    public double[] predict(DataInstance instance) {
        if (learnedBN == null)
            throw new IllegalArgumentException("The model has not been learned");

        if (!Utils.isMissingValue(instance.getValue(targetVar)))
            System.out.println("Class Variable can not be set.");

        this.inference.setEvidence(instance);
        this.inference.runInference();
        Multinomial dist = this.inference.getPosterior(targetVar);

        return dist.getParameters();
    }

    /**
     * Learns a TAN structure from data using the Chow-Liu algorithm included in the Hugin API.
     * Parallel learning is performed only if the parallel mode was set to true.
     * @param dataStream a stream of data instances to be processed during the TAN structural learning.
     * @return a <code>DAG</code> structure in AMIDST format.
     * @throws ExceptionHugin
     */
    public DAG learnDAG(DataStream dataStream) throws ExceptionHugin {
        Variables modelHeader = new Variables(dataStream.getAttributes());
        this.targetVar = modelHeader.getVariableByName(this.nameTarget);
        DAG dag = new DAG(modelHeader);
        BayesianNetwork bn = new BayesianNetwork(dag);

        Domain huginNetwork = null;

        try {
            huginNetwork = BNConverterToHugin.convertToHugin(bn);

        } catch (ExceptionHugin exceptionHugin) {
            System.out.println("ParallelTan LearnDAG Error 1");
            exceptionHugin.printStackTrace();
            throw new IllegalStateException("Hugin Exception: " + exceptionHugin.getMessage());
        }
        DataOnMemory dataOnMemory = ReservoirSampling.samplingNumberOfSamples(this.numSamplesOnMemory, dataStream);

        // Set the number of cases
        int numCases = dataOnMemory.getNumberOfDataInstances();
        try {
            huginNetwork.setNumberOfCases(numCases);
            System.out.println("Learning DAG with " + this.numCores + " cores.");
            huginNetwork.setConcurrencyLevel(this.numCores);
        } catch (ExceptionHugin exceptionHugin) {
            System.out.println("ParallelTan LearnDAG Error 2");
            exceptionHugin.printStackTrace();
            throw new IllegalStateException("Hugin Exception: " + exceptionHugin.getMessage());
        }

        NodeList nodeList = huginNetwork.getNodes();

        // It is more efficient to loop the matrix of values in this way. 1st variables and 2nd cases
        for (int i = 0; i < nodeList.size(); i++) {
            Variable var = bn.getDAG().getVariables().getVariableById(i);
            Node n = nodeList.get(i);

            try {
                if (n.getKind().compareTo(NetworkModel.H_KIND_DISCRETE) == 0) {
                    ((DiscreteChanceNode) n).getExperienceTable();
                    for (int j = 0; j < numCases; j++) {
                        double state = dataOnMemory.getDataInstance(j).getValue(var);
                        if (!Utils.isMissingValue(state)) {
                            ((DiscreteChanceNode) n).setCaseState(j, (int) state);
                        }
                    }
                } else {
                    ((ContinuousChanceNode) n).getExperienceTable();
                    for (int j = 0; j < numCases; j++) {
                        double value = dataOnMemory.getDataInstance(j).getValue(var);
                        if (!Utils.isMissingValue(value))
                            ((ContinuousChanceNode) n).setCaseValue(j, value);
                    }
                }
            } catch (ExceptionHugin exceptionHugin) {
                System.out.println("ParallelTan LearnDAG Error 3 with node " + Integer.toString(i));
                exceptionHugin.printStackTrace();
                throw new IllegalStateException("Hugin Exception: " + exceptionHugin.getMessage());
            }
            //            if(i==0 || i==6) {
            //                System.out.println("Node " + i);
            //                System.out.println(Arrays.toString(n.getTable().getData()));
            //
            //                System.out.println(Arrays.toString (((DiscreteChanceNode) n).getExperienceTable().getData()));
            //                System.out.println(((DiscreteChanceNode) n).getEnteredValue());
            //                System.out.println(n.getChildren().stream().count());
            //            }
        }

        //Structural learning
        DiscreteChanceNode root, target;
        try {
            root = (DiscreteChanceNode) huginNetwork.getNodeByName(nameRoot);
            target = (DiscreteChanceNode) huginNetwork.getNodeByName(nameTarget);

        } catch (ExceptionHugin exceptionHugin) {
            System.out.println("ParallelTan LearnDAG Error 4");
            exceptionHugin.printStackTrace();
            throw new IllegalStateException("Hugin Exception: " + exceptionHugin.getMessage());
        }

        try {
            //root.getHome().getNodes().stream().forEach(node -> System.out.println(node.getName())
            huginNetwork.learnChowLiuTree(root, target);
        } catch (ExceptionHugin exceptionHugin) {
            System.out.println("ParallelTan LearnDAG Error 5");
            exceptionHugin.printStackTrace();
            throw new IllegalStateException("Hugin Exception: " + exceptionHugin.getMessage());
        }

        DAG dagLearned;
        try {
            dagLearned = (BNConverterToAMIDST.convertToAmidst(huginNetwork)).getDAG();
            dagLearned.getVariables().setAttributes(dataStream.getAttributes());
        } catch (ExceptionHugin exceptionHugin) {
            System.out.println("ParallelTan LearnDAG Error 6");
            exceptionHugin.printStackTrace();
            throw new IllegalStateException("Hugin Exception: " + exceptionHugin.getMessage());
        }

        return dagLearned;

    }

    /**
     * Learns the parameters of a TAN structure using the {@link eu.amidst.core.learning.parametric.ParallelMaximumLikelihood}.
     * @param dataStream a stream of data instances for learning the parameters.
     * @return a <code>BayesianNetwork</code> object in ADMIST format.
     */
    public BayesianNetwork learn(DataStream<DataInstance> dataStream) {
        try {
            ParallelMLMissingData parameterLearningAlgorithm = new ParallelMLMissingData();
            parameterLearningAlgorithm.setLaplace(true);
            parameterLearningAlgorithm.setParallelMode(this.parallelMode);
            parameterLearningAlgorithm.setDAG(this.learnDAG(dataStream));
            parameterLearningAlgorithm.setDataStream(dataStream);
            parameterLearningAlgorithm.initLearning();
            parameterLearningAlgorithm.runLearning();
            learnedBN = parameterLearningAlgorithm.getLearntBayesianNetwork();

            this.inference = new HuginInference();
            this.inference.setModel(this.learnedBN);
            return this.learnedBN;
        } catch (ExceptionHugin ex) {
            throw new IllegalStateException("Hugin Exception: " + ex.getMessage());
        }
    }

    /**
     * Learns the parameters of a TAN structure using the {@link eu.amidst.core.learning.parametric.ParallelMaximumLikelihood}.
     * @param dataStream a stream of data instances for learning the parameters.
     * @param batchSize the size of the batch for the parallel ML algorithm.
     * @return a <code>BayesianNetwork</code> object in ADMIST format.
     * @throws ExceptionHugin
     */
    public BayesianNetwork learn(DataStream<DataInstance> dataStream, int batchSize) throws ExceptionHugin {
        ParallelMLMissingData parameterLearningAlgorithm = new ParallelMLMissingData();
        parameterLearningAlgorithm.setWindowsSize(batchSize);
        parameterLearningAlgorithm.setParallelMode(this.parallelMode);
        parameterLearningAlgorithm.setDAG(this.learnDAG(dataStream));
        parameterLearningAlgorithm.setDataStream(dataStream);
        parameterLearningAlgorithm.initLearning();
        parameterLearningAlgorithm.runLearning();
        learnedBN = parameterLearningAlgorithm.getLearntBayesianNetwork();

        this.inference = new HuginInference();
        this.inference.setModel(this.learnedBN);
        return this.learnedBN;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String listOptions() {

        return this.classNameID() + ",\\" + "-numSamplesOnMemory, 1000, Number of samples on memory\\"
                + "-numCores," + Runtime.getRuntime().availableProcessors() + ", Number of cores\\"
                + "-windowsSize, 1000, Batch size\\" + "-nameRoot, root, Name of root variable.\\"
                + "-nameTarget, target, Name of target variable\\" + "-parallelMode, true, Run in parallel\\";
    }

    /**
     * {@inheritDoc}
     */
    public void loadOptions() {
        this.setNumSamplesOnMemory(this.getIntOption("-numSamplesOnMemory"));
        this.setNumCores(this.getIntOption("-numCores"));
        this.setBatchSize(this.getIntOption("-windowsSize"));
        this.setNameRoot(this.getOption("-nameRoot"));
        this.setNameTarget(this.getOption("-nameTarget"));
        this.setParallelMode(this.getBooleanOption("-parallelMode"));
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String listOptionsRecursively() {
        return this.listOptions() + "\n" + BayesianNetwork.listOptionsRecursively() + "\n"
                + AmidstOptionsHandler.listOptionsRecursively(BayesianNetworkSampler.class);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String classNameID() {
        return "ParallelTAN";
    }

    public static void main(String[] args) throws IOException {

        OptionParser.setArgsOptions(ParallelTAN.class, args);

        BayesianNetworkGenerator.loadOptions();

        BayesianNetworkGenerator.setNumberOfGaussianVars(0);
        BayesianNetworkGenerator.setNumberOfMultinomialVars(2000, 10);
        BayesianNetworkGenerator.setSeed(0);

        BayesianNetwork bn = BayesianNetworkGenerator.generateNaiveBayes(2);

        int sampleSize = 5000;
        BayesianNetworkSampler sampler = new BayesianNetworkSampler(bn);
        sampler.loadOptions();

        DataStream<DataInstance> data = sampler.sampleToDataStream(sampleSize);

        for (int i = 1; i <= 4; i++) {
            int samplesOnMemory = 1000;
            int numCores = i;
            System.out
                    .println("Learning TAN: " + samplesOnMemory + " samples on memory, " + numCores + "core/s ...");
            ParallelTAN tan = new ParallelTAN();
            tan.setOptions(args);
            //tan.loadOptionsFromFile("configurationFiles/conf.txt");
            tan.setNumCores(numCores);
            tan.setNumSamplesOnMemory(samplesOnMemory);
            tan.setNameRoot(bn.getVariables().getListOfVariables().get(0).getName());
            tan.setNameTarget(bn.getVariables().getListOfVariables().get(1).getName());
            Stopwatch watch = Stopwatch.createStarted();
            BayesianNetwork model = tan.learn(data);
            System.out.println(watch.stop());
        }
    }
}