edu.umd.umiacs.clip.tools.scor.TFIDF.java Source code

Java tutorial

Introduction

Here is the source code for edu.umd.umiacs.clip.tools.scor.TFIDF.java

Source

/**
 * Tools Scor
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.umd.umiacs.clip.tools.scor;

import static edu.umd.umiacs.clip.tools.io.AllFiles.readAllLines;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.IOException;
import java.io.UncheckedIOException;
import static java.lang.Integer.min;
import static java.lang.Math.sqrt;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.reducing;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toMap;
import java.util.stream.Stream;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;

/**
 *
 * @author Mossaab Bagdouri
 */
public class TFIDF extends Scorer {

    private transient TObjectIntMap<String> DF;
    protected final int N;
    private IndexReader ir;
    private String field;

    public TFIDF(String dfPath) {
        DF = new TObjectIntHashMap<>();
        List<String> lines = readAllLines(dfPath);
        N = new Integer(lines.get(0).split("\\s+")[1]);
        for (int i = 1; i < lines.size(); i++) {
            String[] pair = lines.get(i).split("\\s+");
            int df = new Integer(pair[1]);
            DF.put(pair[0], df);
        }
    }

    public TFIDF(IndexReader ir, String field) {
        this.ir = ir;
        this.field = field;
        N = ir.numDocs();
    }

    public TFIDF(TObjectIntMap<String> DF, int N) {
        this.DF = DF;
        this.N = N;
    }

    public TObjectIntMap<String> getDF() {
        return DF;
    }

    public int getN() {
        return N;
    }

    public Map<String, Integer> tf(String doc) {
        return Stream.of(doc.split(" ")).filter(word -> !word.isEmpty())
                .collect(groupingBy(identity(), reducing(0, e -> 1, Integer::sum)));
    }

    public Map<String, Integer> tf(int docid) {
        Map<String, Integer> map = new HashMap<>();
        try {
            Terms terms = ir.getTermVector(docid, field);
            if (terms != null) {
                TermsEnum iter = terms.iterator();
                BytesRef element;
                while ((element = iter.next()) != null) {
                    map.put(element.utf8ToString(), (int) iter.totalTermFreq());
                }
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
        return map;
    }

    private Map<String, Float> tfidf(Object doc) {
        return ((doc instanceof Integer) ? tf((int) doc) : tf((String) doc)).entrySet().stream()
                .collect(toMap(entry -> entry.getKey(), entry -> entry.getValue() * idf(df(entry.getKey()))));
    }

    private float dot(Map<String, Float> v1, Map<String, Float> v2) {
        return (float) v1.keySet().parallelStream().filter(v2::containsKey)
                .mapToDouble(word -> v1.get(word) * v2.get(word)).sum();
    }

    private float norm(Map<String, Float> vec) {
        return vec.isEmpty() ? 1 : (float) sqrt(vec.values().parallelStream().mapToDouble(v -> v * v).sum());
    }

    public int df(String word) {
        if (DF == null) {
            try {
                return ir.docFreq(new Term(field, word));
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }
        return DF.containsKey(word) ? min(N, DF.get(word)) : 1;
    }

    private float idf(float df) {
        return (float) Math.log(N / df);
    }

    @Override
    public float scoreProcessed(Object query, Object text) {
        Map<String, Float> queryVec = (Map<String, Float>) query;
        Map<String, Float> textVec = (Map<String, Float>) text;
        return dot(queryVec, textVec) / (norm(queryVec) * norm(textVec));
    }

    @Override
    public Object getProcessedQuery(String query) {
        return tf(query);
    }

    @Override
    public Object getProcessedText(String text) {
        return tfidf(text);
    }

    @Override
    public Object getProcessedText(int docid) {
        return tfidf(docid);
    }

    public String getField() {
        return field;
    }
}