Example usage for org.deeplearning4j.models.word2vec VocabWord VocabWord

List of usage examples for org.deeplearning4j.models.word2vec VocabWord VocabWord

Introduction

In this page you can find the example usage for org.deeplearning4j.models.word2vec VocabWord VocabWord.

Prototype

public VocabWord(double wordFrequency, @NonNull String word) 

Source Link

Usage

From source file:de.mpii.docsimilarity.mr.utils.io.WordVectorSerializer.java

License:Apache License

/**
 * @param modelFile// w  ww .  j  a v a2  s .c  om
 * @return
 * @throws FileNotFoundException
 * @throws IOException
 * @throws NumberFormatException
 */
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    Word2Vec ret = new Word2Vec();
    try (BufferedReader reader = new BufferedReader(
            new InputStreamReader(GzipUtils.isCompressedFilename(modelFile.getName())
                    ? new GZIPInputStream(new FileInputStream(modelFile))
                    : new FileInputStream(modelFile)))) {
        String line = reader.readLine();
        String[] initial = line.split(" ");
        int words = Integer.parseInt(initial[0]);
        int layerSize = Integer.parseInt(initial[1]);
        syn0 = Nd4j.create(words, layerSize);

        cache = new InMemoryLookupCache(false);

        int currLine = 0;
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(" ");
            assert split.length == layerSize + 1;
            String word = split[0];

            float[] vector = new float[split.length - 1];
            for (int i = 1; i < split.length; i++) {
                vector[i - 1] = Float.parseFloat(split[i]);
            }

            syn0.putRow(currLine, Transforms.unitVec(Nd4j.create(vector)));

            cache.addWordToIndex(cache.numWords(), word);
            cache.addToken(new VocabWord(1, word));
            cache.putVocabWord(word);

            currLine++;
        }

        lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache)
                .vectorLength(layerSize).build();
        lookupTable.setSyn0(syn0);

        ret.setVocab(cache);
        ret.setLookupTable(lookupTable);
    }
    return ret;
}

From source file:de.mpii.docsimilarity.mr.utils.io.WordVectorSerializer.java

License:Apache License

/**
 * Read a binary word2vec file.//www  .  j  a va  2 s  . c o  m
 *
 * @param modelFile
 *            the File to read
 * @param linebreaks
 *            if true, the reader expects each word/vector to be in a separate line, terminated
 *            by a line break
 * @return a {@link Word2Vec model}
 * @throws NumberFormatException
 * @throws IOException
 * @throws FileNotFoundException
 */
private static Word2Vec readBinaryModel(File modelFile, boolean linebreaks)
        throws NumberFormatException, IOException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    int words, size;
    try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
            ? new GZIPInputStream(new FileInputStream(modelFile))
            : new FileInputStream(modelFile)); DataInputStream dis = new DataInputStream(bis)) {
        words = Integer.parseInt(readString(dis));
        size = Integer.parseInt(readString(dis));
        syn0 = Nd4j.create(words, size);
        cache = new InMemoryLookupCache(false);
        lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(size)
                .build();

        String word;
        for (int i = 0; i < words; i++) {

            word = readString(dis);
            log.trace("Loading " + word + " with word " + i);

            float[] vector = new float[size];

            for (int j = 0; j < size; j++) {
                vector[j] = readFloat(dis);
            }

            syn0.putRow(i, Transforms.unitVec(Nd4j.create(vector)));

            cache.addWordToIndex(cache.numWords(), word);
            cache.addToken(new VocabWord(1, word));
            cache.putVocabWord(word);

            if (linebreaks) {
                dis.readByte(); // line break
            }
        }
    }

    Word2Vec ret = new Word2Vec();

    lookupTable.setSyn0(syn0);
    ret.setVocab(cache);
    ret.setLookupTable(lookupTable);
    return ret;

}

From source file:de.mpii.docsimilarity.mr.utils.io.WordVectorSerializer.java

License:Apache License

/**
 * Read a binary word2vec file./*from   w  w w  .j  av a 2  s. c  o  m*/
 *
 * @param modelFile
 *            the File to read
 * @return a {@link Word2Vec model}
 * @throws NumberFormatException
 * @throws IOException
 * @throws FileNotFoundException
 */
public static Word2Vec readBinaryModel(String modelFile, FSDataInputStream modelstream,
        Set<String> requiredTerms) throws NumberFormatException, IOException {
    boolean linebreaks = DEFAULT_LINEBREAKS;
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    int words, size;
    int count = 0;
    try (BufferedInputStream bis = new BufferedInputStream(
            GzipUtils.isCompressedFilename(modelFile) ? new GZIPInputStream(modelstream) : modelstream);
            DataInputStream dis = new DataInputStream(bis)) {
        words = Integer.parseInt(readString(dis));
        size = Integer.parseInt(readString(dis));
        System.out.println("words " + words + ", size " + size);
        syn0 = Nd4j.create(words, size);
        cache = new InMemoryLookupCache(false);
        lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(size)
                .build();

        String word;

        for (int i = 0; i < words; i++) {

            word = readString(dis);

            log.trace("Loading " + word + " with word " + i);

            float[] vector = new float[size];

            for (int j = 0; j < size; j++) {
                vector[j] = readFloat(dis);
            }
            if (requiredTerms.contains(word)) {
                syn0.putRow(i, Transforms.unitVec(Nd4j.create(vector)));
                cache.addWordToIndex(cache.numWords(), word);
                cache.addToken(new VocabWord(1, word));
                cache.putVocabWord(word);
                count++;
            }

            if (linebreaks) {
                dis.readByte(); // line break
            }
        }
    }

    Word2Vec ret = new Word2Vec();

    lookupTable.setSyn0(syn0);
    ret.setVocab(cache);
    ret.setLookupTable(lookupTable);
    System.out.println("Load " + count + " terms in word2vec.");
    return ret;

}

From source file:de.mpii.docsimilarity.mr.utils.io.WordVectorSerializer.java

License:Apache License

/**
 * Loads an in memory cache from the given path (sets syn0 and the vocab)
 *
 * @param vectorsFile/* ww w .j  a va 2 s  .c o m*/
 *            the path of the file to load
 * @return
 * @throws FileNotFoundException
 */
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) throws FileNotFoundException {
    BufferedReader write = new BufferedReader(new FileReader(vectorsFile));
    VocabCache cache = new InMemoryLookupCache();

    InMemoryLookupTable lookupTable;

    LineIterator iter = IOUtils.lineIterator(write);
    List<INDArray> arrays = new ArrayList<>();
    while (iter.hasNext()) {
        String line = iter.nextLine();
        String[] split = line.split(" ");
        String word = split[0];
        VocabWord word1 = new VocabWord(1.0, word);
        cache.addToken(word1);
        cache.addWordToIndex(cache.numWords(), word);
        word1.setIndex(cache.numWords());
        cache.putVocabWord(word);
        INDArray row = Nd4j.create(Nd4j.createBuffer(split.length - 1));
        for (int i = 1; i < split.length; i++) {
            row.putScalar(i - 1, Float.parseFloat(split[i]));
        }
        arrays.add(row);
    }

    INDArray syn = Nd4j.create(new int[] { arrays.size(), arrays.get(0).columns() });
    for (int i = 0; i < syn.rows(); i++) {
        syn.putRow(i, arrays.get(i));
    }

    lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(arrays.get(0).columns())
            .useAdaGrad(false).cache(cache).build();
    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);

    iter.close();

    return new Pair<>(lookupTable, cache);
}

From source file:edu.umd.umiacs.clip.tools.scor.WordVectorUtils.java

License:Apache License

public static WordVectors loadTxt(File vectorsFile, boolean... normalize) {
    AbstractCache cache = new AbstractCache<>();
    INDArray arrays[] = lines(vectorsFile.toPath()).map(line -> line.split(" "))
            .filter(fields -> fields.length > 2).map(split -> {
                VocabWord word = new VocabWord(1.0, split[0]);
                word.setIndex(cache.numWords());
                cache.addToken(word);//from   w  ww  .jav  a2s .com
                cache.addWordToIndex(word.getIndex(), split[0]);
                float[] vector = new float[split.length - 1];
                range(1, split.length).parallel().forEach(i -> vector[i - 1] = parseFloat(split[i]));
                return Nd4j.create(vector);
            }).toArray(size -> new INDArray[size]);

    INDArray syn = Nd4j.vstack(arrays);

    InMemoryLookupTable lookupTable = new InMemoryLookupTable.Builder().vectorLength(arrays[0].columns())
            .useAdaGrad(false).cache(cache).useHierarchicSoftmax(false).build();
    Nd4j.clearNans(syn);
    if (normalize.length > 0 && normalize[0]) {
        syn.diviColumnVector(syn.norm2(1));
    }

    lookupTable.setSyn0(syn);

    WordVectorsImpl vectors = new WordVectorsImpl();
    vectors.setLookupTable(lookupTable);
    vectors.setVocab(cache);
    return vectors;
}