Example usage for weka.core Instances get

List of usage examples for weka.core Instances get

Introduction

In this page you can find the example usage for weka.core Instances get.

Prototype



@Override
publicInstance get(int index) 

Source Link

Document

Returns the instance at the given position.

Usage

From source file:jmetal.test.survivalanalysis.GenerateSurvivalGraph.java

License:Open Source License

/** 
 * Evaluates a solution /*from w w w.  j ava 2s  .  c om*/
 * @param solution The solution to evaluate
 */
public void evaluate(Solution solution) {
    Binary variable;
    int counterSelectedFeatures;

    DataSource source;

    double testStatistic = Double.MAX_VALUE;
    double pValue = Double.MAX_VALUE;
    double ArithmeticHarmonicCutScore = Double.MAX_VALUE;
    //double statScore;
    REXP x;

    variable = ((Binary) solution.getDecisionVariables()[0]);

    counterSelectedFeatures = 0;

    try {
        // read the data file 
        source = new DataSource(this.dataFileName);
        Instances data = source.getDataSet();
        //System.out.print("Data read successfully. ");
        //System.out.print("Number of attributes: " + data.numAttributes());
        //System.out.println(". Number of instances: " + data.numInstances());

        // save the attribute 'T' and 'Censor'
        attTime = data.attribute(data.numAttributes() - 2);
        attCensor = data.attribute(data.numAttributes() - 1);

        // First filter the attributes based on chromosome
        Instances tmpData = this.filterByChromosome(data, solution);

        // Now filter the attribute 'T' and 'Censor'
        Remove filter = new Remove();
        // remove the two last attributes : 'T' and 'Censor'
        filter.setAttributeIndices("" + (tmpData.numAttributes() - 1) + "," + tmpData.numAttributes());
        //System.out.println("After chromosome filtering no of attributes: " + tmpData.numAttributes());
        filter.setInputFormat(tmpData);
        Instances dataClusterer = Filter.useFilter(tmpData, filter);

        // filtering complete

        // List the selected features/attributes
        Enumeration<Attribute> attributeList = dataClusterer.enumerateAttributes();
        System.out.println("Selected attributes/features: ");
        while (attributeList.hasMoreElements()) {
            Attribute att = attributeList.nextElement();
            System.out.print(att.name() + ",");
        }

        System.out.println();

        /*
        // debug: write the filtered dataset
                
         ArffSaver saver = new ArffSaver();
         saver.setInstances(dataClusterer);
         saver.setFile(new File("filteered-data.arff"));
         saver.writeBatch();
        // end debug
                
        */

        // train hierarchical clusterer

        HierarchicalClusterer clusterer = new HierarchicalClusterer();
        clusterer.setOptions(new String[] { "-L", this.HC_LinkType });
        //Link type (Single, Complete, Average, Mean, Centroid, Ward, Adjusted complete, Neighbor Joining)
        //[SINGLE|COMPLETE|AVERAGE|MEAN|CENTROID|WARD|ADJCOMPLETE|NEIGHBOR_JOINING]

        //clusterer.setDebug(true);
        clusterer.setNumClusters(2);
        clusterer.setDistanceFunction(new EuclideanDistance());
        clusterer.setDistanceIsBranchLength(false); // ?? Should it be changed to false? (Noman)

        clusterer.buildClusterer(dataClusterer);

        double[][] distanceMatrix = clusterer.getDistanceMatrix();

        // Cluster evaluation:
        ClusterEvaluation eval = new ClusterEvaluation();
        eval.setClusterer(clusterer);

        if (this.testDataFileName != null) {

            DataSource testSource = new DataSource(this.testDataFileName);

            Instances tmpTestData = testSource.getDataSet();
            tmpTestData.setClassIndex(tmpTestData.numAttributes() - 1);
            //testSource.

            // First filter the attributes based on chromosome
            Instances testData = this.filterByChromosome(tmpTestData, solution);
            //String[] options = new String[2];
            //options[0] = "-t";
            //options[1] = "/some/where/somefile.arff";
            //eval.
            //System.out.println(eval.evaluateClusterer(testData, options));
            eval.evaluateClusterer(testData);
            System.out.println("\nCluster evluation for this solution(" + this.testDataFileName + "): "
                    + eval.clusterResultsToString());
        }

        // First analyze using my library function

        // save the cluster assignments

        int[] clusterAssignment = new int[dataClusterer.numInstances()];
        int classOneCnt = 0;
        int classTwoCnt = 0;
        for (int i = 0; i < dataClusterer.numInstances(); ++i) {
            clusterAssignment[i] = clusterer.clusterInstance(dataClusterer.get(i));
            if (clusterAssignment[i] == 0) {
                ++classOneCnt;
            } else if (clusterAssignment[i] == 1) {
                ++classTwoCnt;
            }
            //System.out.println("Instance " + i + ": " + clusterAssignment[i]);
        }

        System.out.println("Class 1 cnt: " + classOneCnt + " Class 2 cnt: " + classTwoCnt);

        // create arrays with time (event occurrence time) and censor data for use with jstat LogRankTest
        double[] time1 = new double[classOneCnt];
        double[] censor1 = new double[classOneCnt];
        double[] time2 = new double[classTwoCnt];
        double[] censor2 = new double[classTwoCnt];

        //data = source.getDataSet();
        for (int i = 0, cnt1 = 0, cnt2 = 0; i < dataClusterer.numInstances(); ++i) {
            //clusterAssignment[i] = clusterer.clusterInstance(dataClusterer.get(i));
            if (clusterAssignment[i] == 0) {
                time1[cnt1] = data.get(i).value(attTime);
                censor1[cnt1++] = data.get(i).value(attCensor);
                //System.out.println("i: " + i + " T: " + time1[cnt1-1]);
            } else if (clusterAssignment[i] == 1) {
                time2[cnt2] = data.get(i).value(attTime);
                //System.out.println("i: " + i + " T: " + time2[cnt2-1]);
                censor2[cnt2++] = data.get(i).value(attCensor);
                ;
            }
            //System.out.println("Instance " + i + ": " + clusterAssignment[i]);
        }

        //Instances[] classInstances = separateClassInstances(clusterAssignment, this.dataFileName,solution);
        //System.out.println("Class instances seperated");

        // calculate log rank test and p values

        LogRankTest testclass1 = new LogRankTest(time1, time2, censor1, censor2);
        double[] scores = testclass1.logRank();
        testStatistic = scores[0];
        pValue = scores[2];

        ArithmeticHarmonicCutScore = this.getArithmeticHarmonicCutScore(distanceMatrix, clusterAssignment);
        //debug:
        System.out.println("Calculation by myLibrary:\n testStatistic: " + scores[0] + " pValue: " + scores[2]
                + " Arithmetic Harmonic Cut Score: " + ArithmeticHarmonicCutScore);
        //end debug
        //WilcoxonTest testclass1 = new WilcoxonTest(time1, censor1, time2, censor2);
        //testStatistic = testclass1.testStatistic;
        //pValue = testclass1.pValue;true

        // Now analyze calling R for Log Rank test, Parallelization not possible

        String strT = "time <- c(";
        String strC = "censor <- c(";
        String strG = "group <- c(";

        for (int i = 0; i < dataClusterer.numInstances() - 1; ++i) {
            strT = strT + (int) data.get(i).value(attTime) + ",";
            strG = strG + clusterer.clusterInstance(dataClusterer.get(i)) + ",";
            strC = strC + (int) data.get(i).value(attCensor) + ",";
        }

        int tmpi = dataClusterer.numInstances() - 1;
        strT = strT + (int) data.get(tmpi).value(attTime) + ")";
        strG = strG + clusterer.clusterInstance(dataClusterer.get(tmpi)) + ")";
        strC = strC + (int) data.get(tmpi).value(attCensor) + ")";

        this.re.eval(strT);
        this.re.eval(strC);
        this.re.eval(strG);

        //debug
        //System.out.println(strT);
        //System.out.println(strC);
        //System.out.println(strG);
        //end debug

        /** If you are calling surv_test from coin library */
        /*v
        re.eval("library(coin)");
        re.eval("grp <- factor (group)");
        re.eval("result <- surv_test(Surv(time,censor)~grp,distribution=\"exact\")");
                
        x=re.eval("statistic(result)");
        testStatistic = x.asDouble();
        //x=re.eval("pvalue(result)");
        //pValue = x.asDouble();
        //System.out.println("StatScore: " + statScore + "pValue: " + pValue);
         */

        /** If you are calling survdiff from survival library (much faster) */
        re.eval("library(survival)");
        re.eval("res2 <- survdiff(Surv(time,censor)~group,rho=0)");
        x = re.eval("res2$chisq");
        testStatistic = x.asDouble();
        //System.out.println(x);
        x = re.eval("pchisq(res2$chisq, df=1, lower.tail = FALSE)");
        //x = re.eval("1.0 - pchisq(res2$chisq, df=1)");
        pValue = x.asDouble();
        //debug:
        //System.out.println("Calculation by R: StatScore: " + testStatistic + "pValue: " + pValue);
        //end debug

        System.out.println("Calculation by R:");
        System.out.println("StatScore: " + testStatistic + "  pValue: " + pValue);

        re.eval("timestrata1.surv <- survfit( Surv(time, censor)~ strata(group), conf.type=\"log-log\")");
        re.eval("timestrata1.surv1 <- survfit( Surv(time, censor)~ 1, conf.type=\"none\")");
        String evalStr = "jpeg('SurvivalPlot-" + this.SolutionID + ".jpg')";
        re.eval(evalStr);
        re.eval("plot(timestrata1.surv, col=c(2,3), xlab=\"Time\", ylab=\"Survival Probability\")");
        re.eval("par(new=T)");
        re.eval("plot(timestrata1.surv1,col=1)");
        re.eval("legend(0.2, c(\"Group1\",\"Group2\",\"Whole\"))");
        re.eval("dev.off()");

        System.out.println("\nCluster Assignments:");
        for (int i = 0; i < dataClusterer.numInstances(); ++i) {
            System.out.println("Instance " + i + ": " + clusterAssignment[i]);
        }

    } catch (Exception e) {
        // TODO Auto-generated catch block
        System.err.println("Can't open the data file.");
        e.printStackTrace();
        System.exit(1);
    }

}

From source file:jmetal.test.survivalanalysis.GenerateSurvivalGraphOld.java

License:Open Source License

/** 
 * Evaluates a solution - actually generate the survival graph 
 * @param solution The solution to evaluate
 *///from ww  w.j  a va  2  s  .  c o m
public void evaluate(Solution solution) {
    Binary variable;
    int counterSelectedFeatures;

    DataSource source;

    double testStatistic = Double.MAX_VALUE;
    double pValue = Double.MAX_VALUE;
    //double statScore;
    REXP x;

    variable = ((Binary) solution.getDecisionVariables()[0]);

    counterSelectedFeatures = 0;

    System.out.println("\nSolution ID " + this.SolutionID);

    try {
        // read the data file 
        source = new DataSource(this.dataFileName);
        Instances data = source.getDataSet();
        //System.out.print("Data read successfully. ");
        //System.out.print("Number of attributes: " + data.numAttributes());
        //System.out.println(". Number of instances: " + data.numInstances());

        // save the attribute 'T' and 'Censor'
        attTime = data.attribute(data.numAttributes() - 2);
        attCensor = data.attribute(data.numAttributes() - 1);

        // First filter the attributes based on chromosome
        Instances tmpData = this.filterByChromosome(data, solution);

        // Now filter the attribute 'T' and 'Censor'
        Remove filter = new Remove();
        // remove the two last attributes : 'T' and 'Censor'
        filter.setAttributeIndices("" + (tmpData.numAttributes() - 1) + "," + tmpData.numAttributes());
        //System.out.println("After chromosome filtering no of attributes: " + tmpData.numAttributes());
        filter.setInputFormat(tmpData);
        Instances dataClusterer = Filter.useFilter(tmpData, filter);

        Enumeration<Attribute> attributeList = dataClusterer.enumerateAttributes();
        System.out.println("Selected attributes: ");
        while (attributeList.hasMoreElements()) {
            Attribute att = attributeList.nextElement();
            System.out.print(att.name() + ",");
        }

        System.out.println();
        // filtering complete

        // Debug: write the filtered dataset
        /*
        ArffSaver saver = new ArffSaver();
        saver.setInstances(dataClusterer);
        saver.setFile(new File("filteered-data.arff"));
        saver.writeBatch();
         */

        // train hierarchical clusterer

        HierarchicalClusterer clusterer = new HierarchicalClusterer();
        clusterer.setOptions(new String[] { "-L", "COMPLETE" }); // complete linkage clustering
        //clusterer.setDebug(true);
        clusterer.setNumClusters(2);
        clusterer.setDistanceFunction(new EuclideanDistance());
        //clusterer.setDistanceFunction(new ChebyshevDistance());
        clusterer.setDistanceIsBranchLength(false);

        clusterer.buildClusterer(dataClusterer);

        // Cluster evaluation:
        ClusterEvaluation eval = new ClusterEvaluation();
        eval.setClusterer(clusterer);

        if (this.testDataFileName != null) {

            DataSource testSource = new DataSource(this.testDataFileName);

            Instances tmpTestData = testSource.getDataSet();
            tmpTestData.setClassIndex(tmpTestData.numAttributes() - 1);
            //testSource.

            // First filter the attributes based on chromosome
            Instances testData = this.filterByChromosome(tmpTestData, solution);
            //String[] options = new String[2];
            //options[0] = "-t";
            //options[1] = "/some/where/somefile.arff";
            //eval.
            //System.out.println(eval.evaluateClusterer(testData, options));
            eval.evaluateClusterer(testData);
            System.out.println("\nCluster evluation for this solution: " + eval.clusterResultsToString());
        }

        // Print the cluster assignments:

        // save the cluster assignments
        //if (printClusterAssignment==true){
        int[] clusterAssignment = new int[dataClusterer.numInstances()];
        int classOneCnt = 0;
        int classTwoCnt = 0;
        for (int i = 0; i < dataClusterer.numInstances(); ++i) {
            clusterAssignment[i] = clusterer.clusterInstance(dataClusterer.get(i));
            if (clusterAssignment[i] == 0) {
                ++classOneCnt;
            } else if (clusterAssignment[i] == 1) {
                ++classTwoCnt;
            }
            //System.out.println("Instance " + i + ": " + clusterAssignment[i]);
        }

        System.out.println("Class 1 cnt: " + classOneCnt + " Class 2 cnt: " + classTwoCnt);
        //}

        /*
                
                         
                 // create arrays with time (event occurrence time) and censor data for use with jstat LogRankTest
                 double[] time1 = new double[classOneCnt];   
                 double[] censor1 = new double[classOneCnt];
                 double[] time2 = new double[classTwoCnt];
                 double[] censor2 = new double[classTwoCnt];
                
                
                 //data = source.getDataSet();
                 for (int i=0, cnt1=0, cnt2=0; i<dataClusterer.numInstances(); ++i){
                    clusterAssignment[i] = clusterer.clusterInstance(dataClusterer.get(i));
                    if (clusterAssignment[i]==0){
                       time1[cnt1] = data.get(i).value(attTime);
                       censor1[cnt1++] = 1;
                       //System.out.println("i: " + i + " T: " + time1[cnt1-1]);
                    }
                    else if (clusterAssignment[i]==1){
                       time2[cnt2] = data.get(i).value(attTime);
                       //System.out.println("i: " + i + " T: " + time2[cnt2-1]);
                       censor2[cnt2++] = 1;
                    }
                    //System.out.println("Instance " + i + ": " + clusterAssignment[i]);
                 }
                
                
                
                 //Instances[] classInstances = separateClassInstances(clusterAssignment, this.dataFileName,solution);
                 //System.out.println("Class instances seperated");
                
                 // calculate log rank test and p values
                         
                 //LogRankTest testclass1 = new LogRankTest(time1, censor1, time2, censor2);
                 //testStatistic = testclass1.testStatistic;
                 //pValue = testclass1.pValue;
                
                
                 WilcoxonTest testclass1 = new WilcoxonTest(time1, censor1, time2, censor2);
                 testStatistic = testclass1.testStatistic;
                 pValue = testclass1.pValue;true
        */

        String strT = "time1 <- c(";
        String strC = "censor1 <- c(";
        String strG = "group1 <- c(";

        for (int i = 0; i < dataClusterer.numInstances() - 1; ++i) {
            strT = strT + (int) data.get(i).value(attTime) + ",";
            strG = strG + clusterer.clusterInstance(dataClusterer.get(i)) + ",";
            strC = strC + (int) data.get(i).value(attCensor) + ",";

        }

        int tmpi = dataClusterer.numInstances() - 1;
        strT = strT + (int) data.get(tmpi).value(attTime) + ")";
        strG = strG + clusterer.clusterInstance(dataClusterer.get(tmpi)) + ")";
        strC = strC + (int) data.get(tmpi).value(attCensor) + ")";

        this.re.eval(strT);
        this.re.eval(strC);
        this.re.eval(strG);

        // for MyLogRankTest

        double[] time1 = new double[classOneCnt];
        double[] time2 = new double[classTwoCnt];
        double[] censor1 = new double[classOneCnt];
        double[] censor2 = new double[classTwoCnt];

        int i1 = 0, i2 = 0;

        for (int i = 0; i < dataClusterer.numInstances(); ++i) {

            strT = strT + (int) data.get(i).value(attTime) + ",";
            strG = strG + clusterer.clusterInstance(dataClusterer.get(i)) + ",";
            strC = strC + (int) data.get(i).value(attCensor) + ",";

            if (clusterer.clusterInstance(dataClusterer.get(i)) == 0) {
                time1[i1] = data.get(i).value(attTime);
                censor1[i1] = data.get(i).value(attCensor);
                ++i1;
            } else {
                time2[i2] = data.get(i).value(attTime);
                censor2[i2] = data.get(i).value(attCensor);
                ++i2;
            }

        }

        /** If you are calling surv_test from coin library */
        /*v
        re.eval("library(coin)");
        re.eval("grp <- factor (group)");
        re.eval("result <- surv_test(Surv(time,censor)~grp,distribution=\"exact\")");
                
        x=re.eval("statistic(result)");
        testStatistic = x.asDouble();
        //x=re.eval("pvalue(result)");
        //pValue = x.asDouble();
        //System.out.println("StatScore: " + statScore + "pValue: " + pValue);
        */

        /** If you are calling survdiff from survival library (much faster) */
        re.eval("library(survival)");
        re.eval("res21 <- survdiff(Surv(time1,censor1)~group1,rho=0)");
        x = re.eval("res21$chisq");
        testStatistic = x.asDouble();
        //System.out.println(x);
        x = re.eval("pchisq(res21$chisq, df=1, lower.tail = FALSE)");
        //x = re.eval("1.0 - pchisq(res2$chisq, df=1)");
        pValue = x.asDouble();
        System.out.println("Results from R:");
        System.out.println("StatScore: " + testStatistic + "  pValue: " + pValue);

        re.eval("timestrata1.surv <- survfit( Surv(time1, censor1)~ strata(group1), conf.type=\"log-log\")");
        re.eval("timestrata1.surv1 <- survfit( Surv(time1, censor1)~ 1, conf.type=\"none\")");
        String evalStr = "jpeg('SurvivalPlot-" + this.SolutionID + ".jpg')";
        re.eval(evalStr);
        re.eval("plot(timestrata1.surv, col=c(2,3), xlab=\"Time\", ylab=\"Survival Probability\")");
        re.eval("par(new=T)");
        re.eval("plot(timestrata1.surv1,col=1)");
        re.eval("legend(0.2, c(\"Group1\",\"Group2\",\"Whole\"))");
        re.eval("dev.off()");

        System.out.println("Results from my code: ");
        LogRankTest lrt = new LogRankTest(time1, time2, censor1, censor2);
        double[] results = lrt.logRank();
        System.out.println("Statistics: " + results[0] + " variance: " + results[1] + " pValue: " + results[2]);

    } catch (Exception e) {
        // TODO Auto-generated catch block
        System.err.println("Can't open the data file.");
        e.printStackTrace();
        System.exit(1);
    }

    /**********
     *  Current Implementation considers two objectives
     *  1. pvalue to be minimized / statistical score to be maximized
     *  2. Number of Features to be maximized/minimized
     */

}

From source file:lfsom.data.LFSData.java

License:Apache License

/**
 * Gets the data from a csv file./* w  ww.ja v a 2  s  .co  m*/
 * 
 * @param fileName
 */

public LFSData(String fileName) {
    Class claseCargador = CSVLoader.class;

    if (fileName.endsWith(ArffLoader.FILE_EXTENSION)) {
        claseCargador = ArffLoader.class;
    } else {
        if (fileName.endsWith(JSONLoader.FILE_EXTENSION)) {
            claseCargador = JSONLoader.class;
        } else {
            if (fileName.endsWith(MatlabLoader.FILE_EXTENSION)) {
                claseCargador = MatlabLoader.class;
            } else {
                if (fileName.endsWith(XRFFLoader.FILE_EXTENSION)) {
                    claseCargador = XRFFLoader.class;
                } else {
                    if (fileName.endsWith(C45Loader.FILE_EXTENSION)) {
                        claseCargador = C45Loader.class;
                    }
                }
            }
        }
    }

    try {
        AbstractFileLoader cargador = (AbstractFileLoader) claseCargador.getConstructor().newInstance();
        boolean cambio_col = false;

        cargador.setSource(new File(fileName));

        Instances data1 = cargador.getDataSet();

        double[][] matrix2 = new double[data1.size()][data1.numAttributes()];

        for (int i = 0; i < data1.size(); i++) {
            matrix2[i] = data1.get(i).toDoubleArray();
        }

        // Ahora se comprueba si todas las columnas son ok

        Integer[] colVale;
        dim = 0;

        if (data1.size() > 0) {
            colVale = new Integer[matrix2[0].length];
            double[] stdevX = StatisticSample.stddeviation(matrix2);

            for (int k = 0; k < matrix2[0].length; k++) {
                if (Math.abs(stdevX[k]) >= 0.000000001) {
                    colVale[k] = dim;
                    dim++;
                } else {
                    colVale[k] = -1;
                    cambio_col = true;
                }
            }

        } else {
            dim = data1.numAttributes();
            colVale = new Integer[dim];
            for (int k = 0; k < dim; k++) {
                colVale[k] = k;
            }
        }

        double[][] matrixAssign = new double[matrix2.length][dim];

        if (cambio_col) {
            for (int k = 0; k < matrix2.length; k++) {
                for (int w = 0; w < matrix2[0].length; w++) {
                    if (colVale[w] != -1) {
                        matrixAssign[k][colVale[w]] = matrix2[k][w];
                    }
                }

            }

        } else {
            matrixAssign = matrix2;
        }

        // Fin de la comprobacion

        setLabels(new String[dim]);
        for (int i = 0; i < data1.numAttributes(); i++) {
            if (colVale[i] != -1) {
                getLabels()[colVale[i]] = data1.attribute(i).name();
            }
        }

        BufferedWriter br = new BufferedWriter(new FileWriter("d:/tmp/fich.csv"));
        StringBuilder sb = new StringBuilder();

        for (int i = 0; i < matrixAssign.length; i++) {
            String cad = String.valueOf(matrixAssign[i][0]);
            for (int k = 1; k < matrixAssign[i].length; k++)
                cad += "," + matrixAssign[i][k];
            sb.append(cad + "\n");
        }

        br.write(sb.toString());
        br.close();

        setMatrix(matrixAssign);

    } catch (Exception e) {
        e.printStackTrace();
        System.exit(1);
    }
}

From source file:lu.lippmann.cdb.lab.kmeans.KmeansImproved.java

License:Open Source License

/**
 * /*from ww  w .ja  va2 s . c o  m*/
 * @param instances
 * @param k
 * @param clusters_sizes
 * @param clusters_centroids
 * @return
 */
private double R2(SimpleKMeans kMeans) {
    //int k, int[] clusters_sizes, Instances clusters_centroids){
    final int k = kMeans.getNumClusters();
    final int[] clusters_sizes = kMeans.getClusterSizes();
    final Instances clusters_centroids = kMeans.getClusterCentroids();
    double inter, total;
    double[] weights = new double[k];
    double[] centroid = new double[instances.numAttributes()];
    final int N = instances.numInstances();
    final double instance_weight = 1.0;
    inter = total = 0;

    //Computing the centroid of the entire set
    for (int i = 0; i < N; i++) {
        final Instance instance = instances.get(i);
        double[] temp = instance.toDoubleArray();
        for (int j = 0; j < temp.length; j++)
            centroid[j] += temp[j];
    }
    for (int j = 0; j < centroid.length; j++) {
        centroid[j] = centroid[j] / N;
    }

    for (int i = 0; i < k; i++) {
        weights[i] = (0.0 + clusters_sizes[i]) / N;
    }

    final Instance centroid_G = new DenseInstance(instance_weight, centroid);
    for (int i = 0; i < N; i++) {
        total += Math.pow(distance.distance(instances.instance(i), centroid_G), 2);
    }
    total = total / N;

    for (int i = 0; i < k; i++) {
        inter += weights[i] * Math.pow(distance.distance(clusters_centroids.get(i), centroid_G), 2);
    }

    return (inter / total);
}

From source file:lu.lippmann.cdb.lab.mds.ClassicMDS.java

License:Open Source License

/**
 * //from  www.ja  va2 s. c o  m
 */
public static CollapsedInstances distanceBetweenInstances(final Instances instances,
        final MDSDistancesEnum distEnum, final int maxInstances, final boolean ignoreClassInDistance)
        throws Exception {
    KmeansResult mapCentroids = null;

    final NormalizableDistance usedDist;
    if (distEnum.equals(MDSDistancesEnum.EUCLIDEAN)) {
        usedDist = new EuclideanDistance(instances);
        //usedDist.setDontNormalize(true);
        //usedDist.setAttributeIndices("1");
        //usedDist.setInvertSelection(true);
    } else if (distEnum.equals(MDSDistancesEnum.MANHATTAN))
        usedDist = new ManhattanDistance(instances);
    else if (distEnum.equals(MDSDistancesEnum.MINKOWSKI)) {
        usedDist = new MinkowskiDistance(instances);
        final String[] parameters = MDSDistancesEnum.MINKOWSKI.getParameters();
        //Change order
        double order = Double.valueOf(parameters[0]).doubleValue();
        ((MinkowskiDistance) usedDist).setOrder(order);
    } else if (distEnum.equals(MDSDistancesEnum.CHEBYSHEV))
        usedDist = new ChebyshevDistance(instances);
    //else if (distEnum.equals(MDSDistancesEnum.DT)) usedDist=new DTDistance(instances);
    else
        throw new IllegalStateException();

    final int numInstances = instances.numInstances();
    final boolean collapsed = (numInstances > maxInstances)
            && (distEnum.equals(MDSDistancesEnum.EUCLIDEAN) || distEnum.equals(MDSDistancesEnum.MANHATTAN));

    SimpleMatrix distances;

    //Ignore class in distance
    if (ignoreClassInDistance && instances.classIndex() != -1) {
        usedDist.setAttributeIndices("" + (instances.classIndex() + 1));
        usedDist.setInvertSelection(true);
    }

    int numCollapsedInstances = numInstances;
    if (collapsed) {
        //Compute distance with centroids using K-means with K=MAX_INSTANCES
        mapCentroids = getSimplifiedInstances(instances, usedDist, maxInstances);

        final List<Instance> centroids = mapCentroids.getCentroids();
        numCollapsedInstances = centroids.size();

        distances = new SimpleMatrix(numCollapsedInstances, numCollapsedInstances);

        for (int i = 0; i < numCollapsedInstances; i++) {
            for (int j = i + 1; j < numCollapsedInstances; j++) {
                double dist = usedDist.distance(centroids.get(i), centroids.get(j));
                distances.set(i, j, dist);
                distances.set(j, i, dist);
            }
        }
    } else {
        distances = new SimpleMatrix(numCollapsedInstances, numCollapsedInstances);
        for (int i = 0; i < numCollapsedInstances; i++) {
            for (int j = i + 1; j < numCollapsedInstances; j++) {
                double dist = usedDist.distance(instances.get(i), instances.get(j));
                distances.set(i, j, dist);
                distances.set(j, i, dist);
            }
        }
    }
    return new CollapsedInstances(instances, mapCentroids, distances, collapsed);
}

From source file:lu.lippmann.cdb.lab.mds.MDSViewBuilder.java

License:Open Source License

/**
 * //from w  ww .  j  a  v  a  2  s  . co  m
 */
public static JXPanel buildMDSViewFromDataSet(final Instances instances, final MDSResult mdsResult,
        final int maxInstances, final Listener<Instances> listener, final String... attrNameToUseAsPointTitle)
        throws Exception {

    final XYSeriesCollection dataset = new XYSeriesCollection();

    final JFreeChart chart = ChartFactory.createScatterPlot("", // title 
            "X", "Y", // axis labels 
            dataset, // dataset 
            PlotOrientation.VERTICAL, attrNameToUseAsPointTitle.length == 0, // legend? 
            true, // tooltips? yes 
            false // URLs? no 
    );

    final XYPlot xyPlot = (XYPlot) chart.getPlot();

    xyPlot.setBackgroundPaint(Color.WHITE);
    xyPlot.getDomainAxis().setTickLabelsVisible(false);
    xyPlot.getRangeAxis().setTickLabelsVisible(false);

    //FIXME : should be different for Shih
    if (!mdsResult.isNormalized()) {
        String stress = FormatterUtil.DECIMAL_FORMAT
                .format(ClassicMDS.getKruskalStressFromMDSResult(mdsResult));
        chart.setTitle(mdsResult.getCInstances().isCollapsed()
                ? "Collapsed MDS(Instances=" + maxInstances + ",Stress=" + stress + ")"
                : "MDS(Stress=" + stress + ")");
    } else {
        chart.setTitle(mdsResult.getCInstances().isCollapsed() ? "Collapsed MDS(Instances=" + maxInstances + ")"
                : "MDS");
    }

    final SimpleMatrix coordinates = mdsResult.getCoordinates();
    buildFilteredSeries(mdsResult, xyPlot, attrNameToUseAsPointTitle);

    final ChartPanel chartPanel = new ChartPanel(chart);
    chartPanel.setMouseWheelEnabled(true);
    chartPanel.setPreferredSize(new Dimension(1200, 900));
    chartPanel.setBorder(new TitledBorder("MDS Projection"));
    chartPanel.setBackground(Color.WHITE);

    final JButton selectionButton = new JButton("Select data");
    selectionButton.addActionListener(new ActionListener() {

        @Override
        public void actionPerformed(ActionEvent e) {
            final org.jfree.data.Range XDomainRange = xyPlot.getDomainAxis().getRange();
            final org.jfree.data.Range YDomainRange = xyPlot.getRangeAxis().getRange();
            final Instances cInstances = mdsResult.getCollapsedInstances();
            final Instances selectedInstances = new Instances(cInstances, 0);
            List<Instances> clusters = null;
            if (mdsResult.getCInstances().isCollapsed()) {
                clusters = mdsResult.getCInstances().getCentroidMap().getClusters();
            }
            for (int i = 0; i < cInstances.numInstances(); i++) {
                final Instance centroid = instances.instance(i);
                if (XDomainRange.contains(coordinates.get(i, 0))
                        && YDomainRange.contains(coordinates.get(i, 1))) {
                    if (mdsResult.getCInstances().isCollapsed()) {
                        if (clusters != null) {
                            final Instances elementsOfCluster = clusters.get(i);
                            final int nbElements = elementsOfCluster.numInstances();
                            for (int k = 0; k < nbElements; k++) {
                                selectedInstances.add(elementsOfCluster.get(k));
                            }
                        }
                    } else {
                        selectedInstances.add(centroid);
                    }
                }
            }
            if (listener != null) {
                listener.onAction(selectedInstances);
            }
        }
    });

    final JXPanel allPanel = new JXPanel();
    allPanel.setLayout(new BorderLayout());
    allPanel.add(chartPanel, BorderLayout.CENTER);
    final JXPanel southPanel = new JXPanel();
    southPanel.add(selectionButton);
    allPanel.add(southPanel, BorderLayout.SOUTH);
    return allPanel;
}

From source file:machinelearningcw.EnhancedLinearPerceptron.java

public void calculateMeansAndSTDev(Instances instances) {
    means = new double[instances.numAttributes() - 1];//intialize means
    std = new double[instances.numAttributes() - 1];//intialize stdevs
    for (int j = 0; j < instances.numAttributes() - 1; j++) {
        Stats s = new Stats();

        for (int i = 0; i < instances.numInstances(); i++) {
            s.add(instances.get(i).value(j));//adds values to calc std
        }/*w  ww  .  j  a v  a  2s  .c om*/
        s.calculateDerived(); //calculates mean and stdDev
        means[j] = s.mean;
        std[j] = s.stdDev;
    }

}

From source file:machinelearningproject.RFTree.java

public Instances bootstrap(Instances instances) {
    Instances randomInstances = new Instances(instances, instances.numInstances());

    for (int i = 0; i < instances.numInstances(); i++) {
        int rand = new Random().nextInt(instances.numInstances());
        randomInstances.add(instances.get(rand));
    }// ww  w .  j  a va 2s.  c o  m

    return randomInstances;
}

From source file:machinelearningproject.RFTree.java

@Override
public Tree buildTree(Instances instances) throws Exception {
    Tree tree = new Tree();
    ArrayList<String> availableAttributes = new ArrayList();
    int largestInfoGainAttrIdx = -1;
    double largestInfoGainAttrValue = 0.0;

    //choose random fraction
    int numAttr = instances.numAttributes();
    int k = (int) round(sqrt(numAttr));
    ArrayList<Integer> randomIdx = randomFraction(numAttr);

    for (int idx = 0; idx < k; idx++) {
        if (idx != instances.classIndex()) {
            availableAttributes.add(instances.attribute(idx).name());
        }/*from w w  w  .java2s  .co  m*/
    }

    if (instances.numInstances() == 0) {
        return null;
    } else if (calculateClassEntropy(instances) == 0.0) {
        // all examples have the sama classification
        tree.attributeName = instances.get(0).stringValue(instances.classIndex());
    } else if (availableAttributes.isEmpty()) {
        // mode classification
        tree.attributeName = getModeClass(instances, instances.classIndex());
    } else {
        for (int idx = 0; idx < instances.numAttributes(); idx++) {
            if (idx != instances.classIndex()) {
                double attrInfoGain = calculateInformationGain(instances, idx, instances.classIndex());
                if (largestInfoGainAttrValue < attrInfoGain) {
                    largestInfoGainAttrIdx = idx;
                    largestInfoGainAttrValue = attrInfoGain;
                }
            }
        }

        if (largestInfoGainAttrIdx != -1) {
            tree.attributeName = instances.attribute(largestInfoGainAttrIdx).name();
            ArrayList<String> attrValues = new ArrayList();
            for (int i = 0; i < instances.numInstances(); i++) {
                Instance instance = instances.get(i);
                String attrValue = instance.stringValue(largestInfoGainAttrIdx);
                if (attrValues.isEmpty() || !attrValues.contains(attrValue)) {
                    attrValues.add(attrValue);
                }
            }

            for (String attrValue : attrValues) {
                Node node = new Node(attrValue);
                Instances copyInstances = new Instances(instances);
                copyInstances.setClassIndex(instances.classIndex());
                int i = 0;
                while (i < copyInstances.numInstances()) {
                    Instance instance = copyInstances.get(i);
                    // reducing examples
                    if (!instance.stringValue(largestInfoGainAttrIdx).equals(attrValue)) {
                        copyInstances.delete(i);
                        i--;
                    }
                    i++;
                }
                copyInstances.deleteAttributeAt(largestInfoGainAttrIdx);
                node.subTree = buildTree(copyInstances);
                tree.nodes.add(node);
            }
        }
    }

    return tree;
}

From source file:machinelearningproject.Tree.java

public String getModeClass(Instances instances, int classIdx) {
    HashMap<String, Integer> classMap = new HashMap<>();
    int numInstances = instances.size();

    for (int i = 0; i < numInstances; i++) {
        Instance instance = instances.get(i);
        String key = instance.stringValue(classIdx);
        if (classMap.isEmpty() || !classMap.containsKey(key)) {
            classMap.put(key, 1);/* w  w  w .j a  v  a2s .  com*/
        } else {
            if (classMap.containsKey(key)) {
                classMap.put(key, classMap.get(key) + 1);
            }
        }
    }
    Iterator<String> keySetIterator = classMap.keySet().iterator();
    String modeClass = "";
    int count = 0;
    while (keySetIterator.hasNext()) {
        String key = keySetIterator.next();
        System.out.println("key: " + key + " value: " + classMap.get(key));
        if (count < classMap.get(key)) {
            modeClass = key;
            count = classMap.get(key);
        }
    }
    return modeClass;
}