eu.amidst.core.utils.BayesianNetworkSampler.java Source code

Java tutorial

Introduction

Here is the source code for eu.amidst.core.utils.BayesianNetworkSampler.java

Source

/*
 *
 *
 *    Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements.
 *    See the NOTICE file distributed with this work for additional information regarding copyright ownership.
 *    The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use
 *    this file except in compliance with the License.  You may obtain a copy of the License at
 *
 *            http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software distributed under the License is
 *    distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and limitations under the License.
 *
 *
 */

package eu.amidst.core.utils;

import com.google.common.base.Stopwatch;
import eu.amidst.core.datastream.Attribute;
import eu.amidst.core.datastream.Attributes;
import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.io.BayesianNetworkLoader;
import eu.amidst.core.io.DataStreamWriter;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.Variable;

import java.io.Serializable;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * This class implements the interface {@link eu.amidst.core.utils.AmidstOptionsHandler}.
 * It defines a sampler of data from a {@link BayesianNetwork}.
 */
public class BayesianNetworkSampler implements AmidstOptionsHandler, Serializable {

    /** Represents the serial version ID for serializing the object. */
    private static final long serialVersionUID = 4107783324901370839L;

    /** Represents the {@link BayesianNetwork} object from which the data will be sampled. */
    private BayesianNetwork network;

    /** Represents the list of variables given in a causal order. */
    private List<Variable> causalOrder;

    /** Represents the initial seed for random sampling. */
    private int seed = 0;

    /** Represents a {@link java.util.Random} object. */
    private Random random = new Random(seed);

    /** Represents a {@code Map} containing the hidden variables. */
    private Map<Variable, Boolean> hiddenVars = new HashMap();

    /** Represents a {@code Map} containing the noisy variables. */
    private Map<Variable, Double> marNoise = new HashMap();

    /**
     * Creates a new BayesianNetworkSampler given an input {@link BayesianNetwork} object.
     * @param network1 an input {@link BayesianNetwork} object.
     */
    public BayesianNetworkSampler(BayesianNetwork network1) {
        network = network1;
        this.causalOrder = Utils.getTopologicalOrder(network.getDAG());
    }

    /**
     * Returns a {@code Stream} of randomly sampled {@link Assignment}s for an input number of samples.
     * @param nSamples an input number of samples.
     * @return a {@code Stream} of randomly sampled {@link Assignment}s.
     */
    private Stream<Assignment> getSampleStream(int nSamples) {
        LocalRandomGenerator randomGenerator = new LocalRandomGenerator(seed);
        return IntStream.range(0, nSamples).mapToObj(i -> sample(network, causalOrder, randomGenerator.current()))
                .map(this::filter);
    }

    /**
     * Sets a given {@link Variable} object as hidden.
     * @param var a given {@link Variable} object.
     */
    public void setHiddenVar(Variable var) {
        this.hiddenVars.put(var, true);
    }

    /**
     * Sets a given {@link Variable} object as noisy.
     * @param var a given {@link Variable} object.
     * @param noiseProb a double that represents the noise probability.
     */
    public void setMARVar(Variable var, double noiseProb) {
        this.marNoise.put(var, noiseProb);
    }

    /**
     * Filters a given {@link Assignment} object, i.e., sets the values assigned to either missing or noisy variables to Double.NaN.
     * @param assignment a given {@link Assignment} object.
     * @return a filtered {@link Assignment}.
     */
    private Assignment filter(Assignment assignment) {
        hiddenVars.keySet().stream().forEach(var -> assignment.setValue(var, Utils.missingValue()));
        marNoise.entrySet().forEach(e -> {
            if (random.nextDouble() < e.getValue())
                assignment.setValue(e.getKey(), Utils.missingValue());
        });

        return assignment;
    }

    /**
     * Returns a {@code List} of randomly sampled {@link Assignment}s of size nSamples.
     * @param nSamples an {@code int} that represents the number of samples.
     * @return a {@code List} of randomly sampled {@link Assignment}s.
     */
    private List<Assignment> getSampleList(int nSamples) {
        return this.getSampleStream(nSamples).collect(Collectors.toList());
    }

    /**
     * Returns an iterator over randomly sampled {@link Assignment}s of size nSamples.
     * @param nSamples an {@code int} that represents the number of samples.
     * @return an iterator over randomly sampled {@link Assignment}s.
     */
    private Iterable<Assignment> getSampleIterator(int nSamples) {
        class I implements Iterable<Assignment> {
            public Iterator<Assignment> iterator() {
                return getSampleStream(nSamples).iterator();
            }
        }
        return new I();
    }

    /**
     * Sets the seed.
     * @param seed an {@code int} that represents the seed value.
     */
    public void setSeed(int seed) {
        this.seed = seed;
        random = new Random(seed);
    }

    /**
     * Samples randomly a data stream of size nSamples from this BayesianNetworkSampler.
     * @param nSamples an {@code int} that represents the number of samples in the data stream.
     * @return a {@link DataStream} of {@link DataInstance}s.
     */
    public DataStream<DataInstance> sampleToDataStream(int nSamples) {
        class TemporalDataStream implements DataStream<DataInstance>, Serializable {
            /** Represents the serial version ID for serializing the object. */
            private static final long serialVersionUID = -3436599636425587512L;

            Attributes atts;
            BayesianNetworkSampler sampler;
            int nSamples;

            TemporalDataStream(BayesianNetworkSampler sampler1, int nSamples1) {
                this.sampler = sampler1;
                this.nSamples = nSamples1;
                List<Attribute> list = this.sampler.network.getVariables().getListOfVariables().stream()
                        .map(var -> new Attribute(var.getVarID(), var.getName(), var.getStateSpaceType()))
                        .collect(Collectors.toList());
                this.atts = new Attributes(list);
            }

            @Override
            public Attributes getAttributes() {
                return atts;
            }

            @Override
            public Stream<DataInstance> stream() {
                class TemporalDataInstance implements DataInstance, Serializable {

                    /** Represents the serial version ID for serializing the object. */
                    private static final long serialVersionUID = -3436599636425587512L;

                    Assignment assignment;
                    Attributes attributes;
                    List<Variable> variables;

                    TemporalDataInstance(Assignment assignment1, Attributes atts) {
                        this.assignment = assignment1;
                        this.attributes = atts;
                        this.variables = sampler.network.getVariables().getListOfVariables();
                    }

                    @Override
                    public double getValue(Variable var) {
                        return this.assignment.getValue(var);
                    }

                    @Override
                    public void setValue(Variable var, double value) {
                        this.assignment.setValue(var, value);
                    }

                    @Override
                    public Attributes getAttributes() {
                        return attributes;
                    }

                    @Override
                    public Set<Variable> getVariables() {
                        return assignment.getVariables();
                    }

                    @Override
                    public double getValue(Attribute att) {
                        return this.assignment.getValue(variables.get(att.getIndex()));
                    }

                    @Override
                    public void setValue(Attribute att, double value) {
                        if (!att.isSpecialAttribute())
                            this.assignment.setValue(variables.get(att.getIndex()), value);
                    }

                    @Override
                    public double[] toArray() {
                        int numAtts = attributes.getNumberOfAttributes();
                        double[] values = new double[numAtts];
                        for (int att = 0; att < numAtts; att++) {
                            values[att] = getValue(attributes.getFullListOfAttributes().get(att));
                        }

                        return values;
                    }

                    @Override
                    public String toString() {
                        return this.outputString();
                    }
                }

                return this.sampler.getSampleStream(this.nSamples).map(a -> new TemporalDataInstance(a, this.atts));
            }

            @Override
            public void close() {

            }

            @Override
            public boolean isRestartable() {
                return false;
            }

            @Override
            public void restart() {

            }
        }

        random = new Random(seed);

        return new TemporalDataStream(this, nSamples);
    }

    /**
     * Samples an {@link Assignment} randomly from a {@link BayesianNetwork}.
     * @param network a {@link BayesianNetwork} object.
     * @param causalOrder a list of variables given in a causal order.
     * @param random a {@link Random} object.
     * @return the sampled {@link Assignment}.
     */
    private static Assignment sample(BayesianNetwork network, List<Variable> causalOrder, Random random) {

        HashMapAssignment assignment = new HashMapAssignment(network.getNumberOfVars());
        for (Variable var : causalOrder) {
            double sampledValue = network.getConditionalDistribution(var).getUnivariateDistribution(assignment)
                    .sample(random);
            assignment.setValue(var, sampledValue);
        }
        return assignment;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String listOptions() {
        return classNameID() + ",\\" + "-seed, 0, seed for random number generator\\";
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String listOptionsRecursively() {
        return this.listOptions() + "\n" + network.listOptionsRecursively();
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void loadOptions() {
        seed = getIntOption("-seed");
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String classNameID() {
        return "BayesianNetworkSampler";
    }

    public static void main(String[] args) throws Exception {

        Stopwatch watch = Stopwatch.createStarted();

        BayesianNetwork network = BayesianNetworkLoader.loadFromFile("./networks/dataWeka/asia.bn");

        BayesianNetworkSampler sampler = new BayesianNetworkSampler(network);
        sampler.setSeed(0);

        DataStream<DataInstance> dataStream = sampler.sampleToDataStream(100);

        DataStreamWriter.writeDataToFile(dataStream, "./datasets/simulated/asisa-samples.arff");

        System.out.println(watch.stop());

        for (Assignment assignment : sampler.getSampleIterator(10)) {
            System.out.println(assignment.outputString());
        }
        System.out.println();

        for (Assignment assignment : sampler.getSampleList(2)) {
            System.out.println(assignment.outputString());
        }
        System.out.println();

        sampler.getSampleStream(2).forEach(e -> System.out.println(e.outputString()));
    }

}