statechum.analysis.learning.experiments.PairSelection.WekaDataCollector.java Source code

Java tutorial

Introduction

Here is the source code for statechum.analysis.learning.experiments.PairSelection.WekaDataCollector.java

Source

/* Copyright (c) 2013 The University of Sheffield.
 * 
 * This file is part of StateChum
 * 
 * StateChum is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * StateChum is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with StateChum.  If not, see <http://www.gnu.org/licenses/>.
 */
package statechum.analysis.learning.experiments.PairSelection;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import statechum.Configuration.ScoreMode;
import statechum.DeterministicDirectedSparseGraph.CmpVertex;
import statechum.analysis.learning.PairScore;
import statechum.analysis.learning.StatePair;
import statechum.analysis.learning.experiments.PairSelection.PairQualityLearner.PairMeasurements;
import statechum.analysis.learning.rpnicore.LearnerGraph;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

public class WekaDataCollector {
    Classifier classifier;

    /** The maximal number of attributes to use as part of a conditional statement. Where 0, no conditionals are considered, for QSM (Score/Red/Blue) it has to be 2.
     */
    int maxLevel;

    /**
     * The length of an instance, taking {@link WekaDataCollector#maxLevel} into account.
     */
    int instanceLength;

    /** Attributes associated with each instance. These are not the same as attributes of comparators because comparators are only considering individual comparisons and here we are working with if-then chains. */
    Attribute[] attributesOfAnInstance;

    public int getInstanceLength() {
        return instanceLength;
    }

    /**
     * Begins construction of an instance of pair classifier.
     */
    public WekaDataCollector() {
        FastVector vecBool = new FastVector(2);
        vecBool.addElement(Boolean.TRUE.toString());
        vecBool.addElement(Boolean.FALSE.toString());
        classAttribute = new Attribute("class", vecBool);
    }

    protected int n, L;

    /** Number of values for attributes to consider. */
    protected final static int V = 2;

    /**
     * Completes construction of an instance of pair classifier. Comparators contain attributes that are tied into the training set when it is constructed in this method.
     * 
     * @param trainingSetName the name for a training set.
     * @param capacity the maximal number of elements in the training set
     * @param argAssessor a collection of assessors to use.
     * @param level the maximal number of attributes to use as part of a conditional statement.
     */
    public void initialise(String trainingSetName, int capacity, List<PairRank> argAssessor, int level) {
        if (assessors != null)
            throw new IllegalArgumentException("WekaDataCollector should not be re-initialised");

        assessors = argAssessor;
        measurementsForUnfilteredCollectionOfPairs.valueAverage = new double[assessors.size()];
        measurementsForUnfilteredCollectionOfPairs.valueSD = new double[assessors.size()];
        comparators = new ArrayList<PairComparator>(assessors.size());
        for (PairRank pr : assessors)
            comparators.add(new PairComparator(pr));
        if (comparators.size() > Long.SIZE - 1)
            throw new IllegalArgumentException("attributes will not fit into long");

        maxLevel = level;

        n = comparators.size();// the number of indices we go through
        L = comparators.size() * 2;// the length of the attributes accumulated for each pair.
        long instanceLen = L, sectionLen = 1;
        if (V * n < maxLevel)
            throw new IllegalArgumentException("too many levels for the considered number of attributes");

        for (int i = 0; i < maxLevel; ++i) {
            sectionLen *= V * (n - i);
            instanceLen += L * sectionLen;
            if (sectionLen > Integer.MAX_VALUE)
                throw new IllegalArgumentException("too many attributes per instance");
        }
        instanceLength = (int) instanceLen;

        FastVector attributes = new FastVector(instanceLength + 1);
        attributesOfAnInstance = new Attribute[instanceLength];
        fillInAttributeNames(attributesOfAnInstance, 0, 0, 1, 0, "", 0);
        for (int i = 0; i < instanceLength; ++i)
            attributes.addElement(attributesOfAnInstance[i]);
        attributes.addElement(classAttribute);
        trainingData = new Instances(trainingSetName, attributes, capacity);
        trainingData.setClass(classAttribute);
    }

    protected void fillInAttributeNames(Attribute[] whatToFillIn, int section_start, int idx_in_section,
            int section_size, long xyz, String pathToThisLevel, int currentLevel) {
        final int sectionPlusOffset = section_start + L * idx_in_section;

        int i = 0;

        for (PairComparator cmp : comparators)
            whatToFillIn[sectionPlusOffset + i++] = cmp.getAttribute()
                    .copy(pathToThisLevel + cmp.getAttribute().name());
        for (PairRank pr : assessors)
            whatToFillIn[sectionPlusOffset + i++] = pr.getAttribute()
                    .copy(pathToThisLevel + pr.getAttribute().name());

        if (currentLevel < maxLevel) {
            final int rowNumber = V * (n - currentLevel);
            final int nextSectionStart = section_start + L * section_size;
            int z = 0;
            for (int attr = 0; attr < n; ++attr) {
                long positionalBit = 1 << attr;
                if ((xyz & positionalBit) == 0) // this attribute was not already used on a path to the current instance of fillInEntry
                {

                    fillInAttributeNames(whatToFillIn, nextSectionStart, idx_in_section * rowNumber + z + 0,
                            section_size * rowNumber, xyz | positionalBit,
                            pathToThisLevel + " if " + comparators.get(attr).getAttribute().name() + "==-1 then ",
                            currentLevel + 1);
                    fillInAttributeNames(whatToFillIn, nextSectionStart, idx_in_section * rowNumber + z + 1,
                            section_size * rowNumber, xyz | positionalBit,
                            pathToThisLevel + " if " + comparators.get(attr).getAttribute().name() + "==1 then ",
                            currentLevel + 1);
                    z += V;
                }
            }
        }
    }

    protected double convertAssessmentResultToString(int assessmentResult, Attribute attribute) {
        String value = null;
        switch (assessmentResult) {
        case 2:
            value = TWO;
            break;
        case 1:
            value = ONE;
            break;
        case 0:
            value = ZERO;
            break;
        case -1:
            value = MINUSONE;
            break;
        case -2:
            value = MINUSTWO;
            break;

        default:
            throw new IllegalArgumentException(
                    "invalid comparison value " + assessmentResult + " for attribute " + attribute);

        }
        double outcome = attribute.indexOfValue(value);
        if (outcome < 0)
            throw new IllegalArgumentException("value " + value + " was not defined for attribute " + attribute);

        return outcome;
    }

    /**  Constructs a Weka {@link Instance} for a pair of interest.
     * 
     * @param comparisonResults metrics related to the considered pair.
     * @param classification whether this is a correct pair
     * @return an instance of a test or a training sample. 
     */
    Instance constructInstance(int[] comparisonResults, boolean classification) {
        if (comparisonResults.length != instanceLength)
            throw new IllegalArgumentException("results' length does not match the number of comparators");

        double[] instanceValues = new double[instanceLength + 1];
        for (int i = 0; i < instanceLength; ++i)
            instanceValues[i] = convertAssessmentResultToString(comparisonResults[i], attributesOfAnInstance[i]);

        instanceValues[instanceLength] = trainingData.classAttribute()
                .indexOfValue(Boolean.toString(classification));
        Instance outcome = new Instance(1, instanceValues);
        outcome.setDataset(trainingData);
        return outcome;
    }

    /** Given an outcome of a comparison of a pair to other pairs, attempts to estimate how significant the result is.
     * This is useful for prioritising states to be selected as red.
     * 
     * For instance, this can based on counting the attributes that contributed to a decision
     * by a learner to consider the pair as either good or bad. At present, we evaluate the probability that the given result belongs to the specific class. 
     *  
     * @param comparisonResults the outcome of {@link #fillInPairDetails(int[], PairScore, Collection)}.
     * @return a non-negative "quality" of a pair. 
     * @throws Exception 
     */
    double getPairQuality(int[] comparisonResults) throws Exception {
        return classifier.distributionForInstance(constructInstance(comparisonResults, false))[0];
    }

    final Attribute classAttribute;
    public Instances trainingData;

    static final String MINUSTWO = "-2";
    static final String MINUSONE = "-1";
    static final String ZERO = "0";
    static final String ONE = "1";
    static final String TWO = "2";

    protected List<PairComparator> comparators;
    protected List<PairRank> assessors;

    class MeasurementsForCollectionOfPairs {
        Map<StatePair, PairMeasurements> measurementsForComparators = new HashMap<StatePair, PairMeasurements>();
        double valueAverage[] = new double[0], valueSD[] = new double[0];
    }

    Map<CmpVertex, Integer> treeForComparators = new TreeMap<CmpVertex, Integer>();
    MeasurementsForCollectionOfPairs measurementsForUnfilteredCollectionOfPairs = new MeasurementsForCollectionOfPairs();

    LearnerGraph tentativeGraph = null;

    void buildSetsForComparatorsThatDoNotDependOnFiltering(Collection<PairScore> pairs, LearnerGraph graph) {
        treeForComparators.clear();
        tentativeGraph = graph;

        for (PairScore pair : pairs) {
            if (!treeForComparators.containsKey(pair.getQ()))
                treeForComparators.put(pair.getQ(), PairQualityLearner.computeTreeSize(graph, pair.getQ()));
        }
        buildSetsForComparatorsDependingOnFiltering(measurementsForUnfilteredCollectionOfPairs, pairs);
    }

    /** Given a collection of pairs and a tentative automaton, constructs auxiliary structures used by comparators and stores it as an instance variable.
     * The graph used for construction is the one that was passed earlier to {@link WekaDataCollector#buildSetsForComparatorsThatDoNotDependOnFiltering(Collection, LearnerGraph)}.
     * @param pairs pairs to build sets for
     * @param measurements where to store the result of measurement.
     */
    void buildSetsForComparatorsDependingOnFiltering(MeasurementsForCollectionOfPairs measurements,
            Collection<PairScore> pairs) {
        measurements.measurementsForComparators.clear();
        if (measurements.valueAverage.length < n)
            measurements.valueAverage = new double[n];
        if (measurements.valueSD.length < n)
            measurements.valueSD = new double[n];

        Arrays.fill(measurements.valueAverage, 0);
        Arrays.fill(measurements.valueSD, 0);

        for (PairScore pair : pairs) {
            PairMeasurements m = new PairMeasurements();
            m.nrOfAlternatives = -1;
            for (PairScore p : pairs) {
                if (p.getR() == pair.getR())
                    ++m.nrOfAlternatives;
            }

            Collection<CmpVertex> adjacentOutgoingBlue = tentativeGraph.transitionMatrix.get(pair.getQ()).values(),
                    adjacentOutgoingRed = tentativeGraph.transitionMatrix.get(pair.getR()).values();
            m.adjacent = adjacentOutgoingBlue.contains(pair.getR()) || adjacentOutgoingRed.contains(pair.getQ());
            ScoreMode origScore = tentativeGraph.config.getLearnerScoreMode();
            tentativeGraph.config.setLearnerScoreMode(ScoreMode.COMPATIBILITY);
            m.compatibilityScore = tentativeGraph.pairscores.computePairCompatibilityScore(pair);
            tentativeGraph.config.setLearnerScoreMode(origScore);

            measurements.measurementsForComparators.put(pair, m);
        }

        if (assessors != null)
            for (PairScore pair : pairs)
                for (int i = 0; i < assessors.size(); ++i) {
                    long value = assessors.get(i).getValue(pair);
                    measurements.valueAverage[i] += value;
                    measurements.valueSD[i] += value * value;
                }

        if (assessors != null)
            for (int i = 0; i < assessors.size(); ++i) {
                measurements.valueAverage[i] /= pairs.size();
                measurements.valueSD[i] = Math.sqrt(measurements.valueSD[i] / pairs.size()
                        - measurements.valueAverage[i] * measurements.valueAverage[i]);
            }
    }

    /** Used to denote a value corresponding to an "inconclusive" verdict where a comparator returns values of greater for some points and less for others. */
    public static final int comparison_inconclusive = -10;

    int comparePairWithOthers(PairComparator cmp, PairScore pair, Collection<PairScore> others) {
        int comparisonResult = 0;
        for (PairScore w : others) {// it does not matter if w==pair, the comparison result will be zero so it will not affect anything
            int newValue = cmp.compare(pair, w);
            assert newValue != comparison_inconclusive;
            // comparisonResults[i] can be 1,0,-1, same for newValue
            if (newValue > 0) {
                if (comparisonResult < 0) {
                    comparisonResult = comparison_inconclusive;
                    break;
                }
                comparisonResult = newValue;
            } else if (newValue < 0) {
                if (comparisonResult > 0) {
                    comparisonResult = comparison_inconclusive;
                    break;
                }
                comparisonResult = newValue;
            }
        }
        return comparisonResult;
    }

    /** Given a pair and a collection of possible pairs to merge, compares the specified pairs to others to determine its attributes that may make it more likely to be a valid merge.
     * Where the returned value is +1 or -1 in a specific cell, this means that the pair of interest is not dominated in the specific component by all other pairs.
     * The outcome of 1 means that it is equal to some other pairs and above others but never below.
     * In a similar way, -1 means that it does not dominate any other pairs.
     * 
     * @param pair pair to consider
     * @param others other pairs (possibly, both valid and invalid mergers).
     * @param whatToFillIn array to populate with results
     * @param offset the starting position to fill in.
     */
    void comparePairWithOthers(PairScore pair, Collection<PairScore> others, int[] whatToFillIn, int offset) {
        assert !comparators.isEmpty();

        int i = 0;
        for (PairComparator cmp : comparators) {
            whatToFillIn[i + offset] = comparePairWithOthers(cmp, pair, others);
            ++i;
        }

        for (int cnt = 0; cnt < comparators.size(); ++cnt)
            if (whatToFillIn[cnt + offset] == comparison_inconclusive)
                whatToFillIn[cnt + offset] = 0;
    }

    /** Assesses a supplied pair based on the values.
     * 
     * @param pair pair to consider
     * @param measurements set of measurements to use for assessment
     * @param whatToFillIn array to populate with results
     * @param offset the starting position to fill in.
     */
    void assessPair(PairScore pair, MeasurementsForCollectionOfPairs measurements, int[] whatToFillIn, int offset) {
        assert !assessors.isEmpty();
        //Arrays.fill(whatToFillIn, offset, comparators.size(), 0);
        for (int i = 0; i < assessors.size(); ++i)
            whatToFillIn[i + offset] = assessors.get(i).getRanking(pair, measurements.valueAverage[i],
                    measurements.valueSD[i]);
    }

    protected void fillInEntry(int[] whatToFillIn, int section_start, int idx_in_section, int section_size,
            long xyz, PairScore pairOfInterest, Collection<PairScore> pairs,
            MeasurementsForCollectionOfPairs measurements, int currentLevel) {
        final int sectionPlusOffset = section_start + L * idx_in_section;
        comparePairWithOthers(pairOfInterest, pairs, whatToFillIn, sectionPlusOffset);
        assessPair(pairOfInterest, measurements, whatToFillIn, sectionPlusOffset + n);
        if (currentLevel < maxLevel) {
            final int rowNumber = V * (n - currentLevel);
            final int nextSectionStart = section_start + L * section_size;
            int z = 0;
            for (int attr = 0; attr < n; ++attr) {
                long positionalBit = 1 << attr;
                if ((xyz & positionalBit) == 0) // this attribute was not already used on a path to the current instance of fillInEntry
                {
                    int attributeREL = whatToFillIn[sectionPlusOffset + attr];
                    if (attributeREL != 0) {
                        assert attributeREL == 1 || attributeREL == -1;
                        Collection<PairScore> others = new ArrayList<PairScore>(pairs.size());
                        for (PairScore other : pairs) {
                            int comparisonOnAttribute_i = comparePairWithOthers(comparators.get(attr), other,
                                    pairs);
                            if (comparisonOnAttribute_i == attributeREL) // we only compare our vertex with those that are also distinguished by the specified attribute
                                others.add(other);
                        }
                        if (others.size() > 1) {
                            MeasurementsForCollectionOfPairs measurementsForFilteredPairs = new MeasurementsForCollectionOfPairs();
                            buildSetsForComparatorsDependingOnFiltering(measurementsForFilteredPairs, pairs);
                            // the value of 2 below is a reflection that we only distinguish between two different relative values. If SD part were considered, there would be a lot more values.
                            fillInEntry(whatToFillIn, nextSectionStart,
                                    idx_in_section * rowNumber + z + (attributeREL > 0 ? 1 : 0),
                                    section_size * rowNumber, xyz | positionalBit, pairOfInterest, others,
                                    measurementsForFilteredPairs, currentLevel + 1);
                        }
                    }
                    z += V;
                }
            }
        }
    }

    /** Fills in the array with comparison results. For correct operation, the supplied pair of interest has to be included in the collection of pairs. */
    public void fillInPairDetails(int[] whatToFillIn, PairScore pairOfInterest, Collection<PairScore> pairs) {
        if (whatToFillIn.length < getInstanceLength())
            throw new IllegalArgumentException("array is too short");
        fillInEntry(whatToFillIn, 0, 0, 1, 0, pairOfInterest, pairs, measurementsForUnfilteredCollectionOfPairs, 0);
    }

    /** Given a collection of pairs from a tentative graph, this method generates Weka data instances and adds them to the Weka dataset.
     * We do not compare correct pairs with each other, or wrong pairs with each other. Pairs that have negative scores are ignored.
     * 
     * @param pairs pairs to add
     * @param currentGraph the current graph
     * @param correctGraph the graph we are trying to learn by merging states in tentativeGraph.
     */
    public void updateDatasetWithPairs(Collection<PairScore> pairs, LearnerGraph currentGraph,
            LearnerGraph correctGraph) {
        buildSetsForComparatorsThatDoNotDependOnFiltering(pairs, currentGraph);

        List<PairScore> correctPairs = new LinkedList<PairScore>(), wrongPairs = new LinkedList<PairScore>();
        List<PairScore> pairsToConsider = new LinkedList<PairScore>();
        if (!pairs.isEmpty()) {
            for (PairScore p : pairs)
                if (p.getQ().isAccept() && p.getR().isAccept())
                    pairsToConsider.add(p);// only consider non-negatives
        }
        PairQualityLearner.SplitSetOfPairsIntoRightAndWrong(currentGraph, correctGraph, pairsToConsider,
                correctPairs, wrongPairs);

        for (PairScore p : pairsToConsider) {
            int[] comparisonResults = new int[instanceLength];
            fillInPairDetails(comparisonResults, p, pairsToConsider);// only compare with other non-negatives
            boolean correctPair = correctPairs.contains(p);
            //boolean correctPair = p.equals(PairQualityLearner.LearnerThatCanClassifyPairs.pickPairQSMLike(pairsToConsider));
            trainingData.add(constructInstance(comparisonResults, correctPair));
        }

        /*
        // Compute Weka statistics, where we compare each pair to all others.
        for(PairScore p:correctPairs)
        {
           int []comparisonResults = comparePairWithOthers(p, pairs);
           //System.out.println(p+" "+Arrays.toString(comparisonResults));
           trainingData.add(constructInstance(comparisonResults, assessPair(p), true));
        }
            
        for(PairScore p:wrongPairs)
        {
           int []comparisonResults = comparePairWithOthers(p, pairs);
           //System.out.println(p+" "+Arrays.toString(comparisonResults));
           trainingData.add(constructInstance(comparisonResults, assessPair(p), false));
        }*/
    }

    /**
     * Provides helper methods in order to train a classifier to recognise good/bad pairs
     * <hr/>
     * It is a nested class to permit access to instance variables. This seems natural because elements of this class need access to data obtained from the transition matrix. 
     *
     */
    public abstract class PairRankingSupport {
        /** Weka attribute associated with this comparator. */
        final Attribute att;

        public Attribute getAttribute() {
            return att;
        }

        protected PairRankingSupport(String name, String[] range) {
            FastVector vecA = new FastVector(3);
            for (String v : range)
                vecA.addElement(v);
            att = new Attribute(name, vecA);
        }

        @Override
        public String toString() {
            return att.name();
        }

        public PairMeasurements measurementsForCurrentStack(PairScore p) {
            return measurementsForUnfilteredCollectionOfPairs.measurementsForComparators.get(p);
        }

        public int treeRootedAt(CmpVertex p) {
            return treeForComparators.get(p);
        }

        public LearnerGraph tentativeGraph() {
            return tentativeGraph;
        }
    }

    /** Used to compute values permitting one to train a classifier to recognise good/bad pairs. 
     * 
     */
    public class PairComparator extends PairRankingSupport implements Comparator<PairScore> {
        protected final PairRank assessor;

        protected PairComparator(PairRank argAssessor) {
            super("REL " + argAssessor.getAttribute().name(), new String[] { ZERO, ONE, MINUSONE });
            assessor = argAssessor;
        }

        @Override
        public int compare(PairScore o1, PairScore o2) {
            return PairQualityLearner.sgn(assessor.getValue(o1) - assessor.getValue(o2));
        }
    }

    /** {@link PairComparator} permits one to compare pairs with each other. This one aims to give a rank to each pair in a collection of pairs, by either retrieving specific attributes or 
     * doing the average/standard deviation thresholding.
     * @author kirill
     *
     */
    public abstract class PairRank extends PairRankingSupport {
        protected PairRank(String name) {
            super(name, new String[] { ZERO, ONE, MINUSONE, TWO, MINUSTWO });
        }

        /** Returns 0,1, or -1 depending on how the pair scores compared to an average across the collection of pairs, 
         * standard deviation and average.
         * 
         * @param pair
         * @param average
         * @param sd
         * @return
         */
        public int getRanking(PairScore pair, double average, double sd) {
            long value = getValue(pair);
            if (isAbsolute())
                return (int) value;

            if (value > average + sd) {
                if (value > average + sd + sd)
                    return 2;
                return 1;
            }
            if (value < average - sd) {
                if (value < average - sd - sd)
                    return -2;
                return -1;
            }
            return 0;
        }

        /** Returns true if {@link PairRank#getRanking} should not use average/standard deviation in order to normalise results across different sets of pairs. This is important where we aim to distinguish between zero/above-zero scores. */
        abstract public boolean isAbsolute();

        /** Obtains a value from a supplied pair that can be used in order to calculate the ranking. */
        abstract public long getValue(PairScore pair);
    }
}