List of usage examples for weka.core Instances delete
public void delete(int index)
From source file:mulan.data.IterativeStratification.java
License:Open Source License
private Instances[] takeTheInstancesOfTheLabel(Instances workingSet, int numLabels, int[] labelIndices, int[] desiredLabel) { // In the returnedInstance in the [0] index is the filtered instances for the desired label // while on the [1] index is the remaining workingSet returned Instances[] returnedInstances = new Instances[2]; Instances filteredInstancesOfLabel = new Instances(workingSet, 0); int numInstances = workingSet.numInstances(); boolean[] trueLabels = new boolean[numLabels]; int[] removedIndexes = new int[desiredLabel[1]]; int count = 0; // Firstly I filter the instances that are annotated with the label // desiredLabel[0] and I keep the indexes of the filtered instances for (int instanceIndex = 0; instanceIndex < numInstances; instanceIndex++) { Instance instance = workingSet.instance(instanceIndex); trueLabels = getTrueLabels(instance, numLabels, labelIndices); if (trueLabels[desiredLabel[0]] == true) { filteredInstancesOfLabel.add(instance); removedIndexes[count] = instanceIndex; count++;//from w w w.j av a 2 s. c o m } } // Using the indexes of the filtered instances i remove them from the // working set. CAUTION: I count in inverse order to make the removal in // the proper way for (int k = count - 1; k >= 0; k--) { workingSet.delete(removedIndexes[k]); } returnedInstances[0] = filteredInstancesOfLabel; returnedInstances[1] = workingSet; return returnedInstances; }
From source file:org.openml.webapplication.algorithm.InstancesHelper.java
License:Open Source License
@SuppressWarnings("unchecked") public static void stratify(Instances dataset) { int numClasses = dataset.classAttribute().numValues(); int numInstances = dataset.numInstances(); double[] classRatios = classRatios(dataset); double[] currentRatios = new double[numClasses]; int[] currentCounts = new int[numClasses]; List<Instance>[] instancesSorted = new LinkedList[numClasses]; for (int i = 0; i < numClasses; ++i) { instancesSorted[i] = new LinkedList<Instance>(); }//w ww. j a va 2s .com // first, sort all instances based on class in different lists for (int i = 0; i < numInstances; ++i) { Instance current = dataset.instance(i); instancesSorted[(int) current.classValue()].add(current); } // now empty the original dataset, all instances are stored in the L.L. for (int i = 0; i < numInstances; i++) { dataset.delete(dataset.numInstances() - 1); } for (int i = 0; i < numInstances; ++i) { int idx = biggestDifference(classRatios, currentRatios); dataset.add(instancesSorted[idx].remove(0)); currentCounts[idx]++; for (int j = 0; j < currentRatios.length; ++j) { currentRatios[j] = (currentCounts[j] * 1.0) / (i + 1); } } }
From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java
License:Open Source License
public static Object jackKnifeClassifierOneWithNoLocationIndex(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, GenericObjectEditor m_ClassifierEditor, double ratio, GraphPane myGraph, ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier, String classifierName, String[] classifierOptions, boolean returnClassifier, int randomNumberForClassifier) { try {/*from w ww . j a v a 2s . c o m*/ StatusPane statusPane = applicationData.getStatusPane(); long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; Classifier tempClassifier; if (m_ClassifierEditor != null) tempClassifier = (Classifier) m_ClassifierEditor.getValue(); else tempClassifier = Classifier.forName(classifierName, classifierOptions); //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files //split the instances into positive and negative Instances posInst = new Instances(applicationData.getDataset1Instances()); posInst.setClassIndex(posInst.numAttributes() - 1); for (int x = 0; x < posInst.numInstances();) if (posInst.instance(x).stringValue(posInst.numAttributes() - 1).equalsIgnoreCase("pos")) x++; else posInst.delete(x); posInst.deleteAttributeType(Attribute.STRING); Instances negInst = new Instances(applicationData.getDataset1Instances()); negInst.setClassIndex(negInst.numAttributes() - 1); for (int x = 0; x < negInst.numInstances();) if (negInst.instance(x).stringValue(negInst.numAttributes() - 1).equalsIgnoreCase("neg")) x++; else negInst.delete(x); negInst.deleteAttributeType(Attribute.STRING); //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy long trainTimeStart = 0, trainTimeElapsed = 0; if (statusPane != null) statusPane.setText("Training Classifier One... May take a while... Please wait..."); //Record Start Time trainTimeStart = System.currentTimeMillis(); Instances fullInst = new Instances(applicationData.getDataset1Instances()); fullInst.setClassIndex(fullInst.numAttributes() - 1); Classifier classifierOne; if (m_ClassifierEditor != null) classifierOne = (Classifier) m_ClassifierEditor.getValue(); else classifierOne = Classifier.forName(classifierName, classifierOptions); if (outputClassifier) classifierOne.buildClassifier(fullInst); //Record Total Time used to build classifier one trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; //Training Done String tclassifierName; if (m_ClassifierEditor != null) tclassifierName = m_ClassifierEditor.getValue().getClass().getName(); else tclassifierName = classifierName; if (classifierResults != null) { classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", tclassifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", " Jack Knife Validation"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); } String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_" + randomNumberForClassifier + ".scores"; BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename)); //Instances foldTrainingInstance; //Instances foldTestingInstance; int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField(); int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField(); int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField(); int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField(); Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel, negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt, negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory()); FastaFormat fastaFormat; String header[] = new String[fullInst.numInstances()]; String data[] = new String[fullInst.numInstances()]; int counter = 0; while ((fastaFormat = fastaFile.nextSequence("pos")) != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } while ((fastaFormat = fastaFile.nextSequence("neg")) != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } //run jack knife validation for (int x = 0; x < fullInst.numInstances(); x++) { if (applicationData.terminateThread == true) { if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); outputCrossValidation.close(); return classifierOne; } if (statusPane != null) statusPane.setText("Running " + (x + 1) + " / " + fullInst.numInstances()); Instances trainPosInst = new Instances(posInst); Instances trainNegInst = new Instances(negInst); Instance testInst; //split data into training and testing if (x < trainPosInst.numInstances()) { testInst = posInst.instance(x); trainPosInst.delete(x); } else { testInst = negInst.instance(x - posInst.numInstances()); trainNegInst.delete(x - posInst.numInstances()); } Instances trainInstances; if (trainPosInst.numInstances() < trainNegInst.numInstances()) { trainInstances = new Instances(trainPosInst); int max = (int) (ratio * trainPosInst.numInstances()); if (ratio == -1) max = trainNegInst.numInstances(); Random rand = new Random(1); for (int y = 0; y < trainNegInst.numInstances() && y < max; y++) { int index = rand.nextInt(trainNegInst.numInstances()); trainInstances.add(trainNegInst.instance(index)); trainNegInst.delete(index); } } else { trainInstances = new Instances(trainNegInst); int max = (int) (ratio * trainNegInst.numInstances()); if (ratio == -1) max = trainPosInst.numInstances(); Random rand = new Random(1); for (int y = 0; y < trainPosInst.numInstances() && y < max; y++) { int index = rand.nextInt(trainPosInst.numInstances()); trainInstances.add(trainPosInst.instance(index)); trainPosInst.delete(index); } } Classifier foldClassifier = tempClassifier; foldClassifier.buildClassifier(trainInstances); double[] results = foldClassifier.distributionForInstance(testInst); int classIndex = testInst.classIndex(); String classValue = testInst.toString(classIndex); outputCrossValidation.write(header[x]); outputCrossValidation.newLine(); outputCrossValidation.write(data[x]); outputCrossValidation.newLine(); if (classValue.equals("pos")) outputCrossValidation.write("pos,0=" + results[0]); else if (classValue.equals("neg")) outputCrossValidation.write("neg,0=" + results[0]); else { outputCrossValidation.close(); throw new Error("Invalid Class Type!"); } outputCrossValidation.newLine(); outputCrossValidation.flush(); } outputCrossValidation.close(); PredictionStats classifierOneStatsOnJackKnife = new PredictionStats(classifierOneFilename, range, threshold); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; if (classifierResults != null) classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); //if(classifierOneDisplayTextArea != null) classifierOneStatsOnJackKnife.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); applicationData.setClassifierOneStats(classifierOneStatsOnJackKnife); if (myGraph != null) myGraph.setMyStats(classifierOneStatsOnJackKnife); if (statusPane != null) statusPane.setText("Done!"); if (returnClassifier) return classifierOne; else return classifierOneStatsOnJackKnife; } catch (Exception e) { e.printStackTrace(); JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE); return null; } }
From source file:uzholdem.classifier.OnlineMultilayerPerceptron.java
License:Open Source License
public void trainModel(Instances aInstances, int numIterations) throws Exception { // setup m_instances if (this.m_instances == null) { this.m_instances = new Instances(aInstances, 0, aInstances.size()); }/* w ww. j a va 2 s .com*/ /////////// if (m_useNomToBin) { if (this.m_nominalToBinaryFilter == null) { m_nominalToBinaryFilter = new NominalToBinary(); try { m_nominalToBinaryFilter.setInputFormat(m_instances); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); return; } } aInstances = Filter.useFilter(aInstances, m_nominalToBinaryFilter); } Instances epochInstances = new Instances(aInstances); epochInstances.randomize(new Random()); Instances valSet = new Instances(aInstances, (int) (aInstances.size() * 0.3)); for (int i = 0; i < valSet.size(); i++) { valSet.add(epochInstances.instance(0)); epochInstances.delete(0); } m_instances = epochInstances; double right = 0; double driftOff = 0; double lastRight = Double.POSITIVE_INFINITY; double bestError = Double.POSITIVE_INFINITY; double tempRate; double totalWeight = 0; double totalValWeight = 0; double origRate = m_learningRate; //only used for when reset int numInVal = valSet.numInstances(); for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; for (int noa = 1; noa < 50 + 1; noa++) { right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating (and training occurs, for the //training set resetNetwork(); calculateOutputs(); tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= noa; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { m_instances = null; throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate."); } ////////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); //note 'right' could be calculated here just using //the calculate output values. This would be faster. //be less modular } } if (right < lastRight) { if (right < bestError) { bestError = right; // save the network weights at this point for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].saveWeights(); } driftOff = 0; } } else { driftOff++; } lastRight = right; if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) { for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].restoreWeights(); } m_accepted = true; } right /= totalValWeight; } m_epoch = noa; m_error = right; //shows what the neuralnet is upto if a gui exists. if (m_accepted) { m_instances = new Instances(m_instances, 0); return; } } }