List of usage examples for org.deeplearning4j.models.paragraphvectors ParagraphVectors nearestLabels
public Collection<String> nearestLabels(INDArray labelVector, int topN)
From source file:com.github.tteofili.p2h.Par2HierTest.java
License:Apache License
@Test public void testP2HOnMTPapers() throws Exception { ParagraphVectors paragraphVectors; LabelAwareIterator iterator;// ww w. ja va 2 s.co m TokenizerFactory tokenizerFactory; ClassPathResource resource = new ClassPathResource("papers/sbc"); // build a iterator for our MT papers dataset iterator = new FilenamesLabelAwareIterator.Builder().addSourceFolder(resource.getFile()).build(); tokenizerFactory = new DefaultTokenizerFactory(); tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); Map<String, INDArray> hvs = new TreeMap<>(); Map<String, INDArray> pvs = new TreeMap<>(); paragraphVectors = new ParagraphVectors.Builder().iterate(iterator).tokenizerFactory(tokenizerFactory) .build(); // fit model paragraphVectors.fit(); Par2Hier par2Hier = new Par2Hier(paragraphVectors, method, k); // fit model par2Hier.fit(); Map<String, String[]> comparison = new TreeMap<>(); // extract paragraph vectors similarities WeightLookupTable<VocabWord> lookupTable = paragraphVectors.getLookupTable(); List<String> labels = paragraphVectors.getLabelsSource().getLabels(); for (String label : labels) { INDArray vector = lookupTable.vector(label); pvs.put(label, vector); Collection<String> strings = paragraphVectors.nearestLabels(vector, 2); Collection<String> hstrings = par2Hier.nearestLabels(vector, 2); String[] stringsArray = new String[2]; stringsArray[0] = new LinkedList<>(strings).get(1); stringsArray[1] = new LinkedList<>(hstrings).get(1); comparison.put(label, stringsArray); hvs.put(label, par2Hier.getLookupTable().vector(label)); } System.out.println("--->func(args):pv,p2h"); // measure similarity indexes double[] intraDocumentSimilarity = getIntraDocumentSimilarity(comparison); System.out.println("ids(" + k + "," + method + "):" + Arrays.toString(intraDocumentSimilarity)); double[] depthSimilarity = getDepthSimilarity(comparison); System.out.println("ds(" + k + "," + method + "):" + Arrays.toString(depthSimilarity)); // classification Map<Integer, Map<Integer, Long>> pvCounts = new HashMap<>(); Map<Integer, Map<Integer, Long>> p2hCounts = new HashMap<>(); for (String label : labels) { INDArray vector = lookupTable.vector(label); int topN = 1; Collection<String> strings = paragraphVectors.nearestLabels(vector, topN); Collection<String> hstrings = par2Hier.nearestLabels(vector, topN); int labelDepth = label.split("\\.").length - 1; int stringDepth = getClass(strings); int hstringDepth = getClass(hstrings); updateCM(pvCounts, labelDepth, stringDepth); updateCM(p2hCounts, labelDepth, hstringDepth); } ConfusionMatrix pvCM = new ConfusionMatrix(pvCounts); ConfusionMatrix p2hCM = new ConfusionMatrix(p2hCounts); System.out.println("mf1(" + k + "," + method + "):" + pvCM.getF1Measure() + "," + p2hCM.getF1Measure()); System.out.println("acc(" + k + "," + method + "):" + pvCM.getAccuracy() + "," + p2hCM.getAccuracy()); // create a CSV with a raw comparison File pvFile = Files.createFile(Paths.get("target/comparison-" + k + "-" + method + ".csv")).toFile(); FileOutputStream pvOutputStream = new FileOutputStream(pvFile); try { Map<String, INDArray> pvs2 = Par2HierUtils.svdPCA(pvs, 2); Map<String, INDArray> hvs2 = Par2HierUtils.svdPCA(hvs, 2); String pvCSV = asStrings(pvs2, hvs2); IOUtils.write(pvCSV, pvOutputStream); } finally { pvOutputStream.flush(); pvOutputStream.close(); } }