Example usage for org.deeplearning4j.models.embeddings WeightLookupTable vector

List of usage examples for org.deeplearning4j.models.embeddings WeightLookupTable vector

Introduction

In this page you can find the example usage for org.deeplearning4j.models.embeddings WeightLookupTable vector.

Prototype

INDArray vector(String word);

Source Link

Usage

From source file:com.github.tteofili.p2h.Par2HierTest.java

License:Apache License

@Test
public void testP2HOnMTPapers() throws Exception {
    ParagraphVectors paragraphVectors;/*from  ww  w. ja v  a  2s.  c o  m*/
    LabelAwareIterator iterator;
    TokenizerFactory tokenizerFactory;
    ClassPathResource resource = new ClassPathResource("papers/sbc");

    // build a iterator for our MT papers dataset
    iterator = new FilenamesLabelAwareIterator.Builder().addSourceFolder(resource.getFile()).build();

    tokenizerFactory = new DefaultTokenizerFactory();
    tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());

    Map<String, INDArray> hvs = new TreeMap<>();
    Map<String, INDArray> pvs = new TreeMap<>();

    paragraphVectors = new ParagraphVectors.Builder().iterate(iterator).tokenizerFactory(tokenizerFactory)
            .build();

    // fit model
    paragraphVectors.fit();

    Par2Hier par2Hier = new Par2Hier(paragraphVectors, method, k);

    // fit model
    par2Hier.fit();

    Map<String, String[]> comparison = new TreeMap<>();

    // extract paragraph vectors similarities
    WeightLookupTable<VocabWord> lookupTable = paragraphVectors.getLookupTable();
    List<String> labels = paragraphVectors.getLabelsSource().getLabels();
    for (String label : labels) {
        INDArray vector = lookupTable.vector(label);
        pvs.put(label, vector);
        Collection<String> strings = paragraphVectors.nearestLabels(vector, 2);
        Collection<String> hstrings = par2Hier.nearestLabels(vector, 2);
        String[] stringsArray = new String[2];
        stringsArray[0] = new LinkedList<>(strings).get(1);
        stringsArray[1] = new LinkedList<>(hstrings).get(1);
        comparison.put(label, stringsArray);
        hvs.put(label, par2Hier.getLookupTable().vector(label));
    }

    System.out.println("--->func(args):pv,p2h");

    // measure similarity indexes
    double[] intraDocumentSimilarity = getIntraDocumentSimilarity(comparison);
    System.out.println("ids(" + k + "," + method + "):" + Arrays.toString(intraDocumentSimilarity));
    double[] depthSimilarity = getDepthSimilarity(comparison);
    System.out.println("ds(" + k + "," + method + "):" + Arrays.toString(depthSimilarity));

    // classification
    Map<Integer, Map<Integer, Long>> pvCounts = new HashMap<>();
    Map<Integer, Map<Integer, Long>> p2hCounts = new HashMap<>();
    for (String label : labels) {

        INDArray vector = lookupTable.vector(label);
        int topN = 1;
        Collection<String> strings = paragraphVectors.nearestLabels(vector, topN);
        Collection<String> hstrings = par2Hier.nearestLabels(vector, topN);
        int labelDepth = label.split("\\.").length - 1;

        int stringDepth = getClass(strings);
        int hstringDepth = getClass(hstrings);

        updateCM(pvCounts, labelDepth, stringDepth);
        updateCM(p2hCounts, labelDepth, hstringDepth);
    }

    ConfusionMatrix pvCM = new ConfusionMatrix(pvCounts);
    ConfusionMatrix p2hCM = new ConfusionMatrix(p2hCounts);

    System.out.println("mf1(" + k + "," + method + "):" + pvCM.getF1Measure() + "," + p2hCM.getF1Measure());
    System.out.println("acc(" + k + "," + method + "):" + pvCM.getAccuracy() + "," + p2hCM.getAccuracy());

    // create a CSV with a raw comparison
    File pvFile = Files.createFile(Paths.get("target/comparison-" + k + "-" + method + ".csv")).toFile();
    FileOutputStream pvOutputStream = new FileOutputStream(pvFile);

    try {
        Map<String, INDArray> pvs2 = Par2HierUtils.svdPCA(pvs, 2);
        Map<String, INDArray> hvs2 = Par2HierUtils.svdPCA(hvs, 2);
        String pvCSV = asStrings(pvs2, hvs2);
        IOUtils.write(pvCSV, pvOutputStream);
    } finally {
        pvOutputStream.flush();
        pvOutputStream.close();
    }
}

From source file:com.github.tteofili.p2h.Par2HierUtils.java

License:Apache License

/**
 * base case: on a leaf hv = pv//from   w  w w .  j  av  a 2 s .co m
 * on a non-leaf node with n children: hv = pv + k centroids of the n hv
 */
private static INDArray getPar2HierVector(WeightLookupTable<VocabWord> lookupTable, PatriciaTrie<String> trie,
        String node, int k, Map<String, INDArray> hvs, Method method) {
    if (hvs.containsKey(node)) {
        return hvs.get(node);
    }
    INDArray hv = lookupTable.vector(node);
    String[] split = node.split(REGEX);
    Collection<String> descendants = new HashSet<>();
    if (split.length == 2) {
        String separator = ".";
        String prefix = node.substring(0, node.indexOf(split[1])) + separator;

        SortedMap<String, String> sortedMap = trie.prefixMap(prefix);

        for (Map.Entry<String, String> entry : sortedMap.entrySet()) {
            if (prefix.lastIndexOf(separator) == entry.getKey().lastIndexOf(separator)) {
                descendants.add(entry.getValue());
            }
        }
    } else {
        descendants = Collections.emptyList();
    }
    if (descendants.size() == 0) {
        // just the pv
        hvs.put(node, hv);
        return hv;
    } else {
        INDArray chvs = Nd4j.zeros(descendants.size(), hv.columns());
        int i = 0;
        for (String desc : descendants) {
            // child hierarchical vector
            INDArray chv = getPar2HierVector(lookupTable, trie, desc, k, hvs, method);
            chvs.putRow(i, chv);
            i++;
        }

        double[][] centroids;
        if (chvs.rows() > k) {
            centroids = Par2HierUtils.getTruncatedVT(chvs, k);
        } else if (chvs.rows() == 1) {
            centroids = Par2HierUtils.getDoubles(chvs.getRow(0));
        } else {
            centroids = Par2HierUtils.getTruncatedVT(chvs, 1);
        }
        switch (method) {
        case CLUSTER:
            INDArray matrix = Nd4j.zeros(centroids.length + 1, hv.columns());
            matrix.putRow(0, hv);
            for (int c = 0; c < centroids.length; c++) {
                matrix.putRow(c + 1, Nd4j.create(centroids[c]));
            }
            hv = Nd4j.create(Par2HierUtils.getTruncatedVT(matrix, 1));
            break;
        case SUM:
            for (double[] centroid : centroids) {
                hv.addi(Nd4j.create(centroid));
            }
            break;
        }

        hvs.put(node, hv);
        return hv;
    }
}