Example usage for org.deeplearning4j.util ModelSerializer restoreComputationGraph

List of usage examples for org.deeplearning4j.util ModelSerializer restoreComputationGraph

Introduction

In this page you can find the example usage for org.deeplearning4j.util ModelSerializer restoreComputationGraph.

Prototype

public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException 

Source Link

Document

Load a computation graph from a file

Usage

From source file:org.apache.tika.dl.imagerec.DL4JVGG16Net.java

License:Apache License

@Override
public void initialize(Map<String, Param> params) throws TikaConfigException {
    try {/*from  w w w . j a  v  a 2 s  .  c  o m*/
        if (serialize) {
            if (cacheDir.exists()) {
                model = ModelSerializer.restoreComputationGraph(cacheDir);
                LOG.info("Preprocessed Model Loaded from {}", cacheDir);
            } else {
                LOG.warn("Preprocessed Model doesn't exist at {}", cacheDir);
                cacheDir.getParentFile().mkdirs();
                ZooModel zooModel = VGG16.builder().build();
                model = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
                LOG.info(
                        "Saving the Loaded model for future use. Saved models are more optimised to consume less resources.");
                ModelSerializer.writeModel(model, cacheDir, true);
            }
        } else {
            LOG.info("Weight graph model loaded via dl4j Helper functions");
            ZooModel zooModel = VGG16.builder().build();
            model = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
        }
        imageNetLabels = new ImageNetLabels();
        available = true;
    } catch (Exception e) {
        available = false;
        LOG.warn(e.getMessage(), e);
        throw new TikaConfigException(e.getMessage(), e);
    }
}

From source file:org.apache.tika.parser.recognition.dl4j.DL4JImageRecogniser.java

License:Apache License

@Override
public void initialize(Map<String, Param> params) throws TikaConfigException {
    try {// ww w.j a v  a  2s  . com
        TrainedModelHelper helper;
        switch (modelType) {
        case "VGG16NOTOP":
            throw new TikaConfigException("VGG16NOTOP is not supported right now");
            /*# TODO hookup VGGNOTOP by uncommenting following code once the issue is resolved by dl4j team
            modelFile = new File(MODEL_DIR_PREPROCESSED+File.separator+"vgg16_notop.zip");
            locationToSave= new File(MODEL_DIR+File.separator+"tikaPreprocessed"+File.separator+"vgg16.zip");
            helper = new TrainedModelHelper(TrainedModels.VGG16NOTOP);
            break;*/
        case "VGG16":
            helper = new TrainedModelHelper(TrainedModels.VGG16);
            modelFile = new File(MODEL_DIR_PREPROCESSED + File.separator + "vgg16.zip");
            locationToSave = new File(
                    MODEL_DIR + File.separator + "tikaPreprocessed" + File.separator + "vgg16.zip");
            break;
        default:
            throw new TikaConfigException("Unknown or unsupported model");
        }
        if (serialize.trim().toLowerCase(Locale.ROOT).equals("yes")) {
            if (!modelFile.exists()) {
                LOG.warn("Preprocessed Model doesn't exist at {}", modelFile);
                modelFile.getParentFile().mkdirs();
                model = helper.loadModel();
                LOG.info(
                        "Saving the Loaded model for future use. Saved models are more optimised to consume less resources.");
                ModelSerializer.writeModel(model, locationToSave, true);
                available = true;
            } else {
                model = ModelSerializer.restoreComputationGraph(locationToSave);
                LOG.info("Preprocessed Model Loaded from {}", locationToSave);
                available = true;
            }

        } else if (serialize.trim().toLowerCase(Locale.ROOT).equals("no")) {
            LOG.info("Weight graph model loaded via dl4j Helper functions");
            model = helper.loadModel();
            available = true;
        } else {
            throw new TikaConfigException("Configuration Error. serialization can be either yes or no.");
        }

        if (!available) {
            return;
        }
        HashMap<Pattern, String> patterns = new HashMap<>();
        patterns.put(Pattern.compile(outPattern), null);
        setMetadataExtractionPatterns(patterns);
        setIgnoredLineConsumer(IGNORED_LINE_LOGGER);
    } catch (Exception e) {
        LOG.warn("exception occured");
        throw new TikaConfigException(e.getMessage(), e);
    }
}