List of usage examples for org.apache.hadoop.io IOUtils closeStream
public static void closeStream(java.io.Closeable stream)
From source file:ml.shifu.guagua.mapreduce.example.sum.SumOutput.java
License:Apache License
/** * Get output file setting and write final sum value to HDFS file. *//*from w w w . ja v a2 s. c o m*/ @Override public void postApplication( final MasterContext<GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<LongWritable>> context) { LOG.info("SumOutput starts to write final sum value to file."); Path out = new Path(context.getProps().getProperty("guagua.sum.output")); LOG.info("Writing results to {}", out.toString()); PrintWriter pw = null; try { FSDataOutputStream fos = FileSystem.get(new Configuration()).create(out); pw = new PrintWriter(fos); pw.println(context.getMasterResult().getWritable().get()); pw.flush(); } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(pw); } }
From source file:ml.shifu.guagua.yarn.example.sum.SumOutput.java
License:Apache License
@Override public void postApplication( final MasterContext<GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<LongWritable>> context) { LOG.info("SumOutput starts to write model to files."); Path out = new Path(context.getProps().getProperty("guagua.sum.output")); LOG.info("Writing results to {}", out.toString()); PrintWriter pw = null;/*from w ww . j a v a 2 s . co m*/ try { FSDataOutputStream fos = FileSystem.get(new Configuration()).create(out); pw = new PrintWriter(fos); pw.println(context.getMasterResult().getWritable().get()); pw.flush(); } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(pw); } }
From source file:ml.shifu.shifu.core.dtrain.dt.BinaryDTSerializer.java
License:Apache License
public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<List<TreeNode>> baggingTrees, String loss, int inputCount, OutputStream output) throws IOException { DataOutputStream fos = null;/* ww w . j a v a2 s . c o m*/ try { fos = new DataOutputStream(new GZIPOutputStream(output)); // version fos.writeInt(CommonConstants.TREE_FORMAT_VERSION); fos.writeUTF(modelConfig.getAlgorithm()); fos.writeUTF(loss); fos.writeBoolean(modelConfig.isClassification()); fos.writeBoolean(modelConfig.getTrain().isOneVsAll()); fos.writeInt(inputCount); Map<Integer, String> columnIndexNameMapping = new HashMap<Integer, String>(); Map<Integer, List<String>> columnIndexCategoricalListMapping = new HashMap<Integer, List<String>>(); Map<Integer, Double> numericalMeanMapping = new HashMap<Integer, Double>(); for (ColumnConfig columnConfig : columnConfigList) { if (columnConfig.isFinalSelect()) { columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName()); } if (columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) { columnIndexCategoricalListMapping.put(columnConfig.getColumnNum(), columnConfig.getBinCategory()); } if (columnConfig.isNumerical() && columnConfig.getMean() != null) { numericalMeanMapping.put(columnConfig.getColumnNum(), columnConfig.getMean()); } } if (columnIndexNameMapping.size() == 0) { boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList); for (ColumnConfig columnConfig : columnConfigList) { if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) { columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName()); } } } // serialize numericalMeanMapping fos.writeInt(numericalMeanMapping.size()); for (Entry<Integer, Double> entry : numericalMeanMapping.entrySet()) { fos.writeInt(entry.getKey()); // for some feature, it is null mean value, it is not selected, just set to 0d to avoid NPE fos.writeDouble(entry.getValue() == null ? 0d : entry.getValue()); } // serialize columnIndexNameMapping fos.writeInt(columnIndexNameMapping.size()); for (Entry<Integer, String> entry : columnIndexNameMapping.entrySet()) { fos.writeInt(entry.getKey()); fos.writeUTF(entry.getValue()); } // serialize columnIndexCategoricalListMapping fos.writeInt(columnIndexCategoricalListMapping.size()); for (Entry<Integer, List<String>> entry : columnIndexCategoricalListMapping.entrySet()) { List<String> categories = entry.getValue(); if (categories != null) { fos.writeInt(entry.getKey()); fos.writeInt(categories.size()); for (String category : categories) { // There is 16k limitation when using writeUTF() function. // if the category value is larger than 10k, write a marker -1 and write bytes instead of // writeUTF; // in read part logic should be changed also to readByte not readUTF according to the marker if (category.length() < Constants.MAX_CATEGORICAL_VAL_LEN) { fos.writeUTF(category); } else { fos.writeShort(UTF_BYTES_MARKER); // marker here byte[] bytes = category.getBytes("UTF-8"); fos.writeInt(bytes.length); for (int i = 0; i < bytes.length; i++) { fos.writeByte(bytes[i]); } } } } } Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList); fos.writeInt(columnMapping.size()); for (Entry<Integer, Integer> entry : columnMapping.entrySet()) { fos.writeInt(entry.getKey()); fos.writeInt(entry.getValue()); } // after model version 4 (>=4), IndependentTreeModel support bagging, here write a default RF/GBT size 1 fos.writeInt(baggingTrees.size()); for (int i = 0; i < baggingTrees.size(); i++) { List<TreeNode> trees = baggingTrees.get(i); int treeLength = trees.size(); fos.writeInt(treeLength); for (TreeNode treeNode : trees) { treeNode.write(fos); } } } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); } }
From source file:ml.shifu.shifu.core.dtrain.dt.DTMaster.java
License:Apache License
/** * Write {@link #trees}, {@link #toDoQueue} and MasterParams to HDFS. *//* ww w . ja v a 2s .c om*/ private void writeStatesToHdfs(Path out, DTMasterParams masterParams, List<TreeNode> trees, boolean isLeafWise, Queue<TreeNode> toDoQueue, Queue<TreeNode> toSplitQueue) { FSDataOutputStream fos = null; try { fos = FileSystem.get(new Configuration()).create(out); // trees int treeLength = trees.size(); fos.writeInt(treeLength); for (TreeNode treeNode : trees) { treeNode.write(fos); } // todo queue fos.writeInt(toDoQueue.size()); for (TreeNode treeNode : toDoQueue) { treeNode.write(fos); } if (isLeafWise && toSplitQueue != null) { fos.writeInt(toSplitQueue.size()); for (TreeNode treeNode : toSplitQueue) { treeNode.write(fos); } } // master result masterParams.write(fos); } catch (Throwable e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); fos = null; } }
From source file:ml.shifu.shifu.core.dtrain.dt.DTOutput.java
License:Apache License
@Override public void postApplication(MasterContext<DTMasterParams, DTWorkerParams> context) { List<TreeNode> trees = context.getMasterResult().getTrees(); if (this.isGBDT) { trees = context.getMasterResult().getTmpTrees(); }/*from w w w .j a v a 2 s. c o m*/ if (LOG.isDebugEnabled()) { LOG.debug("final trees", trees.toString()); } Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelToFileSystem(trees, out); if (this.isGsMode || this.isKFoldCV) { Path valErrOutput = new Path(context.getProps().getProperty(CommonConstants.GS_VALIDATION_ERROR)); writeValErrorToFileSystem( context.getMasterResult().getValidationError() / context.getMasterResult().getValidationCount(), valErrOutput); } IOUtils.closeStream(this.progressOutput); }
From source file:ml.shifu.shifu.core.dtrain.dt.DTOutput.java
License:Apache License
private void writeValErrorToFileSystem(double valError, Path out) { FSDataOutputStream fos = null;//w w w .ja va 2s . c o m try { fos = FileSystem.get(new Configuration()).create(out); LOG.info("Writing valerror to {}", out); fos.write((valError + "").getBytes("UTF-8")); } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(fos); } }
From source file:ml.shifu.shifu.core.dtrain.lr.LogisticRegressionOutput.java
License:Apache License
@Override public void postApplication(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { IOUtils.closeStream(this.progressOutput); // for dry mode, we don't save models files. if (this.isDry) { return;/*w w w.j a va 2s. c o m*/ } if (optimizedWeights == null) { optimizedWeights = context.getMasterResult().getParameters(); } Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); writeModelWeightsToFileSystem(optimizedWeights, out); if (this.isKFoldCV || this.isGsMode) { Path valErrOutput = new Path(context.getProps().getProperty(CommonConstants.GS_VALIDATION_ERROR)); writeValErrorToFileSystem(context.getMasterResult().getTestError(), valErrOutput); } IOUtils.closeStream(this.progressOutput); }
From source file:ml.shifu.shifu.core.dtrain.lr.LogisticRegressionOutput.java
License:Apache License
private void writeModelWeightsToFileSystem(double[] weights, Path out) { if (weights == null || weights.length <= 0) { return;// w w w .j a v a2 s.c o m } FSDataOutputStream fos = null; PrintWriter pw = null; try { fos = FileSystem.get(new Configuration()).create(out); LOG.info("Writing results to {}", out); if (out != null) { pw = new PrintWriter(fos); pw.println(Arrays.toString(weights)); } } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(pw); } }
From source file:ml.shifu.shifu.core.dtrain.nn.BinaryNNSerializer.java
License:Apache License
public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<BasicML> basicNetworks, FileSystem fs, Path output) throws IOException { DataOutputStream fos = null;// w w w .jav a 2s . co m try { fos = new DataOutputStream(new GZIPOutputStream(fs.create(output))); // version fos.writeInt(CommonConstants.NN_FORMAT_VERSION); // write normStr String normStr = modelConfig.getNormalize().getNormType().toString(); ml.shifu.shifu.core.dtrain.StringUtils.writeString(fos, normStr); // compute columns needed Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(columnConfigList); // write column stats to output List<NNColumnStats> csList = new ArrayList<NNColumnStats>(); for (ColumnConfig cc : columnConfigList) { if (columnIndexNameMapping.containsKey(cc.getColumnNum())) { NNColumnStats cs = new NNColumnStats(); cs.setCutoff(modelConfig.getNormalizeStdDevCutOff()); cs.setColumnType(cc.getColumnType()); cs.setMean(cc.getMean()); cs.setStddev(cc.getStdDev()); cs.setColumnNum(cc.getColumnNum()); cs.setColumnName(cc.getColumnName()); cs.setBinCategories(cc.getBinCategory()); cs.setBinBoundaries(cc.getBinBoundary()); cs.setBinPosRates(cc.getBinPosRate()); cs.setBinCountWoes(cc.getBinCountWoe()); cs.setBinWeightWoes(cc.getBinWeightedWoe()); // TODO cache such computation double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false); cs.setWoeMean(meanAndStdDev[0]); cs.setWoeStddev(meanAndStdDev[1]); double[] WgtMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true); cs.setWoeWgtMean(WgtMeanAndStdDev[0]); cs.setWoeWgtStddev(WgtMeanAndStdDev[1]); csList.add(cs); } } fos.writeInt(csList.size()); for (NNColumnStats cs : csList) { cs.write(fos); } // write column index mapping Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList); fos.writeInt(columnMapping.size()); for (Entry<Integer, Integer> entry : columnMapping.entrySet()) { fos.writeInt(entry.getKey()); fos.writeInt(entry.getValue()); } // persist network, set it as list fos.writeInt(basicNetworks.size()); for (BasicML network : basicNetworks) { new PersistBasicFloatNetwork().saveNetwork(fos, (BasicFloatNetwork) network); } } finally { IOUtils.closeStream(fos); } }
From source file:ml.shifu.shifu.core.dtrain.nn.NNOutput.java
License:Apache License
@Override public void postApplication(MasterContext<NNParams, NNParams> context) { IOUtils.closeStream(this.progressOutput); // for dry mode, we don't save models files. if (this.isDry) { return;// w w w. j a va2s. c o m } if (optimizedWeights != null) { Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); // TODO do we need to check IOException and retry again to make sure such important model is saved // successfully. writeModelWeightsToFileSystem(optimizedWeights, out, true); } if (this.gridSearch.hasHyperParam() || this.isKFoldCV) { Path valErrOutput = new Path(context.getProps().getProperty(CommonConstants.GS_VALIDATION_ERROR)); writeValErrorToFileSystem(context.getMasterResult().getTestError(), valErrOutput); } }