edu.cuhk.hccl.cmd.CosineDocumentSimilarity.java Source code

Java tutorial

Introduction

Here is the source code for edu.cuhk.hccl.cmd.CosineDocumentSimilarity.java

Source

/**
 * Copyright (C) 2014 Pengfei Liu <pfliu@se.cuhk.edu.hk>
 * The Chinese University of Hong Kong.
 *
 * This file is part of smart-search-web.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.cuhk.hccl.cmd;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;

/**
 * This class is based on
 * http://stackoverflow.com/questions/1844194/get-cosine-similarity-between-two-documents-in-lucene
 * 
 */
public class CosineDocumentSimilarity {

    private Set<String> terms = new HashSet<String>();
    private RealVector v1;
    private RealVector v2;

    public CosineDocumentSimilarity(Terms vector1, Terms vector2) throws IOException {
        Map<String, Integer> f1 = getTermFrequencies(vector1);
        Map<String, Integer> f2 = getTermFrequencies(vector2);
        v1 = toRealVector(f1);
        v2 = toRealVector(f2);
    }

    public double getCosineSimilarity() {
        return (v1.dotProduct(v2)) / (v1.getNorm() * v2.getNorm());
    }

    private Map<String, Integer> getTermFrequencies(Terms vector) throws IOException {
        TermsEnum termsEnum = null;
        termsEnum = vector.iterator(termsEnum);
        Map<String, Integer> frequencies = new HashMap<String, Integer>();
        BytesRef text = null;
        while ((text = termsEnum.next()) != null) {
            String term = text.utf8ToString();
            int freq = (int) termsEnum.totalTermFreq();
            frequencies.put(term, freq);
            terms.add(term);
        }
        return frequencies;
    }

    private RealVector toRealVector(Map<String, Integer> map) {
        RealVector vector = new ArrayRealVector(terms.size());
        int i = 0;
        for (String term : terms) {
            int value = map.containsKey(term) ? map.get(term) : 0;
            vector.setEntry(i++, value);
        }
        return (RealVector) vector.mapDivide(vector.getL1Norm());
    }
}