List of usage examples for weka.core Instances numClasses
publicint numClasses()
From source file:tr.gov.ulakbim.jDenetX.experiments.wrappers.EvalActiveBoostingID.java
License:Open Source License
public static Instances clusterInstances(Instances data) { XMeans xmeans = new XMeans(); Remove filter = new Remove(); Instances dataClusterer = null;//from w ww.j a v a2 s .com if (data == null) { throw new NullPointerException("Data is null at clusteredInstances method"); } //Get the attributes from the data for creating the sampled_data object ArrayList<Attribute> attrList = new ArrayList<Attribute>(); Enumeration attributes = data.enumerateAttributes(); while (attributes.hasMoreElements()) { attrList.add((Attribute) attributes.nextElement()); } Instances sampled_data = new Instances(data.relationName(), attrList, 0); data.setClassIndex(data.numAttributes() - 1); sampled_data.setClassIndex(data.numAttributes() - 1); filter.setAttributeIndices("" + (data.classIndex() + 1)); data.remove(0);//In Wavelet Stream of MOA always the first element comes without class try { filter.setInputFormat(data); dataClusterer = Filter.useFilter(data, filter); String[] options = new String[4]; options[0] = "-L"; // max. iterations options[1] = Integer.toString(noOfClassesInPool - 1); if (noOfClassesInPool > 2) { options[1] = Integer.toString(noOfClassesInPool - 1); xmeans.setMinNumClusters(noOfClassesInPool - 1); } else { options[1] = Integer.toString(noOfClassesInPool); xmeans.setMinNumClusters(noOfClassesInPool); } xmeans.setMaxNumClusters(data.numClasses() + 1); System.out.println("No of classes in the pool: " + noOfClassesInPool); xmeans.setUseKDTree(true); //xmeans.setOptions(options); xmeans.buildClusterer(dataClusterer); System.out.println("Xmeans\n:" + xmeans); } catch (Exception e) { e.printStackTrace(); } //System.out.println("Assignments\n: " + assignments); ClusterEvaluation eval = new ClusterEvaluation(); eval.setClusterer(xmeans); try { eval.evaluateClusterer(data); int classesToClustersMap[] = eval.getClassesToClusters(); //check the classes to cluster map int clusterNo = 0; for (int i = 0; i < data.size(); i++) { clusterNo = xmeans.clusterInstance(dataClusterer.get(i)); //Check if the class value of instance and class value of cluster matches if ((int) data.get(i).classValue() == classesToClustersMap[clusterNo]) { sampled_data.add(data.get(i)); } } } catch (Exception e) { e.printStackTrace(); } return ((Instances) sampled_data); }
From source file:trainableSegmentation.Trainable_Segmentation.java
License:GNU General Public License
/** * Get probability distribution for classified instance concurrently * @param data classified set of instances * @param classifier current classifier// w w w .ja v a 2 s.c om * @return classification result */ private static Callable<double[][]> probFromInstances(final Instances data, final AbstractClassifier classifier, final AtomicInteger counter) { return new Callable<double[][]>() { public double[][] call() { final int numInstances = data.numInstances(); final int numOfClasses = data.numClasses(); final double[][] probabilityDistribution = new double[numOfClasses][numInstances]; for (int i = 0; i < numInstances; i++) { try { if (0 == i % 4000) counter.addAndGet(4000); double[] probs = classifier.distributionForInstance(data.instance(i)); for (int c = 0; c < numOfClasses; c++) probabilityDistribution[c][i] = probs[c]; } catch (Exception e) { IJ.showMessage("Could not apply Classifier!"); e.printStackTrace(); return null; } } return probabilityDistribution; } }; }
From source file:trainableSegmentation.WekaSegmentation.java
License:GNU General Public License
/** * Classify a slice in a concurrent way//ww w .j a va 2 s.c om * @param slice image to classify * @param dataInfo empty set of instances containing the data structure (attributes and classes) * @param classifier classifier to use * @param counter counter used to display the progress in the tool bar * @param probabilityMaps flag to calculate probabilities or binary results * @return classification result */ public Callable<ImagePlus> classifySlice(final ImagePlus slice, final Instances dataInfo, final AbstractClassifier classifier, final AtomicInteger counter, final boolean probabilityMaps) { if (Thread.currentThread().isInterrupted()) return null; return new Callable<ImagePlus>() { public ImagePlus call() { // Create feature stack for slice IJ.showStatus("Creating features..."); IJ.log("Creating features of slice " + slice.getTitle() + "..."); final FeatureStack sliceFeatures = new FeatureStack(slice); // Use the same features as the current classifier sliceFeatures.setEnabledFeatures(featureStackArray.getEnabledFeatures()); sliceFeatures.setMaximumSigma(maximumSigma); sliceFeatures.setMinimumSigma(minimumSigma); sliceFeatures.setMembranePatchSize(membranePatchSize); sliceFeatures.setMembraneSize(membraneThickness); if (false == sliceFeatures.updateFeaturesST()) { IJ.log("Classifier execution was interrupted."); return null; } filterFeatureStackByList(featureNames, sliceFeatures); final int width = slice.getWidth(); final int height = slice.getHeight(); final int numClasses = dataInfo.numClasses(); ImageStack classificationResult = new ImageStack(width, height); final int numInstances = width * height; final double[][] probArray; if (probabilityMaps) probArray = new double[numClasses][numInstances]; else probArray = new double[1][numInstances]; IJ.log("Classifying slice " + slice.getTitle() + "..."); for (int x = 0; x < width; x++) for (int y = 0; y < height; y++) { try { if (0 == (x + y * width) % 4000) { if (Thread.currentThread().isInterrupted()) return null; counter.addAndGet(4000); } final DenseInstance ins = sliceFeatures.createInstance(x, y, 0); ins.setDataset(dataInfo); if (probabilityMaps) { double[] prob = classifier.distributionForInstance(ins); for (int k = 0; k < numClasses; k++) { probArray[k][x + y * width] = prob[k]; } } else { probArray[0][x + y * width] = classifier.classifyInstance(ins); } } catch (Exception e) { IJ.showMessage("Could not apply Classifier!"); e.printStackTrace(); return null; } } if (probabilityMaps) { for (int k = 0; k < numClasses; k++) classificationResult.addSlice("class-" + (k + 1), new FloatProcessor(width, height, probArray[k])); } else classificationResult.addSlice("result", new FloatProcessor(width, height, probArray[0])); return new ImagePlus("classified-slice", classificationResult); } }; }
From source file:trainableSegmentation.WekaSegmentation.java
License:GNU General Public License
/** * Classify a list of images in a concurrent way * @param list of images to classify//ww w. j a v a 2s . com * @param dataInfo empty set of instances containing the data structure (attributes and classes) * @param classifier classifier to use * @param counter counter used to display the progress in the tool bar * @param probabilityMaps flag to calculate probabilities or binary results * @return classification result */ public Callable<ArrayList<ImagePlus>> classifyListOfImages(final ArrayList<ImagePlus> images, final Instances dataInfo, final AbstractClassifier classifier, final AtomicInteger counter, final boolean probabilityMaps) { if (Thread.currentThread().isInterrupted()) return null; return new Callable<ArrayList<ImagePlus>>() { public ArrayList<ImagePlus> call() { ArrayList<ImagePlus> result = new ArrayList<ImagePlus>(); for (ImagePlus image : images) { // Create feature stack for the image IJ.showStatus("Creating features..."); IJ.log("Creating features of slice " + image.getTitle() + ", size = " + image.getWidth() + "x" + image.getHeight() + "..."); final FeatureStack sliceFeatures = new FeatureStack(image); // Use the same features as the current classifier sliceFeatures.setEnabledFeatures(featureStackArray.getEnabledFeatures()); sliceFeatures.setMaximumSigma(maximumSigma); sliceFeatures.setMinimumSigma(minimumSigma); sliceFeatures.setMembranePatchSize(membranePatchSize); sliceFeatures.setMembraneSize(membraneThickness); if (false == sliceFeatures.updateFeaturesST()) { IJ.log("Classifier execution was interrupted."); return null; } filterFeatureStackByList(featureNames, sliceFeatures); final int width = image.getWidth(); final int height = image.getHeight(); final int numClasses = dataInfo.numClasses(); ImageStack classificationResult = new ImageStack(width, height); final int numInstances = width * height; final double[][] probArray; if (probabilityMaps) probArray = new double[numClasses][numInstances]; else probArray = new double[1][numInstances]; IJ.log("Classifying slice " + image.getTitle() + "..."); for (int x = 0; x < width; x++) for (int y = 0; y < height; y++) { try { if (0 == (x + y * width) % 4000) { if (Thread.currentThread().isInterrupted()) return null; counter.addAndGet(4000); } final DenseInstance ins = sliceFeatures.createInstance(x, y, 0); ins.setDataset(dataInfo); if (probabilityMaps) { double[] prob = classifier.distributionForInstance(ins); for (int k = 0; k < numClasses; k++) { probArray[k][x + y * width] = prob[k]; } } else { probArray[0][x + y * width] = classifier.classifyInstance(ins); } } catch (Exception e) { IJ.showMessage("Could not apply Classifier!"); e.printStackTrace(); return null; } } if (probabilityMaps) { for (int k = 0; k < numClasses; k++) classificationResult.addSlice("class-" + (k + 1), new FloatProcessor(width, height, probArray[k])); } else classificationResult.addSlice("result", new FloatProcessor(width, height, probArray[0])); result.add(new ImagePlus("classified-image-" + image.getTitle(), classificationResult)); } return result; } }; }
From source file:trainableSegmentation.WekaSegmentation.java
License:GNU General Public License
/** * Apply current classifier to set of instances * @param data set of instances/*from w ww. j a va 2 s . co m*/ * @param w image width * @param h image height * @param numThreads The number of threads to use. Set to zero for * auto-detection. * @return result image */ public ImagePlus applyClassifier(final Instances data, int w, int h, int numThreads, boolean probabilityMaps) { if (numThreads == 0) numThreads = Prefs.getThreads(); final int numClasses = data.numClasses(); final int numInstances = data.numInstances(); final int numChannels = (probabilityMaps ? numClasses : 1); final int numSlices = (numChannels * numInstances) / (w * h); IJ.showStatus("Classifying image..."); final long start = System.currentTimeMillis(); ExecutorService exe = Executors.newFixedThreadPool(numThreads); final double[][][] results = new double[numThreads][][]; final Instances[] partialData = new Instances[numThreads]; final int partialSize = numInstances / numThreads; Future<double[][]> fu[] = new Future[numThreads]; final AtomicInteger counter = new AtomicInteger(); for (int i = 0; i < numThreads; i++) { if (Thread.currentThread().isInterrupted()) { exe.shutdown(); return null; } if (i == numThreads - 1) partialData[i] = new Instances(data, i * partialSize, numInstances - i * partialSize); else partialData[i] = new Instances(data, i * partialSize, partialSize); AbstractClassifier classifierCopy = null; try { // The Weka random forest classifiers do not need to be duplicated on each thread // (that saves much memory) if (classifier instanceof FastRandomForest || classifier instanceof RandomForest) classifierCopy = classifier; else classifierCopy = (AbstractClassifier) (AbstractClassifier.makeCopy(classifier)); } catch (Exception e) { IJ.log("Error: classifier could not be copied to classify in a multi-thread way."); e.printStackTrace(); } fu[i] = exe.submit(classifyInstances(partialData[i], classifierCopy, counter, probabilityMaps)); } ScheduledExecutorService monitor = Executors.newScheduledThreadPool(1); ScheduledFuture task = monitor.scheduleWithFixedDelay(new Runnable() { public void run() { IJ.showProgress(counter.get(), numInstances); } }, 0, 1, TimeUnit.SECONDS); // Join threads for (int i = 0; i < numThreads; i++) { try { results[i] = fu[i].get(); } catch (InterruptedException e) { //e.printStackTrace(); return null; } catch (ExecutionException e) { e.printStackTrace(); return null; } finally { exe.shutdown(); task.cancel(true); monitor.shutdownNow(); IJ.showProgress(1); } } exe.shutdown(); // Create final array double[][] classificationResult; classificationResult = new double[numChannels][numInstances]; for (int i = 0; i < numThreads; i++) for (int c = 0; c < numChannels; c++) System.arraycopy(results[i][c], 0, classificationResult[c], i * partialSize, results[i][c].length); IJ.showProgress(1.0); final long end = System.currentTimeMillis(); IJ.log("Classifying whole image data took: " + (end - start) + "ms"); double[] classifiedSlice = new double[w * h]; final ImageStack classStack = new ImageStack(w, h); for (int i = 0; i < numSlices / numChannels; i++) { for (int c = 0; c < numChannels; c++) { System.arraycopy(classificationResult[c], i * (w * h), classifiedSlice, 0, w * h); ImageProcessor classifiedSliceProcessor = new FloatProcessor(w, h, classifiedSlice); classStack.addSlice(probabilityMaps ? getClassLabels()[c] : "", classifiedSliceProcessor); } } ImagePlus classImg = new ImagePlus(probabilityMaps ? "Probability maps" : "Classification result", classStack); return classImg; }
From source file:trainableSegmentation.WekaSegmentation.java
License:GNU General Public License
/** * Classify instances concurrently//from w w w.j a v a 2s.c om * * @param fsa feature stack array with the feature vectors * @param dataInfo empty set of instances containing the data structure (attributes and classes) * @param first index of the first instance to classify (considering the feature stack array as a 1D array) * @param numInstances number of instances to classify in this thread * @param classifier current classifier * @param counter auxiliary counter to be able to update the progress bar * @param probabilityMaps if true return a probability map for each class instead of a classified image * @return classification result */ private static Callable<double[][]> classifyInstances(final FeatureStackArray fsa, final Instances dataInfo, final int first, final int numInstances, final AbstractClassifier classifier, final AtomicInteger counter, final boolean probabilityMaps) { if (Thread.currentThread().isInterrupted()) return null; return new Callable<double[][]>() { public double[][] call() { final double[][] classificationResult; final int width = fsa.getWidth(); final int height = fsa.getHeight(); final int sliceSize = width * height; final int numClasses = dataInfo.numClasses(); if (probabilityMaps) classificationResult = new double[numClasses][numInstances]; else classificationResult = new double[1][numInstances]; for (int i = 0; i < numInstances; i++) { try { if (0 == i % 4000) { if (Thread.currentThread().isInterrupted()) return null; counter.addAndGet(4000); } final int absolutePos = first + i; final int slice = absolutePos / sliceSize; final int localPos = absolutePos - slice * sliceSize; final int x = localPos % width; final int y = localPos / width; DenseInstance ins = fsa.get(slice).createInstance(x, y, 0); ins.setDataset(dataInfo); if (probabilityMaps) { double[] prob = classifier.distributionForInstance(ins); for (int k = 0; k < numClasses; k++) classificationResult[k][i] = prob[k]; } else { classificationResult[0][i] = classifier.classifyInstance(ins); } } catch (Exception e) { IJ.showMessage("Could not apply Classifier!"); e.printStackTrace(); return null; } } return classificationResult; } }; }
From source file:trainableSegmentation.WekaSegmentation.java
License:GNU General Public License
/** * Classify instances concurrently/* www . j ava 2 s . com*/ * * @param data set of instances to classify * @param classifier current classifier * @param counter auxiliary counter to be able to update the progress bar * @param probabilityMaps return a probability map for each class instead of a * classified image * @return classification result */ private static Callable<double[][]> classifyInstances(final Instances data, final AbstractClassifier classifier, final AtomicInteger counter, final boolean probabilityMaps) { if (Thread.currentThread().isInterrupted()) return null; return new Callable<double[][]>() { public double[][] call() { final int numInstances = data.numInstances(); final int numClasses = data.numClasses(); final double[][] classificationResult; if (probabilityMaps) classificationResult = new double[numClasses][numInstances]; else classificationResult = new double[1][numInstances]; for (int i = 0; i < numInstances; i++) { try { if (0 == i % 4000) { if (Thread.currentThread().isInterrupted()) return null; counter.addAndGet(4000); } if (probabilityMaps) { double[] prob = classifier.distributionForInstance(data.get(i)); for (int k = 0; k < numClasses; k++) classificationResult[k][i] = prob[k]; } else { classificationResult[0][i] = classifier.classifyInstance(data.get(i)); } } catch (Exception e) { IJ.showMessage("Could not apply Classifier!"); e.printStackTrace(); return null; } } return classificationResult; } }; }
From source file:tubes2ai.AIJKFFNN.java
@Override public void buildClassifier(Instances instances) throws Exception { getCapabilities().testWithFail(instances); int nInputNeuron, nOutputNeuron; /* Inisialisasi tiap layer */ nInputNeuron = instances.numAttributes() - 1; nOutputNeuron = instances.numClasses(); inputLayer = new Vector<Neuron>(nInputNeuron); hiddenLayer = new Vector<Neuron>(nHiddenNeuron); outputLayer = new Vector<Neuron>(nOutputNeuron); Random random = new Random(getSeed()); Enumeration<Attribute> attributeEnumeration = instances.enumerateAttributes(); attributeList = Collections.list(attributeEnumeration); /* Mengisi layer dengan neuron-neuron dengan weight default */ for (int k = 0; k < nOutputNeuron; k++) { outputLayer.add(new Neuron()); }/*from w ww . j a va 2 s . co m*/ for (int k = 0; k < nInputNeuron; k++) { inputLayer.add(new Neuron()); } /* Kalau ada hidden layer */ if (nHiddenLayer > 0) { for (int j = 0; j < nHiddenNeuron; j++) { hiddenLayer.add(new Neuron()); } } /* Link */ if (nHiddenLayer > 0) { linkNeurons(inputLayer, hiddenLayer); linkNeurons(hiddenLayer, outputLayer); } else { linkNeurons(inputLayer, outputLayer); } for (Neuron neuron : inputLayer) { neuron.initialize(random); } inputLayerArray = new Neuron[nInputNeuron]; int i = 0; for (Neuron neuron : inputLayer) { inputLayerArray[i] = neuron; i++; } outputCalculationArray = new Neuron[nHiddenLayer * nHiddenNeuron + nOutputNeuron]; int j = 0; for (Neuron neuron : hiddenLayer) { outputCalculationArray[j] = neuron; j++; } for (Neuron neuron : outputLayer) { outputCalculationArray[j] = neuron; j++; } if (nHiddenLayer > 0) { for (Neuron neuron : hiddenLayer) { neuron.initialize(random); } } for (Neuron neuron : outputLayer) { neuron.initialize(random); } /* Learning */ int iterations = 0; List<Double> errors = new ArrayList<>(); do { for (Instance instance : instances) { /* Memasukkan instance ke input neuron */ loadInput(instance); /* Menghitung error dari layer output ke input */ /* Menyiapkan nilai target */ for (int ix = 0; ix < outputLayer.size(); ix++) { if (ix == (int) instance.classValue()) { outputLayer.get(ix).errorFromTarget(1); } else { outputLayer.get(ix).errorFromTarget(0); } } if (nHiddenLayer != 0) { for (Neuron nHid : hiddenLayer) { nHid.calculateError(); } } /* Update Weight */ for (int k = 0; k < outputCalculationArray.length; k++) { outputCalculationArray[k].updateWeights(learningRate); } } iterations++; if (iterations % 500 == 0) { System.out.println("FFNN iteration " + iterations); } } while (iterations < maxIterations); }
From source file:uzholdem.classifier.OnlineMultilayerPerceptron.java
License:Open Source License
/** * Call this function to build and train a neural network for the training * data provided.//from w ww. j a v a 2 s . c o m * @param i The training data. * @throws Exception if can't build classification properly. */ public void buildClassifier(Instances i) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(i); // remove instances with missing class i = new Instances(i); i.deleteWithMissingClass(); // only class? -> build ZeroR model if (i.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(i); return; } else { m_ZeroR = null; } m_epoch = 0; m_error = 0; m_instances = null; m_currentInstance = null; m_controlPanel = null; m_nodePanel = null; m_outputs = new NeuralEnd[0]; m_inputs = new NeuralEnd[0]; m_numAttributes = 0; m_numClasses = 0; m_neuralNodes = new NeuralConnection[0]; m_selected = new FastVector(4); m_graphers = new FastVector(2); m_nextId = 0; m_stopIt = true; m_stopped = true; m_accepted = false; m_instances = new Instances(i); m_random = new Random(m_randomSeed); m_instances.randomize(m_random); if (m_useNomToBin) { m_nominalToBinaryFilter = new NominalToBinary(); m_nominalToBinaryFilter.setInputFormat(m_instances); m_instances = Filter.useFilter(m_instances, m_nominalToBinaryFilter); } m_numAttributes = m_instances.numAttributes() - 1; m_numClasses = m_instances.numClasses(); setClassType(m_instances); //this sets up the validation set. Instances valSet = null; //numinval is needed later int numInVal = (int) (m_valSize / 100.0 * m_instances.numInstances()); if (m_valSize > 0) { if (numInVal == 0) { numInVal = 1; } valSet = new Instances(m_instances, 0, numInVal); } /////////// setupInputs(); setupOutputs(); if (m_autoBuild) { setupHiddenLayer(); } ///////////////////////////// //this sets up the gui for usage if (m_gui) { m_win = new JFrame(); m_win.addWindowListener(new WindowAdapter() { public void windowClosing(WindowEvent e) { boolean k = m_stopIt; m_stopIt = true; int well = JOptionPane .showConfirmDialog(m_win, "Are You Sure...\n" + "Click Yes To Accept" + " The Neural Network" + "\n Click No To Return", "Accept Neural Network", JOptionPane.YES_NO_OPTION); if (well == 0) { m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); m_accepted = true; blocker(false); } else { m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE); } m_stopIt = k; } }); m_win.getContentPane().setLayout(new BorderLayout()); m_win.setTitle("Neural Network"); m_nodePanel = new NodePanel(); // without the following two lines, the NodePanel.paintComponents(Graphics) // method will go berserk if the network doesn't fit completely: it will // get called on a constant basis, using 100% of the CPU // see the following forum thread: // http://forum.java.sun.com/thread.jspa?threadID=580929&messageID=2945011 m_nodePanel.setPreferredSize(new Dimension(640, 480)); m_nodePanel.revalidate(); JScrollPane sp = new JScrollPane(m_nodePanel, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); m_controlPanel = new ControlPanel(); m_win.getContentPane().add(sp, BorderLayout.CENTER); m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH); m_win.setSize(640, 480); m_win.setVisible(true); } //This sets up the initial state of the gui if (m_gui) { blocker(true); m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); } //For silly situations in which the network gets accepted before training //commenses if (m_numeric) { setEndsToLinear(); } if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } //connections done. double right = 0; double driftOff = 0; double lastRight = Double.POSITIVE_INFINITY; double bestError = Double.POSITIVE_INFINITY; double tempRate; double totalWeight = 0; double totalValWeight = 0; double origRate = m_learningRate; //only used for when reset //ensure that at least 1 instance is trained through. if (numInVal == m_instances.numInstances()) { numInVal--; } if (numInVal < 0) { numInVal = 0; } for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; for (int noa = 1; noa < m_numEpochs + 1; noa++) { right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating (and training occurs, for the //training set resetNetwork(); calculateOutputs(); tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= noa; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { if (!m_reset) { m_instances = null; throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate."); } else { //reset the network if possible if (m_learningRate <= Utils.SMALL) throw new IllegalStateException( "Learning rate got too small (" + m_learningRate + " <= " + Utils.SMALL + ")!"); m_learningRate /= 2; buildClassifier(i); m_learningRate = origRate; m_instances = new Instances(m_instances, 0); return; } } ////////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); //note 'right' could be calculated here just using //the calculate output values. This would be faster. //be less modular } } if (right < lastRight) { if (right < bestError) { bestError = right; // save the network weights at this point for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].saveWeights(); } driftOff = 0; } } else { driftOff++; } lastRight = right; if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) { for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].restoreWeights(); } m_accepted = true; } right /= totalValWeight; } m_epoch = noa; m_error = right; //shows what the neuralnet is upto if a gui exists. updateDisplay(); //This junction controls what state the gui is in at the end of each //epoch, Such as if it is paused, if it is resumable etc... if (m_gui) { while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && !m_accepted) { m_stopIt = true; m_stopped = true; if (m_epoch >= m_numEpochs && m_valSize == 0) { m_controlPanel.m_startStop.setEnabled(false); } else { m_controlPanel.m_startStop.setEnabled(true); } m_controlPanel.m_startStop.setText("Start"); m_controlPanel.m_startStop.setActionCommand("Start"); m_controlPanel.m_changeEpochs.setEnabled(true); m_controlPanel.m_changeLearning.setEnabled(true); m_controlPanel.m_changeMomentum.setEnabled(true); blocker(true); if (m_numeric) { setEndsToLinear(); } } m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); m_stopped = false; //if the network has been accepted stop the training loop if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } } if (m_accepted) { m_instances = new Instances(m_instances, 0); return; } } if (m_gui) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; } m_instances = new Instances(m_instances, 0); }
From source file:uzholdem.classifier.OnlineMultilayerPerceptron.java
License:Open Source License
public void trainModel(Instances aInstances, int numIterations) throws Exception { // setup m_instances if (this.m_instances == null) { this.m_instances = new Instances(aInstances, 0, aInstances.size()); }//w w w . j a v a2s . com /////////// if (m_useNomToBin) { if (this.m_nominalToBinaryFilter == null) { m_nominalToBinaryFilter = new NominalToBinary(); try { m_nominalToBinaryFilter.setInputFormat(m_instances); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); return; } } aInstances = Filter.useFilter(aInstances, m_nominalToBinaryFilter); } Instances epochInstances = new Instances(aInstances); epochInstances.randomize(new Random()); Instances valSet = new Instances(aInstances, (int) (aInstances.size() * 0.3)); for (int i = 0; i < valSet.size(); i++) { valSet.add(epochInstances.instance(0)); epochInstances.delete(0); } m_instances = epochInstances; double right = 0; double driftOff = 0; double lastRight = Double.POSITIVE_INFINITY; double bestError = Double.POSITIVE_INFINITY; double tempRate; double totalWeight = 0; double totalValWeight = 0; double origRate = m_learningRate; //only used for when reset int numInVal = valSet.numInstances(); for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; for (int noa = 1; noa < 50 + 1; noa++) { right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating (and training occurs, for the //training set resetNetwork(); calculateOutputs(); tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= noa; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { m_instances = null; throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate."); } ////////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); //note 'right' could be calculated here just using //the calculate output values. This would be faster. //be less modular } } if (right < lastRight) { if (right < bestError) { bestError = right; // save the network weights at this point for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].saveWeights(); } driftOff = 0; } } else { driftOff++; } lastRight = right; if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) { for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].restoreWeights(); } m_accepted = true; } right /= totalValWeight; } m_epoch = noa; m_error = right; //shows what the neuralnet is upto if a gui exists. if (m_accepted) { m_instances = new Instances(m_instances, 0); return; } } }