Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.rbm.network; import java.io.IOException; import java.util.ArrayList; import java.util.List; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.mahout.classifier.rbm.layer.Layer; import org.apache.mahout.classifier.rbm.layer.SoftmaxLayer; import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM; import org.apache.mahout.classifier.rbm.model.RBMModel; import org.apache.mahout.classifier.rbm.model.SimpleRBM; import org.apache.mahout.common.ClassUtils; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixWritable; import com.google.common.io.Closeables; /** * A DeepBoltzmannMachine is a (deep belief) neural network consisting of a stack of restricted boltzmann machines. */ public class DeepBoltzmannMachine implements DeepBeliefNetwork, Cloneable { /** The restricted boltzmann machines where nr 0 is lowest. */ private List<RBMModel> rbms; /** * Instantiates a new deep boltzmann machine. * * @param lowestRBM the lowest rbm */ public DeepBoltzmannMachine(RBMModel lowestRBM) { rbms = new ArrayList<RBMModel>(); rbms.add(lowestRBM); } /** * Put a new RBM on the stack. * * @param rbm the RBM * @return true, if successful */ public boolean stackRBM(RBMModel rbm) { if (rbm.getVisibleLayer().equals(rbms.get(rbms.size() - 1).getHiddenLayer())) { rbms.add(rbm); return true; } else return false; } /** * Serialize to the output. * * @param output the output * @param conf the conf * @throws IOException Signals that an I/O exception has occurred. */ public void serialize(Path output, Configuration conf) throws IOException { FileSystem fs = output.getFileSystem(conf); FSDataOutputStream out = fs.create(output, true); try { new IntWritable(rbms.size()).write(out); for (int i = 0; i < rbms.size(); i++) { if (i == 0) out.writeChars(rbms.get(i).getVisibleLayer().getClass().getName() + " "); out.writeChars(rbms.get(i).getHiddenLayer().getClass().getName() + " "); if (i < rbms.size() - 1) MatrixWritable.writeMatrix(out, ((SimpleRBM) rbms.get(i)).getWeightMatrix()); else { MatrixWritable.writeMatrix(out, ((LabeledSimpleRBM) rbms.get(i)).getWeightMatrix()); MatrixWritable.writeMatrix(out, ((LabeledSimpleRBM) rbms.get(i)).getWeightLabelMatrix()); } } } finally { Closeables.closeQuietly(out); } } /** * Materialize from input path. * * @param input the input path * @param conf the hadoop config * @return the deep boltzmann machine * @throws IOException Signals that an I/O exception has occurred. */ public static DeepBoltzmannMachine materialize(Path input, Configuration conf) throws IOException { FileSystem fs = input.getFileSystem(conf); String visLayerType = ""; String hidLayerType = ""; FSDataInputStream in = fs.open(input); DeepBoltzmannMachine dbm = null; try { int rbmSize = in.readInt(); for (int i = 0; i < rbmSize; i++) { RBMModel rbm = null; hidLayerType = ""; visLayerType = ""; char chr; if (i == 0) while ((chr = in.readChar()) != ' ') visLayerType += chr; while ((chr = in.readChar()) != ' ') hidLayerType += chr; Matrix weightMatrix = MatrixWritable.readMatrix(in); Layer vl; if (i == 0) vl = ClassUtils.instantiateAs(visLayerType, Layer.class, new Class[] { int.class }, new Object[] { weightMatrix.rowSize() }); else vl = dbm.rbms.get(dbm.getRbmCount() - 1).getHiddenLayer(); Layer hl = ClassUtils.instantiateAs(hidLayerType, Layer.class, new Class[] { int.class }, new Object[] { weightMatrix.columnSize() }); if (i < rbmSize - 1) { rbm = new SimpleRBM(vl, hl); ((SimpleRBM) rbm).setWeightMatrix(weightMatrix); } else { Matrix weightLabelMatrix = MatrixWritable.readMatrix(in); rbm = new LabeledSimpleRBM(vl, hl, new SoftmaxLayer(weightLabelMatrix.rowSize())); ((LabeledSimpleRBM) rbm).setWeightMatrix(weightMatrix); ((LabeledSimpleRBM) rbm).setWeightLabelMatrix(weightLabelMatrix); } if (i == 0) dbm = new DeepBoltzmannMachine(rbm); else dbm.stackRBM(rbm); } } finally { Closeables.closeQuietly(in); } return dbm; } /** * Get the i-th RBM. * * @param i the i * @return the rBM */ public RBMModel getRBM(Integer i) { if (i <= rbms.size()) return rbms.get(i); else return null; } /** * Gets the size of the rbm stack. * * @return the stacksize of rbms */ public int getRbmCount() { return rbms.size(); } /** * Gets the layer count. * * @return the layer count */ public int getLayerCount() { return rbms.size() + 1; } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#exciteLayer(int) */ @Override public void exciteLayer(int l) { boolean addInput = (l < getRbmCount()); if (addInput) { RBMModel upperRbm = getRBM(l); upperRbm.exciteVisibleLayer(1, false); } if (l > 0) { RBMModel lowerRbm = getRBM(l - 1); lowerRbm.exciteHiddenLayer(1, addInput); } } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#getLayer(int) */ @Override public Layer getLayer(int l) { if (l < getRbmCount()) return getRBM(l).getVisibleLayer(); return getRBM(l - 1).getHiddenLayer(); } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#upPass() */ @Override public void upPass() { for (int i = 0; i < getRbmCount(); i++) { RBMModel rbm = rbms.get(i); rbm.exciteHiddenLayer((i < getRbmCount() - 1) ? 2 : 1, false); rbm.updateHiddenLayer(); } } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#updateLayer(int) */ @Override public void updateLayer(int l) { if (l < getRbmCount()) { RBMModel rbm = getRBM(l); rbm.updateVisibleLayer(); } else getRBM(l - 1).updateHiddenLayer(); } /* (non-Javadoc) * @see java.lang.Object#clone() */ public DeepBoltzmannMachine clone() { DeepBoltzmannMachine dbm = new DeepBoltzmannMachine(rbms.get(0).clone()); for (int i = 1; i < rbms.size(); i++) { RBMModel clonedRbm = getRBM(i).clone(); clonedRbm.setVisibleLayer(dbm.getRBM(i - 1).getHiddenLayer()); dbm.stackRBM(clonedRbm); } return dbm; } }