Example usage for org.apache.commons.math3.distribution PoissonDistribution PoissonDistribution

List of usage examples for org.apache.commons.math3.distribution PoissonDistribution PoissonDistribution

Introduction

In this page you can find the example usage for org.apache.commons.math3.distribution PoissonDistribution PoissonDistribution.

Prototype

public PoissonDistribution(double p) throws NotStrictlyPositiveException 

Source Link

Document

Creates a new Poisson distribution with specified mean.

Usage

From source file:net.openhft.smoothie.MathDecisions.java

private static double footprint(int averageEntries, int refSize, int upFrontScale, int cap) {
    PoissonDistribution p = new PoissonDistribution(averageEntries);
    double stayCapProb = p.cumulativeProbability(cap);
    int objectHeaderSize = 8 + refSize;
    int segmentSize = objectSizeRoundUp((objectHeaderSize + (Segment.HASH_TABLE_SIZE * 2) + 8 + /* bit set */
            4 /* tier */ + (cap * 2 * refSize)));
    double totalSegmentsSize = (stayCapProb + (1 - stayCapProb) * 2) * segmentSize;
    int segmentsArraySize = refSize << upFrontScale;

    return (totalSegmentsSize + segmentsArraySize) / averageEntries;
}

From source file:mastodon.algorithms.SALinearAlgorithm.java

protected void choosePruningCount() {

    int position = 0;
    for (int i = 0; i < stepIterations.length; i++) {
        position += stepIterations[i];//from   w  ww. j a v a2s. co  m
        if (iterationCounter < position) {
            if (i + 1 != currPrunedSpeciesCount) {
                currPrunedSpeciesCount = i + 1;
                double mean = 1.0; //needed when pruning 1 taxon (can't have a mean of 0 in PoissonDistribution())
                if (currPrunedSpeciesCount > 1) {
                    mean = 0.5 * (currPrunedSpeciesCount - 1);
                }
                pd = new PoissonDistribution(mean);

                coolingRate = Math.pow(finalTemp / initTemp, 1.0 / stepIterations[i]);
                currTemp = initTemp;
            }
            break;
        }
    }

    currTemp *= coolingRate;
}

From source file:mastodon.algorithms.SABisectionAlgorithm.java

protected void choosePruningCount() {
    if (iterationCounter % stepIterations == 0) {
        System.out.println(currPrunedSpeciesCount);
        System.out.println(maxScore[0] + " " + maxScore[1]);

        if (iterationCounter > 0) {
            if (maxScore[0] < minMapScore) {
                kLeft = currPrunedSpeciesCount;
            } else {
                kRight = currPrunedSpeciesCount;
            }/* w ww. ja  v a  2s  . c o m*/
            currPrunedSpeciesCount = (int) ((kRight + kLeft) / 2);
        }

        maxScore = new double[2];
        maxScorePruning = new HashMap<BitSet, double[]>();
        currPruning = new BitSet();

        for (int i = 0; i < currPrunedSpeciesCount; i++) {
            int choice = 0;
            do {
                choice = (int) (Random.nextDouble() * bts.getTaxaCount());
            } while (currPruning.get(choice));
            currPruning.set(choice);
        }
        prevPruning = (BitSet) currPruning.clone();
        prevScore = bts.pruneFast(currPruning);
        bts.unPrune();

        maxScorePruning.put(prevPruning, prevScore);

        double mean = 1.0; //needed when pruning 1 taxon (can't have a mean of 0 in PoissonDistribution())
        if (currPrunedSpeciesCount > 1) {
            mean = 0.5 * (currPrunedSpeciesCount - 1);
        }
        pd = new PoissonDistribution(mean);
        currTemp = initTemp;
    }

    currTemp *= coolingRate;
}

From source file:com.github.rinde.rinsim.scenario.generator.PoissonProcessTest.java

/**
 * Checks whether the observations conform to a Poisson process with the
 * specified intensity. Uses a chi square test with the specified confidence.
 * The null hypothesis is that the observations are the result of a poisson
 * process./*from  w w  w  . j  a  va 2s. co m*/
 * @param observations
 * @param intensity
 * @param confidence
 * @return <code>true</code> if the observations
 */
static boolean isPoissonProcess(Frequency observations, double intensity, double length, double confidence) {
    final PoissonDistribution pd = new PoissonDistribution(length * intensity);

    final Iterator<?> it = observations.valuesIterator();
    final long[] observed = new long[observations.getUniqueCount()];
    final double[] expected = new double[observations.getUniqueCount()];

    int index = 0;
    while (it.hasNext()) {
        final Long l = (Long) it.next();
        observed[index] = observations.getCount(l);
        expected[index] = pd.probability(l.intValue()) * observations.getSumFreq();
        if (expected[index] == 0) {
            return false;
        }
        index++;
    }
    final double chi = TestUtils.chiSquareTest(expected, observed);
    return !(chi < confidence);
}

From source file:hivemall.anomaly.ChangeFinder2DTest.java

@Test
public void testPoissonDist() throws HiveException {
    final int examples = 10000;
    final int dims = 3;
    final PoissonDistribution[] poisson = new PoissonDistribution[] { new PoissonDistribution(10.d),
            new PoissonDistribution(5.d), new PoissonDistribution(20.d) };
    final Random rand = new Random(42);
    final Double[] x = new Double[dims];
    final List<Double> xList = Arrays.asList(x);

    Parameters params = new Parameters();
    params.set(LossFunction.logloss);//  w w  w  . j  a  v a 2  s . c om
    params.r1 = 0.01d;
    params.k = 6;
    params.T1 = 10;
    params.T2 = 5;
    PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
    ListObjectInspector listOI = ObjectInspectorFactory.getStandardListObjectInspector(oi);
    final ChangeFinder2D cf = new ChangeFinder2D(params, listOI);
    final double[] outScores = new double[2];

    println("# time x0 x1 x2 outlier change");
    for (int i = 0; i < examples; i++) {
        double r = rand.nextDouble();
        x[0] = r * poisson[0].sample();
        x[1] = r * poisson[1].sample();
        x[2] = r * poisson[2].sample();

        cf.update(xList, outScores);
        printf("%d %f %f %f %f %f%n", i, x[0], x[1], x[2], outScores[0], outScores[1]);
    }
}

From source file:ml.shifu.shifu.core.dtrain.lr.LogisticRegressionWorker.java

@Override
public void init(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
    loadConfigFiles(context.getProps());
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(),
            this.columnConfigList);
    this.inputNum = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    this.outputNum = inputOutputIndex[1];
    this.candidateNum = inputOutputIndex[2];
    this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null
            && !"".equals(modelConfig.getValidationDataSetRawPath()));
    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;//from  w  w w  . jav  a  2  s .  co m
    }

    if (this.inputNum == 0) {
        throw new IllegalStateException(
                "No any variables are selected, please try variable select step firstly.");
    }
    this.rng = new PoissonDistribution(1.0d);
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (Double.compare(upSampleWeight, 1d) != 0) {
        // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }
    double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
    LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
    double crossValidationRate = this.modelConfig.getValidSetRate();
    String tmpFolder = context.getProps().getProperty("guagua.data.tmpfolder", "tmp");

    if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
        // fixed 0.6 and 0.4 of max memory for trainingData and validationData
        this.trainingData = new BytableMemoryDiskList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6),
                tmpFolder + File.separator + "train-" + System.currentTimeMillis(), Data.class.getName());
        this.validationData = new BytableMemoryDiskList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4),
                tmpFolder + File.separator + "test-" + System.currentTimeMillis(), Data.class.getName());
    } else {
        this.trainingData = new BytableMemoryDiskList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - crossValidationRate)),
                tmpFolder + File.separator + "train-" + System.currentTimeMillis(), Data.class.getName());
        this.validationData = new BytableMemoryDiskList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * crossValidationRate),
                tmpFolder + File.separator + "test-" + System.currentTimeMillis(), Data.class.getName());
    }

    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);

    // cannot find a good place to close these two data set, using Shutdown hook
    Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
        @Override
        public void run() {
            LogisticRegressionWorker.this.validationData.close();
            LogisticRegressionWorker.this.trainingData.close();
        }
    }));
}

From source file:ml.shifu.shifu.core.dtrain.nn.AbstractNNWorker.java

@Override
public void init(WorkerContext<NNParams, NNParams> context) {
    // load props firstly
    this.props = context.getProps();

    loadConfigFiles(context.getProps());

    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(),
            modelConfig.getTrain().getGridConfigFileContent());
    this.validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        this.validParams = gs.getParams(trainerId);
        LOG.info("Start grid search master with params: {}", validParams);
    }/*from  w  w w  .  ja  va  2s  .  co  m*/

    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
    }

    this.poissonSampler = Boolean.TRUE.toString()
            .equalsIgnoreCase(context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER));
    this.rng = new PoissonDistribution(1.0d);
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression()
            || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
        // set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }
    Integer epochsPerIterationInteger = this.modelConfig.getTrain().getEpochsPerIteration();
    this.epochsPerIteration = epochsPerIterationInteger == null ? 1 : epochsPerIterationInteger.intValue();
    LOG.info("epochsPerIteration in worker is :{}", epochsPerIteration);

    // Object elmObject = validParams.get(DTrainUtils.IS_ELM);
    // isELM = elmObject == null ? false : "true".equalsIgnoreCase(elmObject.toString());
    // LOG.info("Check isELM: {}", isELM);

    Object dropoutRateObj = validParams.get(CommonConstants.DROPOUT_RATE);
    if (dropoutRateObj != null) {
        this.dropoutRate = Double.valueOf(dropoutRateObj.toString());
    }
    LOG.info("'dropoutRate' in worker is :{}", this.dropoutRate);

    Object miniBatchO = validParams.get(CommonConstants.MINI_BATCH);
    if (miniBatchO != null) {
        int miniBatchs;
        try {
            miniBatchs = Integer.parseInt(miniBatchO.toString());
        } catch (Exception e) {
            miniBatchs = 1;
        }
        if (miniBatchs < 0) {
            this.batchs = 1;
        } else if (miniBatchs > 1000) {
            this.batchs = 1000;
        } else {
            this.batchs = miniBatchs;
        }
        LOG.info("'miniBatchs' in worker is : {}, batchs is {} ", miniBatchs, batchs);
    }

    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(),
            this.columnConfigList);
    this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    // if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
    int classes = modelConfig.getTags().size();
    this.outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1]
            : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes));
    this.candidateCount = inputOutputIndex[2];
    boolean isAfterVarSelect = inputOutputIndex[0] != 0;
    LOG.info("isAfterVarSelect {}: Input count {}, output count {}, candidate count {}", isAfterVarSelect,
            inputNodeCount, outputNodeCount, candidateCount);
    // cache all feature list for sampling features
    this.allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
    String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
    if (StringUtils.isBlank(subsetStr)) {
        this.subFeatures = this.allFeatures;
    } else {
        String[] splits = subsetStr.split(",");
        this.subFeatures = new ArrayList<Integer>(splits.length);
        for (String split : splits) {
            int featureIndex = Integer.parseInt(split);
            this.subFeatures.add(featureIndex);
        }
    }
    this.subFeatureSet = new HashSet<Integer>(this.subFeatures);
    LOG.info("subFeatures size is {}", subFeatures.size());
    this.featureInputsCnt = DTrainUtils.getFeatureInputsCnt(this.modelConfig, this.columnConfigList,
            this.subFeatureSet);

    this.wgtInit = "default";
    Object wgtInitObj = validParams.get(CommonConstants.WEIGHT_INITIALIZER);
    if (wgtInitObj != null) {
        this.wgtInit = wgtInitObj.toString();
    }

    Object lossObj = validParams.get("Loss");
    this.lossStr = lossObj != null ? lossObj.toString() : "squared";
    LOG.info("Loss str is {}", this.lossStr);

    this.isDry = Boolean.TRUE.toString()
            .equalsIgnoreCase(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN));
    this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null
            && !"".equals(modelConfig.getValidationDataSetRawPath()));
    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
    if (isOnDisk()) {
        LOG.info("NNWorker is loading data into disk.");
        try {
            initDiskDataSet();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        // cannot find a good place to close these two data set, using Shutdown hook
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
            @Override
            public void run() {
                ((BufferedFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close();
                ((BufferedFloatMLDataSet) (AbstractNNWorker.this.validationData)).close();
            }
        }));
    } else {
        LOG.info("NNWorker is loading data into memory.");
        double memoryFraction = Double
                .valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
        long memoryStoreSize = (long) (Runtime.getRuntime().maxMemory() * memoryFraction);
        LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
        double crossValidationRate = this.modelConfig.getValidSetRate();
        try {
            if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
                // fixed 0.6 and 0.4 of max memory for trainingData and validationData
                this.trainingData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.6),
                        DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
                this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.4),
                        DTrainUtils.getTestingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
            } else {
                this.trainingData = new MemoryDiskFloatMLDataSet(
                        (long) (memoryStoreSize * (1 - crossValidationRate)),
                        DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
                this.validationData = new MemoryDiskFloatMLDataSet(
                        (long) (memoryStoreSize * crossValidationRate), DTrainUtils.getTestingFile().toString(),
                        this.featureInputsCnt, this.outputNodeCount);
            }
            // cannot find a good place to close these two data set, using Shutdown hook
            Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
                @Override
                public void run() {
                    ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close();
                    ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.validationData)).close();
                }
            }));
        } catch (IOException e) {
            throw new GuaguaRuntimeException(e);
        }
    }

    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
}

From source file:ml.shifu.shifu.core.dtrain.dt.DTWorker.java

@Override
public void init(WorkerContext<DTMasterParams, DTWorkerParams> context) {
    Properties props = context.getProps();
    try {/*ww  w . j a  va  2 s  . c  o m*/
        SourceType sourceType = SourceType
                .valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG),
                sourceType);
        this.columnConfigList = CommonUtils
                .loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }

    this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>();
    for (ColumnConfig config : this.columnConfigList) {
        if (config.isCategorical()) {
            if (config.getBinCategory() != null) {
                Map<String, Integer> tmpMap = new HashMap<String, Integer>();
                for (int i = 0; i < config.getBinCategory().size(); i++) {
                    List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
                    for (String cval : catVals) {
                        tmpMap.put(cval, i);
                    }
                }
                this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap);
            }
        }
    }

    this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);

    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);

    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
    }

    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression()
            || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
        // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }

    this.isContinuousEnabled = Boolean.TRUE.toString()
            .equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));

    this.workerThreadCount = modelConfig.getTrain().getWorkerThreadCount();
    this.threadPool = Executors.newFixedThreadPool(this.workerThreadCount);
    // enable shut down logic
    context.addCompletionCallBack(new WorkerCompletionCallBack<DTMasterParams, DTWorkerParams>() {
        @Override
        public void callback(WorkerContext<DTMasterParams, DTWorkerParams> context) {
            DTWorker.this.threadPool.shutdownNow();
            try {
                DTWorker.this.threadPool.awaitTermination(2, TimeUnit.SECONDS);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    });

    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    this.isOneVsAll = modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll();

    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(),
            modelConfig.getTrain().getGridConfigFileContent());
    Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        validParams = gs.getParams(this.trainerId);
        LOG.info("Start grid search worker with params: {}", validParams);
    }

    this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());

    double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
    LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);

    double validationRate = this.modelConfig.getValidSetRate();
    if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
        // fixed 0.6 and 0.4 of max memory for trainingData and validationData
        this.trainingData = new MemoryLimitedList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>());
        this.validationData = new MemoryLimitedList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>());
    } else {
        if (Double.compare(validationRate, 0d) != 0) {
            this.trainingData = new MemoryLimitedList<Data>(
                    (long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)),
                    new ArrayList<Data>());
            this.validationData = new MemoryLimitedList<Data>(
                    (long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate),
                    new ArrayList<Data>());
        } else {
            this.trainingData = new MemoryLimitedList<Data>(
                    (long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>());
        }
    }

    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    // numerical + categorical = # of all input
    this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
    // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is
    // 1, with index of 0,1,2,3 denotes different classes
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null
            && !"".equals(modelConfig.getValidationDataSetRawPath()));

    int numClasses = this.modelConfig.isClassification() ? this.modelConfig.getTags().size() : 2;
    String imStr = validParams.get("Impurity").toString();
    int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
    double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
    if (imStr.equalsIgnoreCase("entropy")) {
        impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("gini")) {
        impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("friedmanmse")) {
        impurity = new FriedmanMSE(minInstancesPerNode, minInfoGain);
    } else {
        impurity = new Variance(minInstancesPerNode, minInfoGain);
    }

    this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());

    String lossStr = validParams.get("Loss").toString();
    if (lossStr.equalsIgnoreCase("log")) {
        this.loss = new LogLoss();
    } else if (lossStr.equalsIgnoreCase("absolute")) {
        this.loss = new AbsoluteLoss();
    } else if (lossStr.equalsIgnoreCase("halfgradsquared")) {
        this.loss = new HalfGradSquaredLoss();
    } else if (lossStr.equalsIgnoreCase("squared")) {
        this.loss = new SquaredLoss();
    } else {
        try {
            this.loss = (Loss) ClassUtils.newInstance(Class.forName(lossStr));
        } catch (ClassNotFoundException e) {
            LOG.warn("Class not found for {}, using default SquaredLoss", lossStr);
            this.loss = new SquaredLoss();
        }
    }

    if (this.isGBDT) {
        this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
        Object swrObj = validParams.get("GBTSampleWithReplacement");
        if (swrObj != null) {
            this.gbdtSampleWithReplacement = Boolean.TRUE.toString().equalsIgnoreCase(swrObj.toString());
        }

        Object dropoutObj = validParams.get(CommonConstants.DROPOUT_RATE);
        if (dropoutObj != null) {
            this.dropOutRate = Double.valueOf(dropoutObj.toString());
        }
    }

    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();

    this.checkpointOutput = new Path(context.getProps()
            .getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));

    LOG.info(
            "Worker init params:isAfterVarSel={}, treeNum={}, impurity={}, loss={}, learningRate={}, gbdtSampleWithReplacement={}, isRF={}, isGBDT={}, isStratifiedSampling={}, isKFoldCV={}, kCrossValidation={}, dropOutRate={}",
            isAfterVarSelect, treeNum, impurity.getClass().getName(), loss.getClass().getName(),
            this.learningRate, this.gbdtSampleWithReplacement, this.isRF, this.isGBDT,
            this.isStratifiedSampling, this.isKFoldCV, kCrossValidation, this.dropOutRate);

    // for fail over, load existing trees
    if (!context.isFirstIteration()) {
        if (this.isGBDT) {
            // set flag here and recover later in doComputing, this is to make sure recover after load part which
            // can load latest trees in #doCompute
            isNeedRecoverGBDTPredict = true;
        } else {
            // RF , trees are recovered from last master results
            recoverTrees = context.getLastMasterResult().getTrees();
        }
    }

    if (context.isFirstIteration() && this.isContinuousEnabled && this.isGBDT) {
        Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
        TreeModel existingModel = null;
        try {
            existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath,
                    ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
        } catch (IOException e) {
            LOG.error("Error in get existing model, will ignore and start from scratch", e);
        }
        if (existingModel == null) {
            LOG.warn("No model is found even set to continuous model training.");
            return;
        } else {
            recoverTrees = existingModel.getTrees();
            LOG.info("Loading existing {} trees", recoverTrees.size());
        }
    }
}

From source file:ml.shifu.shifu.core.dtrain.wdl.WDLWorker.java

@SuppressWarnings({ "unchecked", "unused" })
@Override//  w  w w . j  a va 2 s  .c om
public void init(WorkerContext<WDLParams, WDLParams> context) {
    Properties props = context.getProps();
    try {
        SourceType sourceType = SourceType
                .valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG),
                sourceType);
        this.columnConfigList = CommonUtils
                .loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }

    this.initCateIndexMap();
    this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);

    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);

    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
    }

    this.poissonSampler = Boolean.TRUE.toString()
            .equalsIgnoreCase(context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER));
    this.rng = new PoissonDistribution(1d);
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (upSampleWeight != 1d && (modelConfig.isRegression()
            || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
        // set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }

    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));

    double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
    LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);

    double validationRate = this.modelConfig.getValidSetRate();
    if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
        // fixed 0.6 and 0.4 of max memory for trainingData and validationData
        this.trainingData = new MemoryLimitedList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>());
        this.validationData = new MemoryLimitedList<Data>(
                (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>());
    } else {
        if (validationRate != 0d) {
            this.trainingData = new MemoryLimitedList<Data>(
                    (long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)),
                    new ArrayList<Data>());
            this.validationData = new MemoryLimitedList<Data>(
                    (long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate),
                    new ArrayList<Data>());
        } else {
            this.trainingData = new MemoryLimitedList<Data>(
                    (long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>());
        }
    }

    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    // numerical + categorical = # of all input
    this.numInputs = inputOutputIndex[0];
    this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
    // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is
    // 1, with index of 0,1,2,3 denotes different classes
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null
            && !"".equals(modelConfig.getValidationDataSetRawPath()));

    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();

    this.validParams = this.modelConfig.getTrain().getParams();

    // Build wide and deep graph
    List<Integer> embedColumnIds = (List<Integer>) this.validParams.get(CommonConstants.NUM_EMBED_COLUMN_IDS);
    Integer embedOutputs = (Integer) this.validParams.get(CommonConstants.NUM_EMBED_OUTPUTS);
    List<Integer> embedOutputList = new ArrayList<Integer>();
    for (Integer cId : embedColumnIds) {
        embedOutputList.add(embedOutputs == null ? CommonConstants.DEFAULT_EMBEDING_OUTPUT : embedOutputs);
    }
    List<Integer> numericalIds = DTrainUtils.getNumericalIds(this.columnConfigList, isAfterVarSelect);
    List<Integer> wideColumnIds = DTrainUtils.getCategoricalIds(columnConfigList, isAfterVarSelect);
    Map<Integer, Integer> idBinCateSizeMap = DTrainUtils.getIdBinCategorySizeMap(columnConfigList);
    int numLayers = (Integer) this.validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    List<String> actFunc = (List<String>) this.validParams.get(CommonConstants.ACTIVATION_FUNC);
    List<Integer> hiddenNodes = (List<Integer>) this.validParams.get(CommonConstants.NUM_HIDDEN_NODES);
    Float l2reg = ((Double) this.validParams.get(CommonConstants.WDL_L2_REG)).floatValue();
    this.wnd = new WideAndDeep(idBinCateSizeMap, numInputs, numericalIds, embedColumnIds, embedOutputList,
            wideColumnIds, hiddenNodes, actFunc, l2reg);
}

From source file:ml.shifu.shifu.core.dtrain.lr.LogisticRegressionWorker.java

protected float sampleWeights(float label) {
    float sampleWeights = 1f;
    // sample negative or kFoldCV, sample rate is 1d
    double sampleRate = (modelConfig.getTrain().getSampleNegOnly() || this.isKFoldCV) ? 1d
            : modelConfig.getTrain().getBaggingSampleRate();
    int classValue = (int) (label + 0.01f);
    if (!modelConfig.isBaggingWithReplacement()) {
        Random random = null;/*from  w w  w  .  j  ava 2 s.com*/
        if (this.isStratifiedSampling) {
            random = baggingRandomMap.get(classValue);
            if (random == null) {
                random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(),
                        CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                baggingRandomMap.put(classValue, random);
            }
        } else {
            random = baggingRandomMap.get(0);
            if (random == null) {
                random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(),
                        CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                baggingRandomMap.put(0, random);
            }
        }
        if (random.nextDouble() <= sampleRate) {
            sampleWeights = 1f;
        } else {
            sampleWeights = 0f;
        }
    } else {
        // bagging with replacement sampling in training data set, take PoissonDistribution for sampling with
        // replacement
        if (this.isStratifiedSampling) {
            PoissonDistribution rng = this.baggingRngMap.get(classValue);
            if (rng == null) {
                rng = new PoissonDistribution(sampleRate);
                this.baggingRngMap.put(classValue, rng);
            }
            sampleWeights = rng.sample();
        } else {
            PoissonDistribution rng = this.baggingRngMap.get(0);
            if (rng == null) {
                rng = new PoissonDistribution(sampleRate);
                this.baggingRngMap.put(0, rng);
            }
            sampleWeights = rng.sample();
        }
    }
    return sampleWeights;
}