List of usage examples for weka.core Instances add
@Override public boolean add(Instance instance)
From source file:LVCoref.WekaWrapper.java
License:Open Source License
public static void main1(String[] args) throws Exception { FastVector atts;//www .j a v a2 s. co m FastVector attsRel; FastVector attVals; FastVector attValsRel; Instances data; Instances dataRel; double[] vals; double[] valsRel; int i; // 1. set up attributes atts = new FastVector(); // - numeric atts.addElement(new Attribute("att1")); // - nominal attVals = new FastVector(); for (i = 0; i < 5; i++) attVals.addElement("val" + (i + 1)); atts.addElement(new Attribute("att2", attVals)); // - string atts.addElement(new Attribute("att3", (FastVector) null)); // - date atts.addElement(new Attribute("att4", "yyyy-MM-dd")); // - relational attsRel = new FastVector(); // -- numeric attsRel.addElement(new Attribute("att5.1")); // -- nominal attValsRel = new FastVector(); for (i = 0; i < 5; i++) attValsRel.addElement("val5." + (i + 1)); attsRel.addElement(new Attribute("att5.2", attValsRel)); dataRel = new Instances("att5", attsRel, 0); atts.addElement(new Attribute("att5", dataRel, 0)); // 2. create Instances object data = new Instances("MyRelation", atts, 0); // 3. fill with data // first instance vals = new double[data.numAttributes()]; // - numeric vals[0] = Math.PI; // - nominal vals[1] = attVals.indexOf("val3"); // - string vals[2] = data.attribute(2).addStringValue("This is a string!"); // - date vals[3] = data.attribute(3).parseDate("2001-11-09"); // - relational dataRel = new Instances(data.attribute(4).relation(), 0); // -- first instance valsRel = new double[2]; valsRel[0] = Math.PI + 1; valsRel[1] = attValsRel.indexOf("val5.3"); dataRel.add(new Instance(1.0, valsRel)); // -- second instance valsRel = new double[2]; valsRel[0] = Math.PI + 2; valsRel[1] = attValsRel.indexOf("val5.2"); dataRel.add(new Instance(1.0, valsRel)); vals[4] = data.attribute(4).addRelation(dataRel); // add data.add(new Instance(1.0, vals)); // second instance vals = new double[data.numAttributes()]; // important: needs NEW array! // - numeric vals[0] = Math.E; // - nominal vals[1] = attVals.indexOf("val1"); // - string vals[2] = data.attribute(2).addStringValue("And another one!"); // - date vals[3] = data.attribute(3).parseDate("2000-12-01"); // - relational dataRel = new Instances(data.attribute(4).relation(), 0); // -- first instance valsRel = new double[2]; valsRel[0] = Math.E + 1; valsRel[1] = attValsRel.indexOf("val5.4"); dataRel.add(new Instance(1.0, valsRel)); // -- second instance valsRel = new double[2]; valsRel[0] = Math.E + 2; valsRel[1] = attValsRel.indexOf("val5.1"); dataRel.add(new Instance(1.0, valsRel)); vals[4] = data.attribute(4).addRelation(dataRel); // add data.add(new Instance(1.0, vals)); // 4. output data System.out.println(data); }
From source file:machinelearningproject.RFTree.java
public Instances bootstrap(Instances instances) { Instances randomInstances = new Instances(instances, instances.numInstances()); for (int i = 0; i < instances.numInstances(); i++) { int rand = new Random().nextInt(instances.numInstances()); randomInstances.add(instances.get(rand)); }/*from w ww . j a v a 2s. co m*/ return randomInstances; }
From source file:machinelearning_cw.MachineLearning_CW.java
/** * /* ww w .jav a 2 s .com*/ * Tests the accuracy of a classifier against a collection of datasets * by Resampling. * * @param classifier The classifier to be tested * @param trainingDatasets A collection of Instances objects containing * the training data for different datasets. * @param testDatasets A collection of Instances objects containing * the test data for different datasets. * @param t The number of times the data should be sampled * @throws Exception */ public static void performClassifierAccuracyTests(Classifier classifier, ArrayList<Instances> trainingDatasets, ArrayList<Instances> testDatasets, int t) throws Exception { ArrayList<Double> accuracies = new ArrayList<Double>(); Random randomGenerator = new Random(); for (int i = 0; i < trainingDatasets.size(); i++) { Instances train = trainingDatasets.get(i); Instances test = testDatasets.get(i); /* Test by Resampling. First, merge train and test data */ for (int j = 0; j < t; j++) { Instances mergedDataSet = mergeDataSets(train, test); train.clear(); test.clear(); /* Randomly sample n instances from the merged dataset * (without replacement) to form the train set */ int n = mergedDataSet.size() / 2; for (int k = 0; k < n; k++) { int indexToRemove = randomGenerator.nextInt(mergedDataSet.size()); train.add(mergedDataSet.remove(indexToRemove)); } /* Reserve remainingdata as test data */ for (int k = 0; k < mergedDataSet.size(); k++) { test.add(mergedDataSet.remove(k)); } /* Train classifier. Recalculates k */ classifier.buildClassifier(train); /* Measure and record the accuracy of the classifier on * the test set */ double accuracy = Helpers.findClassifierAccuracy(classifier, test); accuracies.add(accuracy); } double accuracyAverage = average(accuracies); System.out.println(accuracyAverage); } }
From source file:machinelearning_cw.MachineLearning_CW.java
public static Instances mergeDataSets(Instances datasetA, Instances datasetB) { Instances mergedDataSet = new Instances(datasetA); for (Instance inst : datasetB) { mergedDataSet.add(inst); }/*from w w w .j a v a 2 s.c om*/ return mergedDataSet; }
From source file:marytts.tools.newlanguage.LTSTrainer.java
License:Open Source License
/** * Train the tree, using binary decision nodes. * /*from w ww. j a v a 2s. co m*/ * @param minLeafData * the minimum number of instances that have to occur in at least two subsets induced by split * @return bigTree * @throws IOException * IOException */ public CART trainTree(int minLeafData) throws IOException { Map<String, List<String[]>> grapheme2align = new HashMap<String, List<String[]>>(); for (String gr : this.graphemeSet) { grapheme2align.put(gr, new ArrayList<String[]>()); } Set<String> phChains = new HashSet<String>(); // for every alignment pair collect counts for (int i = 0; i < this.inSplit.size(); i++) { StringPair[] alignment = this.getAlignment(i); for (int inNr = 0; inNr < alignment.length; inNr++) { // System.err.println(alignment[inNr]); // quotation signs needed to represent empty string String outAlNr = "'" + alignment[inNr].getString2() + "'"; // TODO: don't consider alignments to more than three characters if (outAlNr.length() > 5) continue; phChains.add(outAlNr); // storing context and target String[] datapoint = new String[2 * context + 2]; for (int ct = 0; ct < 2 * context + 1; ct++) { int pos = inNr - context + ct; if (pos >= 0 && pos < alignment.length) { datapoint[ct] = alignment[pos].getString1(); } else { datapoint[ct] = "null"; } } // set target datapoint[2 * context + 1] = outAlNr; // add datapoint grapheme2align.get(alignment[inNr].getString1()).add(datapoint); } } // for conversion need feature definition file FeatureDefinition fd = this.graphemeFeatureDef(phChains); int centerGrapheme = fd.getFeatureIndex("att" + (context + 1)); List<CART> stl = new ArrayList<CART>(fd.getNumberOfValues(centerGrapheme)); for (String gr : fd.getPossibleValues(centerGrapheme)) { System.out.println(" Training decision tree for: " + gr); logger.debug(" Training decision tree for: " + gr); ArrayList<Attribute> attributeDeclarations = new ArrayList<Attribute>(); // attributes with values for (int att = 1; att <= context * 2 + 1; att++) { // ...collect possible values ArrayList<String> attVals = new ArrayList<String>(); String featureName = "att" + att; for (String usableGrapheme : fd.getPossibleValues(fd.getFeatureIndex(featureName))) { attVals.add(usableGrapheme); } attributeDeclarations.add(new Attribute(featureName, attVals)); } List<String[]> datapoints = grapheme2align.get(gr); // maybe training is faster with targets limited to grapheme Set<String> graphSpecPh = new HashSet<String>(); for (String[] dp : datapoints) { graphSpecPh.add(dp[dp.length - 1]); } // targetattribute // ...collect possible values ArrayList<String> targetVals = new ArrayList<String>(); for (String phc : graphSpecPh) {// todo: use either fd of phChains targetVals.add(phc); } attributeDeclarations.add(new Attribute(TrainedLTS.PREDICTED_STRING_FEATURENAME, targetVals)); // now, create the dataset adding the datapoints Instances data = new Instances(gr, attributeDeclarations, 0); // datapoints for (String[] point : datapoints) { Instance currInst = new DenseInstance(data.numAttributes()); currInst.setDataset(data); for (int i = 0; i < point.length; i++) { currInst.setValue(i, point[i]); } data.add(currInst); } // Make the last attribute be the class data.setClassIndex(data.numAttributes() - 1); // build the tree without using the J48 wrapper class // standard parameters are: // binary split selection with minimum x instances at the leaves, tree is pruned, confidenced value, subtree raising, // cleanup, don't collapse // Here is used a modifed version of C45PruneableClassifierTree that allow using Unary Classes (see Issue #51) C45PruneableClassifierTree decisionTree; try { decisionTree = new C45PruneableClassifierTreeWithUnary( new BinC45ModelSelection(minLeafData, data, true), true, 0.25f, true, true, false); decisionTree.buildClassifier(data); } catch (Exception e) { throw new RuntimeException("couldn't train decisiontree using weka: ", e); } CART maryTree = TreeConverter.c45toStringCART(decisionTree, fd, data); stl.add(maryTree); } DecisionNode.ByteDecisionNode rootNode = new DecisionNode.ByteDecisionNode(centerGrapheme, stl.size(), fd); for (CART st : stl) { rootNode.addDaughter(st.getRootNode()); } Properties props = new Properties(); props.setProperty("lowercase", String.valueOf(convertToLowercase)); props.setProperty("stress", String.valueOf(considerStress)); props.setProperty("context", String.valueOf(context)); CART bigTree = new CART(rootNode, fd, props); return bigTree; }
From source file:marytts.tools.voiceimport.PauseDurationTrainer.java
License:Open Source License
public boolean compute() throws Exception { // object to store all instances Instances data = null; FeatureDefinition fd = null;/*from ww w. ja v a 2s .com*/ // pause durations are added at the end // all of them are collected first // then discretized List<Integer> durs = new ArrayList<Integer>(); for (int i = 0; i < bnl.getLength(); i++) { VectorsAndDefinition features = this.readFeaturesFor(bnl.getName(i)); if (null == features) continue; List<FeatureVector> vectors = features.getFv(); fd = features.getFd(); if (data == null) data = initData(fd); // reader for label file. BufferedReader lab = new BufferedReader(new FileReader(getProp(LABFILES) + bnl.getName(i) + labExt)); List<String> labSyms = new ArrayList<String>(); List<Integer> labDurs = new ArrayList<Integer>(); int prevTime = 0; int currTime = 0; String line; while ((line = lab.readLine()) != null) { if (line.startsWith("#")) continue; String[] lineLmnts = line.split("\\s+"); if (lineLmnts.length != 3) throw new IllegalArgumentException( "Expected three columns in label file, got " + lineLmnts.length); labSyms.add(lineLmnts[2]); // collect durations currTime = (int) (1000 * Float.parseFloat(lineLmnts[0])); int dur = currTime - prevTime; labDurs.add(dur); prevTime = currTime; } int symbolFeature = fd.getFeatureIndex("phone"); int breakindexFeature = fd.getFeatureIndex("breakindex"); int currLabelNr = 0; // treatment of first pause(s)... while (labSyms.get(currLabelNr).equals("_")) currLabelNr++; for (FeatureVector fv : vectors) { String fvSym = fv.getFeatureAsString(symbolFeature, fd); // all pauses on feature vector side are ignored, they are captured within boundary treatment if (fvSym.equals("_")) continue; if (!fvSym.equals(labSyms.get(currLabelNr))) throw new IllegalArgumentException("Phone symbol of label file (" + fvSym + ") and of feature vector (" + labSyms.get(currLabelNr) + ") don't correspond. Run CorrectedTranscriptionAligner first."); int pauseDur = 0; // durations are taken from pauses on label side if ((currLabelNr + 1) < labSyms.size() && labSyms.get(currLabelNr + 1).equals("_")) { currLabelNr++; pauseDur = labDurs.get(currLabelNr); } int bi = fv.getFeatureAsInt(breakindexFeature); if (bi > 1) { // add new training point with fv durs.add(pauseDur); data.add(createInstance(data, fd, fv)); } // for each break index > 1 currLabelNr++; } // for each featurevector } // for each file // set duration target attribute data = enterDurations(data, durs); // train classifier StringPredictionTree wagonTree = trainTree(data, fd); FileWriter fw = new FileWriter(getProp(TRAINEDTREE)); fw.write(wagonTree.toString()); fw.close(); return true; }
From source file:matres.MatResUI.java
private void doClassification() { J48 m_treeResiko;//www . j a va2s .com J48 m_treeAksi; NaiveBayes m_nbResiko; NaiveBayes m_nbAksi; FastVector m_fvInstanceRisks; FastVector m_fvInstanceActions; InputStream isRiskTree = getClass().getResourceAsStream("data/ResikoTree.model"); InputStream isRiskNB = getClass().getResourceAsStream("data/ResikoNB.model"); InputStream isActionTree = getClass().getResourceAsStream("data/AksiTree.model"); InputStream isActionNB = getClass().getResourceAsStream("data/AksiNB.model"); m_treeResiko = new J48(); m_treeAksi = new J48(); m_nbResiko = new NaiveBayes(); m_nbAksi = new NaiveBayes(); try { //m_treeResiko = (J48) weka.core.SerializationHelper.read("ResikoTree.model"); m_treeResiko = (J48) weka.core.SerializationHelper.read(isRiskTree); //m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read("ResikoNB.model"); m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read(isRiskNB); //m_treeAksi = (J48) weka.core.SerializationHelper.read("AksiTree.model"); m_treeAksi = (J48) weka.core.SerializationHelper.read(isActionTree); //m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read("AksiNB.model"); m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read(isActionNB); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } System.out.println("Setting up an Instance..."); // Values for LIKELIHOOD OF OCCURRENCE FastVector fvLO = new FastVector(5); fvLO.addElement("> 10 in 1 year"); fvLO.addElement("1 - 10 in 1 year"); fvLO.addElement("1 in 1 year to 1 in 10 years"); fvLO.addElement("1 in 10 years to 1 in 100 years"); fvLO.addElement("1 in more than 100 years"); // Values for SAFETY FastVector fvSafety = new FastVector(5); fvSafety.addElement("near miss"); fvSafety.addElement("first aid injury, medical aid injury"); fvSafety.addElement("lost time injury / temporary disability"); fvSafety.addElement("permanent disability"); fvSafety.addElement("fatality"); // Values for EXTRA FUEL COST FastVector fvEFC = new FastVector(5); fvEFC.addElement("< 100 million rupiah"); fvEFC.addElement("0,1 - 1 billion rupiah"); fvEFC.addElement("1 - 10 billion rupiah"); fvEFC.addElement("10 - 100 billion rupiah"); fvEFC.addElement("> 100 billion rupiah"); // Values for SYSTEM RELIABILITY FastVector fvSR = new FastVector(5); fvSR.addElement("< 100 MWh"); fvSR.addElement("0,1 - 1 GWh"); fvSR.addElement("1 - 10 GWh"); fvSR.addElement("10 - 100 GWh"); fvSR.addElement("> 100 GWh"); // Values for EQUIPMENT COST FastVector fvEC = new FastVector(5); fvEC.addElement("< 50 million rupiah"); fvEC.addElement("50 - 500 million rupiah"); fvEC.addElement("0,5 - 5 billion rupiah"); fvEC.addElement("5 -50 billion rupiah"); fvEC.addElement("> 50 billion rupiah"); // Values for CUSTOMER SATISFACTION SOCIAL FACTOR FastVector fvCSSF = new FastVector(5); fvCSSF.addElement("Complaint from the VIP customer"); fvCSSF.addElement("Complaint from industrial customer"); fvCSSF.addElement("Complaint from community"); fvCSSF.addElement("Complaint from community that have potential riot"); fvCSSF.addElement("High potential riot"); // Values for RISK FastVector fvRisk = new FastVector(4); fvRisk.addElement("Low"); fvRisk.addElement("Moderate"); fvRisk.addElement("High"); fvRisk.addElement("Extreme"); // Values for ACTION FastVector fvAction = new FastVector(3); fvAction.addElement("Life Extension Program"); fvAction.addElement("Repair/Refurbish"); fvAction.addElement("Replace/Run to Fail + Investment"); // Defining Attributes, including Class(es) Attributes Attribute attrLO = new Attribute("LO", fvLO); Attribute attrSafety = new Attribute("Safety", fvSafety); Attribute attrEFC = new Attribute("EFC", fvEFC); Attribute attrSR = new Attribute("SR", fvSR); Attribute attrEC = new Attribute("EC", fvEC); Attribute attrCSSF = new Attribute("CSSF", fvCSSF); Attribute attrRisk = new Attribute("Risk", fvRisk); Attribute attrAction = new Attribute("Action", fvAction); m_fvInstanceRisks = new FastVector(7); m_fvInstanceRisks.addElement(attrLO); m_fvInstanceRisks.addElement(attrSafety); m_fvInstanceRisks.addElement(attrEFC); m_fvInstanceRisks.addElement(attrSR); m_fvInstanceRisks.addElement(attrEC); m_fvInstanceRisks.addElement(attrCSSF); m_fvInstanceRisks.addElement(attrRisk); m_fvInstanceActions = new FastVector(7); m_fvInstanceActions.addElement(attrLO); m_fvInstanceActions.addElement(attrSafety); m_fvInstanceActions.addElement(attrEFC); m_fvInstanceActions.addElement(attrSR); m_fvInstanceActions.addElement(attrEC); m_fvInstanceActions.addElement(attrCSSF); m_fvInstanceActions.addElement(attrAction); Instances dataRisk = new Instances("A-Risk-instance-to-classify", m_fvInstanceRisks, 0); Instances dataAction = new Instances("An-Action-instance-to-classify", m_fvInstanceActions, 0); double[] riskValues = new double[dataRisk.numAttributes()]; double[] actionValues = new double[dataRisk.numAttributes()]; String strLO = (String) m_cmbLO.getSelectedItem(); String strSafety = (String) m_cmbSafety.getSelectedItem(); String strEFC = (String) m_cmbEFC.getSelectedItem(); String strSR = (String) m_cmbSR.getSelectedItem(); String strEC = (String) m_cmbEC.getSelectedItem(); String strCSSF = (String) m_cmbCSSF.getSelectedItem(); Instance instRisk = new DenseInstance(7); Instance instAction = new DenseInstance(7); if (strLO.equals("-- none --")) { instRisk.setMissing(0); instAction.setMissing(0); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(0), strLO); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(0), strLO); } if (strSafety.equals("-- none --")) { instRisk.setMissing(1); instAction.setMissing(1); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(1), strSafety); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(1), strSafety); } if (strEFC.equals("-- none --")) { instRisk.setMissing(2); instAction.setMissing(2); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(2), strEFC); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(2), strEFC); } if (strSR.equals("-- none --")) { instRisk.setMissing(3); instAction.setMissing(3); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(3), strSR); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(3), strSR); } if (strEC.equals("-- none --")) { instRisk.setMissing(4); instAction.setMissing(4); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(4), strEC); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(4), strEC); } if (strCSSF.equals("-- none --")) { instRisk.setMissing(5); instAction.setMissing(5); } else { instAction.setValue((Attribute) m_fvInstanceActions.elementAt(5), strCSSF); instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(5), strCSSF); } instRisk.setMissing(6); instAction.setMissing(6); dataRisk.add(instRisk); instRisk.setDataset(dataRisk); dataRisk.setClassIndex(dataRisk.numAttributes() - 1); dataAction.add(instAction); instAction.setDataset(dataAction); dataAction.setClassIndex(dataAction.numAttributes() - 1); System.out.println("Instance Resiko: " + dataRisk.instance(0)); System.out.println("\tNum Attributes : " + dataRisk.numAttributes()); System.out.println("\tNum instances : " + dataRisk.numInstances()); System.out.println("Instance Action: " + dataAction.instance(0)); System.out.println("\tNum Attributes : " + dataAction.numAttributes()); System.out.println("\tNum instances : " + dataAction.numInstances()); int classIndexRisk = 0; int classIndexAction = 0; String strClassRisk = null; String strClassAction = null; try { //classIndexRisk = (int) m_treeResiko.classifyInstance(dataRisk.instance(0)); classIndexRisk = (int) m_treeResiko.classifyInstance(instRisk); classIndexAction = (int) m_treeAksi.classifyInstance(instAction); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } strClassRisk = (String) fvRisk.elementAt(classIndexRisk); strClassAction = (String) fvAction.elementAt(classIndexAction); System.out.println("[Risk Class Index: " + classIndexRisk + " Class Label: " + strClassRisk + "]"); System.out.println("[Action Class Index: " + classIndexAction + " Class Label: " + strClassAction + "]"); if (strClassRisk != null) { m_txtRisk.setText(strClassRisk); } double[] riskDist = null; double[] actionDist = null; try { riskDist = m_nbResiko.distributionForInstance(dataRisk.instance(0)); actionDist = m_nbAksi.distributionForInstance(dataAction.instance(0)); String strProb; // set up RISK progress bars m_jBarRiskLow.setValue((int) (100 * riskDist[0])); m_jBarRiskLow.setString(String.format("%6.3f%%", 100 * riskDist[0])); m_jBarRiskModerate.setValue((int) (100 * riskDist[1])); m_jBarRiskModerate.setString(String.format("%6.3f%%", 100 * riskDist[1])); m_jBarRiskHigh.setValue((int) (100 * riskDist[2])); m_jBarRiskHigh.setString(String.format("%6.3f%%", 100 * riskDist[2])); m_jBarRiskExtreme.setValue((int) (100 * riskDist[3])); m_jBarRiskExtreme.setString(String.format("%6.3f%%", 100 * riskDist[3])); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } double predictedProb = 0.0; String predictedClass = ""; // Loop over all the prediction labels in the distribution. for (int predictionDistributionIndex = 0; predictionDistributionIndex < riskDist.length; predictionDistributionIndex++) { // Get this distribution index's class label. String predictionDistributionIndexAsClassLabel = dataRisk.classAttribute() .value(predictionDistributionIndex); int classIndex = dataRisk.classAttribute().indexOfValue(predictionDistributionIndexAsClassLabel); // Get the probability. double predictionProbability = riskDist[predictionDistributionIndex]; if (predictionProbability > predictedProb) { predictedProb = predictionProbability; predictedClass = predictionDistributionIndexAsClassLabel; } System.out.printf("[%2d %10s : %6.3f]", classIndex, predictionDistributionIndexAsClassLabel, predictionProbability); } m_txtRiskNB.setText(predictedClass); }
From source file:maui.main.MauiModelBuilder.java
License:Open Source License
/** * Builds the model from the training data *///from ww w .j a va 2 s. c o m public void buildModel(HashSet<String> fileNames, VocabularyStore store) throws Exception { // Check whether there is actually any data if (fileNames.size() == 0) { throw new Exception("Couldn't find any data in " + inputDirectoryName); } System.err.println("-- Building the model... "); FastVector atts = new FastVector(3); atts.addElement(new Attribute("filename", (FastVector) null)); atts.addElement(new Attribute("document", (FastVector) null)); atts.addElement(new Attribute("keyphrases", (FastVector) null)); Instances data = new Instances("keyphrase_training_data", atts, 0); // Build model mauiFilter = new MauiFilter(); mauiFilter.setDebug(getDebug()); mauiFilter.setMaxPhraseLength(getMaxPhraseLength()); mauiFilter.setMinPhraseLength(getMinPhraseLength()); mauiFilter.setMinNumOccur(getMinNumOccur()); mauiFilter.setStemmer(getStemmer()); mauiFilter.setDocumentLanguage(getDocumentLanguage()); mauiFilter.setVocabularyName(getVocabularyName()); mauiFilter.setVocabularyFormat(getVocabularyFormat()); mauiFilter.setStopwords(getStopwords()); if (wikipedia != null) { mauiFilter.setWikipedia(wikipedia); } else if (wikipediaServer.equals("localhost") && wikipediaDatabase.equals("database")) { mauiFilter.setWikipedia(wikipedia); } else { mauiFilter.setWikipedia(wikipediaServer, wikipediaDatabase, cacheWikipediaData, wikipediaDataDirectory); } if (classifier != null) { mauiFilter.setClassifier(classifier); } mauiFilter.setInputFormat(data); // set features configurations mauiFilter.setBasicFeatures(useBasicFeatures); mauiFilter.setKeyphrasenessFeature(useKeyphrasenessFeature); mauiFilter.setFrequencyFeatures(useFrequencyFeatures); mauiFilter.setPositionsFeatures(usePositionsFeatures); mauiFilter.setLengthFeature(useLengthFeature); mauiFilter.setThesaurusFeatures(useNodeDegreeFeature); mauiFilter.setBasicWikipediaFeatures(useBasicWikipediaFeatures); mauiFilter.setAllWikipediaFeatures(useAllWikipediaFeatures); mauiFilter.setThesaurusFeatures(useNodeDegreeFeature); mauiFilter.setClassifier(classifier); mauiFilter.setContextSize(contextSize); mauiFilter.setMinKeyphraseness(minKeyphraseness); mauiFilter.setMinSenseProbability(minSenseProbability); if (!vocabularyName.equals("none") && !vocabularyName.equals("wikipedia")) { mauiFilter.loadThesaurus(getStemmer(), getStopwords(), store); } System.err.println("-- Reading the input documents... "); for (String fileName : fileNames) { double[] newInst = new double[3]; newInst[0] = (double) data.attribute(0).addStringValue(fileName); ; File documentTextFile = new File(inputDirectoryName + "/" + fileName + ".txt"); File documentTopicsFile = new File(inputDirectoryName + "/" + fileName + ".key"); try { InputStreamReader is; if (!documentEncoding.equals("default")) { is = new InputStreamReader(new FileInputStream(documentTextFile), documentEncoding); } else { is = new InputStreamReader(new FileInputStream(documentTextFile)); } // Reading the file content StringBuffer txtStr = new StringBuffer(); int c; while ((c = is.read()) != -1) { txtStr.append((char) c); } is.close(); // Adding the text of the document to the instance newInst[1] = (double) data.attribute(1).addStringValue(txtStr.toString()); } catch (Exception e) { System.err.println("Problem with reading " + documentTextFile); e.printStackTrace(); newInst[1] = Instance.missingValue(); } try { InputStreamReader is; if (!documentEncoding.equals("default")) { is = new InputStreamReader(new FileInputStream(documentTopicsFile), documentEncoding); } else { is = new InputStreamReader(new FileInputStream(documentTopicsFile)); } // Reading the content of the keyphrase file StringBuffer keyStr = new StringBuffer(); int c; while ((c = is.read()) != -1) { keyStr.append((char) c); } // Adding the topics to the file newInst[2] = (double) data.attribute(2).addStringValue(keyStr.toString()); } catch (Exception e) { System.err.println("Problem with reading " + documentTopicsFile); e.printStackTrace(); newInst[2] = Instance.missingValue(); } data.add(new Instance(1.0, newInst)); mauiFilter.input(data.instance(0)); data = data.stringFreeStructure(); } mauiFilter.batchFinished(); while ((mauiFilter.output()) != null) { } ; }
From source file:maui.main.MauiTopicExtractor.java
License:Open Source License
/** * Builds the model from the files//from ww w. ja v a2 s . c om */ public void extractKeyphrases(HashSet<String> fileNames, VocabularyStore store) throws Exception { // Check whether there is actually any data if (fileNames.size() == 0) { throw new Exception("Couldn't find any data in " + inputDirectoryName); } mauiFilter.setVocabularyName(getVocabularyName()); mauiFilter.setVocabularyFormat(getVocabularyFormat()); mauiFilter.setDocumentLanguage(getDocumentLanguage()); mauiFilter.setStemmer(getStemmer()); mauiFilter.setStopwords(getStopwords()); if (wikipedia != null) { mauiFilter.setWikipedia(wikipedia); } else if (wikipediaServer.equals("localhost") && wikipediaDatabase.equals("database")) { mauiFilter.setWikipedia(wikipedia); } else { mauiFilter.setWikipedia(wikipediaServer, wikipediaDatabase, cacheWikipediaData, wikipediaDataDirectory); } if (!vocabularyName.equals("none") && !vocabularyName.equals("wikipedia")) { mauiFilter.loadThesaurus(getStemmer(), getStopwords(), store); } FastVector atts = new FastVector(3); atts.addElement(new Attribute("filename", (FastVector) null)); atts.addElement(new Attribute("doc", (FastVector) null)); atts.addElement(new Attribute("keyphrases", (FastVector) null)); Instances data = new Instances("keyphrase_training_data", atts, 0); System.err.println("-- Extracting keyphrases... "); Vector<Double> correctStatistics = new Vector<Double>(); Vector<Double> precisionStatistics = new Vector<Double>(); Vector<Double> recallStatistics = new Vector<Double>(); for (String fileName : fileNames) { double[] newInst = new double[3]; newInst[0] = (double) data.attribute(0).addStringValue(fileName); ; File documentTextFile = new File(inputDirectoryName + "/" + fileName + ".txt"); File documentTopicsFile = new File(inputDirectoryName + "/" + fileName + ".key"); try { InputStreamReader is; if (!documentEncoding.equals("default")) { is = new InputStreamReader(new FileInputStream(documentTextFile), documentEncoding); } else { is = new InputStreamReader(new FileInputStream(documentTextFile)); } // Reading the file content StringBuffer txtStr = new StringBuffer(); int c; while ((c = is.read()) != -1) { txtStr.append((char) c); } is.close(); // Adding the text of the document to the instance newInst[1] = (double) data.attribute(1).addStringValue(txtStr.toString()); } catch (Exception e) { System.err.println("Problem with reading " + documentTextFile); e.printStackTrace(); newInst[1] = Instance.missingValue(); } try { InputStreamReader is; if (!documentEncoding.equals("default")) { is = new InputStreamReader(new FileInputStream(documentTopicsFile), documentEncoding); } else { is = new InputStreamReader(new FileInputStream(documentTopicsFile)); } // Reading the content of the keyphrase file StringBuffer keyStr = new StringBuffer(); int c; while ((c = is.read()) != -1) { keyStr.append((char) c); } // Adding the topics to the file newInst[2] = (double) data.attribute(2).addStringValue(keyStr.toString()); } catch (Exception e) { if (debugMode) { System.err.println("No existing topics for " + documentTextFile); } newInst[2] = Instance.missingValue(); } data.add(new Instance(1.0, newInst)); mauiFilter.input(data.instance(0)); data = data.stringFreeStructure(); if (debugMode) { System.err.println("-- Processing document: " + fileName); } Instance[] topRankedInstances = new Instance[topicsPerDocument]; Instance inst; // Iterating over all extracted keyphrases (inst) while ((inst = mauiFilter.output()) != null) { int index = (int) inst.value(mauiFilter.getRankIndex()) - 1; if (index < topicsPerDocument) { topRankedInstances[index] = inst; } } if (debugMode) { System.err.println("-- Keyphrases and feature values:"); } FileOutputStream out = null; PrintWriter printer = null; if (!documentTopicsFile.exists()) { out = new FileOutputStream(documentTopicsFile); if (!documentEncoding.equals("default")) { printer = new PrintWriter(new OutputStreamWriter(out, documentEncoding)); } else { printer = new PrintWriter(out); } } double numExtracted = 0, numCorrect = 0; wikipedia = mauiFilter.getWikipedia(); HashMap<Article, Integer> topics = null; if (printGraph) { topics = new HashMap<Article, Integer>(); } int p = 0; String root = ""; for (int i = 0; i < topicsPerDocument; i++) { if (topRankedInstances[i] != null) { if (!topRankedInstances[i].isMissing(topRankedInstances[i].numAttributes() - 1)) { numExtracted += 1.0; } if ((int) topRankedInstances[i].value(topRankedInstances[i].numAttributes() - 1) == 1) { numCorrect += 1.0; } if (printer != null) { String topic = topRankedInstances[i].stringValue(mauiFilter.getOutputFormIndex()); printer.print(topic); if (printGraph) { Article article = wikipedia.getArticleByTitle(topic); if (article == null) { article = wikipedia.getMostLikelyArticle(topic, new CaseFolder()); } if (article != null) { if (root == "") { root = article.getTitle(); } topics.put(article, new Integer(p)); } else { if (debugMode) { System.err.println( "Couldn't find article for " + topic + " in " + documentTopicsFile); } } p++; } if (additionalInfo) { printer.print("\t"); printer.print(topRankedInstances[i].stringValue(mauiFilter.getNormalizedFormIndex())); printer.print("\t"); printer.print(Utils.doubleToString( topRankedInstances[i].value(mauiFilter.getProbabilityIndex()), 4)); } printer.println(); } if (debugMode) { System.err.println(topRankedInstances[i]); } } } if (printGraph) { String graphFile = documentTopicsFile.getAbsolutePath().replace(".key", ".gv"); computeGraph(topics, root, graphFile); } if (numExtracted > 0) { if (debugMode) { System.err.println("-- " + numCorrect + " correct"); } double totalCorrect = mauiFilter.getTotalCorrect(); correctStatistics.addElement(new Double(numCorrect)); precisionStatistics.addElement(new Double(numCorrect / numExtracted)); recallStatistics.addElement(new Double(numCorrect / totalCorrect)); } if (printer != null) { printer.flush(); printer.close(); out.close(); } } if (correctStatistics.size() != 0) { double[] st = new double[correctStatistics.size()]; for (int i = 0; i < correctStatistics.size(); i++) { st[i] = correctStatistics.elementAt(i).doubleValue(); } double avg = Utils.mean(st); double stdDev = Math.sqrt(Utils.variance(st)); if (correctStatistics.size() == 1) { System.err.println("\n-- Evaluation results based on 1 document:"); } else { System.err.println("\n-- Evaluation results based on " + correctStatistics.size() + " documents:"); } System.err.println("Avg. number of correct keyphrases per document: " + Utils.doubleToString(avg, 2) + " +/- " + Utils.doubleToString(stdDev, 2)); st = new double[precisionStatistics.size()]; for (int i = 0; i < precisionStatistics.size(); i++) { st[i] = precisionStatistics.elementAt(i).doubleValue(); } double avgPrecision = Utils.mean(st); double stdDevPrecision = Math.sqrt(Utils.variance(st)); System.err.println("Precision: " + Utils.doubleToString(avgPrecision * 100, 2) + " +/- " + Utils.doubleToString(stdDevPrecision * 100, 2)); st = new double[recallStatistics.size()]; for (int i = 0; i < recallStatistics.size(); i++) { st[i] = recallStatistics.elementAt(i).doubleValue(); } double avgRecall = Utils.mean(st); double stdDevRecall = Math.sqrt(Utils.variance(st)); System.err.println("Recall: " + Utils.doubleToString(avgRecall * 100, 2) + " +/- " + Utils.doubleToString(stdDevRecall * 100, 2)); double fMeasure = 2 * avgRecall * avgPrecision / (avgRecall + avgPrecision); System.err.println("F-Measure: " + Utils.doubleToString(fMeasure * 100, 2)); System.err.println(""); } mauiFilter.batchFinished(); }
From source file:meansagnes.MyKMeans.java
@Override public void buildClusterer(Instances data) throws Exception { currentIteration = 0;//www . j a v a2s. co m replaceMissingFilter = new ReplaceMissingValues(); instances = new Instances(data); instances.setClassIndex(-1); replaceMissingFilter.setInputFormat(instances); instances = Filter.useFilter(instances, replaceMissingFilter); distanceFunction.setInstances(instances); clusterCentroids = new Instances(instances, numCluster); clusterAssignments = new int[instances.numInstances()]; // assign a number of instance become a centroid randomly Random randomizer = new Random(getSeed()); int[] instanceAsCentroid = new int[numCluster]; for (int i = 0; i < numCluster; i++) { instanceAsCentroid[i] = -1; } for (int i = 0; i < numCluster; i++) { int centroidCluster = randomizer.nextInt(instances.numInstances()); boolean found = false; for (int j = 0; j < i /* instanceAsCentroid.length */ && !found; j++) { if (instanceAsCentroid[j] == centroidCluster) { i--; found = true; } } if (!found) { clusterCentroids.add(instances.instance(centroidCluster)); instanceAsCentroid[i] = centroidCluster; } } double[][] distancesToCentroid = new double[numCluster][instances.numInstances()]; double[] minDistancesToCentroid = new double[instances.numInstances()]; boolean converged = false; Instances prevCentroids; while (!converged) { currentIteration++; // check distance to each centroid to decide clustering result for (int i = 0; i < numCluster; i++) { // i is cluster index for (int j = 0; j < instances.numInstances(); j++) { // j is instance index distancesToCentroid[i][j] = distanceFunction.distance(clusterCentroids.instance(i), instances.instance(j)); } } for (int j = 0; j < instances.numInstances(); j++) { // j is instance index minDistancesToCentroid[j] = distancesToCentroid[0][j]; clusterAssignments[j] = 0; } for (int j = 0; j < instances.numInstances(); j++) { // j is instance index for (int i = 1; i < numCluster; i++) { // i is cluster index if (minDistancesToCentroid[j] > distancesToCentroid[i][j]) { minDistancesToCentroid[j] = distancesToCentroid[i][j]; clusterAssignments[j] = i; } } } for (int i = 0; i < numCluster; i++) { System.out.println(clusterCentroids.instance(i)); } // update centroids prevCentroids = clusterCentroids; clusterCentroids = new Instances(instances, numCluster); clusteredInstances = new Instances[numCluster]; for (int i = 0; i < numCluster; i++) { clusteredInstances[i] = new Instances(instances, 0); } for (int i = 0; i < instances.numInstances(); i++) { clusteredInstances[clusterAssignments[i]].add(instances.instance(i)); System.out.println(instances.instance(i).toString() + " : " + clusterAssignments[i]); } if (currentIteration == maxIterations) { converged = true; } Instances newCentroids = new Instances(instances, numCluster); for (int i = 0; i < numCluster; i++) { newCentroids.add(moveCentroid(clusteredInstances[i])); } clusterCentroids = newCentroids; boolean centroidChanged = false; for (int i = 0; i < numCluster; i++) { if (distanceFunction.distance(prevCentroids.instance(i), clusterCentroids.instance(i)) > 0) { centroidChanged = true; } } if (!centroidChanged) { converged = true; } System.out.println("\n\n"); } clusterSizes = new int[numCluster]; for (int i = 0; i < numCluster; i++) { clusterSizes[i] = clusteredInstances[i].numInstances(); } distanceFunction.clean(); }