List of usage examples for org.deeplearning4j.models.embeddings WeightLookupTable vector
INDArray vector(String word);
From source file:com.github.tteofili.p2h.Par2HierTest.java
License:Apache License
@Test public void testP2HOnMTPapers() throws Exception { ParagraphVectors paragraphVectors;/*from ww w. ja v a 2s. c o m*/ LabelAwareIterator iterator; TokenizerFactory tokenizerFactory; ClassPathResource resource = new ClassPathResource("papers/sbc"); // build a iterator for our MT papers dataset iterator = new FilenamesLabelAwareIterator.Builder().addSourceFolder(resource.getFile()).build(); tokenizerFactory = new DefaultTokenizerFactory(); tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); Map<String, INDArray> hvs = new TreeMap<>(); Map<String, INDArray> pvs = new TreeMap<>(); paragraphVectors = new ParagraphVectors.Builder().iterate(iterator).tokenizerFactory(tokenizerFactory) .build(); // fit model paragraphVectors.fit(); Par2Hier par2Hier = new Par2Hier(paragraphVectors, method, k); // fit model par2Hier.fit(); Map<String, String[]> comparison = new TreeMap<>(); // extract paragraph vectors similarities WeightLookupTable<VocabWord> lookupTable = paragraphVectors.getLookupTable(); List<String> labels = paragraphVectors.getLabelsSource().getLabels(); for (String label : labels) { INDArray vector = lookupTable.vector(label); pvs.put(label, vector); Collection<String> strings = paragraphVectors.nearestLabels(vector, 2); Collection<String> hstrings = par2Hier.nearestLabels(vector, 2); String[] stringsArray = new String[2]; stringsArray[0] = new LinkedList<>(strings).get(1); stringsArray[1] = new LinkedList<>(hstrings).get(1); comparison.put(label, stringsArray); hvs.put(label, par2Hier.getLookupTable().vector(label)); } System.out.println("--->func(args):pv,p2h"); // measure similarity indexes double[] intraDocumentSimilarity = getIntraDocumentSimilarity(comparison); System.out.println("ids(" + k + "," + method + "):" + Arrays.toString(intraDocumentSimilarity)); double[] depthSimilarity = getDepthSimilarity(comparison); System.out.println("ds(" + k + "," + method + "):" + Arrays.toString(depthSimilarity)); // classification Map<Integer, Map<Integer, Long>> pvCounts = new HashMap<>(); Map<Integer, Map<Integer, Long>> p2hCounts = new HashMap<>(); for (String label : labels) { INDArray vector = lookupTable.vector(label); int topN = 1; Collection<String> strings = paragraphVectors.nearestLabels(vector, topN); Collection<String> hstrings = par2Hier.nearestLabels(vector, topN); int labelDepth = label.split("\\.").length - 1; int stringDepth = getClass(strings); int hstringDepth = getClass(hstrings); updateCM(pvCounts, labelDepth, stringDepth); updateCM(p2hCounts, labelDepth, hstringDepth); } ConfusionMatrix pvCM = new ConfusionMatrix(pvCounts); ConfusionMatrix p2hCM = new ConfusionMatrix(p2hCounts); System.out.println("mf1(" + k + "," + method + "):" + pvCM.getF1Measure() + "," + p2hCM.getF1Measure()); System.out.println("acc(" + k + "," + method + "):" + pvCM.getAccuracy() + "," + p2hCM.getAccuracy()); // create a CSV with a raw comparison File pvFile = Files.createFile(Paths.get("target/comparison-" + k + "-" + method + ".csv")).toFile(); FileOutputStream pvOutputStream = new FileOutputStream(pvFile); try { Map<String, INDArray> pvs2 = Par2HierUtils.svdPCA(pvs, 2); Map<String, INDArray> hvs2 = Par2HierUtils.svdPCA(hvs, 2); String pvCSV = asStrings(pvs2, hvs2); IOUtils.write(pvCSV, pvOutputStream); } finally { pvOutputStream.flush(); pvOutputStream.close(); } }
From source file:com.github.tteofili.p2h.Par2HierUtils.java
License:Apache License
/** * base case: on a leaf hv = pv//from w w w . j av a 2 s .co m * on a non-leaf node with n children: hv = pv + k centroids of the n hv */ private static INDArray getPar2HierVector(WeightLookupTable<VocabWord> lookupTable, PatriciaTrie<String> trie, String node, int k, Map<String, INDArray> hvs, Method method) { if (hvs.containsKey(node)) { return hvs.get(node); } INDArray hv = lookupTable.vector(node); String[] split = node.split(REGEX); Collection<String> descendants = new HashSet<>(); if (split.length == 2) { String separator = "."; String prefix = node.substring(0, node.indexOf(split[1])) + separator; SortedMap<String, String> sortedMap = trie.prefixMap(prefix); for (Map.Entry<String, String> entry : sortedMap.entrySet()) { if (prefix.lastIndexOf(separator) == entry.getKey().lastIndexOf(separator)) { descendants.add(entry.getValue()); } } } else { descendants = Collections.emptyList(); } if (descendants.size() == 0) { // just the pv hvs.put(node, hv); return hv; } else { INDArray chvs = Nd4j.zeros(descendants.size(), hv.columns()); int i = 0; for (String desc : descendants) { // child hierarchical vector INDArray chv = getPar2HierVector(lookupTable, trie, desc, k, hvs, method); chvs.putRow(i, chv); i++; } double[][] centroids; if (chvs.rows() > k) { centroids = Par2HierUtils.getTruncatedVT(chvs, k); } else if (chvs.rows() == 1) { centroids = Par2HierUtils.getDoubles(chvs.getRow(0)); } else { centroids = Par2HierUtils.getTruncatedVT(chvs, 1); } switch (method) { case CLUSTER: INDArray matrix = Nd4j.zeros(centroids.length + 1, hv.columns()); matrix.putRow(0, hv); for (int c = 0; c < centroids.length; c++) { matrix.putRow(c + 1, Nd4j.create(centroids[c])); } hv = Nd4j.create(Par2HierUtils.getTruncatedVT(matrix, 1)); break; case SUM: for (double[] centroid : centroids) { hv.addi(Nd4j.create(centroid)); } break; } hvs.put(node, hv); return hv; } }