Example usage for weka.filters.unsupervised.attribute Remove setAttributeIndicesArray

List of usage examples for weka.filters.unsupervised.attribute Remove setAttributeIndicesArray

Introduction

In this page you can find the example usage for weka.filters.unsupervised.attribute Remove setAttributeIndicesArray.

Prototype

public void setAttributeIndicesArray(int[] attributes) 

Source Link

Document

Set which attributes are to be deleted (or kept if invert is true)

Usage

From source file:WrapperSubset.java

License:Open Source License

/**
 * Evaluates a subset of attributes//  ww  w.  j  av  a2  s.c o m
 *
 * @param subset a bitset representing the attribute subset to be evaluated
 * @return the error rate
 * @throws Exception if the subset could not be evaluated
 */
@Override
public double evaluateSubset(BitSet subset) throws Exception {

    //        if (subset.isEmpty())
    //            return 0.0;

    double evalMetric = 0;
    double[] repError = new double[5];
    int numAttributes = 0;
    int i, j;
    Random Rnd = new Random(m_seed);
    Remove delTransform = new Remove();
    delTransform.setInvertSelection(true);
    // copy the instances
    Instances trainCopy = new Instances(m_trainInstances);

    // count attributes set in the BitSet
    for (i = 0; i < m_numAttribs; i++) {
        if (subset.get(i)) {
            numAttributes++;
        }
    }

    // set up an array of attribute indexes for the filter (+1 for the class)
    int[] featArray = new int[numAttributes + 1];

    for (i = 0, j = 0; i < m_numAttribs; i++) {
        if (subset.get(i)) {
            featArray[j++] = i;
        }
    }

    featArray[j] = m_classIndex;
    delTransform.setAttributeIndicesArray(featArray);
    delTransform.setInputFormat(trainCopy);
    trainCopy = Filter.useFilter(trainCopy, delTransform);

    // max of 5 repetitions of cross validation
    for (i = 0; i < 5; i++) {
        m_Evaluation = new Evaluation(trainCopy);
        m_Evaluation.crossValidateModel(m_BaseClassifier, trainCopy, m_folds, Rnd);

        switch (m_evaluationMeasure) {
        case EVAL_DEFAULT:
            repError[i] = m_Evaluation.errorRate();
            //                     if (m_trainInstances.classAttribute().isNominal()) {
            //                     repError[i] = 1.0 - repError[i];
            //                     }
            break;
        case EVAL_ACCURACY:
            repError[i] = m_Evaluation.errorRate();
            //                     if (m_trainInstances.classAttribute().isNominal()) {
            //                     repError[i] = 1.0 - repError[i];
            //                     }
            break;
        case EVAL_RMSE:
            repError[i] = m_Evaluation.rootMeanSquaredError();
            break;
        case EVAL_MAE:
            repError[i] = m_Evaluation.meanAbsoluteError();
            break;
        case EVAL_FMEASURE:
            if (m_IRClassVal < 0) {
                repError[i] = m_Evaluation.weightedFMeasure();
            } else {
                repError[i] = m_Evaluation.fMeasure(m_IRClassVal);
            }
            break;
        case EVAL_AUC:
            if (m_IRClassVal < 0) {
                repError[i] = m_Evaluation.weightedAreaUnderROC();
            } else {
                repError[i] = m_Evaluation.areaUnderROC(m_IRClassVal);
            }
            break;
        case EVAL_AUPRC:
            if (m_IRClassVal < 0) {
                repError[i] = m_Evaluation.weightedAreaUnderPRC();
            } else {
                repError[i] = m_Evaluation.areaUnderPRC(m_IRClassVal);
            }
            break;
        case EVAL_NEW:
            repError[i] = (1.0 - m_Evaluation.errorRate()) + m_IRfactor * m_Evaluation.weightedFMeasure();
            break;
        }

        // check on the standard deviation
        if (!repeat(repError, i + 1)) {
            i++;
            break;
        }
    }

    for (j = 0; j < i; j++) {
        evalMetric += repError[j];
    }

    evalMetric /= i;
    m_Evaluation = null;

    switch (m_evaluationMeasure) {
    case EVAL_DEFAULT:
    case EVAL_ACCURACY:
    case EVAL_RMSE:
    case EVAL_MAE:
        if (m_trainInstances.classAttribute().isNominal()
                && (m_evaluationMeasure == EVAL_DEFAULT || m_evaluationMeasure == EVAL_ACCURACY)) {
            evalMetric = 1 - evalMetric;
        } else {
            evalMetric = -evalMetric; // maximize
        }

        break;
    }

    return evalMetric;
}

From source file:adams.data.instancesanalysis.FastICA.java

License:Open Source License

/**
 * Performs the actual analysis.//  w ww . j  av a2 s.  co m
 *
 * @param data   the data to analyze
 * @return      null if successful, otherwise error message
 * @throws Exception   if analysis fails
 */
@Override
protected String doAnalyze(Instances data) throws Exception {
    String result;
    Matrix matrix;
    Remove remove;

    result = null;
    m_Components = null;
    m_Sources = null;

    if (!m_AttributeRange.isAllRange()) {
        if (isLoggingEnabled())
            getLogger().info("Filtering attribute range: " + m_AttributeRange.getRange());
        remove = new Remove();
        remove.setAttributeIndicesArray(m_AttributeRange.getIntIndices());
        remove.setInvertSelection(true);
        remove.setInputFormat(data);
        data = Filter.useFilter(data, remove);
    }

    if (isLoggingEnabled())
        getLogger().info("Performing ICA...");
    matrix = m_ICA.transform(MatrixHelper.wekaToMatrixAlgo(MatrixHelper.getAll(data)));
    if (matrix != null) {
        m_Components = MatrixHelper.matrixToSpreadSheet(MatrixHelper.matrixAlgoToWeka(m_ICA.getComponents()),
                "Component-");
        m_Sources = MatrixHelper.matrixToSpreadSheet(MatrixHelper.matrixAlgoToWeka(m_ICA.getSources()),
                "Source-");
    } else {
        result = "Failed to transform data!";
    }

    return result;
}

From source file:adams.data.instancesanalysis.PCA.java

License:Open Source License

/**
 * Performs the actual analysis.//from   w  ww.  jav a  2s.c  o  m
 *
 * @param data   the data to analyze
 * @return      null if successful, otherwise error message
 * @throws Exception   if analysis fails
 */
@Override
protected String doAnalyze(Instances data) throws Exception {
    String result;
    Remove remove;
    PublicPrincipalComponents pca;
    int i;
    Capabilities caps;
    PartitionedMultiFilter2 part;
    Range rangeUnsupported;
    Range rangeSupported;
    TIntList listNominal;
    Range rangeNominal;
    ArrayList<ArrayList<Double>> coeff;
    Instances filtered;
    SpreadSheet transformed;
    WekaInstancesToSpreadSheet conv;
    String colName;

    result = null;
    m_Loadings = null;
    m_Scores = null;

    if (!m_AttributeRange.isAllRange()) {
        if (isLoggingEnabled())
            getLogger().info("Filtering attribute range: " + m_AttributeRange.getRange());
        remove = new Remove();
        remove.setAttributeIndicesArray(m_AttributeRange.getIntIndices());
        remove.setInvertSelection(true);
        remove.setInputFormat(data);
        data = Filter.useFilter(data, remove);
    }
    if (isLoggingEnabled())
        getLogger().info("Performing PCA...");

    listNominal = new TIntArrayList();
    if (m_SkipNominal) {
        for (i = 0; i < data.numAttributes(); i++) {
            if (i == data.classIndex())
                continue;
            if (data.attribute(i).isNominal())
                listNominal.add(i);
        }
    }

    // check for unsupported attributes
    caps = new PublicPrincipalComponents().getCapabilities();
    m_Supported = new TIntArrayList();
    m_Unsupported = new TIntArrayList();
    for (i = 0; i < data.numAttributes(); i++) {
        if (!caps.test(data.attribute(i)) || (i == data.classIndex()) || (listNominal.contains(i)))
            m_Unsupported.add(i);
        else
            m_Supported.add(i);
    }
    data.setClassIndex(-1);

    m_NumAttributes = m_Supported.size();

    // the principal components will delete the attributes without any distinct values.
    // this checks which instances will be kept.
    m_Kept = new ArrayList<>();
    for (i = 0; i < m_Supported.size(); i++) {
        if (data.numDistinctValues(m_Supported.get(i)) > 1)
            m_Kept.add(m_Supported.get(i));
    }

    // build a model using the PublicPrincipalComponents
    pca = new PublicPrincipalComponents();
    pca.setMaximumAttributes(m_MaxAttributes);
    pca.setVarianceCovered(m_Variance);
    pca.setMaximumAttributeNames(m_MaxAttributeNames);
    part = null;
    if (m_Unsupported.size() > 0) {
        rangeUnsupported = new Range();
        rangeUnsupported.setMax(data.numAttributes());
        rangeUnsupported.setIndices(m_Unsupported.toArray());
        rangeSupported = new Range();
        rangeSupported.setMax(data.numAttributes());
        rangeSupported.setIndices(m_Supported.toArray());
        part = new PartitionedMultiFilter2();
        part.setFilters(new Filter[] { pca, new AllFilter(), });
        part.setRanges(new weka.core.Range[] { new weka.core.Range(rangeSupported.getRange()),
                new weka.core.Range(rangeUnsupported.getRange()), });
    }
    try {
        if (part != null)
            part.setInputFormat(data);
        else
            pca.setInputFormat(data);
    } catch (Exception e) {
        result = Utils.handleException(this, "Failed to set data format", e);
    }

    transformed = null;
    if (result == null) {
        try {
            if (part != null)
                filtered = weka.filters.Filter.useFilter(data, part);
            else
                filtered = weka.filters.Filter.useFilter(data, pca);
        } catch (Exception e) {
            result = Utils.handleException(this, "Failed to apply filter", e);
            filtered = null;
        }
        if (filtered != null) {
            conv = new WekaInstancesToSpreadSheet();
            conv.setInput(filtered);
            result = conv.convert();
            if (result == null) {
                transformed = (SpreadSheet) conv.getOutput();
                // shorten column names again
                if (part != null) {
                    for (i = 0; i < transformed.getColumnCount(); i++) {
                        colName = transformed.getColumnName(i);
                        colName = colName.replaceFirst("filtered-[0-9]*-", "");
                        transformed.getHeaderRow().getCell(i).setContentAsString(colName);
                    }
                }
            }
        }
    }

    if (result == null) {
        // get the coefficients from the filter
        m_Scores = transformed;
        coeff = pca.getCoefficients();
        m_Loadings = extractLoadings(data, coeff);
        m_Loadings.setName("Loadings for " + data.relationName());
    }

    return result;
}

From source file:adams.data.instancesanalysis.PLS.java

License:Open Source License

/**
 * Performs the actual analysis.//from  ww w .j  a  va  2  s  . c  o m
 *
 * @param data   the data to analyze
 * @return      null if successful, otherwise error message
 * @throws Exception   if analysis fails
 */
@Override
protected String doAnalyze(Instances data) throws Exception {
    String result;
    Remove remove;
    weka.filters.supervised.attribute.PLS pls;
    WekaInstancesToSpreadSheet conv;
    SpreadSheet transformed;
    Matrix matrix;
    SpreadSheet loadings;
    Row row;
    int i;
    int n;

    m_Loadings = null;
    m_Scores = null;

    data = new Instances(data);
    data.deleteWithMissingClass();

    if (!m_AttributeRange.isAllRange()) {
        if (isLoggingEnabled())
            getLogger().info("Filtering attribute range: " + m_AttributeRange.getRange());
        remove = new Remove();
        remove.setAttributeIndicesArray(m_AttributeRange.getIntIndices());
        remove.setInvertSelection(true);
        remove.setInputFormat(data);
        data = Filter.useFilter(data, remove);
    }
    if (isLoggingEnabled())
        getLogger().info("Performing PLS...");

    pls = new weka.filters.supervised.attribute.PLS();
    pls.setAlgorithm(m_Algorithm);
    pls.setInputFormat(data);
    data = Filter.useFilter(data, pls);
    conv = new WekaInstancesToSpreadSheet();
    conv.setInput(data);
    result = conv.convert();
    if (result == null) {
        transformed = (SpreadSheet) conv.getOutput();
        matrix = pls.getLoadings();
        loadings = new DefaultSpreadSheet();
        for (i = 0; i < matrix.getColumnDimension(); i++)
            loadings.getHeaderRow().addCell("L-" + (i + 1)).setContentAsString("Loading-" + (i + 1));
        for (n = 0; n < matrix.getRowDimension(); n++) {
            row = loadings.addRow();
            for (i = 0; i < matrix.getColumnDimension(); i++)
                row.addCell("L-" + (i + 1)).setContent(matrix.get(n, i));
        }
        m_Loadings = loadings;
        m_Scores = transformed;
    }

    return result;
}

From source file:adams.flow.transformer.AbstractWekaPredictionsTransformer.java

License:Open Source License

/**
 * Filters the data accordingly to the selected attribute range.
 *
 * @param data   the data to filter/*from www  . ja va  2 s .co  m*/
 * @return      the filtered data, null if filtering failed
 */
protected Instances filterTestData(Instances data) {
    int[] indices;
    Remove remove;

    try {
        m_TestAttributes.setMax(data.numAttributes());
        indices = m_TestAttributes.getIntIndices();
        remove = new Remove();
        remove.setAttributeIndicesArray(indices);
        remove.setInvertSelection(true);
        remove.setInputFormat(data);

        return Filter.useFilter(data, remove);
    } catch (Exception e) {
        getLogger().log(Level.SEVERE, "Failed to filter test data using range: " + m_TestAttributes, e);
        return null;
    }
}

From source file:adams.flow.transformer.WekaInstancesMerge.java

License:Open Source License

/**
 * Executes the flow item./*from w  w  w  .ja  v  a  2 s  .c  o  m*/
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    String[] filesStr;
    File[] files;
    int i;
    Instances output;
    Instances[] orig;
    Instances[] inst;
    Instance[] rows;
    HashSet ids;
    int max;
    TIntList uniqueList;
    Remove remove;

    result = null;

    // get filenames
    files = null;
    orig = null;
    if (m_InputToken.getPayload() instanceof String[]) {
        filesStr = (String[]) m_InputToken.getPayload();
        files = new File[filesStr.length];
        for (i = 0; i < filesStr.length; i++)
            files[i] = new PlaceholderFile(filesStr[i]);
    } else if (m_InputToken.getPayload() instanceof File[]) {
        files = (File[]) m_InputToken.getPayload();
    } else if (m_InputToken.getPayload() instanceof Instance[]) {
        rows = (Instance[]) m_InputToken.getPayload();
        orig = new Instances[rows.length];
        for (i = 0; i < rows.length; i++) {
            orig[i] = new Instances(rows[i].dataset(), 1);
            orig[i].add((Instance) rows[i].copy());
        }
    } else if (m_InputToken.getPayload() instanceof Instances[]) {
        orig = (Instances[]) m_InputToken.getPayload();
    } else {
        throw new IllegalStateException("Unhandled input type: " + m_InputToken.getPayload().getClass());
    }

    try {
        output = null;

        // simple merge
        if (m_UniqueID.length() == 0) {
            if (files != null) {
                inst = new Instances[1];
                for (i = 0; i < files.length; i++) {
                    if (isStopped())
                        break;
                    inst[0] = DataSource.read(files[i].getAbsolutePath());
                    inst[0] = prepareData(inst[0], i);
                    if (i == 0) {
                        output = inst[0];
                    } else {
                        if (isLoggingEnabled())
                            getLogger().info("Merging with file #" + (i + 1) + ": " + files[i]);
                        output = Instances.mergeInstances(output, inst[0]);
                    }
                }
            } else if (orig != null) {
                inst = new Instances[1];
                for (i = 0; i < orig.length; i++) {
                    if (isStopped())
                        break;
                    inst[0] = prepareData(orig[i], i);
                    if (i == 0) {
                        output = inst[0];
                    } else {
                        if (isLoggingEnabled())
                            getLogger()
                                    .info("Merging with dataset #" + (i + 1) + ": " + orig[i].relationName());
                        output = Instances.mergeInstances(output, inst[0]);
                    }
                }
            }
        }
        // merge based on row IDs
        else {
            m_AttType = -1;
            max = 0;
            m_UniqueIDAtts = new ArrayList<>();
            if (files != null) {
                orig = new Instances[files.length];
                for (i = 0; i < files.length; i++) {
                    if (isStopped())
                        break;
                    if (isLoggingEnabled())
                        getLogger().info("Loading file #" + (i + 1) + ": " + files[i]);
                    orig[i] = DataSource.read(files[i].getAbsolutePath());
                    max = Math.max(max, orig[i].numInstances());
                }
            } else if (orig != null) {
                for (i = 0; i < orig.length; i++)
                    max = Math.max(max, orig[i].numInstances());
            }
            inst = new Instances[orig.length];
            ids = new HashSet(max);
            for (i = 0; i < orig.length; i++) {
                if (isStopped())
                    break;
                if (isLoggingEnabled())
                    getLogger().info("Updating IDs #" + (i + 1));
                updateIDs(i, orig[i], ids);
                if (isLoggingEnabled())
                    getLogger().info("Preparing dataset #" + (i + 1));
                inst[i] = prepareData(orig[i], i);
            }
            output = merge(orig, inst, ids);
            // remove unnecessary unique ID attributes
            if (m_KeepOnlySingleUniqueID) {
                uniqueList = new TIntArrayList();
                for (String att : m_UniqueIDAtts)
                    uniqueList.add(output.attribute(att).index());
                if (uniqueList.size() > 0) {
                    if (isLoggingEnabled())
                        getLogger().info("Removing duplicate unique ID attributes: " + m_UniqueIDAtts);
                    remove = new Remove();
                    remove.setAttributeIndicesArray(uniqueList.toArray());
                    remove.setInputFormat(output);
                    output = Filter.useFilter(output, remove);
                }
            }
        }

        if (!isStopped()) {
            m_OutputToken = new Token(output);
            updateProvenance(m_OutputToken);
        }
    } catch (Exception e) {
        result = handleException("Failed to merge: ", e);
    }

    return result;
}

From source file:adams.ml.data.InstancesView.java

License:Open Source License

/**
 * Returns a spreadsheet containing only output columns, i.e., the class
 * columns.//w w  w .j av a2s .co  m
 *
 * @return      the output features, null if data has no class columns
 */
@Override
public SpreadSheet getOutputs() {
    Instances data;
    Remove remove;

    if (m_Data.classIndex() == -1)
        return null;

    data = new Instances(m_Data);
    data.setClassIndex(-1);
    remove = new Remove();
    remove.setAttributeIndicesArray(new int[] { m_Data.classIndex() });
    remove.setInvertSelection(true);
    try {
        remove.setInputFormat(data);
        data = Filter.useFilter(data, remove);
        return new InstancesView(data);
    } catch (Exception e) {
        throw new IllegalStateException("Failed to apply Remove filter!", e);
    }
}

From source file:app.RunApp.java

License:Open Source License

/**
 * Load dataset in principal tab// w w w  .j ava2 s .  c o m
 * 
 * @param returnVal Positive number if successfull and negative otherwise
 * @param fileChooser Chooser
 * @param deleteXML Boolean indicating if the generated xml must be removed
 * @return 
 */
private int loadDataset(int returnVal, JFileChooser fileChooser, boolean deleteXML) {
    if (returnVal == JFileChooser.OPEN_DIALOG) {
        File f1 = fileChooser.getSelectedFile();
        datasetName = f1.getName();
        datasetCurrentName = datasetName.substring(0, datasetName.length() - 5);

        String arffFilename = f1.getAbsolutePath();

        xmlPath = arffFilename.substring(0, arffFilename.length() - 5) + ".xml";
        xmlFilename = DataIOUtils.getFileName(xmlPath);

        File fileTmp = new File(xmlPath);

        FileReader fr;
        try {
            views.clear();
            ((DefaultTableModel) jTable2.getModel()).getDataVector().removeAllElements();
            ((DefaultTableModel) jTable3.getModel()).getDataVector().removeAllElements();

            fr = new FileReader(arffFilename);
            BufferedReader bf = new BufferedReader(fr);

            String sString = bf.readLine();

            if (sString.contains("-V:")) {
                mv = true;

                TabPrincipal.setEnabledAt(7, true);
                String s2 = sString.split("'")[1];
                s2 = s2.split("-V:")[1];
                String[] intervals = s2.split("!");
                Vector<Vector<Integer>> newIntervals = new Vector<>();
                int[] intervalsSize = new int[intervals.length];
                int max = Integer.MIN_VALUE;
                int min = Integer.MAX_VALUE;
                double mean = 0;

                for (int i = 0; i < intervals.length; i++) {
                    newIntervals.add(new Vector<Integer>());
                    String[] aux2;

                    viewsIntervals.put("View " + (i + 1), intervals[i]);

                    if (intervals[i].contains(",")) {
                        aux2 = intervals[i].split(",");
                        for (int j = 0; j < aux2.length; j++) {
                            if (aux2[j].contains("-")) {
                                int a = Integer.parseInt(aux2[j].split("-")[0]);
                                int b = Integer.parseInt(aux2[j].split("-")[1]);

                                for (int k = a; k <= b; k++) {
                                    newIntervals.get(i).add(k);
                                }
                            } else {
                                newIntervals.get(i).add(Integer.parseInt(aux2[j]));
                            }

                        }
                    } else {
                        if (intervals[i].contains("-")) {
                            int a = Integer.parseInt(intervals[i].split("-")[0]);
                            int b = Integer.parseInt(intervals[i].split("-")[1]);

                            for (int k = a; k <= b; k++) {
                                newIntervals.get(i).add(k);
                            }
                        } else {
                            newIntervals.get(i).add(Integer.parseInt(intervals[i]));
                        }
                    }
                }

                for (int i = 0; i < newIntervals.size(); i++) {

                    Integer[] indices = new Integer[newIntervals.get(i).size()];
                    for (int j = 0; j < newIntervals.get(i).size(); j++) {
                        indices[j] = newIntervals.get(i).get(j);
                    }

                    System.out.println(Arrays.toString(indices));

                    views.put("View " + (i + 1), indices);

                    if (newIntervals.get(i).size() > max) {
                        max = newIntervals.get(i).size();
                    }
                    if (newIntervals.get(i).size() < min) {
                        min = newIntervals.get(i).size();
                    }
                    mean += newIntervals.get(i).size();
                }

                mean /= intervalsSize.length;
                labelNumViewsValue.setText(Integer.toString(intervalsSize.length));
                labelMaxNumAttrViewValue.setText(Integer.toString(max));
                labelMinNumAttrViewValue.setText(Integer.toString(min));
                labelMeanNumAttrViewValue.setText(Double.toString(mean));
            } else {
                TabPrincipal.setEnabledAt(7, false);
                mv = false;
            }

            int labelFound = 0;
            String labelName;
            String[] labelNamesFound;

            if (DataIOUtils.isMeka(sString)) {
                deleteXML = true;
                isMeka = true;

                int labelCount = DataIOUtils.getLabelsFromARFF(sString);

                if (labelCount > 0) {
                    labelNamesFound = new String[labelCount];

                    while (labelFound < labelCount) {
                        sString = bf.readLine();
                        labelName = DataIOUtils.getLabelNameFromLine(sString);

                        if (labelName != null) {
                            labelNamesFound[labelFound] = labelName;
                            labelFound++;
                        }
                    }
                } else {
                    labelCount = Math.abs(labelCount);
                    labelNamesFound = new String[labelCount];

                    String[] sStrings = new String[labelCount];

                    while (!(sString = bf.readLine()).contains("@data")) {
                        if (!sString.trim().equals("")) {
                            for (int s = 0; s < labelCount - 1; s++) {
                                sStrings[s] = sStrings[s + 1];
                            }
                            sStrings[labelCount - 1] = sString;
                        }
                    }

                    for (int i = 0; i < labelCount; i++) {
                        labelName = DataIOUtils.getLabelNameFromLine(sStrings[i]);

                        if (labelName != null) {
                            labelNamesFound[labelFound] = labelName;
                            labelFound++;
                        }
                    }
                }

                BufferedWriter bwXml = new BufferedWriter(new FileWriter(xmlPath));
                PrintWriter wrXml = new PrintWriter(bwXml);

                DataIOUtils.writeXMLFile(wrXml, labelNamesFound);

                bwXml.close();
                wrXml.close();

                xmlFilename = DataIOUtils.getFilePath(xmlPath);
                fileTmp = new File(xmlPath);
            } else {
                isMeka = false;
            }
        } catch (FileNotFoundException ex) {
            Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            return -1;
        } catch (IOException ex) {
            Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            return -1;
        }

        if (!fileTmp.exists()) {
            xmlPath = DataIOUtils.getXMLString(arffFilename);
            xmlFilename = DataIOUtils.getFilePath(xmlPath);
        }

        //Enable tabs
        TabPrincipal.setEnabledAt(1, true);
        TabPrincipal.setEnabledAt(2, true);
        TabPrincipal.setEnabledAt(3, true);
        TabPrincipal.setEnabledAt(4, true);
        TabPrincipal.setEnabledAt(5, true);
        TabPrincipal.setEnabledAt(6, true);

        try {
            File f = new File(xmlFilename);
            if (f.exists() && !f.isDirectory()) {
                //MultiLabelInstances dataset_temp = new MultiLabelInstances(filename_database_arff, xmlFilename);
            } else {
                JOptionPane.showMessageDialog(null, "File could not be loaded.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }
        } catch (Exception ex) {
            Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            progressBar.setVisible(false);
            progressFrame.setVisible(false);
            progressFrame.repaint();
            return -1;
        }

        initTableMetrics();
        clearTableMetricsPrincipal();

        File f = new File(xmlFilename);
        if (f.exists() && !f.isDirectory()) {
            loadDataset(arffFilename, xmlFilename);
        } else {
            loadDataset(arffFilename, null);
        }

        if (deleteXML) {
            File f2 = new File(xmlFilename);
            f2.delete();
        }

        textChooseFile.setText(arffFilename);
    }

    if (mv) {

        if (((DefaultTableModel) jTable2.getModel()).getRowCount() > 0) {
            ((DefaultTableModel) jTable2.getModel()).getDataVector().removeAllElements();
        }

        for (int i = 0; i < views.size(); i++) {
            MultiLabelInstances view = dataset.clone();

            try {
                Instances inst = view.getDataSet();

                int[] attributes = Utils.toPrimitive(views.get("View " + (i + 1)));

                int[] toKeep = new int[attributes.length + dataset.getNumLabels()];
                System.arraycopy(attributes, 0, toKeep, 0, attributes.length);
                int[] labelIndices = dataset.getLabelIndices();
                System.arraycopy(labelIndices, 0, toKeep, attributes.length, dataset.getNumLabels());

                Remove filterRemove = new Remove();
                filterRemove.setAttributeIndicesArray(toKeep);
                filterRemove.setInvertSelection(true);
                filterRemove.setInputFormat(inst);

                MultiLabelInstances modifiedDataset = new MultiLabelInstances(
                        Filter.useFilter(view.getDataSet(), filterRemove), dataset.getLabelsMetaData());

                LxIxF lif = new LxIxF();
                lif.calculate(modifiedDataset);
                RatioInstancesToAttributes ratioInstAtt = new RatioInstancesToAttributes();
                ratioInstAtt.calculate(modifiedDataset);
                AvgGainRatio avgGainRatio = new AvgGainRatio();
                avgGainRatio.calculate(modifiedDataset);

                ((DefaultTableModel) jTable2.getModel()).addRow(
                        new Object[] { "View " + (i + 1), attributes.length, getMetricValueFormatted(lif),
                                getMetricValueFormatted(ratioInstAtt), getMetricValueFormatted(avgGainRatio) });
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    return 1;
}

From source file:data.statistics.MLStatistics.java

License:Open Source License

/**
 * Calculates Phi and Chi-square correlation matrix.
 *
 * @param dataSet//from   w  w w  .  ja v  a  2 s.  c  o  m
 *            A multi-label dataset.
 * @throws java.lang.Exception
 *             To be handled in an upper level.
 */
public void calculatePhiChi2(MultiLabelInstances dataSet) throws Exception {
    numLabels = dataSet.getNumLabels();

    // The indices of the label attributes
    int[] labelIndices;

    labelIndices = dataSet.getLabelIndices();
    numLabels = dataSet.getNumLabels();
    phi = new double[numLabels][numLabels];
    chi2 = new double[numLabels][numLabels];

    Remove remove = new Remove();
    remove.setInvertSelection(true);
    remove.setAttributeIndicesArray(labelIndices);
    remove.setInputFormat(dataSet.getDataSet());
    Instances result = Filter.useFilter(dataSet.getDataSet(), remove);
    result.setClassIndex(result.numAttributes() - 1);

    for (int i = 0; i < numLabels; i++) {
        int a[] = new int[numLabels];
        int b[] = new int[numLabels];
        int c[] = new int[numLabels];
        int d[] = new int[numLabels];
        double e[] = new double[numLabels];
        double f[] = new double[numLabels];
        double g[] = new double[numLabels];
        double h[] = new double[numLabels];
        for (int j = 0; j < result.numInstances(); j++) {
            for (int l = 0; l < numLabels; l++) {
                if (result.instance(j).stringValue(i).equals("0")) {
                    if (result.instance(j).stringValue(l).equals("0")) {
                        a[l]++;
                    } else {
                        c[l]++;
                    }
                } else {
                    if (result.instance(j).stringValue(l).equals("0")) {
                        b[l]++;
                    } else {
                        d[l]++;
                    }
                }
            }
        }
        for (int l = 0; l < numLabels; l++) {
            e[l] = a[l] + b[l];
            f[l] = c[l] + d[l];
            g[l] = a[l] + c[l];
            h[l] = b[l] + d[l];
            double mult = e[l] * f[l] * g[l] * h[l];
            double denominator = Math.sqrt(mult);
            double nominator = a[l] * d[l] - b[l] * c[l];
            phi[i][l] = nominator / denominator;
            chi2[i][l] = phi[i][l] * phi[i][l] * (a[l] + b[l] + c[l] + d[l]);
        }
    }
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.EvaluateSimilarityFunction.java

License:Open Source License

public static ClusterContingencyTable evaluateClassifier(final Classifier classifier, final Instances test) {
    try {// w w  w.  j a  v a  2s .  c  o  m
        final Map<Integer, Set<RealVector>> Umap = new TreeMap<Integer, Set<RealVector>>();
        final Map<Integer, Set<RealVector>> Vmap = new TreeMap<Integer, Set<RealVector>>();

        final Remove rm_filter = new Remove();
        rm_filter.setAttributeIndicesArray(new int[] { test.classIndex() });
        rm_filter.setInputFormat(test);

        for (final Instance i : test) {
            rm_filter.input(i);
            final double[] phi = rm_filter.output().toDoubleArray();
            //            final double[] phi = WekaUtil.unlabeledFeatures( i );

            final int cluster = (int) classifier.classifyInstance(i);
            Set<RealVector> u = Umap.get(cluster);
            if (u == null) {
                u = new HashSet<RealVector>();
                Umap.put(cluster, u);
            }
            u.add(new ArrayRealVector(phi));

            final int true_label = (int) i.classValue();
            Set<RealVector> v = Vmap.get(true_label);
            if (v == null) {
                v = new HashSet<RealVector>();
                Vmap.put(true_label, v);
            }
            v.add(new ArrayRealVector(phi));
        }

        final ArrayList<Set<RealVector>> U = new ArrayList<Set<RealVector>>();
        for (final Map.Entry<Integer, Set<RealVector>> e : Umap.entrySet()) {
            U.add(e.getValue());
        }

        final ArrayList<Set<RealVector>> V = new ArrayList<Set<RealVector>>();
        for (final Map.Entry<Integer, Set<RealVector>> e : Vmap.entrySet()) {
            V.add(e.getValue());
        }

        return new ClusterContingencyTable(U, V);
    } catch (final RuntimeException ex) {
        throw ex;
    } catch (final Exception ex) {
        throw new RuntimeException(ex);
    }
}