Example usage for org.deeplearning4j.models.paragraphvectors ParagraphVectors fit

List of usage examples for org.deeplearning4j.models.paragraphvectors ParagraphVectors fit

Introduction

In this page you can find the example usage for org.deeplearning4j.models.paragraphvectors ParagraphVectors fit.

Prototype

@Override
    public void fit() 

Source Link

Usage

From source file:com.github.tteofili.p2h.Par2HierTest.java

License:Apache License

@Test
public void testP2HOnMTPapers() throws Exception {
    ParagraphVectors paragraphVectors;
    LabelAwareIterator iterator;/* www .  j  a  va2  s . c om*/
    TokenizerFactory tokenizerFactory;
    ClassPathResource resource = new ClassPathResource("papers/sbc");

    // build a iterator for our MT papers dataset
    iterator = new FilenamesLabelAwareIterator.Builder().addSourceFolder(resource.getFile()).build();

    tokenizerFactory = new DefaultTokenizerFactory();
    tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());

    Map<String, INDArray> hvs = new TreeMap<>();
    Map<String, INDArray> pvs = new TreeMap<>();

    paragraphVectors = new ParagraphVectors.Builder().iterate(iterator).tokenizerFactory(tokenizerFactory)
            .build();

    // fit model
    paragraphVectors.fit();

    Par2Hier par2Hier = new Par2Hier(paragraphVectors, method, k);

    // fit model
    par2Hier.fit();

    Map<String, String[]> comparison = new TreeMap<>();

    // extract paragraph vectors similarities
    WeightLookupTable<VocabWord> lookupTable = paragraphVectors.getLookupTable();
    List<String> labels = paragraphVectors.getLabelsSource().getLabels();
    for (String label : labels) {
        INDArray vector = lookupTable.vector(label);
        pvs.put(label, vector);
        Collection<String> strings = paragraphVectors.nearestLabels(vector, 2);
        Collection<String> hstrings = par2Hier.nearestLabels(vector, 2);
        String[] stringsArray = new String[2];
        stringsArray[0] = new LinkedList<>(strings).get(1);
        stringsArray[1] = new LinkedList<>(hstrings).get(1);
        comparison.put(label, stringsArray);
        hvs.put(label, par2Hier.getLookupTable().vector(label));
    }

    System.out.println("--->func(args):pv,p2h");

    // measure similarity indexes
    double[] intraDocumentSimilarity = getIntraDocumentSimilarity(comparison);
    System.out.println("ids(" + k + "," + method + "):" + Arrays.toString(intraDocumentSimilarity));
    double[] depthSimilarity = getDepthSimilarity(comparison);
    System.out.println("ds(" + k + "," + method + "):" + Arrays.toString(depthSimilarity));

    // classification
    Map<Integer, Map<Integer, Long>> pvCounts = new HashMap<>();
    Map<Integer, Map<Integer, Long>> p2hCounts = new HashMap<>();
    for (String label : labels) {

        INDArray vector = lookupTable.vector(label);
        int topN = 1;
        Collection<String> strings = paragraphVectors.nearestLabels(vector, topN);
        Collection<String> hstrings = par2Hier.nearestLabels(vector, topN);
        int labelDepth = label.split("\\.").length - 1;

        int stringDepth = getClass(strings);
        int hstringDepth = getClass(hstrings);

        updateCM(pvCounts, labelDepth, stringDepth);
        updateCM(p2hCounts, labelDepth, hstringDepth);
    }

    ConfusionMatrix pvCM = new ConfusionMatrix(pvCounts);
    ConfusionMatrix p2hCM = new ConfusionMatrix(p2hCounts);

    System.out.println("mf1(" + k + "," + method + "):" + pvCM.getF1Measure() + "," + p2hCM.getF1Measure());
    System.out.println("acc(" + k + "," + method + "):" + pvCM.getAccuracy() + "," + p2hCM.getAccuracy());

    // create a CSV with a raw comparison
    File pvFile = Files.createFile(Paths.get("target/comparison-" + k + "-" + method + ".csv")).toFile();
    FileOutputStream pvOutputStream = new FileOutputStream(pvFile);

    try {
        Map<String, INDArray> pvs2 = Par2HierUtils.svdPCA(pvs, 2);
        Map<String, INDArray> hvs2 = Par2HierUtils.svdPCA(hvs, 2);
        String pvCSV = asStrings(pvs2, hvs2);
        IOUtils.write(pvCSV, pvOutputStream);
    } finally {
        pvOutputStream.flush();
        pvOutputStream.close();
    }
}

From source file:org.knime.ext.textprocessing.dl4j.nodes.embeddings.learn.d2v.Doc2VecLearnerNodeModel.java

License:Open Source License

@Override
protected WordVectorFileStorePortObject[] executeDL4JMemorySafe(final PortObject[] inObjects,
        final ExecutionContext exec) throws Exception {
    final BufferedDataTable table = (BufferedDataTable) inObjects[0];

    TableUtils.checkForEmptyTable(table);

    final String labelColumnName = m_wordVecParameterSettings
            .getString(WordVectorLearnerParameter.LABEL_COLUMN);
    final String documentColumnName = m_wordVecParameterSettings
            .getString(WordVectorLearnerParameter.DOCUMENT_COLUMN);

    final String sequenceAlgo = m_wordVecParameterSettings
            .getString(WordVectorLearnerParameter.SEQUENCE_LEARNING_ALGO);

    // training parameters
    final int trainingIterations = m_wordVecParameterSettings
            .getInteger(WordVectorLearnerParameter.TRAINING_ITERATIONS);
    final int minWordFrequency = m_wordVecParameterSettings
            .getInteger(WordVectorLearnerParameter.MIN_WORD_FREQUENCY);
    final int layerSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.LAYER_SIZE);
    final int seed = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.SEED);
    final double learningRate = m_wordVecParameterSettings.getDouble(WordVectorLearnerParameter.LEARNING_RATE);
    final double sampling = m_wordVecParameterSettings.getDouble(WordVectorLearnerParameter.SAMPLING);
    double negativeSampling = m_wordVecParameterSettings
            .getDouble(WordVectorLearnerParameter.NEGATIVE_SAMPLING);
    final double minLearningRate = m_wordVecParameterSettings
            .getDouble(WordVectorLearnerParameter.MIN_LEARNING_RATE);
    final int windowSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.WINDOW_SIZE);
    final int epochs = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.EPOCHS);
    final int batchSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.BATCH_SIZE);

    final boolean skipMissing = m_wordVecParameterSettings
            .getBoolean(WordVectorLearnerParameter.SKIP_MISSING_CELLS);
    final boolean useHS = m_wordVecParameterSettings
            .getBoolean(WordVectorLearnerParameter.USE_HIERARCHICAL_SOFTMAX);

    final TokenizerFactory t = new DefaultTokenizerFactory();

    final BufferedDataTableLabelledDocumentIterator docIter = new BufferedDataTableLabelledDocumentIterator(
            table, documentColumnName, labelColumnName, skipMissing);

    // Either hierarchical softmax or negative sampling should be used at the same time.
    if (useHS) {/* w  ww  .  j av a  2s.c o  m*/
        negativeSampling = 0.0;
    }

    // build doc2vec model
    final ParagraphVectors d2v = new ParagraphVectors.Builder().learningRate(learningRate)
            .minLearningRate(minLearningRate).seed(seed).layerSize(layerSize).batchSize(batchSize)
            .windowSize(windowSize).minWordFrequency(minWordFrequency).iterations(trainingIterations)
            .epochs(epochs).iterate(docIter).tokenizerFactory(t).trainElementsRepresentation(true)
            .allowParallelTokenization(false).sequenceLearningAlgorithm(parseSequenceAlgo(sequenceAlgo))
            .useHierarchicSoftmax(useHS).negativeSample(negativeSampling).sampling(sampling).build();

    d2v.fit();
    docIter.close();

    final WordVectorFileStorePortObject outPortObject = WordVectorFileStorePortObject.create(d2v,
            new WordVectorPortObjectSpec(WordVectorTrainingMode.DOC2VEC),
            exec.createFileStore(UUID.randomUUID().toString()));
    return new WordVectorFileStorePortObject[] { outPortObject };
}

From source file:org.knime.ext.textprocessing.dl4j.nodes.embeddings.learn.WordVectorLearnerNodeModel.java

License:Open Source License

@Override
protected WordVectorPortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec)
        throws Exception {
    final BufferedDataTable table = (BufferedDataTable) inObjects[0];

    TableUtils.checkForEmptyTable(table);

    final WordVectorTrainingMode mode = WordVectorTrainingMode.valueOf(
            m_wordVecParameterSettings.getString(WordVectorLearnerParameter.WORD_VECTOR_TRAINING_MODE));
    final String labelColumnName = m_dataParameterSettings.getString(DataParameter.LABEL_COLUMN);
    final String documentColumnName = m_dataParameterSettings.getString(DataParameter.DOCUMENT_COLUMN);
    WordVectors wordVectors = null;//from   w w  w.j  a  v a2  s .  com

    // training parameters
    final int trainingIterations = m_learnerParameterSettings.getInteger(LearnerParameter.TRAINING_ITERATIONS);
    final int minWordFrequency = m_wordVecParameterSettings
            .getInteger(WordVectorLearnerParameter.MIN_WORD_FREQUENCY);
    final int layerSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.LAYER_SIZE);
    final int seed = m_learnerParameterSettings.getInteger(LearnerParameter.SEED);
    final double learningRate = m_learnerParameterSettings.getDouble(LearnerParameter.GLOBAL_LEARNING_RATE);
    final double minLearningRate = m_wordVecParameterSettings
            .getDouble(WordVectorLearnerParameter.MIN_LEARNING_RATE);
    final int windowSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.WINDOW_SIZE);
    final int epochs = m_dataParameterSettings.getInteger(DataParameter.EPOCHS);
    final int batchSize = m_dataParameterSettings.getInteger(DataParameter.BATCH_SIZE);

    // sentence tokenizer and preprocessing
    final boolean usePreproc = m_wordVecParameterSettings
            .getBoolean(WordVectorLearnerParameter.USE_BASIC_PREPROCESSING);
    final TokenizerFactory t = new DefaultTokenizerFactory();
    if (usePreproc) {
        t.setTokenPreProcessor(new CommonPreprocessor());
    }

    switch (mode) {
    case DOC2VEC:
        final LabelAwareIterator docIter = new BufferedDataTableLabelledDocumentIterator(table,
                documentColumnName, labelColumnName);

        // build doc2vec model
        final ParagraphVectors d2v = new ParagraphVectors.Builder().learningRate(learningRate)
                .minLearningRate(minLearningRate).seed(seed).layerSize(layerSize).batchSize(batchSize)
                .windowSize(windowSize).minWordFrequency(minWordFrequency).iterations(trainingIterations)
                .epochs(epochs).iterate(docIter).trainWordVectors(true).tokenizerFactory(t)
                .allowParallelTokenization(false).build();

        d2v.fit();
        wordVectors = d2v;

        break;

    case WORD2VEC:
        final SentenceIterator sentenceIter = new BufferedDataTableSentenceIterator(table, documentColumnName);

        // build word2vec model
        final Word2Vec w2v = new Word2Vec.Builder().learningRate(learningRate).minLearningRate(minLearningRate)
                .seed(seed).layerSize(layerSize).batchSize(batchSize).windowSize(windowSize)
                .minWordFrequency(minWordFrequency).iterations(trainingIterations).epochs(epochs)
                .iterate(sentenceIter).tokenizerFactory(t).allowParallelTokenization(false).build();

        w2v.fit();
        wordVectors = w2v;

        break;

    default:
        throw new InvalidSettingsException("No case defined for WordVectorTrainingMode: " + mode);
    }

    final WordVectorPortObject outPortObject = new WordVectorPortObject(wordVectors, m_outputSpec);
    return new WordVectorPortObject[] { outPortObject };
}