Example usage for weka.core Instances setClassIndex

List of usage examples for weka.core Instances setClassIndex

Introduction

In this page you can find the example usage for weka.core Instances setClassIndex.

Prototype

public void setClassIndex(int classIndex) 

Source Link

Document

Sets the class index of the set.

Usage

From source file:com.sliit.neuralnetwork.RecurrentNN.java

public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException {
    System.out.println("calling trainModel");
    try {/* w  ww.j  a  v a  2 s .co m*/

        System.out.println("Neural Network Training start");
        loadSaveNN(modelName, false);
        if (model == null) {

            buildModel();
        }

        File fileGeneral = new File(filePath);
        CSVLoader loader = new CSVLoader();
        loader.setSource(fileGeneral);
        Instances instances = loader.getDataSet();
        instances.setClassIndex(instances.numAttributes() - 1);
        StratifiedRemoveFolds stratified = new StratifiedRemoveFolds();
        String[] options = new String[6];
        options[0] = "-N";
        options[1] = Integer.toString(5);
        options[2] = "-F";
        options[3] = Integer.toString(1);
        options[4] = "-S";
        options[5] = Integer.toString(1);
        stratified.setOptions(options);
        stratified.setInputFormat(instances);
        stratified.setInvertSelection(false);
        Instances testInstances = Filter.useFilter(instances, stratified);
        stratified.setInvertSelection(true);
        Instances trainInstances = Filter.useFilter(instances, stratified);
        String directory = fileGeneral.getParent();
        CSVSaver saver = new CSVSaver();
        File trainFile = new File(directory + "/" + "normtrainadded.csv");
        File testFile = new File(directory + "/" + "normtestadded.csv");
        if (trainFile.exists()) {

            trainFile.delete();
        }
        trainFile.createNewFile();
        if (testFile.exists()) {

            testFile.delete();
        }
        testFile.createNewFile();
        saver.setFile(trainFile);
        saver.setInstances(trainInstances);
        saver.writeBatch();
        saver = new CSVSaver();
        saver.setFile(testFile);
        saver.setInstances(testInstances);
        saver.writeBatch();
        SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ",");
        recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile));
        SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ",");
        testReader.initialize(new org.datavec.api.split.FileSplit(testFile));
        DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                recordReader, 2, outputs, inputsTot, false);
        DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                testReader, 2, outputs, inputsTot, false);
        roc = new ArrayList<Map<String, Double>>();
        String statMsg = "";
        Evaluation evaluation;

        for (int i = 0; i < 100; i++) {
            if (i % 2 == 0) {

                model.fit(iterator);
                evaluation = model.evaluate(testIterator);
            } else {

                model.fit(testIterator);
                evaluation = model.evaluate(iterator);
            }
            Map<String, Double> map = new HashMap<String, Double>();
            Map<Integer, Integer> falsePositives = evaluation.falsePositives();
            Map<Integer, Integer> trueNegatives = evaluation.trueNegatives();
            Map<Integer, Integer> truePositives = evaluation.truePositives();
            Map<Integer, Integer> falseNegatives = evaluation.falseNegatives();
            double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1));
            double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1));
            map.put("FPR", fpr);
            map.put("TPR", tpr);
            roc.add(map);
            statMsg = evaluation.stats();
            iterator.reset();
            testIterator.reset();
        }
        loadSaveNN(modelName, true);
        System.out.println("ROC " + roc);

        return statMsg;

    } catch (Exception e) {
        e.printStackTrace();
        System.out.println("Error ocuured while building neural netowrk :" + e.getMessage());
        throw new NeuralException(e.getLocalizedMessage(), e);
    }
}

From source file:com.sliit.normalize.NormalizeDataset.java

public String normalizeDataset() {
    System.out.println("start normalizing data");

    String filePathOut = "";
    try {/*from   ww  w  . j  ava2  s.co  m*/

        CSVLoader loader = new CSVLoader();
        if (reducedDiemensionFile != null) {

            loader.setSource(reducedDiemensionFile);
        } else {
            if (tempFIle != null && tempFIle.exists()) {

                loader.setSource(tempFIle);
            } else {

                loader.setSource(csvFile);
            }
        }
        Instances dataInstance = loader.getDataSet();
        Normalize normalize = new Normalize();
        dataInstance.setClassIndex(dataInstance.numAttributes() - 1);
        normalize.setInputFormat(dataInstance);
        String directory = csvFile.getParent();
        outputFile = new File(directory + "/" + "normalized" + csvFile.getName());
        if (!outputFile.exists()) {

            outputFile.createNewFile();
        }
        CSVSaver saver = new CSVSaver();
        saver.setFile(outputFile);
        for (int i = 1; i < dataInstance.numInstances(); i++) {

            normalize.input(dataInstance.instance(i));
        }
        normalize.batchFinished();
        Instances outPut = new Instances(dataInstance, 0);
        for (int i = 1; i < dataInstance.numInstances(); i++) {

            outPut.add(normalize.output());
        }
        Attribute attribute = dataInstance.attribute(outPut.numAttributes() - 1);
        for (int j = 0; j < attribute.numValues(); j++) {

            if (attribute.value(j).equals("normal.")) {
                outPut.renameAttributeValue(attribute, attribute.value(j), "0");
            } else {
                outPut.renameAttributeValue(attribute, attribute.value(j), "1");
            }
        }
        saver.setInstances(outPut);
        saver.writeBatch();
        writeToNewFile(directory);
        filePathOut = directory + "norm" + csvFile.getName();
        if (tempFIle != null) {

            tempFIle.delete();
        }
        if (reducedDiemensionFile != null) {

            reducedDiemensionFile.delete();
        }
        outputFile.delete();
    } catch (IOException e) {

        log.error("Error occurred:" + e.getMessage());
    } catch (Exception e) {

        log.error("Error occurred:" + e.getMessage());
    }
    return filePathOut;
}

From source file:com.sliit.normalize.NormalizeDataset.java

public int whiteningData() {
    System.out.println("whiteningData");

    int nums = 0;
    try {/*from w w w  .j  av a 2 s.c  o  m*/

        if (tempFIle != null && tempFIle.exists()) {

            csv.setSource(tempFIle);
        } else {

            csv.setSource(csvFile);
        }
        Instances instances = csv.getDataSet();
        if (instances.numAttributes() > 10) {
            instances.setClassIndex(instances.numAttributes() - 1);
            RandomProjection random = new RandomProjection();
            random.setDistribution(
                    new SelectedTag(RandomProjection.GAUSSIAN, RandomProjection.TAGS_DSTRS_TYPE));
            reducedDiemensionFile = new File(csvFile.getParent() + "/tempwhite.csv");
            if (!reducedDiemensionFile.exists()) {

                reducedDiemensionFile.createNewFile();
            }
            // CSVSaver saver = new CSVSaver();
            /// saver.setFile(reducedDiemensionFile);
            random.setInputFormat(instances);
            //saver.setRetrieval(AbstractSaver.INCREMENTAL);
            BufferedWriter writer = new BufferedWriter(new FileWriter(reducedDiemensionFile));
            for (int i = 0; i < instances.numInstances(); i++) {

                random.input(instances.instance(i));
                random.setNumberOfAttributes(10);
                random.setReplaceMissingValues(true);
                writer.write(random.output().toString());
                writer.newLine();
                //saver.writeIncremental(random.output());
            }
            writer.flush();
            writer.close();
            nums = random.getNumberOfAttributes();
        } else {

            nums = instances.numAttributes();
        }
    } catch (IOException e) {

        log.error("Error occurred:" + e.getMessage());
    } catch (Exception e) {

        log.error("Error occurred:" + e.getMessage());
    }
    return nums;
}

From source file:com.sliit.rules.RuleContainer.java

public String predictionResult(String filePath) throws Exception {

    File testPath = new File(filePath);
    CSVLoader loader = new CSVLoader();
    loader.setSource(testPath);/*from w  w  w .j a va  2  s.com*/
    Instances testInstances = loader.getDataSet();
    testInstances.setClassIndex(testInstances.numAttributes() - 1);
    Evaluation eval = new Evaluation(testInstances);
    eval.evaluateModel(ruleMoldel, testInstances);
    ArrayList<Prediction> predictions = eval.predictions();
    int predictedVal = (int) predictions.get(0).predicted();
    String cdetails = instances.classAttribute().value(predictedVal);
    return cdetails;
}

From source file:com.sliit.views.DataVisualizerPanel.java

void getRocCurve() {
    try {/*  w  w  w .  j a v  a2  s.c om*/
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        // display curve
        String plotName = vmc.getName();
        final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
        jf.setSize(500, 400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(vmc, BorderLayout.CENTER);
        jf.addWindowListener(new java.awt.event.WindowAdapter() {
            public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
            }
        });
        jf.setVisible(true);
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.KNNView.java

void getRocCurve() {
    try {/*from  ww  w  . ja va2s.c  om*/
        Instances data;
        data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        // train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        rocPanel.removeAll();
        rocPanel.add(vmc, "vmc", 0);
        rocPanel.revalidate();

    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.sliit.views.SVMView.java

/**
 * draw ROC curve/* w  ww  .ja  va2  s.  c o m*/
 */
void getRocCurve() {
    try {
        Instances data;
        data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText())));
        data.setClassIndex(data.numAttributes() - 1);

        //train classifier
        Classifier cl = new NaiveBayes();
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cl, data, 10, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        //            rocPanel.removeAll();
        //            rocPanel.add(vmc, "vmc", 0);
        //            rocPanel.revalidate();
    } catch (IOException ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.spread.experiment.tempuntilofficialrelease.ClassificationViaClustering108.java

License:Open Source License

/**
 * builds the classifier/*from ww  w.  ja v a 2s . co  m*/
 * 
 * @param data the training instances
 * @throws Exception if something goes wrong
 */
@Override
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // save original header (needed for clusters to classes output)
    m_OriginalHeader = data.stringFreeStructure();

    // remove class attribute for clusterer
    Instances clusterData = new Instances(data);
    clusterData.setClassIndex(-1);
    clusterData.deleteAttributeAt(data.classIndex());
    m_ClusteringHeader = clusterData.stringFreeStructure();

    if (m_ClusteringHeader.numAttributes() == 0) {
        System.err.println("Data contains only class attribute, defaulting to ZeroR model.");
        m_ZeroR = new ZeroR();
        m_ZeroR.buildClassifier(data);
    } else {
        m_ZeroR = null;

        // build clusterer
        m_ActualClusterer = AbstractClusterer.makeCopy(m_Clusterer);
        m_ActualClusterer.buildClusterer(clusterData);

        if (!getLabelAllClusters()) {

            // determine classes-to-clusters mapping
            ClusterEvaluation eval = new ClusterEvaluation();
            eval.setClusterer(m_ActualClusterer);
            eval.evaluateClusterer(clusterData);
            double[] clusterAssignments = eval.getClusterAssignments();
            int[][] counts = new int[eval.getNumClusters()][m_OriginalHeader.numClasses()];
            int[] clusterTotals = new int[eval.getNumClusters()];
            double[] best = new double[eval.getNumClusters() + 1];
            double[] current = new double[eval.getNumClusters() + 1];
            for (int i = 0; i < data.numInstances(); i++) {
                Instance instance = data.instance(i);
                if (!instance.classIsMissing()) {
                    counts[(int) clusterAssignments[i]][(int) instance.classValue()]++;
                    clusterTotals[(int) clusterAssignments[i]]++;
                }
            }
            best[eval.getNumClusters()] = Double.MAX_VALUE;
            ClusterEvaluation.mapClasses(eval.getNumClusters(), 0, counts, clusterTotals, current, best, 0);
            m_ClustersToClasses = new double[best.length];
            System.arraycopy(best, 0, m_ClustersToClasses, 0, best.length);
        } else {
            m_ClusterClassProbs = new double[m_ActualClusterer.numberOfClusters()][data.numClasses()];
            for (int i = 0; i < data.numInstances(); i++) {
                Instance clusterInstance = clusterData.instance(i);
                Instance originalInstance = data.instance(i);
                if (!originalInstance.classIsMissing()) {
                    double[] probs = m_ActualClusterer.distributionForInstance(clusterInstance);
                    for (int j = 0; j < probs.length; j++) {
                        m_ClusterClassProbs[j][(int) originalInstance.classValue()] += probs[j];
                    }
                }
            }
            for (int i = 0; i < m_ClusterClassProbs.length; i++) {
                Utils.normalize(m_ClusterClassProbs[i]);
            }
        }
    }
}

From source file:com.tum.classifiertest.FastRfUtils.java

License:Open Source License

/**
 * Load a dataset into memory.// w ww . j  a  v  a  2 s. com
 *
 * @param location the location of the dataset
 *
 * @return the dataset
 */
public static Instances readInstances(String location) throws Exception {
    Instances data = new weka.core.converters.ConverterUtils.DataSource(location).getDataSet();
    if (data.classIndex() == -1)
        data.setClassIndex(data.numAttributes() - 1);
    return data;
}

From source file:com.yahoo.labs.samoa.instances.SamoaToWekaInstanceConverter.java

License:Apache License

/**
* Weka instances information./*  w  ww  .j a va 2 s . c  o m*/
*
* @param instances the instances
* @return the weka.core. instances
*/
public weka.core.Instances wekaInstancesInformation(Instances instances) {
    weka.core.Instances wekaInstances;
    ArrayList<weka.core.Attribute> attInfo = new ArrayList<weka.core.Attribute>();
    for (int i = 0; i < instances.numAttributes(); i++) {
        attInfo.add(wekaAttribute(i, instances.attribute(i)));
    }
    wekaInstances = new weka.core.Instances(instances.getRelationName(), attInfo, 0);
    if (instances.instanceInformation.numOutputAttributes() == 1) {
        wekaInstances.setClassIndex(instances.classIndex());
    } else {
        //Assign a classIndex to a MultiLabel instance for compatibility reasons
        wekaInstances.setClassIndex(instances.instanceInformation.numOutputAttributes() - 1); //instances.numAttributes()-1); //Last
    }
    //System.out.println(attInfo.get(3).name());
    //System.out.println(attInfo.get(3).isNominal());
    //System.out.println(wekaInstances.attribute(3).name());
    //System.out.println(wekaInstances.attribute(3).isNominal());
    return wekaInstances;
}