List of usage examples for org.apache.mahout.classifier.sgd UniformPrior UniformPrior
UniformPrior
From source file:com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode.java
License:Apache License
public void SetupPOLR() { System.err.println("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize); LOG.debug("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize); this.global_parameter_vector = new GradientBuffer(this.num_categories, this.FeatureVectorSize); String[] predictor_label_names = this.PredictorLabelNames.split(","); String[] variable_types = this.PredictorVariableTypes.split(","); polr_modelparams = new POLRModelParameters(); polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine, // target)); polr_modelparams.setNumFeatures(this.FeatureVectorSize); polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias)); List<String> typeList = Lists.newArrayList(); for (int x = 0; x < variable_types.length; x++) { typeList.add(variable_types[x]); }//w w w. j a va 2s . co m List<String> predictorList = Lists.newArrayList(); for (int x = 0; x < predictor_label_names.length; x++) { predictorList.add(predictor_label_names[x]); } polr_modelparams.setTypeMap(predictorList, typeList); polr_modelparams.setLambda(this.Lambda); // based on defaults - match // command line polr_modelparams.setLearningRate(this.LearningRate); // based on defaults - // match command line // setup record factory stuff here --------- if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { // need to rethink this this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames); } polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories()); // ----- this normally is generated from the POLRModelParams ------ this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda) .learningRate(this.LearningRate); polr_modelparams.setPOLR(polr); // this.bSetup = true; }
From source file:com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode.java
License:Apache License
private void SetupPOLR() { // do splitting strings into arrays here... String[] predictor_label_names = this.PredictorLabelNames.split(","); String[] variable_types = this.PredictorVariableTypes.split(","); polr_modelparams = new POLRModelParameters(); polr_modelparams.setTargetVariable(this.TargetVariableName); polr_modelparams.setNumFeatures(this.FeatureVectorSize); polr_modelparams.setUseBias(true);// ww w . ja va2 s .c om List<String> typeList = Lists.newArrayList(); for (int x = 0; x < variable_types.length; x++) { typeList.add(variable_types[x]); } List<String> predictorList = Lists.newArrayList(); for (int x = 0; x < predictor_label_names.length; x++) { predictorList.add(predictor_label_names[x]); } // where do these come from? polr_modelparams.setTypeMap(predictorList, typeList); polr_modelparams.setLambda(this.Lambda); // based on defaults - match // command line polr_modelparams.setLearningRate(this.LearningRate); // based on defaults - // match command line // setup record factory stuff here --------- if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { // it defaults to the CSV record factor, but a custom one this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames); } polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories()); // ----- this normally is generated from the POLRModelParams ------ this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda) .learningRate(this.LearningRate); polr_modelparams.setPOLR(polr); // this.bSetup = true; }
From source file:com.cloudera.knittingboar.sgd.POLRMasterDriver.java
License:Apache License
/** * Take the newly loaded config junk and setup the local data structures * //from w w w. j a v a 2 s .com */ public void Setup() { this.global_parameter_vector = new GradientBuffer(this.num_categories, this.FeatureVectorSize); String[] predictor_label_names = this.PredictorLabelNames.split(","); String[] variable_types = this.PredictorVariableTypes.split(","); polr_modelparams = new POLRModelParameters(); polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine, // target)); polr_modelparams.setNumFeatures(this.FeatureVectorSize); polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias)); List<String> typeList = Lists.newArrayList(); for (int x = 0; x < variable_types.length; x++) { typeList.add(variable_types[x]); } List<String> predictorList = Lists.newArrayList(); for (int x = 0; x < predictor_label_names.length; x++) { predictorList.add(predictor_label_names[x]); } polr_modelparams.setTypeMap(predictorList, typeList); polr_modelparams.setLambda(this.Lambda); // based on defaults - match // command line polr_modelparams.setLearningRate(this.LearningRate); // based on defaults - // match command line // setup record factory stuff here --------- if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { // need to rethink this this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames); } polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories()); // ----- this normally is generated from the POLRModelParams ------ // this.polr = new ParallelOnlineLogisticRegression(this.num_categories, // this.FeatureVectorSize, new L1()) this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda) .learningRate(this.LearningRate); polr_modelparams.setPOLR(polr); this.bSetup = true; }
From source file:com.cloudera.knittingboar.sgd.POLRWorkerDriver.java
License:Apache License
/** * called after conf vars are loaded//from w w w . j a va 2 s. co m */ public void Setup() { // do splitting strings into arrays here... String[] predictor_label_names = this.PredictorLabelNames.split(","); String[] variable_types = this.PredictorVariableTypes.split(","); polr_modelparams = new POLRModelParameters(); polr_modelparams.setTargetVariable(this.TargetVariableName); polr_modelparams.setNumFeatures(this.FeatureVectorSize); polr_modelparams.setUseBias(true); List<String> typeList = Lists.newArrayList(); for (int x = 0; x < variable_types.length; x++) { typeList.add(variable_types[x]); } List<String> predictorList = Lists.newArrayList(); for (int x = 0; x < predictor_label_names.length; x++) { predictorList.add(predictor_label_names[x]); } // where do these come from? polr_modelparams.setTypeMap(predictorList, typeList); polr_modelparams.setLambda(this.Lambda); // based on defaults - match // command line polr_modelparams.setLearningRate(this.LearningRate); // based on defaults - // match command line // setup record factory stuff here --------- if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { // it defaults to the CSV record factor, but a custom one this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames); } polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories()); // ----- this normally is generated from the POLRModelParams ------ // this.polr = new ParallelOnlineLogisticRegression(this.num_categories, // this.FeatureVectorSize, new L1()) this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda) .learningRate(this.LearningRate); polr_modelparams.setPOLR(polr); this.bSetup = true; }
From source file:com.ml.ira.algos.AdaptiveLogisticModelParameters.java
License:Apache License
private static PriorFunction createPrior(String cmd, double priorOption) { if (cmd == null) { return null; }//from w w w .j av a 2s. co m if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new L1(); } if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new L2(); } if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new UniformPrior(); } if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new TPrior(priorOption); } if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new ElasticBandPrior(priorOption); } return null; }
From source file:org.deidentifier.arx.aggregates.classification.MultiClassLogisticRegression.java
License:Apache License
/** * Creates a new instance//from w ww .j a v a 2 s .co m * @param specification * @param config */ public MultiClassLogisticRegression(ClassificationDataSpecification specification, ARXLogisticRegressionConfiguration config) { // Store this.config = config; this.specification = specification; // Prepare classifier PriorFunction prior = null; switch (config.getPriorFunction()) { case ELASTIC_BAND: prior = new ElasticBandPrior(); break; case L1: prior = new L1(); break; case L2: prior = new L2(); break; case UNIFORM: prior = new UniformPrior(); break; default: throw new IllegalArgumentException("Unknown prior function"); } this.lr = new OnlineLogisticRegression(this.specification.classMap.size(), config.getVectorLength(), prior); // Configure this.lr.learningRate(config.getLearningRate()); this.lr.alpha(config.getAlpha()); this.lr.lambda(config.getLambda()); this.lr.stepOffset(config.getStepOffset()); this.lr.decayExponent(config.getDecayExponent()); // Prepare encoders this.interceptEncoder = new ConstantValueEncoder("intercept"); this.wordEncoder = new StaticWordValueEncoder("feature"); // Configure this.lr.learningRate(1); this.lr.alpha(1); this.lr.lambda(0.000001); this.lr.stepOffset(10000); this.lr.decayExponent(0.2); }
From source file:tv.floe.metronome.classification.logisticregression.iterativereduce.POLRMasterNode.java
License:Apache License
public void SetupPOLR() { System.err.println("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize); LOG.debug("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize); this.global_parameter_vector = new ParameterVector(); //this.num_categories, //this.FeatureVectorSize); String[] predictor_label_names = this.PredictorLabelNames.split(","); String[] variable_types = this.PredictorVariableTypes.split(","); polr_modelparams = new POLRModelParameters(); polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine, // target)); polr_modelparams.setNumFeatures(this.FeatureVectorSize); polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias)); List<String> typeList = Lists.newArrayList(); for (int x = 0; x < variable_types.length; x++) { typeList.add(variable_types[x]); }/*from w w w . j a v a 2 s .c o m*/ List<String> predictorList = Lists.newArrayList(); for (int x = 0; x < predictor_label_names.length; x++) { predictorList.add(predictor_label_names[x]); } polr_modelparams.setTypeMap(predictorList, typeList); polr_modelparams.setLambda(this.Lambda); // based on defaults - match // command line polr_modelparams.setLearningRate(this.LearningRate); // based on defaults - // match command line // setup record factory stuff here --------- /* if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY .equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); } else */ if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { // need to rethink this /* this.VectorFactory = new CSVBasedDatasetRecordFactory( this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory) .firstLine(this.ColumnHeaderNames); */ } polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories()); // ----- this normally is generated from the POLRModelParams ------ this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda) .learningRate(this.LearningRate); polr_modelparams.setPOLR(polr); // this.bSetup = true; }
From source file:tv.floe.metronome.classification.logisticregression.iterativereduce.POLRWorkerNode.java
License:Apache License
/** * TODO:/*from w ww . j a v a 2 s. c om*/ * - throw a mis-configuration exception of some sort * */ private void SetupPOLR() { // do splitting strings into arrays here... String[] predictor_label_names = this.PredictorLabelNames.split(","); String[] variable_types = this.PredictorVariableTypes.split(","); polr_modelparams = new POLRModelParameters(); polr_modelparams.setTargetVariable(this.TargetVariableName); polr_modelparams.setNumFeatures(this.FeatureVectorSize); polr_modelparams.setUseBias(true); List<String> typeList = Lists.newArrayList(); for (int x = 0; x < variable_types.length; x++) { typeList.add(variable_types[x]); } List<String> predictorList = Lists.newArrayList(); for (int x = 0; x < predictor_label_names.length; x++) { predictorList.add(predictor_label_names[x]); } // where do these come from? polr_modelparams.setTypeMap(predictorList, typeList); polr_modelparams.setLambda(this.Lambda); // based on defaults - match // command line polr_modelparams.setLearningRate(this.LearningRate); // based on defaults - // match command line // setup record factory stuff here --------- // ####### disabled this input format, was not a long term solution ########## /* if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY .equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); } else */ if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { // it defaults to the CSV record factor, but a custom one /* this.VectorFactory = new CSVBasedDatasetRecordFactory( this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory) .firstLine(this.ColumnHeaderNames); */ // throw new Exception("Invalid Record Factory Class"); } polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories()); // ----- this normally is generated from the POLRModelParams ------ this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda) .learningRate(this.LearningRate); polr_modelparams.setPOLR(polr); // this.bSetup = true; }