List of usage examples for org.apache.commons.math3.distribution PoissonDistribution PoissonDistribution
public PoissonDistribution(double p) throws NotStrictlyPositiveException
From source file:net.openhft.smoothie.MathDecisions.java
private static double footprint(int averageEntries, int refSize, int upFrontScale, int cap) { PoissonDistribution p = new PoissonDistribution(averageEntries); double stayCapProb = p.cumulativeProbability(cap); int objectHeaderSize = 8 + refSize; int segmentSize = objectSizeRoundUp((objectHeaderSize + (Segment.HASH_TABLE_SIZE * 2) + 8 + /* bit set */ 4 /* tier */ + (cap * 2 * refSize))); double totalSegmentsSize = (stayCapProb + (1 - stayCapProb) * 2) * segmentSize; int segmentsArraySize = refSize << upFrontScale; return (totalSegmentsSize + segmentsArraySize) / averageEntries; }
From source file:mastodon.algorithms.SALinearAlgorithm.java
protected void choosePruningCount() { int position = 0; for (int i = 0; i < stepIterations.length; i++) { position += stepIterations[i];//from w ww. j a v a2s. co m if (iterationCounter < position) { if (i + 1 != currPrunedSpeciesCount) { currPrunedSpeciesCount = i + 1; double mean = 1.0; //needed when pruning 1 taxon (can't have a mean of 0 in PoissonDistribution()) if (currPrunedSpeciesCount > 1) { mean = 0.5 * (currPrunedSpeciesCount - 1); } pd = new PoissonDistribution(mean); coolingRate = Math.pow(finalTemp / initTemp, 1.0 / stepIterations[i]); currTemp = initTemp; } break; } } currTemp *= coolingRate; }
From source file:mastodon.algorithms.SABisectionAlgorithm.java
protected void choosePruningCount() { if (iterationCounter % stepIterations == 0) { System.out.println(currPrunedSpeciesCount); System.out.println(maxScore[0] + " " + maxScore[1]); if (iterationCounter > 0) { if (maxScore[0] < minMapScore) { kLeft = currPrunedSpeciesCount; } else { kRight = currPrunedSpeciesCount; }/* w ww. ja v a 2s . c o m*/ currPrunedSpeciesCount = (int) ((kRight + kLeft) / 2); } maxScore = new double[2]; maxScorePruning = new HashMap<BitSet, double[]>(); currPruning = new BitSet(); for (int i = 0; i < currPrunedSpeciesCount; i++) { int choice = 0; do { choice = (int) (Random.nextDouble() * bts.getTaxaCount()); } while (currPruning.get(choice)); currPruning.set(choice); } prevPruning = (BitSet) currPruning.clone(); prevScore = bts.pruneFast(currPruning); bts.unPrune(); maxScorePruning.put(prevPruning, prevScore); double mean = 1.0; //needed when pruning 1 taxon (can't have a mean of 0 in PoissonDistribution()) if (currPrunedSpeciesCount > 1) { mean = 0.5 * (currPrunedSpeciesCount - 1); } pd = new PoissonDistribution(mean); currTemp = initTemp; } currTemp *= coolingRate; }
From source file:com.github.rinde.rinsim.scenario.generator.PoissonProcessTest.java
/** * Checks whether the observations conform to a Poisson process with the * specified intensity. Uses a chi square test with the specified confidence. * The null hypothesis is that the observations are the result of a poisson * process./*from w w w . j a va 2s. co m*/ * @param observations * @param intensity * @param confidence * @return <code>true</code> if the observations */ static boolean isPoissonProcess(Frequency observations, double intensity, double length, double confidence) { final PoissonDistribution pd = new PoissonDistribution(length * intensity); final Iterator<?> it = observations.valuesIterator(); final long[] observed = new long[observations.getUniqueCount()]; final double[] expected = new double[observations.getUniqueCount()]; int index = 0; while (it.hasNext()) { final Long l = (Long) it.next(); observed[index] = observations.getCount(l); expected[index] = pd.probability(l.intValue()) * observations.getSumFreq(); if (expected[index] == 0) { return false; } index++; } final double chi = TestUtils.chiSquareTest(expected, observed); return !(chi < confidence); }
From source file:hivemall.anomaly.ChangeFinder2DTest.java
@Test public void testPoissonDist() throws HiveException { final int examples = 10000; final int dims = 3; final PoissonDistribution[] poisson = new PoissonDistribution[] { new PoissonDistribution(10.d), new PoissonDistribution(5.d), new PoissonDistribution(20.d) }; final Random rand = new Random(42); final Double[] x = new Double[dims]; final List<Double> xList = Arrays.asList(x); Parameters params = new Parameters(); params.set(LossFunction.logloss);// w w w . j a v a 2 s . c om params.r1 = 0.01d; params.k = 6; params.T1 = 10; params.T2 = 5; PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector; ListObjectInspector listOI = ObjectInspectorFactory.getStandardListObjectInspector(oi); final ChangeFinder2D cf = new ChangeFinder2D(params, listOI); final double[] outScores = new double[2]; println("# time x0 x1 x2 outlier change"); for (int i = 0; i < examples; i++) { double r = rand.nextDouble(); x[0] = r * poisson[0].sample(); x[1] = r * poisson[1].sample(); x[2] = r * poisson[2].sample(); cf.update(xList, outScores); printf("%d %f %f %f %f %f%n", i, x[0], x[1], x[2], outScores[0], outScores[1]); } }
From source file:ml.shifu.shifu.core.dtrain.lr.LogisticRegressionWorker.java
@Override public void init(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) { loadConfigFiles(context.getProps()); int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList); this.inputNum = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; this.outputNum = inputOutputIndex[1]; this.candidateNum = inputOutputIndex[2]; this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath())); this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample(); this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if (kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true;//from w w w . jav a 2 s . co m } if (this.inputNum == 0) { throw new IllegalStateException( "No any variables are selected, please try variable select step firstly."); } this.rng = new PoissonDistribution(1.0d); Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight(); if (Double.compare(upSampleWeight, 1d) != 0) { // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value LOG.info("Enable up sampling with weight {}.", upSampleWeight); this.upSampleRng = new PoissonDistribution(upSampleWeight - 1); } double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6")); LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction); double crossValidationRate = this.modelConfig.getValidSetRate(); String tmpFolder = context.getProps().getProperty("guagua.data.tmpfolder", "tmp"); if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { // fixed 0.6 and 0.4 of max memory for trainingData and validationData this.trainingData = new BytableMemoryDiskList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), tmpFolder + File.separator + "train-" + System.currentTimeMillis(), Data.class.getName()); this.validationData = new BytableMemoryDiskList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), tmpFolder + File.separator + "test-" + System.currentTimeMillis(), Data.class.getName()); } else { this.trainingData = new BytableMemoryDiskList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - crossValidationRate)), tmpFolder + File.separator + "train-" + System.currentTimeMillis(), Data.class.getName()); this.validationData = new BytableMemoryDiskList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * crossValidationRate), tmpFolder + File.separator + "test-" + System.currentTimeMillis(), Data.class.getName()); } // create Splitter String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER); this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter); // cannot find a good place to close these two data set, using Shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { LogisticRegressionWorker.this.validationData.close(); LogisticRegressionWorker.this.trainingData.close(); } })); }
From source file:ml.shifu.shifu.core.dtrain.nn.AbstractNNWorker.java
@Override public void init(WorkerContext<NNParams, NNParams> context) { // load props firstly this.props = context.getProps(); loadConfigFiles(context.getProps()); this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent()); this.validParams = this.modelConfig.getTrain().getParams(); if (gs.hasHyperParam()) { this.validParams = gs.getParams(trainerId); LOG.info("Start grid search master with params: {}", validParams); }/*from w w w . ja va 2s . co m*/ Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if (kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation); } this.poissonSampler = Boolean.TRUE.toString() .equalsIgnoreCase(context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER)); this.rng = new PoissonDistribution(1.0d); Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight(); if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) { // set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value LOG.info("Enable up sampling with weight {}.", upSampleWeight); this.upSampleRng = new PoissonDistribution(upSampleWeight - 1); } Integer epochsPerIterationInteger = this.modelConfig.getTrain().getEpochsPerIteration(); this.epochsPerIteration = epochsPerIterationInteger == null ? 1 : epochsPerIterationInteger.intValue(); LOG.info("epochsPerIteration in worker is :{}", epochsPerIteration); // Object elmObject = validParams.get(DTrainUtils.IS_ELM); // isELM = elmObject == null ? false : "true".equalsIgnoreCase(elmObject.toString()); // LOG.info("Check isELM: {}", isELM); Object dropoutRateObj = validParams.get(CommonConstants.DROPOUT_RATE); if (dropoutRateObj != null) { this.dropoutRate = Double.valueOf(dropoutRateObj.toString()); } LOG.info("'dropoutRate' in worker is :{}", this.dropoutRate); Object miniBatchO = validParams.get(CommonConstants.MINI_BATCH); if (miniBatchO != null) { int miniBatchs; try { miniBatchs = Integer.parseInt(miniBatchO.toString()); } catch (Exception e) { miniBatchs = 1; } if (miniBatchs < 0) { this.batchs = 1; } else if (miniBatchs > 1000) { this.batchs = 1000; } else { this.batchs = miniBatchs; } LOG.info("'miniBatchs' in worker is : {}, batchs is {} ", miniBatchs, batchs); } int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList); this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; // if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1 int classes = modelConfig.getTags().size(); this.outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes)); this.candidateCount = inputOutputIndex[2]; boolean isAfterVarSelect = inputOutputIndex[0] != 0; LOG.info("isAfterVarSelect {}: Input count {}, output count {}, candidate count {}", isAfterVarSelect, inputNodeCount, outputNodeCount, candidateCount); // cache all feature list for sampling features this.allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect); String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET); if (StringUtils.isBlank(subsetStr)) { this.subFeatures = this.allFeatures; } else { String[] splits = subsetStr.split(","); this.subFeatures = new ArrayList<Integer>(splits.length); for (String split : splits) { int featureIndex = Integer.parseInt(split); this.subFeatures.add(featureIndex); } } this.subFeatureSet = new HashSet<Integer>(this.subFeatures); LOG.info("subFeatures size is {}", subFeatures.size()); this.featureInputsCnt = DTrainUtils.getFeatureInputsCnt(this.modelConfig, this.columnConfigList, this.subFeatureSet); this.wgtInit = "default"; Object wgtInitObj = validParams.get(CommonConstants.WEIGHT_INITIALIZER); if (wgtInitObj != null) { this.wgtInit = wgtInitObj.toString(); } Object lossObj = validParams.get("Loss"); this.lossStr = lossObj != null ? lossObj.toString() : "squared"; LOG.info("Loss str is {}", this.lossStr); this.isDry = Boolean.TRUE.toString() .equalsIgnoreCase(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN)); this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath())); this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample(); if (isOnDisk()) { LOG.info("NNWorker is loading data into disk."); try { initDiskDataSet(); } catch (IOException e) { throw new RuntimeException(e); } // cannot find a good place to close these two data set, using Shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { ((BufferedFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close(); ((BufferedFloatMLDataSet) (AbstractNNWorker.this.validationData)).close(); } })); } else { LOG.info("NNWorker is loading data into memory."); double memoryFraction = Double .valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6")); long memoryStoreSize = (long) (Runtime.getRuntime().maxMemory() * memoryFraction); LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction); double crossValidationRate = this.modelConfig.getValidSetRate(); try { if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { // fixed 0.6 and 0.4 of max memory for trainingData and validationData this.trainingData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.6), DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount); this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.4), DTrainUtils.getTestingFile().toString(), this.featureInputsCnt, this.outputNodeCount); } else { this.trainingData = new MemoryDiskFloatMLDataSet( (long) (memoryStoreSize * (1 - crossValidationRate)), DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount); this.validationData = new MemoryDiskFloatMLDataSet( (long) (memoryStoreSize * crossValidationRate), DTrainUtils.getTestingFile().toString(), this.featureInputsCnt, this.outputNodeCount); } // cannot find a good place to close these two data set, using Shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close(); ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.validationData)).close(); } })); } catch (IOException e) { throw new GuaguaRuntimeException(e); } } // create Splitter String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER); this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter); }
From source file:ml.shifu.shifu.core.dtrain.dt.DTWorker.java
@Override public void init(WorkerContext<DTMasterParams, DTWorkerParams> context) { Properties props = context.getProps(); try {/*ww w . j a va 2 s . c o m*/ SourceType sourceType = SourceType .valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils .loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>(); for (ColumnConfig config : this.columnConfigList) { if (config.isCategorical()) { if (config.getBinCategory() != null) { Map<String, Integer> tmpMap = new HashMap<String, Integer>(); for (int i = 0; i < config.getBinCategory().size(); i++) { List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i)); for (String cval : catVals) { tmpMap.put(cval, i); } } this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap); } } } this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList); // create Splitter String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER); this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter); Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if (kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation); } Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight(); if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) { // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value LOG.info("Enable up sampling with weight {}.", upSampleWeight); this.upSampleRng = new PoissonDistribution(upSampleWeight - 1); } this.isContinuousEnabled = Boolean.TRUE.toString() .equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING)); this.workerThreadCount = modelConfig.getTrain().getWorkerThreadCount(); this.threadPool = Executors.newFixedThreadPool(this.workerThreadCount); // enable shut down logic context.addCompletionCallBack(new WorkerCompletionCallBack<DTMasterParams, DTWorkerParams>() { @Override public void callback(WorkerContext<DTMasterParams, DTWorkerParams> context) { DTWorker.this.threadPool.shutdownNow(); try { DTWorker.this.threadPool.awaitTermination(2, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } }); this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); this.isOneVsAll = modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll(); GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent()); Map<String, Object> validParams = this.modelConfig.getTrain().getParams(); if (gs.hasHyperParam()) { validParams = gs.getParams(this.trainerId); LOG.info("Start grid search worker with params: {}", validParams); } this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString()); double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6")); LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction); double validationRate = this.modelConfig.getValidSetRate(); if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { // fixed 0.6 and 0.4 of max memory for trainingData and validationData this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>()); this.validationData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>()); } else { if (Double.compare(validationRate, 0d) != 0) { this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>()); this.validationData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>()); } else { this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>()); } } int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList); // numerical + categorical = # of all input this.inputCount = inputOutputIndex[0] + inputOutputIndex[1]; // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is // 1, with index of 0,1,2,3 denotes different classes this.isAfterVarSelect = (inputOutputIndex[3] == 1); this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath())); int numClasses = this.modelConfig.isClassification() ? this.modelConfig.getTags().size() : 2; String imStr = validParams.get("Impurity").toString(); int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString()); double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString()); if (imStr.equalsIgnoreCase("entropy")) { impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain); } else if (imStr.equalsIgnoreCase("gini")) { impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain); } else if (imStr.equalsIgnoreCase("friedmanmse")) { impurity = new FriedmanMSE(minInstancesPerNode, minInfoGain); } else { impurity = new Variance(minInstancesPerNode, minInfoGain); } this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm()); String lossStr = validParams.get("Loss").toString(); if (lossStr.equalsIgnoreCase("log")) { this.loss = new LogLoss(); } else if (lossStr.equalsIgnoreCase("absolute")) { this.loss = new AbsoluteLoss(); } else if (lossStr.equalsIgnoreCase("halfgradsquared")) { this.loss = new HalfGradSquaredLoss(); } else if (lossStr.equalsIgnoreCase("squared")) { this.loss = new SquaredLoss(); } else { try { this.loss = (Loss) ClassUtils.newInstance(Class.forName(lossStr)); } catch (ClassNotFoundException e) { LOG.warn("Class not found for {}, using default SquaredLoss", lossStr); this.loss = new SquaredLoss(); } } if (this.isGBDT) { this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString()); Object swrObj = validParams.get("GBTSampleWithReplacement"); if (swrObj != null) { this.gbdtSampleWithReplacement = Boolean.TRUE.toString().equalsIgnoreCase(swrObj.toString()); } Object dropoutObj = validParams.get(CommonConstants.DROPOUT_RATE); if (dropoutObj != null) { this.dropOutRate = Double.valueOf(dropoutObj.toString()); } } this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample(); this.checkpointOutput = new Path(context.getProps() .getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId())); LOG.info( "Worker init params:isAfterVarSel={}, treeNum={}, impurity={}, loss={}, learningRate={}, gbdtSampleWithReplacement={}, isRF={}, isGBDT={}, isStratifiedSampling={}, isKFoldCV={}, kCrossValidation={}, dropOutRate={}", isAfterVarSelect, treeNum, impurity.getClass().getName(), loss.getClass().getName(), this.learningRate, this.gbdtSampleWithReplacement, this.isRF, this.isGBDT, this.isStratifiedSampling, this.isKFoldCV, kCrossValidation, this.dropOutRate); // for fail over, load existing trees if (!context.isFirstIteration()) { if (this.isGBDT) { // set flag here and recover later in doComputing, this is to make sure recover after load part which // can load latest trees in #doCompute isNeedRecoverGBDTPredict = true; } else { // RF , trees are recovered from last master results recoverTrees = context.getLastMasterResult().getTrees(); } } if (context.isFirstIteration() && this.isContinuousEnabled && this.isGBDT) { Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); TreeModel existingModel = null; try { existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())); } catch (IOException e) { LOG.error("Error in get existing model, will ignore and start from scratch", e); } if (existingModel == null) { LOG.warn("No model is found even set to continuous model training."); return; } else { recoverTrees = existingModel.getTrees(); LOG.info("Loading existing {} trees", recoverTrees.size()); } } }
From source file:ml.shifu.shifu.core.dtrain.wdl.WDLWorker.java
@SuppressWarnings({ "unchecked", "unused" }) @Override// w w w . j a va 2 s .c om public void init(WorkerContext<WDLParams, WDLParams> context) { Properties props = context.getProps(); try { SourceType sourceType = SourceType .valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils .loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } this.initCateIndexMap(); this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList); // create Splitter String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER); this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter); Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold(); if (kCrossValidation != null && kCrossValidation > 0) { isKFoldCV = true; LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation); } this.poissonSampler = Boolean.TRUE.toString() .equalsIgnoreCase(context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER)); this.rng = new PoissonDistribution(1d); Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight(); if (upSampleWeight != 1d && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) { // set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value LOG.info("Enable up sampling with weight {}.", upSampleWeight); this.upSampleRng = new PoissonDistribution(upSampleWeight - 1); } this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6")); LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction); double validationRate = this.modelConfig.getValidSetRate(); if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) { // fixed 0.6 and 0.4 of max memory for trainingData and validationData this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>()); this.validationData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>()); } else { if (validationRate != 0d) { this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>()); this.validationData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>()); } else { this.trainingData = new MemoryLimitedList<Data>( (long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>()); } } int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList); // numerical + categorical = # of all input this.numInputs = inputOutputIndex[0]; this.inputCount = inputOutputIndex[0] + inputOutputIndex[1]; // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is // 1, with index of 0,1,2,3 denotes different classes this.isAfterVarSelect = (inputOutputIndex[3] == 1); this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath())); this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample(); this.validParams = this.modelConfig.getTrain().getParams(); // Build wide and deep graph List<Integer> embedColumnIds = (List<Integer>) this.validParams.get(CommonConstants.NUM_EMBED_COLUMN_IDS); Integer embedOutputs = (Integer) this.validParams.get(CommonConstants.NUM_EMBED_OUTPUTS); List<Integer> embedOutputList = new ArrayList<Integer>(); for (Integer cId : embedColumnIds) { embedOutputList.add(embedOutputs == null ? CommonConstants.DEFAULT_EMBEDING_OUTPUT : embedOutputs); } List<Integer> numericalIds = DTrainUtils.getNumericalIds(this.columnConfigList, isAfterVarSelect); List<Integer> wideColumnIds = DTrainUtils.getCategoricalIds(columnConfigList, isAfterVarSelect); Map<Integer, Integer> idBinCateSizeMap = DTrainUtils.getIdBinCategorySizeMap(columnConfigList); int numLayers = (Integer) this.validParams.get(CommonConstants.NUM_HIDDEN_LAYERS); List<String> actFunc = (List<String>) this.validParams.get(CommonConstants.ACTIVATION_FUNC); List<Integer> hiddenNodes = (List<Integer>) this.validParams.get(CommonConstants.NUM_HIDDEN_NODES); Float l2reg = ((Double) this.validParams.get(CommonConstants.WDL_L2_REG)).floatValue(); this.wnd = new WideAndDeep(idBinCateSizeMap, numInputs, numericalIds, embedColumnIds, embedOutputList, wideColumnIds, hiddenNodes, actFunc, l2reg); }
From source file:ml.shifu.shifu.core.dtrain.lr.LogisticRegressionWorker.java
protected float sampleWeights(float label) { float sampleWeights = 1f; // sample negative or kFoldCV, sample rate is 1d double sampleRate = (modelConfig.getTrain().getSampleNegOnly() || this.isKFoldCV) ? 1d : modelConfig.getTrain().getBaggingSampleRate(); int classValue = (int) (label + 0.01f); if (!modelConfig.isBaggingWithReplacement()) { Random random = null;/*from w w w . j ava 2 s.com*/ if (this.isStratifiedSampling) { random = baggingRandomMap.get(classValue); if (random == null) { random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED); baggingRandomMap.put(classValue, random); } } else { random = baggingRandomMap.get(0); if (random == null) { random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED); baggingRandomMap.put(0, random); } } if (random.nextDouble() <= sampleRate) { sampleWeights = 1f; } else { sampleWeights = 0f; } } else { // bagging with replacement sampling in training data set, take PoissonDistribution for sampling with // replacement if (this.isStratifiedSampling) { PoissonDistribution rng = this.baggingRngMap.get(classValue); if (rng == null) { rng = new PoissonDistribution(sampleRate); this.baggingRngMap.put(classValue, rng); } sampleWeights = rng.sample(); } else { PoissonDistribution rng = this.baggingRngMap.get(0); if (rng == null) { rng = new PoissonDistribution(sampleRate); this.baggingRngMap.put(0, rng); } sampleWeights = rng.sample(); } } return sampleWeights; }