List of usage examples for org.deeplearning4j.models.embeddings.wordvectors WordVectors hasWord
boolean hasWord(String word);
From source file:org.knime.ext.textprocessing.dl4j.nodes.embeddings.apply.WordVectorApplyNodeModel.java
License:Open Source License
/** * Replaces each word contained in a document with its corresponding word vector. If a word from the document is not * contained in the used {@link WordVectors} model it will be skipped. The output is a {@link ListCell} containing * {@link ListCell}s containing the word vectors as {@link DoubleCell}s. * * @param wordVec the {@link WordVectors} model to use * @param document the document to use//from w w w .j av a 2s . c o m * @return {@link ListCell} of {@link ListCell}c of {@link DoubleCell}s containing converted words */ private ListCell replaceWordsByWordVector(final WordVectors wordVec, final String document) { final TokenizerFactory tokenizerFac = new DefaultTokenizerFactory(); tokenizerFac.setTokenPreProcessor(new CommonPreprocessor()); final Tokenizer t = tokenizerFac.create(document); final List<ListCell> listCells = new ArrayList<ListCell>(); while (t.hasMoreTokens()) { final String word = t.nextToken(); if (!word.isEmpty()) { if (wordVec.hasWord(word)) { listCells.add(wordToListCell(wordVec, word)); } else { m_unknownWords.add(word); } } } return CollectionCellFactory.createListCell(listCells); }
From source file:org.knime.ext.textprocessing.dl4j.nodes.embeddings.apply.WordVectorApplyNodeModel.java
License:Open Source License
/** * Calculates the mean vector of all word vectors of all words contained in a document. * * @param wordVec the {@link WordVectors} model to use * @param document the document for which the mean should be calculated * @return {@link INDArray} containing the mean vector of the document *///from w ww . jav a 2 s .co m private INDArray calculateDocumentMean(final WordVectors wordVec, final String document) { final TokenizerFactory tokenizerFac = new DefaultTokenizerFactory(); tokenizerFac.setTokenPreProcessor(new CommonPreprocessor()); final Tokenizer t = tokenizerFac.create(document); final List<String> tokens = t.getTokens(); int numberOfWordsMatchingWithVoc = 0; for (final String token : tokens) { if (wordVec.hasWord(token)) { numberOfWordsMatchingWithVoc++; } } final INDArray documentWordVectors = Nd4j.create(numberOfWordsMatchingWithVoc, wordVec.lookupTable().layerSize()); int i = 0; for (final String token : tokens) { if (!token.isEmpty()) { if (wordVec.hasWord(token)) { documentWordVectors.putRow(i, wordVec.getWordVectorMatrix(token)); i++; } else { m_unknownWords.add(token); } } } final INDArray documentMeanVector = documentWordVectors.mean(0); return documentMeanVector; }
From source file:org.knime.ext.textprocessing.dl4j.nodes.embeddings.apply.WordVectorApplyNodeModel2.java
License:Open Source License
private DataRow processRow(final DataRow row, final int documentColumnIndex, final WordVectors wordVectors) throws DataCellConversionException, IllegalStateException { final List<DataCell> cells = TableUtils.toListOfCells(row); final DataCell cell = row.getCell(documentColumnIndex); final String document = ConverterUtils.convertDataCellToJava(cell, String.class); ListCell convertedDocument;/*from w w w. j a v a 2 s . c o m*/ final Tokenizer t = new DefaultTokenizerFactory().create(document); final List<String> matchingTokens = new ArrayList<String>(); for (final String token : t.getTokens()) { if (wordVectors.hasWord(token)) { matchingTokens.add(token); } else { m_unknownWordsCtr++; } m_totalWordsCtr++; } if (matchingTokens.size() == 0) { cells.add(new MissingCell("No tokens in row " + row.getKey() + " match the vocabulary!")); } else { if (m_calculateMean.getBooleanValue()) { final INDArray documentMeanVector = calculateDocumentMean(wordVectors, matchingTokens); convertedDocument = CollectionCellFactory .createListCell(NDArrayUtils.toListOfDoubleCells(documentMeanVector)); } else { convertedDocument = replaceWordsByWordVector(wordVectors, matchingTokens); } cells.add(convertedDocument); } return new DefaultRow(row.getKey(), cells); }