Example usage for weka.core Instances randomize

List of usage examples for weka.core Instances randomize

Introduction

In this page you can find the example usage for weka.core Instances randomize.

Prototype

public void randomize(Random random) 

Source Link

Document

Shuffles the instances in the set so that they are ordered randomly.

Usage

From source file:uzholdem.classifier.OnlineMultilayerPerceptron.java

License:Open Source License

public void trainModel(Instances aInstances, int numIterations) throws Exception {

    // setup m_instances
    if (this.m_instances == null) {

        this.m_instances = new Instances(aInstances, 0, aInstances.size());
    }/*from  ww  w . j a  v  a  2  s .com*/
    ///////////

    if (m_useNomToBin) {
        if (this.m_nominalToBinaryFilter == null) {
            m_nominalToBinaryFilter = new NominalToBinary();
            try {
                m_nominalToBinaryFilter.setInputFormat(m_instances);
            } catch (Exception e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
                return;
            }
        }
        aInstances = Filter.useFilter(aInstances, m_nominalToBinaryFilter);
    }

    Instances epochInstances = new Instances(aInstances);
    epochInstances.randomize(new Random());

    Instances valSet = new Instances(aInstances, (int) (aInstances.size() * 0.3));
    for (int i = 0; i < valSet.size(); i++) {
        valSet.add(epochInstances.instance(0));
        epochInstances.delete(0);
    }

    m_instances = epochInstances;
    double right = 0;
    double driftOff = 0;
    double lastRight = Double.POSITIVE_INFINITY;
    double bestError = Double.POSITIVE_INFINITY;
    double tempRate;
    double totalWeight = 0;
    double totalValWeight = 0;
    double origRate = m_learningRate; //only used for when reset

    int numInVal = valSet.numInstances();

    for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
        if (!m_instances.instance(noa).classIsMissing()) {
            totalWeight += m_instances.instance(noa).weight();
        }
    }
    if (m_valSize != 0) {
        for (int noa = 0; noa < valSet.numInstances(); noa++) {
            if (!valSet.instance(noa).classIsMissing()) {
                totalValWeight += valSet.instance(noa).weight();
            }
        }
    }
    m_stopped = false;

    for (int noa = 1; noa < 50 + 1; noa++) {
        right = 0;
        for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
            m_currentInstance = m_instances.instance(nob);

            if (!m_currentInstance.classIsMissing()) {

                //this is where the network updating (and training occurs, for the
                //training set
                resetNetwork();
                calculateOutputs();
                tempRate = m_learningRate * m_currentInstance.weight();
                if (m_decay) {
                    tempRate /= noa;
                }

                right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight();
                updateNetworkWeights(tempRate, m_momentum);

            }

        }
        right /= totalWeight;
        if (Double.isInfinite(right) || Double.isNaN(right)) {

            m_instances = null;
            throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate.");

        }

        ////////////////////////do validation testing if applicable
        if (m_valSize != 0) {
            right = 0;
            for (int nob = 0; nob < valSet.numInstances(); nob++) {
                m_currentInstance = valSet.instance(nob);
                if (!m_currentInstance.classIsMissing()) {
                    //this is where the network updating occurs, for the validation set
                    resetNetwork();
                    calculateOutputs();
                    right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight();
                    //note 'right' could be calculated here just using
                    //the calculate output values. This would be faster.
                    //be less modular
                }

            }

            if (right < lastRight) {
                if (right < bestError) {
                    bestError = right;
                    // save the network weights at this point
                    for (int noc = 0; noc < m_numClasses; noc++) {
                        m_outputs[noc].saveWeights();
                    }
                    driftOff = 0;
                }
            } else {
                driftOff++;
            }
            lastRight = right;
            if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
                for (int noc = 0; noc < m_numClasses; noc++) {
                    m_outputs[noc].restoreWeights();
                }
                m_accepted = true;
            }
            right /= totalValWeight;
        }
        m_epoch = noa;
        m_error = right;
        //shows what the neuralnet is upto if a gui exists. 

        if (m_accepted) {
            m_instances = new Instances(m_instances, 0);
            return;
        }
    }

}