Example usage for org.apache.mahout.classifier.sgd UniformPrior UniformPrior

List of usage examples for org.apache.mahout.classifier.sgd UniformPrior UniformPrior

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd UniformPrior UniformPrior.

Prototype

UniformPrior

Source Link

Usage

From source file:com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode.java

License:Apache License

public void SetupPOLR() {

    System.err.println("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize);
    LOG.debug("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize);

    this.global_parameter_vector = new GradientBuffer(this.num_categories, this.FeatureVectorSize);

    String[] predictor_label_names = this.PredictorLabelNames.split(",");

    String[] variable_types = this.PredictorVariableTypes.split(",");

    polr_modelparams = new POLRModelParameters();
    polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine,
                                                                 // target));
    polr_modelparams.setNumFeatures(this.FeatureVectorSize);
    polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias));

    List<String> typeList = Lists.newArrayList();
    for (int x = 0; x < variable_types.length; x++) {
        typeList.add(variable_types[x]);
    }//w  w w. j a  va  2s  .  co m

    List<String> predictorList = Lists.newArrayList();
    for (int x = 0; x < predictor_label_names.length; x++) {
        predictorList.add(predictor_label_names[x]);
    }

    polr_modelparams.setTypeMap(predictorList, typeList);
    polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                             // command line
    polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                         // match command line

    // setup record factory stuff here ---------

    if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");

    } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new RCV1RecordFactory();

    } else {

        // need to rethink this

        this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName,
                polr_modelparams.getTypeMap());

        ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames);

    }

    polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

    // ----- this normally is generated from the POLRModelParams ------

    this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
            new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                    .learningRate(this.LearningRate);

    polr_modelparams.setPOLR(polr);
    // this.bSetup = true;

}

From source file:com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode.java

License:Apache License

private void SetupPOLR() {

    // do splitting strings into arrays here...
    String[] predictor_label_names = this.PredictorLabelNames.split(",");
    String[] variable_types = this.PredictorVariableTypes.split(",");

    polr_modelparams = new POLRModelParameters();
    polr_modelparams.setTargetVariable(this.TargetVariableName);
    polr_modelparams.setNumFeatures(this.FeatureVectorSize);
    polr_modelparams.setUseBias(true);// ww  w . ja  va2 s  .c om

    List<String> typeList = Lists.newArrayList();
    for (int x = 0; x < variable_types.length; x++) {
        typeList.add(variable_types[x]);
    }

    List<String> predictorList = Lists.newArrayList();
    for (int x = 0; x < predictor_label_names.length; x++) {
        predictorList.add(predictor_label_names[x]);
    }

    // where do these come from?
    polr_modelparams.setTypeMap(predictorList, typeList);
    polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                             // command line
    polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                         // match command line

    // setup record factory stuff here ---------

    if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");

    } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new RCV1RecordFactory();

    } else {

        // it defaults to the CSV record factor, but a custom one

        this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName,
                polr_modelparams.getTypeMap());

        ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames);

    }

    polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

    // ----- this normally is generated from the POLRModelParams ------

    this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
            new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                    .learningRate(this.LearningRate);

    polr_modelparams.setPOLR(polr);

    // this.bSetup = true;
}

From source file:com.cloudera.knittingboar.sgd.POLRMasterDriver.java

License:Apache License

/**
 * Take the newly loaded config junk and setup the local data structures
 * //from  w  w  w.  j a v  a  2 s .com
 */
public void Setup() {

    this.global_parameter_vector = new GradientBuffer(this.num_categories, this.FeatureVectorSize);

    String[] predictor_label_names = this.PredictorLabelNames.split(",");

    String[] variable_types = this.PredictorVariableTypes.split(",");

    polr_modelparams = new POLRModelParameters();
    polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine,
                                                                 // target));
    polr_modelparams.setNumFeatures(this.FeatureVectorSize);
    polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias));

    List<String> typeList = Lists.newArrayList();
    for (int x = 0; x < variable_types.length; x++) {
        typeList.add(variable_types[x]);
    }

    List<String> predictorList = Lists.newArrayList();
    for (int x = 0; x < predictor_label_names.length; x++) {
        predictorList.add(predictor_label_names[x]);
    }

    polr_modelparams.setTypeMap(predictorList, typeList);
    polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                             // command line
    polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                         // match command line

    // setup record factory stuff here ---------

    if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");

    } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new RCV1RecordFactory();

    } else {

        // need to rethink this

        this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName,
                polr_modelparams.getTypeMap());

        ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames);

    }

    polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

    // ----- this normally is generated from the POLRModelParams ------

    // this.polr = new ParallelOnlineLogisticRegression(this.num_categories,
    // this.FeatureVectorSize, new L1())
    this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
            new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                    .learningRate(this.LearningRate);

    polr_modelparams.setPOLR(polr);
    this.bSetup = true;

}

From source file:com.cloudera.knittingboar.sgd.POLRWorkerDriver.java

License:Apache License

/**
 * called after conf vars are loaded//from  w w w  . j a va  2  s.  co m
 */
public void Setup() {

    // do splitting strings into arrays here...
    String[] predictor_label_names = this.PredictorLabelNames.split(",");
    String[] variable_types = this.PredictorVariableTypes.split(",");

    polr_modelparams = new POLRModelParameters();
    polr_modelparams.setTargetVariable(this.TargetVariableName);
    polr_modelparams.setNumFeatures(this.FeatureVectorSize);
    polr_modelparams.setUseBias(true);

    List<String> typeList = Lists.newArrayList();
    for (int x = 0; x < variable_types.length; x++) {
        typeList.add(variable_types[x]);
    }

    List<String> predictorList = Lists.newArrayList();
    for (int x = 0; x < predictor_label_names.length; x++) {
        predictorList.add(predictor_label_names[x]);
    }

    // where do these come from?
    polr_modelparams.setTypeMap(predictorList, typeList);
    polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                             // command line
    polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                         // match command line

    // setup record factory stuff here ---------

    if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");

    } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new RCV1RecordFactory();

    } else {

        // it defaults to the CSV record factor, but a custom one

        this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName,
                polr_modelparams.getTypeMap());

        ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames);

    }

    polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

    // ----- this normally is generated from the POLRModelParams ------

    // this.polr = new ParallelOnlineLogisticRegression(this.num_categories,
    // this.FeatureVectorSize, new L1())
    this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
            new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                    .learningRate(this.LearningRate);

    polr_modelparams.setPOLR(polr);

    this.bSetup = true;
}

From source file:com.ml.ira.algos.AdaptiveLogisticModelParameters.java

License:Apache License

private static PriorFunction createPrior(String cmd, double priorOption) {
    if (cmd == null) {
        return null;
    }//from  w  w w  .j  av  a  2s. co  m
    if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new L1();
    }
    if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new L2();
    }
    if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new UniformPrior();
    }
    if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new TPrior(priorOption);
    }
    if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
        return new ElasticBandPrior(priorOption);
    }

    return null;
}

From source file:org.deidentifier.arx.aggregates.classification.MultiClassLogisticRegression.java

License:Apache License

/**
 * Creates a new instance//from  w  ww .j  a  v  a  2 s .co m
 * @param specification
 * @param config
 */
public MultiClassLogisticRegression(ClassificationDataSpecification specification,
        ARXLogisticRegressionConfiguration config) {

    // Store
    this.config = config;
    this.specification = specification;

    // Prepare classifier
    PriorFunction prior = null;
    switch (config.getPriorFunction()) {
    case ELASTIC_BAND:
        prior = new ElasticBandPrior();
        break;
    case L1:
        prior = new L1();
        break;
    case L2:
        prior = new L2();
        break;
    case UNIFORM:
        prior = new UniformPrior();
        break;
    default:
        throw new IllegalArgumentException("Unknown prior function");
    }
    this.lr = new OnlineLogisticRegression(this.specification.classMap.size(), config.getVectorLength(), prior);

    // Configure
    this.lr.learningRate(config.getLearningRate());
    this.lr.alpha(config.getAlpha());
    this.lr.lambda(config.getLambda());
    this.lr.stepOffset(config.getStepOffset());
    this.lr.decayExponent(config.getDecayExponent());

    // Prepare encoders
    this.interceptEncoder = new ConstantValueEncoder("intercept");
    this.wordEncoder = new StaticWordValueEncoder("feature");

    // Configure
    this.lr.learningRate(1);
    this.lr.alpha(1);
    this.lr.lambda(0.000001);
    this.lr.stepOffset(10000);
    this.lr.decayExponent(0.2);
}

From source file:tv.floe.metronome.classification.logisticregression.iterativereduce.POLRMasterNode.java

License:Apache License

public void SetupPOLR() {

    System.err.println("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize);
    LOG.debug("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize);

    this.global_parameter_vector = new ParameterVector(); //this.num_categories,
    //this.FeatureVectorSize);

    String[] predictor_label_names = this.PredictorLabelNames.split(",");

    String[] variable_types = this.PredictorVariableTypes.split(",");

    polr_modelparams = new POLRModelParameters();
    polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine,
                                                                 // target));
    polr_modelparams.setNumFeatures(this.FeatureVectorSize);
    polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias));

    List<String> typeList = Lists.newArrayList();
    for (int x = 0; x < variable_types.length; x++) {
        typeList.add(variable_types[x]);
    }/*from  w  w w  .  j  a  v  a  2  s  .c o  m*/

    List<String> predictorList = Lists.newArrayList();
    for (int x = 0; x < predictor_label_names.length; x++) {
        predictorList.add(predictor_label_names[x]);
    }

    polr_modelparams.setTypeMap(predictorList, typeList);
    polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                             // command line
    polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                         // match command line

    // setup record factory stuff here ---------
    /*    
        if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY
            .equals(this.RecordFactoryClassname)) {
                  
          this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");
                  
        } else */
    if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new RCV1RecordFactory();

    } else {

        // need to rethink this
        /*  
          this.VectorFactory = new CSVBasedDatasetRecordFactory(
              this.TargetVariableName, polr_modelparams.getTypeMap());
                  
          ((CSVBasedDatasetRecordFactory) this.VectorFactory)
              .firstLine(this.ColumnHeaderNames);
          */
    }

    polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

    // ----- this normally is generated from the POLRModelParams ------

    this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
            new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                    .learningRate(this.LearningRate);

    polr_modelparams.setPOLR(polr);
    // this.bSetup = true;

}

From source file:tv.floe.metronome.classification.logisticregression.iterativereduce.POLRWorkerNode.java

License:Apache License

/**
 * TODO:/*from  w  ww  .  j  a  v a 2 s. c  om*/
 * - throw a mis-configuration exception of some sort
 * 
 */
private void SetupPOLR() {

    // do splitting strings into arrays here...
    String[] predictor_label_names = this.PredictorLabelNames.split(",");
    String[] variable_types = this.PredictorVariableTypes.split(",");

    polr_modelparams = new POLRModelParameters();
    polr_modelparams.setTargetVariable(this.TargetVariableName);
    polr_modelparams.setNumFeatures(this.FeatureVectorSize);
    polr_modelparams.setUseBias(true);

    List<String> typeList = Lists.newArrayList();
    for (int x = 0; x < variable_types.length; x++) {
        typeList.add(variable_types[x]);
    }

    List<String> predictorList = Lists.newArrayList();
    for (int x = 0; x < predictor_label_names.length; x++) {
        predictorList.add(predictor_label_names[x]);
    }

    // where do these come from?
    polr_modelparams.setTypeMap(predictorList, typeList);
    polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                             // command line
    polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                         // match command line

    // setup record factory stuff here ---------

    // ####### disabled this input format, was not a long term solution ##########
    /*    
    if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY
        .equals(this.RecordFactoryClassname)) {
              
      this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");
              
    } else */

    if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

        this.VectorFactory = new RCV1RecordFactory();

    } else {

        // it defaults to the CSV record factor, but a custom one
        /*      
              this.VectorFactory = new CSVBasedDatasetRecordFactory(
                  this.TargetVariableName, polr_modelparams.getTypeMap());
                      
              ((CSVBasedDatasetRecordFactory) this.VectorFactory)
                  .firstLine(this.ColumnHeaderNames);
          */

        //   throw new Exception("Invalid Record Factory Class");

    }

    polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

    // ----- this normally is generated from the POLRModelParams ------

    this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
            new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                    .learningRate(this.LearningRate);

    polr_modelparams.setPOLR(polr);

    // this.bSetup = true;
}