List of usage examples for org.apache.mahout.ep State getPayload
public T getPayload()
From source file:com.memonews.mahout.sentiment.SGDHelper.java
License:Apache License
static void analyzeState(final SGDInfo info, final int leakType, final int k, final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws IOException { final int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length]; final int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length)); double maxBeta; double nonZeros; double positive; double norm;/*from w ww . j ava 2s. c om*/ double lambda = 0; double mu = 0; if (best != null) { final CrossFoldLearner state = best.getPayload().getLearner(); info.setAverageCorrect(state.percentCorrect()); info.setAverageLL(state.logLikelihood()); final OnlineLogisticRegression model = state.getModels().get(0); // finish off pending regularization model.close(); final Matrix beta = model.getBeta(); maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(final double v) { return Math.abs(v) > 1.0e-6 ? 1 : 0; } }); positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(final double v) { return v > 0 ? 1 : 0; } }); norm = beta.aggregate(Functions.PLUS, Functions.ABS); lambda = best.getMappedParams()[0]; mu = best.getMappedParams()[1]; } else { maxBeta = 0; nonZeros = 0; positive = 0; norm = 0; } if (k % (bump * scale) == 0) { if (best != null) { ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", best.getPayload().getLearner().getModels().get(0)); } info.setStep(info.getStep() + 0.25); System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]); } }
From source file:com.tdunning.ch16.train.TrainNewsGroups.java
License:Apache License
public static void main(String[] args) throws IOException { File base = new File(args[0]); int leakType = 0; if (args.length > 1) { leakType = Integer.parseInt(args[1]); }/*from w ww . j a va2 s. c om*/ Dictionary newsGroups = new Dictionary(); encoder.setProbes(2); AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1()); learningAlgorithm.setInterval(800); learningAlgorithm.setAveragingWindow(500); List<File> files = Lists.newArrayList(); File[] directories = base.listFiles(); Arrays.sort(directories, Ordering.usingToString()); for (File newsgroup : directories) { if (newsgroup.isDirectory()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } } Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); System.out.printf("%s\n", Arrays.asList(directories)); double averageLL = 0; double averageCorrect = 0; int k = 0; double step = 0; int[] bumps = { 1, 2, 5 }; for (File file : files) { String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Vector v = encodeFeatureVector(file); learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); double maxBeta; double nonZeros; double positive; double norm; double lambda = 0; double mu = 0; if (best != null) { CrossFoldLearner state = best.getPayload().getLearner(); averageCorrect = state.percentCorrect(); averageLL = state.logLikelihood(); OnlineLogisticRegression model = state.getModels().get(0); // finish off pending regularization model.close(); Matrix beta = model.getBeta(); maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(double v) { return Math.abs(v) > 1.0e-6 ? 1 : 0; } }); positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(double v) { return v > 0 ? 1 : 0; } }); norm = beta.aggregate(Functions.PLUS, Functions.ABS); lambda = learningAlgorithm.getBest().getMappedParams()[0]; mu = learningAlgorithm.getBest().getMappedParams()[1]; } else { maxBeta = 0; nonZeros = 0; positive = 0; norm = 0; } if (k % (bump * scale) == 0) { if (learningAlgorithm.getBest() != null) { ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); } step += 0.25; System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]); } } learningAlgorithm.close(); dissect(newsGroups, learningAlgorithm, files); System.out.println("exiting main"); ModelSerializer.writeBinary("/tmp/news-group.model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); }
From source file:gov.llnl.ontology.mains.ExtendWordNet.java
License:Open Source License
/** * Trains the hypernym and cousin predictor models based on the evidence * gathered for positive and negative relationships. Returns true if both * models could be trained.//w w w. j ava 2 s .co m */ public boolean trainModel(int numPasses) { // Train the hypernym predictor. AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, basis.numDimensions(), new L1()); for (int i = 0; i < numPasses; ++i) { trainHypernyms(knownPositives.map(), model, 1); trainHypernyms(knownNegatives.map(), model, 0); } // Get the best predictor for hypernyms from the trainer. If no trainer // could be found, return false. State<Wrapper, CrossFoldLearner> best = model.getBest(); if (best == null) return false; hypernymPredictor = best.getPayload().getLearner().getModels().get(0); /* // Train the cousin predictor using the similarity scores. model = new AdaptiveLogisticRegression( 2, numSimilarityScores, new L1()); for (int i = 0; i < numPasses; ++i) { trainCousins(knownPositives.map(), model); trainCousins(knownNegatives.map(), model); } // Get the best cousin predictor model from the trainer. If no trainer // could be found, return false. best = model.getBest(); if (best == null) return false; cousinPredictor = best.getPayload().getLearner().getModels().get(0); */ return true; }
From source file:gov.llnl.ontology.mains.TrainLogisticRegression.java
License:Open Source License
public static void main(String[] args) throws Exception { MRArgOptions options = new MRArgOptions(); options.addOption('w', "wordnetDir", "Specifies the wordnet directory", true, "PATH", "Required"); options.addOption('b', "basisMapping", "Specifies a serialzied basis mapping", true, "FILE", "Required"); options.addOption('n', "numPasses", "Specifies the number of training passes to make. " + "(Default: 5)", true, "INT", "Optional"); options.parseOptions(args);//from w w w.j a va 2 s .c om if (options.numPositionalArgs() != 1 || !options.hasOption('w') || !options.hasOption('b')) { System.out.println("usage: java TrainLogisticRegression [OPTIONS] <out>\n" + options.prettyPrint()); System.exit(1); } EvidenceTable table = options.evidenceTable(); Scan scan = new Scan(); table.setupScan(scan, options.sourceCorpus()); StringBasisMapping basis = SerializableUtil.load(new File(options.getStringOption('b'))); basis.setReadOnly(true); int numDimensions = basis.numDimensions(); AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, numDimensions, new L1()); int numPasses = options.getIntOption('n'); for (int i = 0; i < numPasses; ++i) { Iterator<Result> resultIter = table.iterator(scan); while (resultIter.hasNext()) { Result row = resultIter.next(); HypernymStatus status = table.getHypernymStatus(row); if (status == HypernymStatus.TERMS_MISSING || status == HypernymStatus.NOVEL_HYPONYM || status == HypernymStatus.NOVEL_HYPERNYM) continue; // Extract a CompactSparse vector representing the number of // dependency paths. SparseDoubleVector vector = new CompactSparseVector(numDimensions); Counter<String> pathCounts = table.getDependencyPaths(row); for (Map.Entry<String, Integer> entry : pathCounts) { int dimension = basis.getDimension(entry.getKey()); if (dimension >= 0) vector.set(dimension, entry.getValue()); } int classLabel = (status == HypernymStatus.KNOWN_HYPERNYM) ? 1 : 0; model.train(classLabel, new MahoutSparseVector(vector, numDimensions)); } } State<Wrapper, CrossFoldLearner> best = model.getBest(); if (best == null) { System.err.println("The Learner could not be learned"); System.exit(1); } OnlineLearner classifier = best.getPayload().getLearner().getModels().get(0); SerializableUtil.save(classifier, new File(options.getPositionalArg(0))); }