List of usage examples for org.deeplearning4j.models.paragraphvectors ParagraphVectors fit
@Override
public void fit()
From source file:com.github.tteofili.p2h.Par2HierTest.java
License:Apache License
@Test public void testP2HOnMTPapers() throws Exception { ParagraphVectors paragraphVectors; LabelAwareIterator iterator;/* www . j a va2 s . c om*/ TokenizerFactory tokenizerFactory; ClassPathResource resource = new ClassPathResource("papers/sbc"); // build a iterator for our MT papers dataset iterator = new FilenamesLabelAwareIterator.Builder().addSourceFolder(resource.getFile()).build(); tokenizerFactory = new DefaultTokenizerFactory(); tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); Map<String, INDArray> hvs = new TreeMap<>(); Map<String, INDArray> pvs = new TreeMap<>(); paragraphVectors = new ParagraphVectors.Builder().iterate(iterator).tokenizerFactory(tokenizerFactory) .build(); // fit model paragraphVectors.fit(); Par2Hier par2Hier = new Par2Hier(paragraphVectors, method, k); // fit model par2Hier.fit(); Map<String, String[]> comparison = new TreeMap<>(); // extract paragraph vectors similarities WeightLookupTable<VocabWord> lookupTable = paragraphVectors.getLookupTable(); List<String> labels = paragraphVectors.getLabelsSource().getLabels(); for (String label : labels) { INDArray vector = lookupTable.vector(label); pvs.put(label, vector); Collection<String> strings = paragraphVectors.nearestLabels(vector, 2); Collection<String> hstrings = par2Hier.nearestLabels(vector, 2); String[] stringsArray = new String[2]; stringsArray[0] = new LinkedList<>(strings).get(1); stringsArray[1] = new LinkedList<>(hstrings).get(1); comparison.put(label, stringsArray); hvs.put(label, par2Hier.getLookupTable().vector(label)); } System.out.println("--->func(args):pv,p2h"); // measure similarity indexes double[] intraDocumentSimilarity = getIntraDocumentSimilarity(comparison); System.out.println("ids(" + k + "," + method + "):" + Arrays.toString(intraDocumentSimilarity)); double[] depthSimilarity = getDepthSimilarity(comparison); System.out.println("ds(" + k + "," + method + "):" + Arrays.toString(depthSimilarity)); // classification Map<Integer, Map<Integer, Long>> pvCounts = new HashMap<>(); Map<Integer, Map<Integer, Long>> p2hCounts = new HashMap<>(); for (String label : labels) { INDArray vector = lookupTable.vector(label); int topN = 1; Collection<String> strings = paragraphVectors.nearestLabels(vector, topN); Collection<String> hstrings = par2Hier.nearestLabels(vector, topN); int labelDepth = label.split("\\.").length - 1; int stringDepth = getClass(strings); int hstringDepth = getClass(hstrings); updateCM(pvCounts, labelDepth, stringDepth); updateCM(p2hCounts, labelDepth, hstringDepth); } ConfusionMatrix pvCM = new ConfusionMatrix(pvCounts); ConfusionMatrix p2hCM = new ConfusionMatrix(p2hCounts); System.out.println("mf1(" + k + "," + method + "):" + pvCM.getF1Measure() + "," + p2hCM.getF1Measure()); System.out.println("acc(" + k + "," + method + "):" + pvCM.getAccuracy() + "," + p2hCM.getAccuracy()); // create a CSV with a raw comparison File pvFile = Files.createFile(Paths.get("target/comparison-" + k + "-" + method + ".csv")).toFile(); FileOutputStream pvOutputStream = new FileOutputStream(pvFile); try { Map<String, INDArray> pvs2 = Par2HierUtils.svdPCA(pvs, 2); Map<String, INDArray> hvs2 = Par2HierUtils.svdPCA(hvs, 2); String pvCSV = asStrings(pvs2, hvs2); IOUtils.write(pvCSV, pvOutputStream); } finally { pvOutputStream.flush(); pvOutputStream.close(); } }
From source file:org.knime.ext.textprocessing.dl4j.nodes.embeddings.learn.d2v.Doc2VecLearnerNodeModel.java
License:Open Source License
@Override protected WordVectorFileStorePortObject[] executeDL4JMemorySafe(final PortObject[] inObjects, final ExecutionContext exec) throws Exception { final BufferedDataTable table = (BufferedDataTable) inObjects[0]; TableUtils.checkForEmptyTable(table); final String labelColumnName = m_wordVecParameterSettings .getString(WordVectorLearnerParameter.LABEL_COLUMN); final String documentColumnName = m_wordVecParameterSettings .getString(WordVectorLearnerParameter.DOCUMENT_COLUMN); final String sequenceAlgo = m_wordVecParameterSettings .getString(WordVectorLearnerParameter.SEQUENCE_LEARNING_ALGO); // training parameters final int trainingIterations = m_wordVecParameterSettings .getInteger(WordVectorLearnerParameter.TRAINING_ITERATIONS); final int minWordFrequency = m_wordVecParameterSettings .getInteger(WordVectorLearnerParameter.MIN_WORD_FREQUENCY); final int layerSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.LAYER_SIZE); final int seed = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.SEED); final double learningRate = m_wordVecParameterSettings.getDouble(WordVectorLearnerParameter.LEARNING_RATE); final double sampling = m_wordVecParameterSettings.getDouble(WordVectorLearnerParameter.SAMPLING); double negativeSampling = m_wordVecParameterSettings .getDouble(WordVectorLearnerParameter.NEGATIVE_SAMPLING); final double minLearningRate = m_wordVecParameterSettings .getDouble(WordVectorLearnerParameter.MIN_LEARNING_RATE); final int windowSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.WINDOW_SIZE); final int epochs = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.EPOCHS); final int batchSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.BATCH_SIZE); final boolean skipMissing = m_wordVecParameterSettings .getBoolean(WordVectorLearnerParameter.SKIP_MISSING_CELLS); final boolean useHS = m_wordVecParameterSettings .getBoolean(WordVectorLearnerParameter.USE_HIERARCHICAL_SOFTMAX); final TokenizerFactory t = new DefaultTokenizerFactory(); final BufferedDataTableLabelledDocumentIterator docIter = new BufferedDataTableLabelledDocumentIterator( table, documentColumnName, labelColumnName, skipMissing); // Either hierarchical softmax or negative sampling should be used at the same time. if (useHS) {/* w ww . j av a 2s.c o m*/ negativeSampling = 0.0; } // build doc2vec model final ParagraphVectors d2v = new ParagraphVectors.Builder().learningRate(learningRate) .minLearningRate(minLearningRate).seed(seed).layerSize(layerSize).batchSize(batchSize) .windowSize(windowSize).minWordFrequency(minWordFrequency).iterations(trainingIterations) .epochs(epochs).iterate(docIter).tokenizerFactory(t).trainElementsRepresentation(true) .allowParallelTokenization(false).sequenceLearningAlgorithm(parseSequenceAlgo(sequenceAlgo)) .useHierarchicSoftmax(useHS).negativeSample(negativeSampling).sampling(sampling).build(); d2v.fit(); docIter.close(); final WordVectorFileStorePortObject outPortObject = WordVectorFileStorePortObject.create(d2v, new WordVectorPortObjectSpec(WordVectorTrainingMode.DOC2VEC), exec.createFileStore(UUID.randomUUID().toString())); return new WordVectorFileStorePortObject[] { outPortObject }; }
From source file:org.knime.ext.textprocessing.dl4j.nodes.embeddings.learn.WordVectorLearnerNodeModel.java
License:Open Source License
@Override protected WordVectorPortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception { final BufferedDataTable table = (BufferedDataTable) inObjects[0]; TableUtils.checkForEmptyTable(table); final WordVectorTrainingMode mode = WordVectorTrainingMode.valueOf( m_wordVecParameterSettings.getString(WordVectorLearnerParameter.WORD_VECTOR_TRAINING_MODE)); final String labelColumnName = m_dataParameterSettings.getString(DataParameter.LABEL_COLUMN); final String documentColumnName = m_dataParameterSettings.getString(DataParameter.DOCUMENT_COLUMN); WordVectors wordVectors = null;//from w w w.j a v a2 s . com // training parameters final int trainingIterations = m_learnerParameterSettings.getInteger(LearnerParameter.TRAINING_ITERATIONS); final int minWordFrequency = m_wordVecParameterSettings .getInteger(WordVectorLearnerParameter.MIN_WORD_FREQUENCY); final int layerSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.LAYER_SIZE); final int seed = m_learnerParameterSettings.getInteger(LearnerParameter.SEED); final double learningRate = m_learnerParameterSettings.getDouble(LearnerParameter.GLOBAL_LEARNING_RATE); final double minLearningRate = m_wordVecParameterSettings .getDouble(WordVectorLearnerParameter.MIN_LEARNING_RATE); final int windowSize = m_wordVecParameterSettings.getInteger(WordVectorLearnerParameter.WINDOW_SIZE); final int epochs = m_dataParameterSettings.getInteger(DataParameter.EPOCHS); final int batchSize = m_dataParameterSettings.getInteger(DataParameter.BATCH_SIZE); // sentence tokenizer and preprocessing final boolean usePreproc = m_wordVecParameterSettings .getBoolean(WordVectorLearnerParameter.USE_BASIC_PREPROCESSING); final TokenizerFactory t = new DefaultTokenizerFactory(); if (usePreproc) { t.setTokenPreProcessor(new CommonPreprocessor()); } switch (mode) { case DOC2VEC: final LabelAwareIterator docIter = new BufferedDataTableLabelledDocumentIterator(table, documentColumnName, labelColumnName); // build doc2vec model final ParagraphVectors d2v = new ParagraphVectors.Builder().learningRate(learningRate) .minLearningRate(minLearningRate).seed(seed).layerSize(layerSize).batchSize(batchSize) .windowSize(windowSize).minWordFrequency(minWordFrequency).iterations(trainingIterations) .epochs(epochs).iterate(docIter).trainWordVectors(true).tokenizerFactory(t) .allowParallelTokenization(false).build(); d2v.fit(); wordVectors = d2v; break; case WORD2VEC: final SentenceIterator sentenceIter = new BufferedDataTableSentenceIterator(table, documentColumnName); // build word2vec model final Word2Vec w2v = new Word2Vec.Builder().learningRate(learningRate).minLearningRate(minLearningRate) .seed(seed).layerSize(layerSize).batchSize(batchSize).windowSize(windowSize) .minWordFrequency(minWordFrequency).iterations(trainingIterations).epochs(epochs) .iterate(sentenceIter).tokenizerFactory(t).allowParallelTokenization(false).build(); w2v.fit(); wordVectors = w2v; break; default: throw new InvalidSettingsException("No case defined for WordVectorTrainingMode: " + mode); } final WordVectorPortObject outPortObject = new WordVectorPortObject(wordVectors, m_outputSpec); return new WordVectorPortObject[] { outPortObject }; }