List of usage examples for weka.core Instance value
public double value(Attribute att);
From source file:main.NaiveBayes.java
License:Open Source License
/** * Calculates the class membership probabilities for the given test instance. * //from w ww.j a va2s. c o m * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if there is a problem generating the prediction */ @Override public double[] distributionForInstance(Instance instance) throws Exception { if (m_UseDiscretization) { m_Disc.input(instance); instance = m_Disc.output(); } double[] probs = new double[m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { probs[j] = m_ClassDistribution.getProbability(j); } Enumeration<Attribute> enumAtts = instance.enumerateAttributes(); int attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { double temp, max = 0; for (int j = 0; j < m_NumClasses; j++) { temp = Math.max(1e-75, Math.pow(m_Distributions[attIndex][j].getProbability(instance.value(attribute)), m_Instances.attribute(attIndex).weight())); probs[j] *= temp; if (probs[j] > max) { max = probs[j]; } if (Double.isNaN(probs[j])) { throw new Exception("NaN returned from estimator for attribute " + attribute.name() + ":\n" + m_Distributions[attIndex][j].toString()); } } if ((max > 0) && (max < 1e-75)) { // Danger of probability underflow for (int j = 0; j < m_NumClasses; j++) { probs[j] *= 1e75; } } } attIndex++; } // Display probabilities Utils.normalize(probs); return probs; }
From source file:maui.main.MauiTopicExtractor.java
License:Open Source License
/** * Builds the model from the files//from w w w . ja v a 2 s .c o m */ public void extractKeyphrases(HashSet<String> fileNames, VocabularyStore store) throws Exception { // Check whether there is actually any data if (fileNames.size() == 0) { throw new Exception("Couldn't find any data in " + inputDirectoryName); } mauiFilter.setVocabularyName(getVocabularyName()); mauiFilter.setVocabularyFormat(getVocabularyFormat()); mauiFilter.setDocumentLanguage(getDocumentLanguage()); mauiFilter.setStemmer(getStemmer()); mauiFilter.setStopwords(getStopwords()); if (wikipedia != null) { mauiFilter.setWikipedia(wikipedia); } else if (wikipediaServer.equals("localhost") && wikipediaDatabase.equals("database")) { mauiFilter.setWikipedia(wikipedia); } else { mauiFilter.setWikipedia(wikipediaServer, wikipediaDatabase, cacheWikipediaData, wikipediaDataDirectory); } if (!vocabularyName.equals("none") && !vocabularyName.equals("wikipedia")) { mauiFilter.loadThesaurus(getStemmer(), getStopwords(), store); } FastVector atts = new FastVector(3); atts.addElement(new Attribute("filename", (FastVector) null)); atts.addElement(new Attribute("doc", (FastVector) null)); atts.addElement(new Attribute("keyphrases", (FastVector) null)); Instances data = new Instances("keyphrase_training_data", atts, 0); System.err.println("-- Extracting keyphrases... "); Vector<Double> correctStatistics = new Vector<Double>(); Vector<Double> precisionStatistics = new Vector<Double>(); Vector<Double> recallStatistics = new Vector<Double>(); for (String fileName : fileNames) { double[] newInst = new double[3]; newInst[0] = (double) data.attribute(0).addStringValue(fileName); ; File documentTextFile = new File(inputDirectoryName + "/" + fileName + ".txt"); File documentTopicsFile = new File(inputDirectoryName + "/" + fileName + ".key"); try { InputStreamReader is; if (!documentEncoding.equals("default")) { is = new InputStreamReader(new FileInputStream(documentTextFile), documentEncoding); } else { is = new InputStreamReader(new FileInputStream(documentTextFile)); } // Reading the file content StringBuffer txtStr = new StringBuffer(); int c; while ((c = is.read()) != -1) { txtStr.append((char) c); } is.close(); // Adding the text of the document to the instance newInst[1] = (double) data.attribute(1).addStringValue(txtStr.toString()); } catch (Exception e) { System.err.println("Problem with reading " + documentTextFile); e.printStackTrace(); newInst[1] = Instance.missingValue(); } try { InputStreamReader is; if (!documentEncoding.equals("default")) { is = new InputStreamReader(new FileInputStream(documentTopicsFile), documentEncoding); } else { is = new InputStreamReader(new FileInputStream(documentTopicsFile)); } // Reading the content of the keyphrase file StringBuffer keyStr = new StringBuffer(); int c; while ((c = is.read()) != -1) { keyStr.append((char) c); } // Adding the topics to the file newInst[2] = (double) data.attribute(2).addStringValue(keyStr.toString()); } catch (Exception e) { if (debugMode) { System.err.println("No existing topics for " + documentTextFile); } newInst[2] = Instance.missingValue(); } data.add(new Instance(1.0, newInst)); mauiFilter.input(data.instance(0)); data = data.stringFreeStructure(); if (debugMode) { System.err.println("-- Processing document: " + fileName); } Instance[] topRankedInstances = new Instance[topicsPerDocument]; Instance inst; // Iterating over all extracted keyphrases (inst) while ((inst = mauiFilter.output()) != null) { int index = (int) inst.value(mauiFilter.getRankIndex()) - 1; if (index < topicsPerDocument) { topRankedInstances[index] = inst; } } if (debugMode) { System.err.println("-- Keyphrases and feature values:"); } FileOutputStream out = null; PrintWriter printer = null; if (!documentTopicsFile.exists()) { out = new FileOutputStream(documentTopicsFile); if (!documentEncoding.equals("default")) { printer = new PrintWriter(new OutputStreamWriter(out, documentEncoding)); } else { printer = new PrintWriter(out); } } double numExtracted = 0, numCorrect = 0; wikipedia = mauiFilter.getWikipedia(); HashMap<Article, Integer> topics = null; if (printGraph) { topics = new HashMap<Article, Integer>(); } int p = 0; String root = ""; for (int i = 0; i < topicsPerDocument; i++) { if (topRankedInstances[i] != null) { if (!topRankedInstances[i].isMissing(topRankedInstances[i].numAttributes() - 1)) { numExtracted += 1.0; } if ((int) topRankedInstances[i].value(topRankedInstances[i].numAttributes() - 1) == 1) { numCorrect += 1.0; } if (printer != null) { String topic = topRankedInstances[i].stringValue(mauiFilter.getOutputFormIndex()); printer.print(topic); if (printGraph) { Article article = wikipedia.getArticleByTitle(topic); if (article == null) { article = wikipedia.getMostLikelyArticle(topic, new CaseFolder()); } if (article != null) { if (root == "") { root = article.getTitle(); } topics.put(article, new Integer(p)); } else { if (debugMode) { System.err.println( "Couldn't find article for " + topic + " in " + documentTopicsFile); } } p++; } if (additionalInfo) { printer.print("\t"); printer.print(topRankedInstances[i].stringValue(mauiFilter.getNormalizedFormIndex())); printer.print("\t"); printer.print(Utils.doubleToString( topRankedInstances[i].value(mauiFilter.getProbabilityIndex()), 4)); } printer.println(); } if (debugMode) { System.err.println(topRankedInstances[i]); } } } if (printGraph) { String graphFile = documentTopicsFile.getAbsolutePath().replace(".key", ".gv"); computeGraph(topics, root, graphFile); } if (numExtracted > 0) { if (debugMode) { System.err.println("-- " + numCorrect + " correct"); } double totalCorrect = mauiFilter.getTotalCorrect(); correctStatistics.addElement(new Double(numCorrect)); precisionStatistics.addElement(new Double(numCorrect / numExtracted)); recallStatistics.addElement(new Double(numCorrect / totalCorrect)); } if (printer != null) { printer.flush(); printer.close(); out.close(); } } if (correctStatistics.size() != 0) { double[] st = new double[correctStatistics.size()]; for (int i = 0; i < correctStatistics.size(); i++) { st[i] = correctStatistics.elementAt(i).doubleValue(); } double avg = Utils.mean(st); double stdDev = Math.sqrt(Utils.variance(st)); if (correctStatistics.size() == 1) { System.err.println("\n-- Evaluation results based on 1 document:"); } else { System.err.println("\n-- Evaluation results based on " + correctStatistics.size() + " documents:"); } System.err.println("Avg. number of correct keyphrases per document: " + Utils.doubleToString(avg, 2) + " +/- " + Utils.doubleToString(stdDev, 2)); st = new double[precisionStatistics.size()]; for (int i = 0; i < precisionStatistics.size(); i++) { st[i] = precisionStatistics.elementAt(i).doubleValue(); } double avgPrecision = Utils.mean(st); double stdDevPrecision = Math.sqrt(Utils.variance(st)); System.err.println("Precision: " + Utils.doubleToString(avgPrecision * 100, 2) + " +/- " + Utils.doubleToString(stdDevPrecision * 100, 2)); st = new double[recallStatistics.size()]; for (int i = 0; i < recallStatistics.size(); i++) { st[i] = recallStatistics.elementAt(i).doubleValue(); } double avgRecall = Utils.mean(st); double stdDevRecall = Math.sqrt(Utils.variance(st)); System.err.println("Recall: " + Utils.doubleToString(avgRecall * 100, 2) + " +/- " + Utils.doubleToString(stdDevRecall * 100, 2)); double fMeasure = 2 * avgRecall * avgPrecision / (avgRecall + avgPrecision); System.err.println("F-Measure: " + Utils.doubleToString(fMeasure * 100, 2)); System.err.println(""); } mauiFilter.batchFinished(); }
From source file:meka.classifiers.multilabel.CDN.java
License:Open Source License
@Override public double[] distributionForInstance(Instance x) throws Exception { int L = x.classIndex(); //ArrayList<double[]> collection = new ArrayList<double[]>(100); double y[] = new double[L]; // for collectiing marginal int sequence[] = A.make_sequence(L); double likelihood[] = new double[L]; for (int i = 0; i < I; i++) { Collections.shuffle(Arrays.asList(sequence)); for (int j : sequence) { // x = [x,y[1],...,y[j-1],y[j+1],...,y[L]] x.setDataset(D_templates[j]); // q = h_j(x) i.e. p(y_j | x) double dist[] = h[j].distributionForInstance(x); int k = A.samplePMF(dist, m_R); x.setValue(j, k);// w w w . ja v a 2 s .c o m likelihood[j] = dist[k]; // likelihood double s = Utils.sum(likelihood); // collect // and where is is good if (i > (I - I_c)) { y[j] += x.value(j); } // else still burning in } } // finish, calculate marginals for (int j = 0; j < L; j++) { y[j] /= I_c; } return y; }
From source file:meka.classifiers.multilabel.incremental.RTUpdateable.java
License:Open Source License
@Override public void updateClassifier(Instance x) throws Exception { int L = x.classIndex(); for (int j = 0; j < L; j++) { if (x.value(j) > 0.0) { Instance x_j = convertInstance(x); x_j.setClassValue(j);//from w ww . j a va 2 s .com ((UpdateableClassifier) m_Classifier).updateClassifier(x_j); } } }
From source file:meka.classifiers.multitarget.CCp.java
License:Open Source License
@Override public double[] distributionForInstance(Instance x) throws Exception { int L = x.classIndex(); confidences = new double[L]; root.classify(x);//from w w w .j a va2 s . c om double y[] = new double[L * 2]; for (int j = 0; j < L; j++) { y[j] = x.value(j); y[j + L] = confidences[j]; // <--- this is the extra line } return y; }
From source file:meka.core.Metrics.java
License:Open Source License
/** Get Data for Plotting PR and ROC curves. */ public static Instances curveDataMacroAveraged(int Y[][], double P[][]) { // Note: 'Threshold' contains the probability threshold that gives rise to the previous performance values. Instances curveData[] = curveData(Y, P); int L = curveData.length; int noNullIndex = -1; for (int i = 0; i < curveData.length; i++) { if (curveData[i] == null) { L--;//from w ww . j a v a 2 s . c om } else { if (noNullIndex == -1) { // checking for the first curveData that is not null (=does not consist of // only missing values or 0s) noNullIndex = i; } } } Instances avgCurve = new Instances(curveData[noNullIndex], 0); int D = avgCurve.numAttributes(); for (double t = 0.0; t < 1.; t += 0.01) { Instance x = (Instance) curveData[noNullIndex].instance(0).copy(); //System.out.println("x1\n"+x); boolean firstloop = true; for (int j = 0; j < L; j++) { // if there are only missing values in a column, curveData[j] is null if (curveData[j] == null) { continue; } int i = ThresholdCurve.getThresholdInstance(curveData[j], t); if (firstloop) { // reset for (int a = 0; a < D; a++) { x.setValue(a, curveData[j].instance(i).value(a) * 1. / L); } firstloop = false; } else { // add for (int a = 0; a < D; a++) { double v = x.value(a); x.setValue(a, v + curveData[j].instance(i).value(a) * 1. / L); } } } //System.out.println("x2\n"+x); avgCurve.add(x); } /* System.out.println(avgCurve); System.exit(1); // Average everything for (int i = 0; i < avgCurve.numInstances(); i++) { for(int j = 0; j < L; j++) { for (int a = 0; a < D; a++) { double o = avgCurve.instance(i).value(a); avgCurve.instance(i).setValue(a, o / L); } } } */ return avgCurve; }
From source file:meka.core.MLUtils.java
License:Open Source License
/** * Instance with L labels to double[] of length L. * Rounds to the nearest whole number./*from w w w . j av a2 s . co m*/ */ public static final double[] toDoubleArray(Instance x, int L) { double a[] = new double[L]; for (int i = 0; i < L; i++) { a[i] = Math.round(x.value(i)); } return a; }
From source file:meka.core.MLUtils.java
License:Open Source License
/** * ToBitString - returns a String representation of x = [0,0,1,0,1,0,0,0], e.g., "000101000". * NOTE: It may be better to use a sparse representation for some applications. *//*from w w w .jav a 2s .c o m*/ public static final String toBitString(Instance x, int L) { StringBuilder sb = new StringBuilder(L); for (int i = 0; i < L; i++) { sb.append((int) Math.round(x.value(i))); } return sb.toString(); }
From source file:meka.core.MLUtils.java
License:Open Source License
/** * To Sub Indices Set - return the indices out of 'sub_indices', in x, whose values are greater than 1. *///from ww w. ja v a 2 s . c o m public static final List toSubIndicesSet(Instance x, int sub_indices[]) { List<Integer> y_list = new ArrayList<Integer>(); for (int j : sub_indices) { if (x.value(j) > 0.) { y_list.add(j); } } return y_list; }
From source file:meka.core.MLUtils.java
License:Open Source License
/** * To Indices Set - return the indices in x, whose values are greater than 1. *///from w w w . jav a 2s . c om public static final List<Integer> toIndicesSet(Instance x, int L) { List<Integer> y_list = new ArrayList<Integer>(); for (int j = 0; j < L; j++) { if (x.value(j) > 0.) { y_list.add(j); } } return y_list; }