Example usage for org.deeplearning4j.nn.modelimport.keras KerasModel getComputationGraph

List of usage examples for org.deeplearning4j.nn.modelimport.keras KerasModel getComputationGraph

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.modelimport.keras KerasModel getComputationGraph.

Prototype

public ComputationGraph getComputationGraph()
        throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException 

Source Link

Document

Build a ComputationGraph from this Keras Model configuration and import weights.

Usage

From source file:org.apache.tika.dl.imagerec.DL4JInceptionV3Net.java

License:Apache License

@Override
public void initialize(Map<String, Param> params) throws TikaConfigException {

    //STEP 1: resolve weights file, download if necessary
    modelWeightsPath = mayBeDownloadFile(modelWeightsPath);

    //STEP 2: Load labels map
    try (InputStream stream = retrieveResource(mayBeDownloadFile(labelFile))) {
        this.labelMap = loadClassIndex(stream);
    } catch (IOException | ParseException e) {
        LOG.error("Could not load labels map", e);
        return;/*from ww w.  ja va  2  s.co m*/
    }

    //STEP 3: initialize the graph
    try {
        this.imageLoader = new NativeImageLoader(imgHeight, imgWidth, imgChannels);
        LOG.info("Going to load Inception network...");
        long st = System.currentTimeMillis();

        KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelWeightsPath)
                .enforceTrainingConfig(false);
        builder.inputShape(new int[] { imgHeight, imgWidth, 3 });
        KerasModel model = builder.buildModel();
        this.graph = model.getComputationGraph();

        long time = System.currentTimeMillis() - st;
        LOG.info("Loaded the Inception model. Time taken={}ms", time);
    } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
        throw new TikaConfigException(e.getMessage(), e);
    }
}