List of usage examples for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression AdaptiveLogisticRegression
public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior)
From source file:com.memonews.mahout.sentiment.SentimentModelTrainer.java
License:Apache License
public static void main(final String[] args) throws IOException { final File base = new File(args[0]); final String modelPath = args.length > 1 ? args[1] : "target/model"; final Multiset<String> overallCounts = HashMultiset.create(); final Dictionary newsGroups = new Dictionary(); final SentimentModelHelper helper = new SentimentModelHelper(); helper.getEncoder().setProbes(2);/* ww w . j a v a2s . c o m*/ final AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(2, SentimentModelHelper.FEATURES, new L1()); learningAlgorithm.setInterval(800); learningAlgorithm.setAveragingWindow(500); final List<File> files = Lists.newArrayList(); for (final File newsgroup : base.listFiles()) { if (newsgroup.isDirectory()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } } Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); final SGDInfo info = new SGDInfo(); int k = 0; for (final File file : files) { final String ng = file.getParentFile().getName(); final int actual = newsGroups.intern(ng); final Vector v = helper.encodeFeatureVector(file, overallCounts); learningAlgorithm.train(actual, v); k++; final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); SGDHelper.analyzeState(info, 0, k, best); } learningAlgorithm.close(); SGDHelper.dissect(0, newsGroups, learningAlgorithm, files, overallCounts); System.out.println("exiting main"); ModelSerializer.writeBinary(modelPath, learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); final List<Integer> counts = Lists.newArrayList(); System.out.printf("Word counts\n"); for (final String count : overallCounts.elementSet()) { counts.add(overallCounts.count(count)); } Collections.sort(counts, Ordering.natural().reverse()); k = 0; for (final Integer count : counts) { System.out.printf("%d\t%d\n", k, count); k++; if (k > 1000) { break; } } }
From source file:com.ml.ira.algos.AdaptiveLogisticModelParameters.java
License:Apache License
public AdaptiveLogisticRegression createAdaptiveLogisticRegression() { if (alr == null) { alr = new AdaptiveLogisticRegression(getMaxTargetCategories(), getNumFeatures(), createPrior(prior, priorOption)); alr.setInterval(interval);/*from ww w . ja va 2 s . co m*/ alr.setAveragingWindow(averageWindow); alr.setThreadCount(threads); alr.setAucEvaluator(createAUC(auc)); } return alr; }
From source file:com.tdunning.ch16.train.TrainNewsGroups.java
License:Apache License
public static void main(String[] args) throws IOException { File base = new File(args[0]); int leakType = 0; if (args.length > 1) { leakType = Integer.parseInt(args[1]); }// ww w .j ava 2s.co m Dictionary newsGroups = new Dictionary(); encoder.setProbes(2); AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1()); learningAlgorithm.setInterval(800); learningAlgorithm.setAveragingWindow(500); List<File> files = Lists.newArrayList(); File[] directories = base.listFiles(); Arrays.sort(directories, Ordering.usingToString()); for (File newsgroup : directories) { if (newsgroup.isDirectory()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } } Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); System.out.printf("%s\n", Arrays.asList(directories)); double averageLL = 0; double averageCorrect = 0; int k = 0; double step = 0; int[] bumps = { 1, 2, 5 }; for (File file : files) { String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Vector v = encodeFeatureVector(file); learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); double maxBeta; double nonZeros; double positive; double norm; double lambda = 0; double mu = 0; if (best != null) { CrossFoldLearner state = best.getPayload().getLearner(); averageCorrect = state.percentCorrect(); averageLL = state.logLikelihood(); OnlineLogisticRegression model = state.getModels().get(0); // finish off pending regularization model.close(); Matrix beta = model.getBeta(); maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(double v) { return Math.abs(v) > 1.0e-6 ? 1 : 0; } }); positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(double v) { return v > 0 ? 1 : 0; } }); norm = beta.aggregate(Functions.PLUS, Functions.ABS); lambda = learningAlgorithm.getBest().getMappedParams()[0]; mu = learningAlgorithm.getBest().getMappedParams()[1]; } else { maxBeta = 0; nonZeros = 0; positive = 0; norm = 0; } if (k % (bump * scale) == 0) { if (learningAlgorithm.getBest() != null) { ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); } step += 0.25; System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]); } } learningAlgorithm.close(); dissect(newsGroups, learningAlgorithm, files); System.out.println("exiting main"); ModelSerializer.writeBinary("/tmp/news-group.model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); }
From source file:gov.llnl.ontology.mains.ExtendWordNet.java
License:Open Source License
/** * Trains the hypernym and cousin predictor models based on the evidence * gathered for positive and negative relationships. Returns true if both * models could be trained.//from w w w . j av a 2 s . c o m */ public boolean trainModel(int numPasses) { // Train the hypernym predictor. AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, basis.numDimensions(), new L1()); for (int i = 0; i < numPasses; ++i) { trainHypernyms(knownPositives.map(), model, 1); trainHypernyms(knownNegatives.map(), model, 0); } // Get the best predictor for hypernyms from the trainer. If no trainer // could be found, return false. State<Wrapper, CrossFoldLearner> best = model.getBest(); if (best == null) return false; hypernymPredictor = best.getPayload().getLearner().getModels().get(0); /* // Train the cousin predictor using the similarity scores. model = new AdaptiveLogisticRegression( 2, numSimilarityScores, new L1()); for (int i = 0; i < numPasses; ++i) { trainCousins(knownPositives.map(), model); trainCousins(knownNegatives.map(), model); } // Get the best cousin predictor model from the trainer. If no trainer // could be found, return false. best = model.getBest(); if (best == null) return false; cousinPredictor = best.getPayload().getLearner().getModels().get(0); */ return true; }
From source file:gov.llnl.ontology.mains.TrainLogisticRegression.java
License:Open Source License
public static void main(String[] args) throws Exception { MRArgOptions options = new MRArgOptions(); options.addOption('w', "wordnetDir", "Specifies the wordnet directory", true, "PATH", "Required"); options.addOption('b', "basisMapping", "Specifies a serialzied basis mapping", true, "FILE", "Required"); options.addOption('n', "numPasses", "Specifies the number of training passes to make. " + "(Default: 5)", true, "INT", "Optional"); options.parseOptions(args);/*from w ww. j ava 2 s . com*/ if (options.numPositionalArgs() != 1 || !options.hasOption('w') || !options.hasOption('b')) { System.out.println("usage: java TrainLogisticRegression [OPTIONS] <out>\n" + options.prettyPrint()); System.exit(1); } EvidenceTable table = options.evidenceTable(); Scan scan = new Scan(); table.setupScan(scan, options.sourceCorpus()); StringBasisMapping basis = SerializableUtil.load(new File(options.getStringOption('b'))); basis.setReadOnly(true); int numDimensions = basis.numDimensions(); AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, numDimensions, new L1()); int numPasses = options.getIntOption('n'); for (int i = 0; i < numPasses; ++i) { Iterator<Result> resultIter = table.iterator(scan); while (resultIter.hasNext()) { Result row = resultIter.next(); HypernymStatus status = table.getHypernymStatus(row); if (status == HypernymStatus.TERMS_MISSING || status == HypernymStatus.NOVEL_HYPONYM || status == HypernymStatus.NOVEL_HYPERNYM) continue; // Extract a CompactSparse vector representing the number of // dependency paths. SparseDoubleVector vector = new CompactSparseVector(numDimensions); Counter<String> pathCounts = table.getDependencyPaths(row); for (Map.Entry<String, Integer> entry : pathCounts) { int dimension = basis.getDimension(entry.getKey()); if (dimension >= 0) vector.set(dimension, entry.getValue()); } int classLabel = (status == HypernymStatus.KNOWN_HYPERNYM) ? 1 : 0; model.train(classLabel, new MahoutSparseVector(vector, numDimensions)); } } State<Wrapper, CrossFoldLearner> best = model.getBest(); if (best == null) { System.err.println("The Learner could not be learned"); System.exit(1); } OnlineLearner classifier = best.getPayload().getLearner().getModels().get(0); SerializableUtil.save(classifier, new File(options.getPositionalArg(0))); }
From source file:opennlp.addons.mahout.AdaptiveLogisticRegressionTrainer.java
License:Apache License
@Override public MaxentModel doTrain(DataIndexer indexer) throws IOException { // TODO: Lets use the predMap here as well for encoding int numberOfOutcomes = indexer.getOutcomeLabels().length; int numberOfFeatures = indexer.getPredLabels().length; AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(numberOfOutcomes, numberOfFeatures, new L1()); // TODO: Make these parameters configurable ... // what are good values ?! pa.setInterval(800);//from w w w . java2 s . c o m pa.setAveragingWindow(500); for (int k = 0; k < iterations; k++) { trainOnlineLearner(indexer, pa); // What should be reported at the end of every iteration ?! System.out.println("Iteration " + (k + 1)); } pa.close(); return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(), createPrepMap(indexer)); }
From source file:opennlp.addons.mahout.LogisticRegressionTrainer.java
License:Apache License
@Override public MaxentModel doTrain(DataIndexer indexer) throws IOException { // TODO: Lets use the predMap here as well for encoding int outcomes[] = indexer.getOutcomeList(); int cardinality = indexer.getPredLabels().length; AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(indexer.getOutcomeLabels().length, cardinality, new L1()); pa.setInterval(800);/*from w w w .j a va2s. c o m*/ pa.setAveragingWindow(500); // PassiveAggressive pa = new PassiveAggressive(indexer.getOutcomeLabels().length, cardinality); // pa.learningRate(10000); // OnlineLogisticRegression pa = new OnlineLogisticRegression(indexer.getOutcomeLabels().length, cardinality, // new L1()); // // pa.alpha(1).stepOffset(250) // .decayExponent(0.9) // .lambda(3.0e-5) // .learningRate(3000); // TODO: Should we do both ?! AdaptiveLogisticRegression ?! for (int k = 0; k < iterations; k++) { trainOnlineLearner(indexer, pa); // What should be reported at the end of every iteration ?! System.out.println("Iteration " + (k + 1)); } pa.close(); Map<String, Integer> predMap = new HashMap<String, Integer>(); String predLabels[] = indexer.getPredLabels(); for (int i = 0; i < predLabels.length; i++) { predMap.put(predLabels[i], i); } return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(), predMap); // return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), predMap); }