org.deeplearning4j.models.word2vec.WordVectorSerializerTest.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.models.word2vec.WordVectorSerializerTest.java

Source

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.models.word2vec;

import com.google.common.primitives.Doubles;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.ArrayUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;

/**
 * @author jeffreytang
 */
public class WordVectorSerializerTest {

    private File textFile, binaryFile, textFile2;
    String pathToWriteto;

    private Logger logger = LoggerFactory.getLogger(WordVectorSerializerTest.class);

    @Before
    public void before() throws Exception {
        if (textFile == null) {
            textFile = new ClassPathResource("word2vecserialization/google_news_30.txt").getFile();
        }
        if (binaryFile == null) {
            binaryFile = new ClassPathResource("word2vecserialization/google_news_30.bin.gz").getFile();
        }
        pathToWriteto = new ClassPathResource("word2vecserialization/testing_word2vec_serialization.txt").getFile()
                .getAbsolutePath();
        FileUtils.deleteDirectory(new File("word2vec-index"));
    }

    @Test
    @Ignore
    public void testLoaderTextSmall() throws Exception {
        INDArray vec = Nd4j.create(new double[] { 0.002001, 0.002210, -0.001915, -0.001639, 0.000683, 0.001511,
                0.000470, 0.000106, -0.001802, 0.001109, -0.002178, 0.000625, -0.000376, -0.000479, -0.001658,
                -0.000941, 0.001290, 0.001513, 0.001485, 0.000799, 0.000772, -0.001901, -0.002048, 0.002485,
                0.001901, 0.001545, -0.000302, 0.002008, -0.000247, 0.000367, -0.000075, -0.001492, 0.000656,
                -0.000669, -0.001913, 0.002377, 0.002190, -0.000548, -0.000113, 0.000255, -0.001819, -0.002004,
                0.002277, 0.000032, -0.001291, -0.001521, -0.001538, 0.000848, 0.000101, 0.000666, -0.002107,
                -0.001904, -0.000065, 0.000572, 0.001275, -0.001585, 0.002040, 0.000463, 0.000560, -0.000304,
                0.001493, -0.001144, -0.001049, 0.001079, -0.000377, 0.000515, 0.000902, -0.002044, -0.000992,
                0.001457, 0.002116, 0.001966, -0.001523, -0.001054, -0.000455, 0.001001, -0.001894, 0.001499,
                0.001394, -0.000799, -0.000776, -0.001119, 0.002114, 0.001956, -0.000590, 0.002107, 0.002410,
                0.000908, 0.002491, -0.001556, -0.000766, -0.001054, -0.001454, 0.001407, 0.000790, 0.000212,
                -0.001097, 0.000762, 0.001530, 0.000097, 0.001140, -0.002476, 0.002157, 0.000240, -0.000916,
                -0.001042, -0.000374, -0.001468, -0.002185, -0.001419, 0.002139, -0.000885, -0.001340, 0.001159,
                -0.000852, 0.002378, -0.000802, -0.002294, 0.001358, -0.000037, -0.001744, 0.000488, 0.000721,
                -0.000241, 0.000912, -0.001979, 0.000441, 0.000908, -0.001505, 0.000071, -0.000030, -0.001200,
                -0.001416, -0.002347, 0.000011, 0.000076, 0.000005, -0.001967, -0.002481, -0.002373, -0.002163,
                -0.000274, 0.000696, 0.000592, -0.001591, 0.002499, -0.001006, -0.000637, -0.000702, 0.002366,
                -0.001882, 0.000581, -0.000668, 0.001594, 0.000020, 0.002135, -0.001410, -0.001303, -0.002096,
                -0.001833, -0.001600, -0.001557, 0.001222, -0.000933, 0.001340, 0.001845, 0.000678, 0.001475,
                0.001238, 0.001170, -0.001775, -0.001717, -0.001828, -0.000066, 0.002065, -0.001368, -0.001530,
                -0.002098, 0.001653, -0.002089, -0.000290, 0.001089, -0.002309, -0.002239, 0.000721, 0.001762,
                0.002132, 0.001073, 0.001581, -0.001564, -0.001820, 0.001987, -0.001382, 0.000877, 0.000287,
                0.000895, -0.000591, 0.000099, -0.000843, -0.000563 });
        String w1 = "database";
        String w2 = "DBMS";
        WordVectors vecModel = WordVectorSerializer.loadGoogleModel(
                new ClassPathResource("word2vec/googleload/sample_vec.txt").getFile(), false, true);
        WordVectors vectorsBinary = WordVectorSerializer
                .loadGoogleModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile(), true, true);
        INDArray textWeights = vecModel.lookupTable().getWeights();
        INDArray binaryWeights = vectorsBinary.lookupTable().getWeights();
        Collection<String> nearest = vecModel.wordsNearest("database", 10);
        Collection<String> nearestBinary = vectorsBinary.wordsNearest("database", 10);
        System.out.println(nearestBinary);
        assertEquals(vecModel.similarity("DBMS", "DBMS's"), vectorsBinary.similarity("DBMS", "DBMS's"), 1e-1);

    }

    @Test
    @Ignore
    public void testLoaderText() throws IOException {
        WordVectors vec = WordVectorSerializer.loadGoogleModel(textFile, false);
        assertEquals(vec.vocab().numWords(), 30);
        assertTrue(vec.vocab().hasToken("Morgan_Freeman"));
        assertTrue(vec.vocab().hasToken("JA_Montalbano"));
    }

    @Test
    public void testLoaderBinary() throws IOException {
        WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
        assertEquals(vec.vocab().numWords(), 30);
        assertTrue(vec.vocab().hasToken("Morgan_Freeman"));
        assertTrue(vec.vocab().hasToken("JA_Montalbano"));
        double[] wordVector1 = vec.getWordVector("Morgan_Freeman");
        double[] wordVector2 = vec.getWordVector("JA_Montalbano");
        assertTrue(wordVector1.length == 300);
        assertTrue(wordVector2.length == 300);
        assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
        assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
    }

    @Test
    @Ignore
    public void testWriteWordVectors() throws IOException {
        WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
        InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
        InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
        WordVectorSerializer.writeWordVectors(lookupTable, lookupCache, pathToWriteto);

        WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
        double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
        double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
        assertTrue(wordVector1.length == 300);
        assertTrue(wordVector2.length == 300);
        assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
        assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
    }

    @Test
    @Ignore
    public void testWriteWordVectorsFromWord2Vec() throws IOException {
        WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
        WordVectorSerializer.writeWordVectors((Word2Vec) vec, pathToWriteto);

        WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
        INDArray wordVector1 = wordVectors.getWordVectorMatrix("Morgan_Freeman");
        INDArray wordVector2 = wordVectors.getWordVectorMatrix("JA_Montalbano");
        assertEquals(vec.getWordVectorMatrix("Morgan_Freeman"), wordVector1);
        assertEquals(vec.getWordVectorMatrix("JA_Montalbano"), wordVector2);
        assertTrue(wordVector1.length() == 300);
        assertTrue(wordVector2.length() == 300);
        assertEquals(wordVector1.getDouble(0), 0.044423, 1e-3);
        assertEquals(wordVector2.getDouble(0), 0.051964, 1e-3);
    }

    @Test
    @Ignore
    public void testFromTableAndVocab() throws IOException {

        WordVectors vec = WordVectorSerializer.loadGoogleModel(textFile, false);
        InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
        InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();

        WordVectors wordVectors = WordVectorSerializer.fromTableAndVocab(lookupTable, lookupCache);
        double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
        double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
        assertTrue(wordVector1.length == 300);
        assertTrue(wordVector2.length == 300);
        assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
        assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
    }

    @Test
    public void testFullModelSerialization() throws Exception {
        File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
        SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
        // Split on white spaces in the line to get words
        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());

        InMemoryLookupCache cache = new InMemoryLookupCache(false);
        WeightLookupTable table = new InMemoryLookupTable.Builder().vectorLength(100).useAdaGrad(false)
                .negative(5.0).cache(cache).lr(0.025f).build();

        Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100)
                .lookupTable(table).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5)
                .vocabCache(cache).seed(42)
                //                .workers(6)
                .windowSize(5).iterate(iter).tokenizerFactory(t).build();

        assertEquals(new ArrayList<String>(), vec.getStopWords());
        vec.fit();

        logger.info("Original word 0: " + cache.wordFor(cache.wordAtIndex(0)));

        logger.info("Closest Words:");
        Collection<String> lst = vec.wordsNearest("day", 10);
        System.out.println(lst);

        WordVectorSerializer.writeFullModel(vec, "tempModel.txt");

        File modelFile = new File("tempModel.txt");

        assertTrue(modelFile.exists());
        assertTrue(modelFile.length() > 0);

        Word2Vec vec2 = WordVectorSerializer.loadFullModel("tempModel.txt");

        assertNotEquals(null, vec2);

        assertEquals(vec.getConfiguration(), vec2.getConfiguration());

        logger.info("Source ExpTable: " + ArrayUtils.toString(((InMemoryLookupTable) table).getExpTable()));
        logger.info("Dest  ExpTable: "
                + ArrayUtils.toString(((InMemoryLookupTable) vec2.getLookupTable()).getExpTable()));
        assertTrue(ArrayUtils.isEquals(((InMemoryLookupTable) table).getExpTable(),
                ((InMemoryLookupTable) vec2.getLookupTable()).getExpTable()));

        InMemoryLookupTable restoredTable = (InMemoryLookupTable) vec2.lookupTable();

        /*
                logger.info("Restored word 1: " + restoredTable.getVocab().wordFor(restoredTable.getVocab().wordAtIndex(1)));
                logger.info("Restored word 'it': " + restoredTable.getVocab().wordFor("it"));
                logger.info("Original word 1: " + cache.wordFor(cache.wordAtIndex(1)));
                logger.info("Original word 'i': " + cache.wordFor("i"));
                logger.info("Original word 0: " + cache.wordFor(cache.wordAtIndex(0)));
                logger.info("Restored word 0: " + restoredTable.getVocab().wordFor(restoredTable.getVocab().wordAtIndex(0)));
        */
        assertEquals(cache.wordAtIndex(1), restoredTable.getVocab().wordAtIndex(1));
        assertEquals(cache.wordAtIndex(7), restoredTable.getVocab().wordAtIndex(7));
        assertEquals(cache.wordAtIndex(15), restoredTable.getVocab().wordAtIndex(15));

        /*
        these tests needed only to make sure INDArray equality is working properly
         */
        double[] array1 = new double[] { 0.323232325, 0.65756575, 0.12315, 0.12312315, 0.1232135, 0.12312315,
                0.4343423425, 0.15 };
        double[] array2 = new double[] { 0.423232325, 0.25756575, 0.12375, 0.12311315, 0.1232035, 0.12318315,
                0.4343493425, 0.25 };
        assertNotEquals(Nd4j.create(array1), Nd4j.create(array2));
        assertEquals(Nd4j.create(array1), Nd4j.create(array1));

        INDArray rSyn0_1 = restoredTable.getSyn0().slice(1);
        INDArray oSyn0_1 = ((InMemoryLookupTable) table).getSyn0().slice(1);

        logger.info("Restored syn0: " + rSyn0_1);
        logger.info("Original syn0: " + oSyn0_1);

        assertEquals(oSyn0_1, rSyn0_1);

        // just checking $^###! syn0/syn1 order
        int cnt = 0;
        for (VocabWord word : cache.vocabWords()) {
            INDArray rSyn0 = restoredTable.getSyn0().slice(word.getIndex());
            INDArray oSyn0 = ((InMemoryLookupTable) table).getSyn0().slice(word.getIndex());

            assertEquals(rSyn0, oSyn0);
            assertEquals(1.0, arraysSimilarity(rSyn0, oSyn0), 0.001);

            INDArray rSyn1 = restoredTable.getSyn1().slice(word.getIndex());
            INDArray oSyn1 = ((InMemoryLookupTable) table).getSyn1().slice(word.getIndex());

            assertEquals(rSyn1, oSyn1);
            if (arraysSimilarity(rSyn1, oSyn1) < 0.98) {
                logger.info("Restored syn1: " + rSyn1);
                logger.info("Original  syn1: " + oSyn1);
            }
            // we exclude word 222 since it has syn1 full of zeroes
            if (cnt != 222)
                assertEquals(1.0, arraysSimilarity(rSyn1, oSyn1), 0.001);

            if (((InMemoryLookupTable) table).getSyn1Neg() != null) {
                INDArray rSyn1Neg = restoredTable.getSyn1Neg().slice(word.getIndex());
                INDArray oSyn1Neg = ((InMemoryLookupTable) table).getSyn1Neg().slice(word.getIndex());

                assertEquals(rSyn1Neg, oSyn1Neg);
                //                assertEquals(1.0, arraysSimilarity(rSyn1Neg, oSyn1Neg), 0.001);
            }
            assertEquals(word.getHistoricalGradient(),
                    restoredTable.getVocab().wordFor(word.getWord()).getHistoricalGradient());

            cnt++;
        }

        // at this moment we can assume that whole model is transferred, and we can call fit over new model
        //        iter.reset();

        iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());

        vec2.setTokenizerFactory(t);
        vec2.setSentenceIter(iter);

        vec2.fit();

        INDArray day1 = vec.getWordVectorMatrix("day");
        INDArray day2 = vec2.getWordVectorMatrix("day");

        INDArray night1 = vec.getWordVectorMatrix("night");
        INDArray night2 = vec2.getWordVectorMatrix("night");

        double simD = arraysSimilarity(day1, day2);
        double simN = arraysSimilarity(night1, night2);

        logger.info("Vec1 day: " + day1);
        logger.info("Vec2 day: " + day2);

        logger.info("Vec1 night: " + night1);
        logger.info("Vec2 night: " + night2);

        logger.info("Day/day cross-model similarity: " + simD);
        logger.info("Night/night cross-model similarity: " + simN);

        logger.info("Vec1 day/night similiraty: " + vec.similarity("day", "night"));
        logger.info("Vec2 day/night similiraty: " + vec2.similarity("day", "night"));

        // check if cross-model values are not the same
        assertNotEquals(1.0, simD, 0.001);
        assertNotEquals(1.0, simN, 0.001);

        // check if cross-model values are still close to each other
        assertTrue(simD > 0.70);
        assertTrue(simN > 0.70);

        modelFile.delete();
    }

    @Test
    public void testOutputStream() throws Exception {
        File file = File.createTempFile("tmp_ser", "ssa");
        file.deleteOnExit();

        File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
        SentenceIterator iter = new BasicLineIterator(inputFile);
        // Split on white spaces in the line to get words
        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());

        InMemoryLookupCache cache = new InMemoryLookupCache(false);
        WeightLookupTable table = new InMemoryLookupTable.Builder().vectorLength(100).useAdaGrad(false)
                .negative(5.0).cache(cache).lr(0.025f).build();

        Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100)
                .lookupTable(table).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5)
                .vocabCache(cache).seed(42)
                //                .workers(6)
                .windowSize(5).iterate(iter).tokenizerFactory(t).build();

        assertEquals(new ArrayList<String>(), vec.getStopWords());
        vec.fit();

        INDArray day1 = vec.getWordVectorMatrix("day");

        WordVectorSerializer.writeWordVectors(vec, new FileOutputStream(file));

        WordVectors vec2 = WordVectorSerializer.loadTxtVectors(file);

        INDArray day2 = vec2.getWordVectorMatrix("day");

        assertEquals(day1, day2);
    }

    private double arraysSimilarity(INDArray array1, INDArray array2) {
        if (array1.equals(array2))
            return 1.0;

        INDArray vector = Transforms.unitVec(array1);
        INDArray vector2 = Transforms.unitVec(array2);
        if (vector == null || vector2 == null)
            return -1;
        return Nd4j.getBlasWrapper().dot(vector, vector2);

    }

}