opennlp.tools.util.model.BaseModel.java Source code

Java tutorial

Introduction

Here is the source code for opennlp.tools.util.model.BaseModel.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.tools.util.model;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.net.URL;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Properties;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

import opennlp.tools.util.BaseToolFactory;
import opennlp.tools.util.InvalidFormatException;
import opennlp.tools.util.Version;
import opennlp.tools.util.ext.ExtensionLoader;

/**
 * This model is a common based which can be used by the components
 * model classes.
 *
 * TODO:
 * Provide sub classes access to serializers already in constructor
 */
public abstract class BaseModel implements ArtifactProvider, Serializable {

    protected static final String MANIFEST_ENTRY = "manifest.properties";
    protected static final String FACTORY_NAME = "factory";

    private static final String MANIFEST_VERSION_PROPERTY = "Manifest-Version";
    private static final String COMPONENT_NAME_PROPERTY = "Component-Name";
    private static final String VERSION_PROPERTY = "OpenNLP-Version";
    private static final String TIMESTAMP_PROPERTY = "Timestamp";
    private static final String LANGUAGE_PROPERTY = "Language";

    public static final String TRAINING_CUTOFF_PROPERTY = "Training-Cutoff";
    public static final String TRAINING_ITERATIONS_PROPERTY = "Training-Iterations";
    public static final String TRAINING_EVENTHASH_PROPERTY = "Training-Eventhash";

    private static String SERIALIZER_CLASS_NAME_PREFIX = "serializer-class-";

    private Map<String, ArtifactSerializer> artifactSerializers = new HashMap<>();

    protected Map<String, Object> artifactMap = new HashMap<>();

    protected BaseToolFactory toolFactory;

    private String componentName;

    private boolean subclassSerializersInitiated = false;
    private boolean finishedLoadingArtifacts = false;

    private boolean isLoadedFromSerialized;

    private BaseModel(String componentName, boolean isLoadedFromSerialized) {
        this.isLoadedFromSerialized = isLoadedFromSerialized;

        this.componentName = Objects.requireNonNull(componentName, "componentName must not be null!");
    }

    /**
     * Initializes the current instance. The sub-class constructor should call the
     * method {@link #checkArtifactMap()} to check the artifact map is OK.
     * <p>
     * Sub-classes will have access to custom artifacts and serializers provided
     * by the factory.
     *
     * @param componentName
     *          the component name
     * @param languageCode
     *          the language code
     * @param manifestInfoEntries
     *          additional information in the manifest
     * @param factory
     *          the factory
     */
    protected BaseModel(String componentName, String languageCode, Map<String, String> manifestInfoEntries,
            BaseToolFactory factory) {

        this(componentName, false);

        Objects.requireNonNull(languageCode, "languageCode must not be null");

        createBaseArtifactSerializers(artifactSerializers);

        Properties manifest = new Properties();
        manifest.setProperty(MANIFEST_VERSION_PROPERTY, "1.0");
        manifest.setProperty(LANGUAGE_PROPERTY, languageCode);
        manifest.setProperty(VERSION_PROPERTY, Version.currentVersion().toString());
        manifest.setProperty(TIMESTAMP_PROPERTY, Long.toString(System.currentTimeMillis()));
        manifest.setProperty(COMPONENT_NAME_PROPERTY, componentName);

        if (manifestInfoEntries != null) {
            for (Map.Entry<String, String> entry : manifestInfoEntries.entrySet()) {
                manifest.setProperty(entry.getKey(), entry.getValue());
            }
        }

        artifactMap.put(MANIFEST_ENTRY, manifest);
        finishedLoadingArtifacts = true;

        if (factory != null) {
            setManifestProperty(FACTORY_NAME, factory.getClass().getCanonicalName());
            artifactMap.putAll(factory.createArtifactMap());

            // new manifest entries
            Map<String, String> entries = factory.createManifestEntries();
            for (Entry<String, String> entry : entries.entrySet()) {
                setManifestProperty(entry.getKey(), entry.getValue());
            }
        }

        try {
            initializeFactory();
        } catch (InvalidFormatException e) {
            throw new IllegalArgumentException("Could not initialize tool factory. ", e);
        }
        loadArtifactSerializers();
    }

    /**
     * Initializes the current instance. The sub-class constructor should call the
     * method {@link #checkArtifactMap()} to check the artifact map is OK.
     *
     * @param componentName
     *          the component name
     * @param languageCode
     *          the language code
     * @param manifestInfoEntries
     *          additional information in the manifest
     */
    protected BaseModel(String componentName, String languageCode, Map<String, String> manifestInfoEntries) {
        this(componentName, languageCode, manifestInfoEntries, null);
    }

    /**
     * Initializes the current instance.
     *
     * @param componentName the component name
     * @param in the input stream containing the model
     *
     * @throws IOException
     */
    protected BaseModel(String componentName, InputStream in) throws IOException {
        this(componentName, true);

        loadModel(in);
    }

    protected BaseModel(String componentName, File modelFile) throws IOException {
        this(componentName, true);

        try (InputStream in = new BufferedInputStream(new FileInputStream(modelFile))) {
            loadModel(in);
        }
    }

    protected BaseModel(String componentName, URL modelURL) throws IOException {
        this(componentName, true);

        try (InputStream in = new BufferedInputStream(modelURL.openStream())) {
            loadModel(in);
        }
    }

    private void loadModel(InputStream in) throws IOException {

        Objects.requireNonNull(in, "in must not be null");

        createBaseArtifactSerializers(artifactSerializers);

        if (!in.markSupported()) {
            in = new BufferedInputStream(in);
        }

        // TODO: Discuss this solution, the buffering should
        int MODEL_BUFFER_SIZE_LIMIT = Integer.MAX_VALUE;
        in.mark(MODEL_BUFFER_SIZE_LIMIT);

        final ZipInputStream zip = new ZipInputStream(in);

        // The model package can contain artifacts which are serialized with 3rd party
        // serializers which are configured in the manifest file. To be able to load
        // the model the manifest must be read first, and afterwards all the artifacts
        // can be de-serialized.

        // The ordering of artifacts in a zip package is not guaranteed. The stream is first
        // read until the manifest appears, reseted, and read again to load all artifacts.

        boolean isSearchingForManifest = true;

        ZipEntry entry;
        while ((entry = zip.getNextEntry()) != null && isSearchingForManifest) {

            if ("manifest.properties".equals(entry.getName())) {
                // TODO: Probably better to use the serializer here directly!
                ArtifactSerializer factory = artifactSerializers.get("properties");
                artifactMap.put(entry.getName(), factory.create(zip));
                isSearchingForManifest = false;
            }

            zip.closeEntry();
        }

        initializeFactory();

        loadArtifactSerializers();

        // The Input Stream should always be reset-able because if markSupport returns
        // false it is wrapped before hand into an Buffered InputStream
        in.reset();

        finishLoadingArtifacts(in);

        checkArtifactMap();
    }

    private void initializeFactory() throws InvalidFormatException {
        String factoryName = getManifestProperty(FACTORY_NAME);
        if (factoryName == null) {
            // load the default factory
            Class<? extends BaseToolFactory> factoryClass = getDefaultFactory();
            if (factoryClass != null) {
                this.toolFactory = BaseToolFactory.create(factoryClass, this);
            }
        } else {
            try {
                this.toolFactory = BaseToolFactory.create(factoryName, this);
            } catch (InvalidFormatException e) {
                throw new IllegalArgumentException(e);
            }
        }
    }

    /**
     * Sub-classes should override this method if their module has a default
     * BaseToolFactory sub-class.
     *
     * @return the default {@link BaseToolFactory} for the module, or null if none.
     */
    protected Class<? extends BaseToolFactory> getDefaultFactory() {
        return null;
    }

    /**
     * Loads the artifact serializers.
     */
    private void loadArtifactSerializers() {
        if (!subclassSerializersInitiated)
            createArtifactSerializers(artifactSerializers);
        subclassSerializersInitiated = true;
    }

    /**
     * Finish loading the artifacts now that it knows all serializers.
     */
    private void finishLoadingArtifacts(InputStream in) throws IOException {

        final ZipInputStream zip = new ZipInputStream(in);

        Map<String, Object> artifactMap = new HashMap<>();

        ZipEntry entry;
        while ((entry = zip.getNextEntry()) != null) {

            // Note: The manifest.properties file will be read here again,
            // there should be no need to prevent that.

            String entryName = entry.getName();
            String extension = getEntryExtension(entryName);

            ArtifactSerializer factory = artifactSerializers.get(extension);

            String artifactSerializerClazzName = getManifestProperty(SERIALIZER_CLASS_NAME_PREFIX + entryName);

            if (artifactSerializerClazzName != null) {
                factory = ExtensionLoader.instantiateExtension(ArtifactSerializer.class,
                        artifactSerializerClazzName);
            }

            if (factory != null) {
                artifactMap.put(entryName, factory.create(zip));
            } else {
                throw new InvalidFormatException("Unknown artifact format: " + extension);
            }

            zip.closeEntry();
        }

        this.artifactMap.putAll(artifactMap);

        finishedLoadingArtifacts = true;
    }

    /**
     * Extracts the "." extension from an entry name.
     *
     * @param entry the entry name which contains the extension
     *
     * @return the extension
     *
     * @throws InvalidFormatException if no extension can be extracted
     */
    private String getEntryExtension(String entry) throws InvalidFormatException {
        int extensionIndex = entry.lastIndexOf('.') + 1;

        if (extensionIndex == -1 || extensionIndex >= entry.length())
            throw new InvalidFormatException("Entry name must have type extension: " + entry);

        return entry.substring(extensionIndex);
    }

    protected ArtifactSerializer getArtifactSerializer(String resourceName) {
        try {
            return artifactSerializers.get(getEntryExtension(resourceName));
        } catch (InvalidFormatException e) {
            throw new IllegalStateException(e);
        }
    }

    protected static Map<String, ArtifactSerializer> createArtifactSerializers() {
        Map<String, ArtifactSerializer> serializers = new HashMap<>();

        GenericModelSerializer.register(serializers);
        PropertiesSerializer.register(serializers);
        DictionarySerializer.register(serializers);
        serializers.put("txt", new ByteArraySerializer());
        serializers.put("html", new ByteArraySerializer());

        return serializers;
    }

    /**
     * Registers all {@link ArtifactSerializer} for their artifact file name extensions.
     * The registered {@link ArtifactSerializer} are used to create and serialize
     * resources in the model package.
     *
     * Override this method to register custom {@link ArtifactSerializer}s.
     *
     * Note:
     * Subclasses should generally invoke super.createArtifactSerializers at the beginning
     * of this method.
     *
     * This method is called during construction.
     *
     * @param serializers the key of the map is the file extension used to lookup
     *     the {@link ArtifactSerializer}.
     */
    protected void createArtifactSerializers(Map<String, ArtifactSerializer> serializers) {
        if (this.toolFactory != null)
            serializers.putAll(this.toolFactory.createArtifactSerializersMap());
    }

    private void createBaseArtifactSerializers(Map<String, ArtifactSerializer> serializers) {
        serializers.putAll(createArtifactSerializers());
    }

    /**
     * Validates the parsed artifacts. If something is not
     * valid subclasses should throw an {@link InvalidFormatException}.
     *
     * Note:
     * Subclasses should generally invoke super.validateArtifactMap at the beginning
     * of this method.
     *
     * @throws InvalidFormatException
     */
    protected void validateArtifactMap() throws InvalidFormatException {
        if (!(artifactMap.get(MANIFEST_ENTRY) instanceof Properties))
            throw new InvalidFormatException("Missing the " + MANIFEST_ENTRY + "!");

        // First check version, everything else might change in the future
        String versionString = getManifestProperty(VERSION_PROPERTY);

        if (versionString != null) {
            Version version;

            try {
                version = Version.parse(versionString);
            } catch (NumberFormatException e) {
                throw new InvalidFormatException("Unable to parse model version '" + versionString + "'!", e);
            }

            // Version check is only performed if current version is not the dev/debug version
            if (!Version.currentVersion().equals(Version.DEV_VERSION)) {
                // Major and minor version must match, revision might be
                // this check allows for the use of models of n minor release behind current minor release
                if (Version.currentVersion().getMajor() != version.getMajor()
                        || Version.currentVersion().getMinor() - 4 > version.getMinor()) {
                    throw new InvalidFormatException("Model version " + version + " is not supported by this ("
                            + Version.currentVersion() + ") version of OpenNLP!");
                }

                // Reject loading a snapshot model with a non-snapshot version
                if (!Version.currentVersion().isSnapshot() && version.isSnapshot()) {
                    throw new InvalidFormatException("Model version " + version
                            + " is a snapshot - snapshot models are not supported by this non-snapshot version ("
                            + Version.currentVersion() + ") of OpenNLP!");
                }
            }
        } else {
            throw new InvalidFormatException(
                    "Missing " + VERSION_PROPERTY + " property in " + MANIFEST_ENTRY + "!");
        }

        if (getManifestProperty(COMPONENT_NAME_PROPERTY) == null)
            throw new InvalidFormatException(
                    "Missing " + COMPONENT_NAME_PROPERTY + " property in " + MANIFEST_ENTRY + "!");

        if (!getManifestProperty(COMPONENT_NAME_PROPERTY).equals(componentName))
            throw new InvalidFormatException("The " + componentName + " cannot load a model for the "
                    + getManifestProperty(COMPONENT_NAME_PROPERTY) + "!");

        if (getManifestProperty(LANGUAGE_PROPERTY) == null)
            throw new InvalidFormatException(
                    "Missing " + LANGUAGE_PROPERTY + " property in " + MANIFEST_ENTRY + "!");

        // Validate the factory. We try to load it using the ExtensionLoader. It
        // will return the factory, null or raise an exception
        String factoryName = getManifestProperty(FACTORY_NAME);
        if (factoryName != null) {
            try {
                if (ExtensionLoader.instantiateExtension(BaseToolFactory.class, factoryName) == null) {
                    throw new InvalidFormatException(
                            "Could not load an user extension specified by the model: " + factoryName);
                }
            } catch (Exception e) {
                throw new InvalidFormatException(
                        "Could not load an user extension specified by the model: " + factoryName, e);
            }
        }

        // validate artifacts declared by the factory
        if (toolFactory != null) {
            toolFactory.validateArtifactMap();
        }
    }

    /**
     * Checks the artifact map.
     * <p>
     * A subclass should call this method from a constructor which accepts the individual
     * artifact map items, to validate that these items form a valid model.
     * <p>
     * If the artifacts are not valid an IllegalArgumentException will be thrown.
     */
    protected void checkArtifactMap() {
        if (!finishedLoadingArtifacts)
            throw new IllegalStateException(
                    "The method BaseModel.finishLoadingArtifacts(..) was not called by BaseModel sub-class.");
        try {
            validateArtifactMap();
        } catch (InvalidFormatException e) {
            throw new IllegalArgumentException(e);
        }
    }

    /**
     * Retrieves the value to the given key from the manifest.properties
     * entry.
     *
     * @param key
     *
     * @return the value
     */
    public final String getManifestProperty(String key) {
        Properties manifest = (Properties) artifactMap.get(MANIFEST_ENTRY);

        return manifest.getProperty(key);
    }

    /**
     * Sets a given value for a given key to the manifest.properties entry.
     *
     * @param key
     * @param value
     */
    protected final void setManifestProperty(String key, String value) {
        Properties manifest = (Properties) artifactMap.get(MANIFEST_ENTRY);

        manifest.setProperty(key, value);
    }

    /**
     * Retrieves the language code of the material which
     * was used to train the model or x-unspecified if
     * non was set.
     *
     * @return the language code of this model
     */
    public final String getLanguage() {
        return getManifestProperty(LANGUAGE_PROPERTY);
    }

    /**
     * Retrieves the OpenNLP version which was used
     * to create the model.
     *
     * @return the version
     */
    public final Version getVersion() {
        String version = getManifestProperty(VERSION_PROPERTY);

        return Version.parse(version);
    }

    /**
     * Serializes the model to the given {@link OutputStream}.
     *
     * @param out stream to write the model to
     * @throws IOException
     */
    @SuppressWarnings("unchecked")
    public final void serialize(OutputStream out) throws IOException {
        if (!subclassSerializersInitiated) {
            throw new IllegalStateException(
                    "The method BaseModel.loadArtifactSerializers() was not called by BaseModel subclass constructor.");
        }

        for (Entry<String, Object> entry : artifactMap.entrySet()) {
            final String name = entry.getKey();
            final Object artifact = entry.getValue();
            if (artifact instanceof SerializableArtifact) {

                SerializableArtifact serializableArtifact = (SerializableArtifact) artifact;

                String artifactSerializerName = serializableArtifact.getArtifactSerializerClass().getName();

                setManifestProperty(SERIALIZER_CLASS_NAME_PREFIX + name, artifactSerializerName);
            }
        }

        ZipOutputStream zip = new ZipOutputStream(out);

        for (Entry<String, Object> entry : artifactMap.entrySet()) {
            String name = entry.getKey();
            zip.putNextEntry(new ZipEntry(name));

            Object artifact = entry.getValue();

            ArtifactSerializer serializer = getArtifactSerializer(name);

            // If model is serialize-able always use the provided serializer
            if (artifact instanceof SerializableArtifact) {

                SerializableArtifact serializableArtifact = (SerializableArtifact) artifact;

                String artifactSerializerName = serializableArtifact.getArtifactSerializerClass().getName();

                serializer = ExtensionLoader.instantiateExtension(ArtifactSerializer.class, artifactSerializerName);
            }

            if (serializer == null) {
                throw new IllegalStateException("Missing serializer for " + name);
            }

            serializer.serialize(artifactMap.get(name), zip);

            zip.closeEntry();
        }

        zip.finish();
        zip.flush();
    }

    public final void serialize(File model) throws IOException {
        try (OutputStream out = new BufferedOutputStream(new FileOutputStream(model))) {
            serialize(out);
        }
    }

    public final void serialize(Path model) throws IOException {
        serialize(model.toFile());
    }

    @SuppressWarnings("unchecked")
    public <T> T getArtifact(String key) {
        Object artifact = artifactMap.get(key);
        if (artifact == null)
            return null;
        return (T) artifact;
    }

    public boolean isLoadedFromSerialized() {
        return isLoadedFromSerialized;
    }

    // These methods are required to serialize/deserialize the model because
    // many of the included objects in this model are not Serializable.
    // An alternative to this solution is to make all included objects
    // Serializable and remove the writeObject and readObject methods.
    // This will allow the usage of final for fields that should not change.

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeUTF(componentName);
        this.serialize(out);
    }

    private void readObject(final ObjectInputStream in) throws IOException {

        isLoadedFromSerialized = true;
        artifactSerializers = new HashMap<>();
        artifactMap = new HashMap<>();

        componentName = in.readUTF();

        this.loadModel(in);
    }
}