de.tudarmstadt.ukp.dkpro.core.mallet.internal.wordembeddings.MalletEmbeddingsUtils.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.core.mallet.internal.wordembeddings.MalletEmbeddingsUtils.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package de.tudarmstadt.ukp.dkpro.core.mallet.internal.wordembeddings;

import de.tudarmstadt.ukp.dkpro.core.api.resources.CompressionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.io.*;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * Helper Methods for reading word embeddings.
 */
public class MalletEmbeddingsUtils {
    private static final Log LOG = LogFactory.getLog(MalletEmbeddingsUtils.class);

    /**
     * Read an embeddings file in text format.
     * <p>
     * If hasHeader is set to true, the first line is expected to contain the size and dimensionality
     * of the vectors. This is typically true for files generated by Word2Vec (in text format).
     *
     * @param file      the input file
     * @param hasHeader if true, read size and dimensionality from the first line
     * @return a {@code Map<String, double[]>} mapping each token to a vector.
     * @throws IOException if the input file cannot be read
     * @see #readEmbeddingFileTxt(InputStream, boolean)
     */
    public static Map<String, double[]> readEmbeddingFileTxt(File file, boolean hasHeader) throws IOException {
        LOG.info("Reading embeddings from file " + file);
        InputStream is = CompressionUtils.getInputStream(file.getAbsolutePath(), new FileInputStream(file));

        return readEmbeddingFileTxt(is, hasHeader);
    }

    /**
     * Read embeddings in text format from an InputStream.
     * Each line is expected to have a whitespace-separated list {@code <token> <value1> <value2> ...}.
     *
     * @param inputStream an {@link InputStream}
     * @param hasHeader   if true, read size and dimensionality from the first line
     * @return a {@code Map<String, double[]>} mapping each token to a vector.
     * @throws IOException if the input file cannot be read
     */
    public static Map<String, double[]> readEmbeddingFileTxt(InputStream inputStream, boolean hasHeader)
            throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));

        final int dimensions;
        final int size;

        if (hasHeader) {
            String[] header = reader.readLine().split(" ");
            assert header.length == 2;
            size = Integer.parseInt(header[0]);
            dimensions = Integer.parseInt(header[1]);
        } else {
            dimensions = -1;
            size = -1;
        }

        Map<String, double[]> embeddings = reader.lines().map(MalletEmbeddingsUtils::lineToEmbedding)
                .collect(Collectors.toMap(Pair::getKey, Pair::getValue));
        reader.close();

        /* assert size and dimension */
        if (hasHeader) {
            /* check that size read matches header information */
            LOG.debug("Checking number and vector sizes for all embeddings.");
            assert size == embeddings.size();
            assert embeddings.values().stream().allMatch(vector -> dimensions == vector.length);
        } else {
            LOG.debug("Checking vector sizes for all embeddings.");
            int firstLength = embeddings.values().stream().findAny().get().length;
            assert embeddings.values().stream().allMatch(vector -> firstLength == vector.length);
        }

        return embeddings;
    }

    /**
     * Convert a single line in the expected format ({@code <token> <value1> ... <valueN>} int a pair
     * holding the token and the corresponding vector.
     *
     * @param line a line
     * @return a {@link Pair}
     */
    private static Pair<String, double[]> lineToEmbedding(String line) {
        String[] array = line.split(" ");
        int size = array.length;
        double[] vector = Arrays.stream(array, 1, size).mapToDouble(Double::parseDouble).toArray();
        return Pair.of(array[0], vector);
    }
}