List of usage examples for org.apache.hadoop.mapred Reporter getCounter
public abstract Counter getCounter(String group, String name);
From source file:co.nubetech.hiho.mapred.MySQLLoadDataMapper.java
License:Apache License
@Override public void map(Text key, FSDataInputStream val, OutputCollector<NullWritable, NullWritable> collector, Reporter reporter) throws IOException { conn = getConnection();/*from ww w .j a v a2 s. c o m*/ com.mysql.jdbc.Statement stmt = null; String query; String[] columnNames = null; if (hasHeaderLine) { BufferedReader headerReader = new BufferedReader(new InputStreamReader(val)); String header = headerReader.readLine(); if (header == null) return; columnNames = header.split(","); val.seek(header.getBytes(utf8).length + newline.length); } try { stmt = (com.mysql.jdbc.Statement) conn.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE); String tablename = (keyIsTableName ? keyToTablename(key) : ""); if (disableKeys && !tablename.equals("")) { reporter.setStatus("Disabling keys on " + tablename); stmt.execute("ALTER TABLE " + tablename + " DISABLE KEYS"); } stmt.setLocalInfileInputStream(val); query = "load data local infile 'abc.txt' into table " + tablename + " "; query += querySuffix; if (hasHeaderLine) query += " (" + StringUtils.join(columnNames, ",") + ")"; reporter.setStatus("Inserting into " + tablename); logger.debug("stmt: " + query); int rows = stmt.executeUpdate(query); logger.debug(rows + " rows updated"); if (disableKeys && !tablename.equals("")) { reporter.setStatus("Re-enabling keys on " + tablename); stmt.execute("ALTER TABLE " + tablename + " ENABLE KEYS"); } if (!tablename.equals("")) reporter.getCounter("MySQLLoadCounters", "ROWS_INSERTED_TABLE_" + tablename).increment(rows); reporter.getCounter("MySQLLoadCounters", "ROWS_INSERTED_TOTAL").increment(rows); } catch (Exception e) { e.printStackTrace(); stmt = null; throw new IOException(e); } finally { try { if (stmt != null) { stmt.close(); } } catch (SQLException s) { s.printStackTrace(); } } }
From source file:com.cloudera.recordservice.mapred.RecordServiceInputFormatBase.java
License:Apache License
/** * Populates RecordService counters in ctx from counters. */// w ww. j ava2 s .co m public static void setCounters(Reporter ctx, TaskStatus.Stats counters) { if (ctx == null) return; ctx.getCounter(COUNTERS_GROUP_NAME, "Records Read").setValue(counters.numRecordsRead); ctx.getCounter(COUNTERS_GROUP_NAME, "Records Returned").setValue(counters.numRecordsReturned); ctx.getCounter(COUNTERS_GROUP_NAME, "Record Serialization Time(ms)").setValue(counters.serializeTimeMs); ctx.getCounter(COUNTERS_GROUP_NAME, "Client Time(ms)").setValue(counters.clientTimeMs); if (counters.hdfsCountersSet) { ctx.getCounter(COUNTERS_GROUP_NAME, "Bytes Read").setValue(counters.bytesRead); ctx.getCounter(COUNTERS_GROUP_NAME, "Decompression Time(ms)").setValue(counters.decompressTimeMs); ctx.getCounter(COUNTERS_GROUP_NAME, "Bytes Read Local").setValue(counters.bytesReadLocal); ctx.getCounter(COUNTERS_GROUP_NAME, "HDFS Throughput(MB/s)") .setValue((long) (counters.hdfsThroughput / (1024 * 1024))); } }
From source file:com.digitalpebble.behemoth.languageidentification.LanguageIdProcessor.java
License:Apache License
public BehemothDocument[] process(BehemothDocument inputDoc, Reporter reporter) { // check that it has some text if (inputDoc.getText() == null) { LOG.info("No text for " + inputDoc.getUrl() + " skipping"); reporter.getCounter("LANGUAGE ID", "MISSING TEXT").increment(1); return new BehemothDocument[] { inputDoc }; }/* w ww. j a v a2 s . co m*/ String lang = null; // skip docs with empty text if (inputDoc.getText().trim().isEmpty()) { LOG.info("Empty text for " + inputDoc.getUrl() + " skipping"); reporter.getCounter("LANGUAGE ID", "EMPTY TEXT").increment(1); return new BehemothDocument[] { inputDoc }; } try { Detector detector = DetectorFactory.create(); detector.append(inputDoc.getText()); lang = detector.detect(); inputDoc.getMetadata(true).put(languageMDKey, new Text(lang)); } catch (LangDetectException e) { LOG.error("Exception on doc " + inputDoc.getUrl(), e); lang = null; } if (reporter != null && lang != null) reporter.getCounter("LANGUAGE DETECTED", lang).increment(1); return new BehemothDocument[] { inputDoc }; }
From source file:com.digitalpebble.behemoth.tika.TikaProcessor.java
License:Apache License
/** * Process a BehemothDocument with Tika/*from ww w. j a v a 2 s .co m*/ * * @return an array of documents or null if an exception is encountered */ public BehemothDocument[] process(BehemothDocument inputDoc, Reporter reporter) { // check that it has some text or content if (inputDoc.getContent() == null && inputDoc.getText() == null) { LOG.info("No content or text for " + inputDoc.getUrl() + " skipping"); if (reporter != null) reporter.getCounter("TIKA", "NO CONTENT OR TEXT").increment(1); return new BehemothDocument[] { inputDoc }; } // determine the content type if missing if (inputDoc.getContentType() == null || inputDoc.getContentType().equals("") == true) { String mt = null; // using the original content if (mimeType == null | forceMTDetection) { if (inputDoc.getContent() != null) { Metadata meta = new Metadata(); meta.set(Metadata.RESOURCE_NAME_KEY, inputDoc.getUrl()); MimeType mimetype = null; try { MediaType mediaType = detector.detect(new ByteArrayInputStream(inputDoc.getContent()), meta); mimetype = mimetypes.forName(mediaType.getType() + "/" + mediaType.getSubtype()); } catch (IOException e) { LOG.error("Exception", e); } catch (MimeTypeException e) { LOG.error("Exception", e); } mt = mimetype.getName(); } else if (mimeType == null && inputDoc.getText() != null) { // force it to text mt = "text/plain"; } } else { mt = mimeType;// allow outside user to specify a mime type if // they know all the content, saves time and // reduces error } if (mt != null) { inputDoc.setContentType(mt); } } // determine which parser to use Parser parser = TikaConfig.getDefaultConfig().getParser(); // skip the processing if the input document already has some text if (inputDoc.getText() != null) { if (reporter != null) reporter.getCounter("TIKA", "TEXT ALREADY AVAILABLE").increment(1); return new BehemothDocument[] { inputDoc }; } // filter based on content length // optional int length = inputDoc.getContent().length; if (contentLengthThresholdFilter != -1 && length > contentLengthThresholdFilter) { if (reporter != null) reporter.getCounter("TIKA", "FILTERED-CONTENT-LENGTH").increment(1); return new BehemothDocument[] { inputDoc }; } // otherwise parse the document and retrieve the text, metadata and // markup annotations InputStream is = new ByteArrayInputStream(inputDoc.getContent()); Metadata metadata = new Metadata(); // put the mimetype in the metadata so that Tika can // decide which parser to use metadata.set(Metadata.CONTENT_TYPE, inputDoc.getContentType()); String ct = inputDoc.getContentType(); try { if (reporter != null && okCounters) reporter.getCounter("MIME-TYPE", ct).increment(1); } catch (Exception counterEx) { LOG.error("Could not add counter MIME-TYPE:" + ct, counterEx); okCounters = false; } // TODO check config whether want the markup or just the text and // metadata? BehemothHandler handler = new TikaMarkupHandler(); boolean doMarkup = config.getBoolean("tika.convert.markup", true); if (!doMarkup) { handler = new TikaTextHandler(); } ParseContext context = new ParseContext(); // TODO generalise the approach so that can set any class via context String customMapper = config.get("tika.context.HtmlMapper.class"); if (customMapper != null) { try { Class<HtmlMapper> customMapperClass = (Class<HtmlMapper>) Class.forName(customMapper); // specify a custom HTML mapper via the Context context.set(HtmlMapper.class, customMapperClass.newInstance()); } catch (Exception e) { LOG.error("Can't use class " + customMapper + " for HtmlMapper, using default"); } } try { parser.parse(is, handler, metadata, context); processMetadata(inputDoc, metadata); processText(inputDoc, handler.getText()); processMarkupAnnotations(inputDoc, handler.getAnnotations()); if (reporter != null) reporter.getCounter("TIKA", "ANNOTATIONS ADDED").increment(handler.getAnnotations().size()); } catch (Exception e) { LOG.error(inputDoc.getUrl().toString(), e); if (reporter != null) reporter.getCounter("TIKA", "PARSING_ERROR").increment(1); return new BehemothDocument[] { inputDoc }; } finally { try { is.close(); } catch (IOException e) { } } // TODO if the content type is an archive maybe process and return // all the subdocuments if (reporter != null) reporter.getCounter("TIKA", "DOC-PARSED").increment(1); return new BehemothDocument[] { inputDoc }; }
From source file:com.ebay.erl.mobius.core.mapred.DefaultMobiusReducer.java
License:Apache License
private void output(Tuple aTuple, OutputCollector<NullWritable, WritableComparable<?>> output, Reporter reporter) throws IOException { aTuple.setToStringOrdering(this.outputColumnNames); if (this._persistantCriteria != null) { if (this._persistantCriteria.accept(aTuple, this.conf)) { output.collect(NullWritable.get(), aTuple); reporter.getCounter("Join/Grouping Records", "EMITTED").increment(1); } else {/*from ww w .j a v a 2s . c o m*/ reporter.getCounter("Join/Grouping Records", "FILTERED").increment(1); } } else { output.collect(NullWritable.get(), aTuple); reporter.getCounter("Join/Grouping Records", "EMITTED").increment(1); } }
From source file:com.TCG.Nutch_DNS.HostDbReducer.java
License:Apache License
public void reduce(Text key, Iterator<CrawlDatum> values, OutputCollector<Text, CrawlDatum> output, Reporter reporter) throws IOException { CrawlDatum fetch = new CrawlDatum(); CrawlDatum old = new CrawlDatum(); boolean fetchSet = false; boolean oldSet = false; byte[] signature = null; boolean multiple = false; // avoid deep copy when only single value exists linked.clear();//from ww w . j ava2 s. c o m org.apache.hadoop.io.MapWritable metaFromParse = null; while (values.hasNext()) { CrawlDatum datum = values.next(); if (!multiple && values.hasNext()) multiple = true; if (CrawlDatum.hasDbStatus(datum)) { if (!oldSet) { if (multiple) { old.set(datum); } else { // no need for a deep copy - this is the only value old = datum; } oldSet = true; } else { // always take the latest version if (old.getFetchTime() < datum.getFetchTime()) old.set(datum); } continue; } if (CrawlDatum.hasFetchStatus(datum)) { if (!fetchSet) { if (multiple) { fetch.set(datum); } else { fetch = datum; } fetchSet = true; } else { // always take the latest version if (fetch.getFetchTime() < datum.getFetchTime()) fetch.set(datum); } continue; } switch (datum.getStatus()) { // collect other info case CrawlDatum.STATUS_LINKED: CrawlDatum link; if (multiple) { link = new CrawlDatum(); link.set(datum); } else { link = datum; } linked.insert(link); break; case CrawlDatum.STATUS_SIGNATURE: signature = datum.getSignature(); break; case CrawlDatum.STATUS_PARSE_META: metaFromParse = datum.getMetaData(); break; default: LOG.warn("Unknown status, key: " + key + ", datum: " + datum); } } // copy the content of the queue into a List // in reversed order int numLinks = linked.size(); List<CrawlDatum> linkList = new ArrayList<CrawlDatum>(numLinks); for (int i = numLinks - 1; i >= 0; i--) { linkList.add(linked.pop()); } // if it doesn't already exist, skip it if (!oldSet && !additionsAllowed) return; // if there is no fetched datum, perhaps there is a link if (!fetchSet && linkList.size() > 0) { fetch = linkList.get(0); fetchSet = true; } // still no new data - record only unchanged old data, if exists, and return if (!fetchSet) { if (oldSet) {// at this point at least "old" should be present output.collect(key, old); reporter.getCounter("CrawlDB status", CrawlDatum.getStatusName(old.getStatus())).increment(1); } else { LOG.warn("Missing fetch and old value, signature=" + signature); } return; } if (signature == null) signature = fetch.getSignature(); long prevModifiedTime = oldSet ? old.getModifiedTime() : 0L; long prevFetchTime = oldSet ? old.getFetchTime() : 0L; // initialize with the latest version, be it fetch or link result.set(fetch); if (oldSet) { // copy metadata from old, if exists if (old.getMetaData().size() > 0) { result.putAllMetaData(old); // overlay with new, if any if (fetch.getMetaData().size() > 0) result.putAllMetaData(fetch); } // set the most recent valid value of modifiedTime if (old.getModifiedTime() > 0 && fetch.getModifiedTime() == 0) { result.setModifiedTime(old.getModifiedTime()); } } switch (fetch.getStatus()) { // determine new status case CrawlDatum.STATUS_LINKED: // it was link if (oldSet) { // if old exists result.set(old); // use it } else { result = schedule.initializeSchedule(key, result); result.setStatus(CrawlDatum.STATUS_DB_UNFETCHED); try { scfilters.initialScore(key, result); } catch (ScoringFilterException e) { if (LOG.isWarnEnabled()) { LOG.warn("Cannot filter init score for url " + key + ", using default: " + e.getMessage()); } result.setScore(0.0f); } } break; case CrawlDatum.STATUS_FETCH_SUCCESS: // succesful fetch case CrawlDatum.STATUS_FETCH_REDIR_TEMP: // successful fetch, redirected case CrawlDatum.STATUS_FETCH_REDIR_PERM: case CrawlDatum.STATUS_FETCH_NOTMODIFIED: // successful fetch, notmodified // determine the modification status int modified = FetchSchedule.STATUS_UNKNOWN; if (fetch.getStatus() == CrawlDatum.STATUS_FETCH_NOTMODIFIED) { modified = FetchSchedule.STATUS_NOTMODIFIED; } else if (fetch.getStatus() == CrawlDatum.STATUS_FETCH_SUCCESS) { // only successful fetches (but not redirects, NUTCH-1422) // are detected as "not modified" by signature comparison if (oldSet && old.getSignature() != null && signature != null) { if (SignatureComparator._compare(old.getSignature(), signature) != 0) { modified = FetchSchedule.STATUS_MODIFIED; } else { modified = FetchSchedule.STATUS_NOTMODIFIED; } } } // set the schedule result = schedule.setFetchSchedule(key, result, prevFetchTime, prevModifiedTime, fetch.getFetchTime(), fetch.getModifiedTime(), modified); // set the result status and signature if (modified == FetchSchedule.STATUS_NOTMODIFIED) { result.setStatus(CrawlDatum.STATUS_DB_NOTMODIFIED); // NUTCH-1341 The page is not modified according to its signature, let's // reset lastModified as well result.setModifiedTime(prevModifiedTime); if (oldSet) result.setSignature(old.getSignature()); } else { switch (fetch.getStatus()) { case CrawlDatum.STATUS_FETCH_SUCCESS: result.setStatus(CrawlDatum.STATUS_DB_FETCHED); break; case CrawlDatum.STATUS_FETCH_REDIR_PERM: result.setStatus(CrawlDatum.STATUS_DB_REDIR_PERM); break; case CrawlDatum.STATUS_FETCH_REDIR_TEMP: result.setStatus(CrawlDatum.STATUS_DB_REDIR_TEMP); break; default: LOG.warn("Unexpected status: " + fetch.getStatus() + " resetting to old status."); if (oldSet) result.setStatus(old.getStatus()); else result.setStatus(CrawlDatum.STATUS_DB_UNFETCHED); } result.setSignature(signature); } // https://issues.apache.org/jira/browse/NUTCH-1656 if (metaFromParse != null) { for (Entry<Writable, Writable> e : metaFromParse.entrySet()) { result.getMetaData().put(e.getKey(), e.getValue()); } } // if fetchInterval is larger than the system-wide maximum, trigger // an unconditional recrawl. This prevents the page to be stuck at // NOTMODIFIED state, when the old fetched copy was already removed with // old segments. if (maxInterval < result.getFetchInterval()) result = schedule.forceRefetch(key, result, false); break; case CrawlDatum.STATUS_SIGNATURE: if (LOG.isWarnEnabled()) { LOG.warn("Lone CrawlDatum.STATUS_SIGNATURE: " + key); } return; case CrawlDatum.STATUS_FETCH_RETRY: // temporary failure if (oldSet) { result.setSignature(old.getSignature()); // use old signature } result = schedule.setPageRetrySchedule(key, result, prevFetchTime, prevModifiedTime, fetch.getFetchTime()); if (result.getRetriesSinceFetch() < retryMax) { result.setStatus(CrawlDatum.STATUS_DB_UNFETCHED); } else { result.setStatus(CrawlDatum.STATUS_DB_GONE); result = schedule.setPageGoneSchedule(key, result, prevFetchTime, prevModifiedTime, fetch.getFetchTime()); } break; case CrawlDatum.STATUS_FETCH_GONE: // permanent failure if (oldSet) result.setSignature(old.getSignature()); // use old signature result.setStatus(CrawlDatum.STATUS_DB_GONE); result = schedule.setPageGoneSchedule(key, result, prevFetchTime, prevModifiedTime, fetch.getFetchTime()); break; default: throw new RuntimeException("Unknown status: " + fetch.getStatus() + " " + key); } try { scfilters.updateDbScore(key, oldSet ? old : null, result, linkList); } catch (Exception e) { if (LOG.isWarnEnabled()) { LOG.warn("Couldn't update score, key=" + key + ": " + e); } } // remove generation time, if any result.getMetaData().remove(Nutch.WRITABLE_GENERATE_TIME_KEY); output.collect(key, result); reporter.getCounter("CrawlDB status", CrawlDatum.getStatusName(result.getStatus())).increment(1); }
From source file:gobblin.metrics.hadoop.HadoopCounterReporterTest.java
License:Apache License
@BeforeClass public void setUp() throws Exception { String contextName = CONTEXT_NAME + "_" + UUID.randomUUID().toString(); Reporter mockedReporter = Mockito.mock(Reporter.class); this.recordsProcessedCount = Mockito.mock(Counters.Counter.class); Mockito.when(mockedReporter.getCounter(contextName, MetricRegistry.name(RECORDS_PROCESSED, Measurements.COUNT.getName()))) .thenReturn(this.recordsProcessedCount); this.recordProcessRateCount = Mockito.mock(Counters.Counter.class); Mockito.when(mockedReporter.getCounter(contextName, MetricRegistry.name(RECORD_PROCESS_RATE, Measurements.COUNT.getName()))) .thenReturn(this.recordProcessRateCount); this.recordSizeDistributionCount = Mockito.mock(Counters.Counter.class); Mockito.when(mockedReporter.getCounter(contextName, MetricRegistry.name(RECORD_SIZE_DISTRIBUTION, Measurements.COUNT.getName()))) .thenReturn(this.recordSizeDistributionCount); this.totalDurationCount = Mockito.mock(Counters.Counter.class); Mockito.when(mockedReporter.getCounter(contextName, MetricRegistry.name(TOTAL_DURATION, Measurements.COUNT.getName()))) .thenReturn(this.totalDurationCount); this.queueSize = Mockito.mock(Counters.Counter.class); Mockito.when(mockedReporter.getCounter(contextName, QUEUE_SIZE)).thenReturn(this.queueSize); this.hadoopCounterReporter = HadoopCounterReporter.builder(mockedReporter).convertRatesTo(TimeUnit.SECONDS) .convertDurationsTo(TimeUnit.SECONDS).filter(MetricFilter.ALL) .build(MetricContext.builder(contextName).buildStrict()); }
From source file:hivemall.fm.FactorizationMachineUDTF.java
License:Apache License
protected void runTrainingIteration(int iterations) throws HiveException { final ByteBuffer inputBuf = this._inputBuf; final NioStatefullSegment fileIO = this._fileIO; assert (inputBuf != null); assert (fileIO != null); final long numTrainingExamples = _t; final boolean adaregr = _va_rand != null; final Reporter reporter = getReporter(); final Counter iterCounter = (reporter == null) ? null : reporter.getCounter("hivemall.fm.FactorizationMachines$Counter", "iteration"); try {/*from w w w .ja v a 2s . c om*/ if (fileIO.getPosition() == 0L) {// run iterations w/o temporary file if (inputBuf.position() == 0) { return; // no training example } inputBuf.flip(); int iter = 2; for (; iter <= iterations; iter++) { reportProgress(reporter); setCounterValue(iterCounter, iter); while (inputBuf.remaining() > 0) { int bytes = inputBuf.getInt(); assert (bytes > 0) : bytes; int xLength = inputBuf.getInt(); final Feature[] x = new Feature[xLength]; for (int j = 0; j < xLength; j++) { x[j] = instantiateFeature(inputBuf); } double y = inputBuf.getDouble(); // invoke train ++_t; train(x, y, adaregr); } if (_cvState.isConverged(iter, numTrainingExamples)) { break; } inputBuf.rewind(); } LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(_t) + " training updates in total) "); } else {// read training examples in the temporary file and invoke train for each example // write training examples in buffer to a temporary file if (inputBuf.remaining() > 0) { writeBuffer(inputBuf, fileIO); } try { fileIO.flush(); } catch (IOException e) { throw new HiveException("Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), e); } if (LOG.isInfoEnabled()) { File tmpFile = fileIO.getFile(); LOG.info( "Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"); } // run iterations int iter = 2; for (; iter <= iterations; iter++) { setCounterValue(iterCounter, iter); inputBuf.clear(); fileIO.resetPosition(); while (true) { reportProgress(reporter); // TODO prefetch // writes training examples to a buffer in the temporary file final int bytesRead; try { bytesRead = fileIO.read(inputBuf); } catch (IOException e) { throw new HiveException("Failed to read a file: " + fileIO.getFile().getAbsolutePath(), e); } if (bytesRead == 0) { // reached file EOF break; } assert (bytesRead > 0) : bytesRead; // reads training examples from a buffer inputBuf.flip(); int remain = inputBuf.remaining(); if (remain < INT_BYTES) { throw new HiveException("Illegal file format was detected"); } while (remain >= INT_BYTES) { int pos = inputBuf.position(); int recordBytes = inputBuf.getInt(); remain -= INT_BYTES; if (remain < recordBytes) { inputBuf.position(pos); break; } final int xLength = inputBuf.getInt(); final Feature[] x = new Feature[xLength]; for (int j = 0; j < xLength; j++) { x[j] = instantiateFeature(inputBuf); } double y = inputBuf.getDouble(); // invoke training ++_t; train(x, y, adaregr); remain -= recordBytes; } inputBuf.compact(); } if (_cvState.isConverged(iter, numTrainingExamples)) { break; } } LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(_t) + " training updates in total)"); } } finally { // delete the temporary file and release resources try { fileIO.close(true); } catch (IOException e) { throw new HiveException("Failed to close a file: " + fileIO.getFile().getAbsolutePath(), e); } this._inputBuf = null; this._fileIO = null; } }
From source file:hivemall.GeneralLearnerBaseUDTF.java
License:Apache License
protected final void runIterativeTraining(@Nonnegative final int iterations) throws HiveException { final ByteBuffer buf = this.inputBuf; final NioStatefulSegment dst = this.fileIO; assert (buf != null); assert (dst != null); final long numTrainingExamples = count; final Reporter reporter = getReporter(); final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter("hivemall.GeneralLearnerBase$Counter", "iteration"); try {/*www.ja v a 2s .co m*/ if (dst.getPosition() == 0L) {// run iterations w/o temporary file if (buf.position() == 0) { return; // no training example } buf.flip(); for (int iter = 2; iter <= iterations; iter++) { cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); while (buf.remaining() > 0) { int recordBytes = buf.getInt(); assert (recordBytes > 0) : recordBytes; int featureVectorLength = buf.getInt(); final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; for (int j = 0; j < featureVectorLength; j++) { featureVector[j] = readFeatureValue(buf, featureType); } float target = buf.getFloat(); train(featureVector, target); } buf.rewind(); if (is_mini_batch) { // Update model with accumulated delta batchUpdate(); } if (cvState.isConverged(numTrainingExamples)) { break; } } logger.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(numTrainingExamples * cvState.getCurrentIteration()) + " training updates in total) "); } else {// read training examples in the temporary file and invoke train for each example // write training examples in buffer to a temporary file if (buf.remaining() > 0) { writeBuffer(buf, dst); } try { dst.flush(); } catch (IOException e) { throw new HiveException("Failed to flush a file: " + dst.getFile().getAbsolutePath(), e); } if (logger.isInfoEnabled()) { File tmpFile = dst.getFile(); logger.info( "Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"); } // run iterations for (int iter = 2; iter <= iterations; iter++) { cvState.next(); setCounterValue(iterCounter, iter); buf.clear(); dst.resetPosition(); while (true) { reportProgress(reporter); // TODO prefetch // writes training examples to a buffer in the temporary file final int bytesRead; try { bytesRead = dst.read(buf); } catch (IOException e) { throw new HiveException("Failed to read a file: " + dst.getFile().getAbsolutePath(), e); } if (bytesRead == 0) { // reached file EOF break; } assert (bytesRead > 0) : bytesRead; // reads training examples from a buffer buf.flip(); int remain = buf.remaining(); if (remain < SizeOf.INT) { throw new HiveException("Illegal file format was detected"); } while (remain >= SizeOf.INT) { int pos = buf.position(); int recordBytes = buf.getInt(); remain -= SizeOf.INT; if (remain < recordBytes) { buf.position(pos); break; } int featureVectorLength = buf.getInt(); final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; for (int j = 0; j < featureVectorLength; j++) { featureVector[j] = readFeatureValue(buf, featureType); } float target = buf.getFloat(); train(featureVector, target); remain -= recordBytes; } buf.compact(); } if (is_mini_batch) { // Update model with accumulated delta batchUpdate(); } if (cvState.isConverged(numTrainingExamples)) { break; } } logger.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(numTrainingExamples * cvState.getCurrentIteration()) + " training updates in total)"); } } catch (Throwable e) { throw new HiveException("Exception caused in the iterative training", e); } finally { // delete the temporary file and release resources try { dst.close(true); } catch (IOException e) { throw new HiveException("Failed to close a file: " + dst.getFile().getAbsolutePath(), e); } this.inputBuf = null; this.fileIO = null; } }
From source file:hivemall.mf.BPRMatrixFactorizationUDTF.java
License:Apache License
private final void runIterativeTraining(@Nonnegative final int iterations) throws HiveException { final ByteBuffer inputBuf = this.inputBuf; final NioFixedSegment fileIO = this.fileIO; assert (inputBuf != null); assert (fileIO != null); final long numTrainingExamples = count; final Reporter reporter = getReporter(); final Counter iterCounter = (reporter == null) ? null : reporter.getCounter("hivemall.mf.BPRMatrixFactorization$Counter", "iteration"); try {/* w w w . j a v a2s .c o m*/ if (lastWritePos == 0) {// run iterations w/o temporary file if (inputBuf.position() == 0) { return; // no training example } inputBuf.flip(); int iter = 2; for (; iter <= iterations; iter++) { reportProgress(reporter); setCounterValue(iterCounter, iter); while (inputBuf.remaining() > 0) { int u = inputBuf.getInt(); int i = inputBuf.getInt(); int j = inputBuf.getInt(); // invoke train count++; train(u, i, j); } cvState.multiplyLoss(0.5d); cvState.logState(iter, eta()); if (cvState.isConverged(iter, numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { etaEstimator.update(1.1f); } else { etaEstimator.update(0.5f); } inputBuf.rewind(); } LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(count) + " training updates in total) "); } else {// read training examples in the temporary file and invoke train for each example // write training examples in buffer to a temporary file if (inputBuf.position() > 0) { writeBuffer(inputBuf, fileIO, lastWritePos); } else if (lastWritePos == 0) { return; // no training example } try { fileIO.flush(); } catch (IOException e) { throw new HiveException("Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), e); } if (LOG.isInfoEnabled()) { File tmpFile = fileIO.getFile(); LOG.info( "Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"); } // run iterations int iter = 2; for (; iter <= iterations; iter++) { setCounterValue(iterCounter, iter); inputBuf.clear(); long seekPos = 0L; while (true) { reportProgress(reporter); // TODO prefetch // writes training examples to a buffer in the temporary file final int bytesRead; try { bytesRead = fileIO.read(seekPos, inputBuf); } catch (IOException e) { throw new HiveException("Failed to read a file: " + fileIO.getFile().getAbsolutePath(), e); } if (bytesRead == 0) { // reached file EOF break; } assert (bytesRead > 0) : bytesRead; seekPos += bytesRead; // reads training examples from a buffer inputBuf.flip(); int remain = inputBuf.remaining(); assert (remain > 0) : remain; for (; remain >= RECORD_BYTES; remain -= RECORD_BYTES) { int u = inputBuf.getInt(); int i = inputBuf.getInt(); int j = inputBuf.getInt(); // invoke train count++; train(u, i, j); } inputBuf.compact(); } cvState.multiplyLoss(0.5d); cvState.logState(iter, eta()); if (cvState.isConverged(iter, numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { etaEstimator.update(1.1f); } else { etaEstimator.update(0.5f); } } LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(count) + " training updates in total)"); } } finally { // delete the temporary file and release resources try { fileIO.close(true); } catch (IOException e) { throw new HiveException("Failed to close a file: " + fileIO.getFile().getAbsolutePath(), e); } this.inputBuf = null; this.fileIO = null; } }