Java tutorial
/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.datasets.datavec; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.datavec.api.io.WritableConverter; import org.datavec.api.io.converters.SelfWritableConverter; import org.datavec.api.records.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataComposableMap; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.ConcatenatingRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; import org.datavec.api.writable.Writable; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; /** * Record reader dataset iterator. Takes a DataVec {@link RecordReader} as input, and handles the conversion to ND4J * DataSet objects as well as producing minibatches from individual records.<br> * <br> * Multiple constructors are available, and a {@link Builder} class is also available.<br> * <br> * Example 1: Image classification, batch size 32, 10 classes<br> * <pre> * {@code RecordReader rr = new ImageRecordReader(28,28,3); //28x28 RGB images * rr.initialize(new FileSplit(new File("/path/to/directory"))); * * DataSetIterator iter = new RecordReaderDataSetIterator.Builder(rr, 32) * //Label index (first arg): Always value 1 when using ImageRecordReader. For CSV etc: use index of the column * // that contains the label (should contain an integer value, 0 to nClasses-1 inclusive). Column indexes start * // at 0. Number of classes (second arg): number of label classes (i.e., 10 for MNIST - 10 digits) * .classification(1, nClasses) * .preProcessor(new ImagePreProcessingScaler()) //For normalization of image values 0-255 to 0-1 * .build() * } * </pre> * <br> * <br> * Example 2: Multi-output regression from CSV, batch size 128<br> * <pre> * {@code RecordReader rr = new CsvRecordReader(0, ','); //Skip 0 header lines, comma separated * rr.initialize(new FileSplit(new File("/path/to/myCsv.txt"))); * * DataSetIterator iter = new RecordReaderDataSetIterator.Builder(rr, 128) * //Specify the columns that the regression labels/targets appear in. Note that all other columns will be * // treated as features. Columns indexes start at 0 * .regression(labelColFrom, labelColTo) * .build() * } * </pre> * @author Adam Gibson */ @Slf4j public class RecordReaderDataSetIterator implements DataSetIterator { private static final String READER_KEY = "reader"; @Getter protected RecordReader recordReader; protected WritableConverter converter; protected int batchSize = 10; protected int maxNumBatches = -1; protected int batchNum = 0; protected int labelIndex = -1; protected int labelIndexTo = -1; protected int numPossibleLabels = -1; protected Iterator<List<Writable>> sequenceIter; protected DataSet last; protected boolean useCurrent = false; protected boolean regression = false; @Getter protected DataSetPreProcessor preProcessor; @Getter private boolean collectMetaData = false; private RecordReaderMultiDataSetIterator underlying; private boolean underlyingIsDisjoint; /** * Constructor for classification, where:<br> * (a) the label index is assumed to be the very last Writable/column, and<br> * (b) the number of classes is inferred from RecordReader.getLabels()<br> * Note that if RecordReader.getLabels() returns null, no output labels will be produced * * @param recordReader Record reader to use as the source of data * @param batchSize Minibatch size, for each call of .next() */ public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize) { this(recordReader, new SelfWritableConverter(), batchSize, -1, -1, recordReader.getLabels() == null ? -1 : recordReader.getLabels().size(), -1, false); } /** * Main constructor for classification. This will convert the input class index (at position labelIndex, with integer * values 0 to numPossibleLabels-1 inclusive) to the appropriate one-hot output/labels representation. * * @param recordReader RecordReader: provides the source of the data * @param batchSize Batch size (number of examples) for the output DataSet objects * @param labelIndex Index of the label Writable (usually an IntWritable), as obtained by recordReader.next() * @param numPossibleLabels Number of classes (possible labels) for classification */ public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndex, int numPossibleLabels) { this(recordReader, new SelfWritableConverter(), batchSize, labelIndex, labelIndex, numPossibleLabels, -1, false); } /** * Constructor for classification, where the maximum number of returned batches is limited to the specified value * * @param recordReader the recordreader to use * @param labelIndex the index/column of the label (for classification) * @param numPossibleLabels the number of possible labels for classification. Not used if regression == true * @param maxNumBatches The maximum number of batches to return between resets. Set to -1 to return all available data */ public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndex, int numPossibleLabels, int maxNumBatches) { this(recordReader, new SelfWritableConverter(), batchSize, labelIndex, labelIndex, numPossibleLabels, maxNumBatches, false); } /** * Main constructor for multi-label regression (i.e., regression with multiple outputs). Can also be used for single * output regression with labelIndexFrom == labelIndexTo * * @param recordReader RecordReader to get data from * @param labelIndexFrom Index of the first regression target * @param labelIndexTo Index of the last regression target, inclusive * @param batchSize Minibatch size * @param regression Require regression = true. Mainly included to avoid clashing with other constructors previously defined :/ */ public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndexFrom, int labelIndexTo, boolean regression) { this(recordReader, new SelfWritableConverter(), batchSize, labelIndexFrom, labelIndexTo, -1, -1, regression); if (!regression) { throw new IllegalArgumentException("This constructor is only for creating regression iterators. " + "If you're doing classification you need to use another constructor that " + "(implicitly) specifies numPossibleLabels"); } } /** * Main constructor * * @param recordReader the recordreader to use * @param converter Converter. May be null. * @param batchSize Minibatch size - number of examples returned for each call of .next() * @param labelIndexFrom the index of the label (for classification), or the first index of the labels for multi-output regression * @param labelIndexTo only used if regression == true. The last index <i>inclusive</i> of the multi-output regression * @param numPossibleLabels the number of possible labels for classification. Not used if regression == true * @param maxNumBatches Maximum number of batches to return * @param regression if true: regression. If false: classification (assume labelIndexFrom is the class it belongs to) */ public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndexFrom, int labelIndexTo, int numPossibleLabels, int maxNumBatches, boolean regression) { this.recordReader = recordReader; this.converter = converter; this.batchSize = batchSize; this.maxNumBatches = maxNumBatches; this.labelIndex = labelIndexFrom; this.labelIndexTo = labelIndexTo; this.numPossibleLabels = numPossibleLabels; this.regression = regression; } protected RecordReaderDataSetIterator(Builder b) { this.recordReader = b.recordReader; this.converter = b.converter; this.batchSize = b.batchSize; this.maxNumBatches = b.maxNumBatches; this.labelIndex = b.labelIndex; this.labelIndexTo = b.labelIndexTo; this.numPossibleLabels = b.numPossibleLabels; this.regression = b.regression; this.preProcessor = b.preProcessor; } /** * When set to true: metadata for the current examples will be present in the returned DataSet. * Disabled by default. * * @param collectMetaData Whether to collect metadata or not */ public void setCollectMetaData(boolean collectMetaData) { if (underlying != null) { underlying.setCollectMetaData(collectMetaData); } this.collectMetaData = collectMetaData; } private void initializeUnderlying() { if (underlying == null) { Record next = recordReader.nextRecord(); initializeUnderlying(next); } } private void initializeUnderlying(Record next) { int totalSize = next.getRecord().size(); //allow people to specify label index as -1 and infer the last possible label if (numPossibleLabels >= 1 && labelIndex < 0) { labelIndex = totalSize - 1; labelIndexTo = labelIndex; } if (recordReader.resetSupported()) { recordReader.reset(); } else { //Hack around the fact that we need the first record to initialize the underlying RRMDSI, but can't reset // the original reader recordReader = new ConcatenatingRecordReader( new CollectionRecordReader(Collections.singletonList(next.getRecord())), recordReader); } RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(batchSize); if (recordReader instanceof SequenceRecordReader) { builder.addSequenceReader(READER_KEY, (SequenceRecordReader) recordReader); } else { builder.addReader(READER_KEY, recordReader); } if (regression) { builder.addOutput(READER_KEY, labelIndex, labelIndexTo); } else if (numPossibleLabels >= 1) { builder.addOutputOneHot(READER_KEY, labelIndex, numPossibleLabels); } //Inputs: assume to be all of the other writables //In general: can't assume label indices are all at the start or end (event though 99% of the time they are) //If they are: easy. If not: use 2 inputs in the underlying as a workaround, and concat them if (labelIndex >= 0 && (labelIndex == 0 || labelIndexTo == totalSize - 1)) { //Labels are first or last -> one input in underlying int inputFrom; int inputTo; if (labelIndex < 0) { //No label inputFrom = 0; inputTo = totalSize - 1; } else if (labelIndex == 0) { inputFrom = labelIndexTo + 1; inputTo = totalSize - 1; } else { inputFrom = 0; inputTo = labelIndex - 1; } builder.addInput(READER_KEY, inputFrom, inputTo); underlyingIsDisjoint = false; } else if (labelIndex >= 0) { Preconditions.checkState(labelIndex < next.getRecord().size(), "Invalid label (from) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", next.getRecord().size() - 1, labelIndex); Preconditions.checkState(labelIndexTo < next.getRecord().size(), "Invalid label (to) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", next.getRecord().size() - 1, labelIndexTo); //Multiple inputs int firstFrom = 0; int firstTo = labelIndex - 1; int secondFrom = labelIndexTo + 1; int secondTo = totalSize - 1; builder.addInput(READER_KEY, firstFrom, firstTo); builder.addInput(READER_KEY, secondFrom, secondTo); underlyingIsDisjoint = true; } else { //No labels - only features builder.addInput(READER_KEY); underlyingIsDisjoint = false; } underlying = builder.build(); if (collectMetaData) { underlying.setCollectMetaData(true); } } private DataSet mdsToDataSet(MultiDataSet mds) { INDArray f; INDArray fm; if (underlyingIsDisjoint) { //Rare case: 2 input arrays -> concat INDArray f1 = getOrNull(mds.getFeatures(), 0); INDArray f2 = getOrNull(mds.getFeatures(), 1); fm = getOrNull(mds.getFeaturesMaskArrays(), 0); //Per-example masking only on the input -> same for both //Can assume 2d features here f = Nd4j.hstack(f1, f2); } else { //Standard case f = getOrNull(mds.getFeatures(), 0); fm = getOrNull(mds.getFeaturesMaskArrays(), 0); } INDArray l = getOrNull(mds.getLabels(), 0); INDArray lm = getOrNull(mds.getLabelsMaskArrays(), 0); DataSet ds = new DataSet(f, l, fm, lm); if (collectMetaData) { List<Serializable> temp = mds.getExampleMetaData(); List<Serializable> temp2 = new ArrayList<>(temp.size()); for (Serializable s : temp) { RecordMetaDataComposableMap m = (RecordMetaDataComposableMap) s; temp2.add(m.getMeta().get(READER_KEY)); } ds.setExampleMetaData(temp2); } //Edge case, for backward compatibility: //If labelIdx == -1 && numPossibleLabels == -1 -> no labels -> set labels array to features array if (labelIndex == -1 && numPossibleLabels == -1 && ds.getLabels() == null) { ds.setLabels(ds.getFeatures()); } if (preProcessor != null) { preProcessor.preProcess(ds); } return ds; } @Override public DataSet next(int num) { if (useCurrent) { useCurrent = false; if (preProcessor != null) preProcessor.preProcess(last); return last; } if (underlying == null) { initializeUnderlying(); } batchNum++; return mdsToDataSet(underlying.next(num)); } //Package private static INDArray getOrNull(INDArray[] arr, int idx) { if (arr == null || arr.length == 0) { return null; } return arr[idx]; } @Override public int inputColumns() { if (last == null) { DataSet next = next(); last = next; useCurrent = true; return next.numInputs(); } else return last.numInputs(); } @Override public int totalOutcomes() { if (last == null) { DataSet next = next(); last = next; useCurrent = true; return next.numOutcomes(); } else return last.numOutcomes(); } @Override public boolean resetSupported() { if (underlying == null) { initializeUnderlying(); } return underlying.resetSupported(); } @Override public boolean asyncSupported() { return true; } @Override public void reset() { batchNum = 0; if (underlying != null) { underlying.reset(); } last = null; useCurrent = false; } @Override public int batch() { return batchSize; } @Override public void setPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor preProcessor) { this.preProcessor = preProcessor; } @Override public boolean hasNext() { return (((sequenceIter != null && sequenceIter.hasNext()) || recordReader.hasNext()) && (maxNumBatches < 0 || batchNum < maxNumBatches)); } @Override public DataSet next() { return next(batchSize); } @Override public void remove() { throw new UnsupportedOperationException(); } @Override public List<String> getLabels() { return recordReader.getLabels(); } /** * Load a single example to a DataSet, using the provided RecordMetaData. * Note that it is more efficient to load multiple instances at once, using {@link #loadFromMetaData(List)} * * @param recordMetaData RecordMetaData to load from. Should have been produced by the given record reader * @return DataSet with the specified example * @throws IOException If an error occurs during loading of the data */ public DataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException { return loadFromMetaData(Collections.singletonList(recordMetaData)); } /** * Load a multiple examples to a DataSet, using the provided RecordMetaData instances. * * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided * to the RecordReaderDataSetIterator constructor * @return DataSet with the specified examples * @throws IOException If an error occurs during loading of the data */ public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException { if (underlying == null) { Record r = recordReader.loadFromMetaData(list.get(0)); initializeUnderlying(r); } //Convert back to composable: List<RecordMetaData> l = new ArrayList<>(list.size()); for (RecordMetaData m : list) { l.add(new RecordMetaDataComposableMap(Collections.singletonMap(READER_KEY, m))); } MultiDataSet m = underlying.loadFromMetaData(l); return mdsToDataSet(m); } /** * Builder class for RecordReaderDataSetIterator */ public static class Builder { protected RecordReader recordReader; protected WritableConverter converter; protected int batchSize; protected int maxNumBatches = -1; protected int labelIndex = -1; protected int labelIndexTo = -1; protected int numPossibleLabels = -1; protected boolean regression = false; protected DataSetPreProcessor preProcessor; private boolean collectMetaData = false; private boolean clOrRegCalled = false; /** * * @param rr Underlying record reader to source data from * @param batchSize Batch size to use */ public Builder(@NonNull RecordReader rr, int batchSize) { this.recordReader = rr; this.batchSize = batchSize; } public Builder writableConverter(WritableConverter converter) { this.converter = converter; return this; } /** * Optional argument, usually not used. If set, can be used to limit the maximum number of minibatches that * will be returned (between resets). If not set, will always return as many minibatches as there is data * available. * * @param maxNumBatches Maximum number of minibatches per epoch / reset */ public Builder maxNumBatches(int maxNumBatches) { this.maxNumBatches = maxNumBatches; return this; } /** * Use this for single output regression (i.e., 1 output/regression target) * * @param labelIndex Column index that contains the regression target (indexes start at 0) */ public Builder regression(int labelIndex) { return regression(labelIndex, labelIndex); } /** * Use this for multiple output regression (1 or more output/regression targets). Note that all regression * targets must be contiguous (i.e., positions x to y, without gaps) * * @param labelIndexFrom Column index of the first regression target (indexes start at 0) * @param labelIndexTo Column index of the last regression target (inclusive) */ public Builder regression(int labelIndexFrom, int labelIndexTo) { this.labelIndex = labelIndexFrom; this.labelIndexTo = labelIndexTo; this.regression = true; clOrRegCalled = true; return this; } /** * Use this for classification * * @param labelIndex Index that contains the label index. Column (indexes start from 0) be an integer value, * and contain values 0 to numClasses-1 * @param numClasses Number of label classes (i.e., number of categories/classes in the dataset) */ public Builder classification(int labelIndex, int numClasses) { this.labelIndex = labelIndex; this.labelIndexTo = labelIndex; this.numPossibleLabels = numClasses; this.regression = false; clOrRegCalled = true; return this; } /** * Optional arg. Allows the preprocessor to be set * @param preProcessor Preprocessor to use */ public Builder preProcessor(DataSetPreProcessor preProcessor) { this.preProcessor = preProcessor; return this; } /** * When set to true: metadata for the current examples will be present in the returned DataSet. * Disabled by default. * * @param collectMetaData Whether metadata should be collected or not */ public Builder collectMetaData(boolean collectMetaData) { this.collectMetaData = collectMetaData; return this; } public RecordReaderDataSetIterator build() { return new RecordReaderDataSetIterator(this); } } }