eu.project.ttc.engines.BilingualAligner.java Source code

Java tutorial

Introduction

Here is the source code for eu.project.ttc.engines.BilingualAligner.java

Source

/*******************************************************************************
 * Copyright 2015-2016 - CNRS (Centre National de Recherche Scientifique)
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *******************************************************************************/
package eu.project.ttc.engines;

import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Joiner;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.MinMaxPriorityQueue;
import com.google.common.collect.Sets;
import com.google.common.primitives.Ints;

import eu.project.ttc.metrics.ExplainedValue;
import eu.project.ttc.metrics.Explanation;
import eu.project.ttc.metrics.IExplanation;
import eu.project.ttc.metrics.SimilarityDistance;
import eu.project.ttc.metrics.TextExplanation;
import eu.project.ttc.models.ContextVector;
import eu.project.ttc.models.Term;
import eu.project.ttc.models.TermIndex;
import eu.project.ttc.models.index.CustomTermIndex;
import eu.project.ttc.models.index.TermIndexes;
import eu.project.ttc.models.index.TermMeasure;
import eu.project.ttc.models.index.TermValueProviders;
import eu.project.ttc.resources.BilingualDictionary;
import eu.project.ttc.utils.AlignerUtils;
import eu.project.ttc.utils.IteratorUtils;
import eu.project.ttc.utils.TermSuiteConstants;

/** 
 * 
 * 
 * 
 * @author Damien Cram
 * 
 */
public class BilingualAligner {

    private static final Logger LOGGER = LoggerFactory.getLogger(BilingualAligner.class);
    private static final String MSG_TERM_NOT_NULL = "Source term must not be null";
    private static final String MSG_REQUIRES_SIZE_2_LEMMAS = "The term %s must have exactly two single-word terms (single-word terms: %s)";
    private static final String MSG_SEVERAL_VECTORS_NOT_COMPUTED = "Several terms have no context vectors in target terminology (nb terms with vector: {}, nb terms without vector: {})";
    private static final String ERR_VECTOR_NOT_SET = "Cannot align on term %s. Cause: context vector no set.";

    /**
     * The bonus factor applied to dictionary candidates when they are
     * merged with distributional candidates
     */
    public static final double DICO_CANDIDATE_BONUS_FACTOR = 30;

    private BilingualDictionary dico;
    private TermIndex sourceTermino;
    private TermIndex targetTermino;

    private SimilarityDistance distance;

    public BilingualAligner(BilingualDictionary dico, TermIndex sourceTermino, TermIndex targetTermino,
            SimilarityDistance distance) {
        super();
        this.dico = dico;
        this.targetTermino = targetTermino;
        this.sourceTermino = sourceTermino;
        this.distance = distance;
    }

    /**
     * Overrides the default distance measure.
     * 
     * @param distance
     *          an object implementing the similarity distance
     */
    public void setDistance(SimilarityDistance distance) {
        this.distance = distance;
    }

    /**
     * 
     * Translates the source term with the help of the dictionary
     * and computes the list of <code>contextSize</code> closest candidate
     * terms in the target terminology.
     * 
     * <code>sourceTerm</code>'s context vector must be computed and normalized,
     * as well as all terms' context vectors in the target term index.
     * 
     * @param sourceTerm
     *          the term to align with target term index
     * @param nbCandidates
     *          the number of {@link TranslationCandidate} to return in the returned list
     * @param minCandidateFrequency
     *          the minimum frequency of a target candidate
     * @return
     *          A sorted list of {@link TranslationCandidate} sorted by distance desc. Each
     *          {@link TranslationCandidate} is a container for a target term index's term 
     *          and its translation score.
     *          
     */
    public List<TranslationCandidate> alignDicoThenDistributional(Term sourceTerm, int nbCandidates,
            int minCandidateFrequency) {
        checkNotNull(sourceTerm);
        Preconditions.checkArgument(sourceTerm.isContextVectorComputed(), ERR_VECTOR_NOT_SET,
                sourceTerm.getGroupingKey());

        List<TranslationCandidate> dicoCandidates = Lists.newArrayList();
        /*
         * 1- find direct translation of the term in the dictionary
         */
        dicoCandidates.addAll(
                sortTruncateNormalize(targetTermino, nbCandidates, alignDico(sourceTerm, Integer.MAX_VALUE)));
        applySpecificityBonus(targetTermino, dicoCandidates);

        /*
         * 2- align against all terms in the corpus
         */
        List<TranslationCandidate> alignedCandidateQueue = alignDistributional(sourceTerm, nbCandidates,
                minCandidateFrequency);

        /*
         * 3- Merge candidates
         */
        List<TranslationCandidate> mergedCandidates = dicoCandidates;
        mergedCandidates.addAll(alignedCandidateQueue);
        Collections.sort(mergedCandidates);

        /*
         * 4- Sort, truncate, and normalize
         */
        List<TranslationCandidate> sortedTruncateedNormalized = sortTruncateNormalize(targetTermino, nbCandidates,
                mergedCandidates);
        return sortedTruncateedNormalized;
    }

    public List<TranslationCandidate> alignDistributional(Term sourceTerm, int nbCandidates,
            int minCandidateFrequency) {
        Queue<TranslationCandidate> alignedCandidateQueue = MinMaxPriorityQueue.maximumSize(nbCandidates).create();
        ContextVector sourceVector = sourceTerm.getContextVector();
        ContextVector translatedSourceVector = AlignerUtils.translateVector(sourceVector, dico,
                AlignerUtils.TRANSLATION_STRATEGY_MOST_SPECIFIC, targetTermino);
        ExplainedValue v;
        int nbVectorsNotComputed = 0;
        int nbVectorsComputed = 0;
        for (Term targetTerm : IteratorUtils.toIterable(targetTermino.singleWordTermIterator())) {
            if (targetTerm.getFrequency() < minCandidateFrequency)
                continue;
            if (targetTerm.isContextVectorComputed()) {
                nbVectorsComputed++;
                v = distance.getExplainedValue(translatedSourceVector, targetTerm.getContextVector());
                alignedCandidateQueue.add(new TranslationCandidate(targetTerm, AlignmentMethod.DISTRIBUTIONAL,
                        v.getValue(), v.getExplanation()));
            }
        }
        if (nbVectorsNotComputed > 0) {
            LOGGER.warn(MSG_SEVERAL_VECTORS_NOT_COMPUTED, nbVectorsComputed, nbVectorsNotComputed);
        }

        // sort alignedCandidates
        List<TranslationCandidate> alignedCandidates = Lists.newArrayListWithCapacity(alignedCandidateQueue.size());
        alignedCandidates.addAll(alignedCandidateQueue);
        normalizeCandidateScores(alignedCandidates);
        return Lists.newArrayList(alignedCandidateQueue);
    }

    private static final String ERR_MSG_BAD_SOURCE_LEMMA_SET_SIZE = "Unexpected size for a source lemma set: %s. Expected size: 2";

    /**
     * 
     * 
     * @param sourceTerm
     * @param nbCandidates
     * @param minCandidateFrequency
     * @return
     */
    public List<TranslationCandidate> align(Term sourceTerm, int nbCandidates, int minCandidateFrequency) {
        if (sourceTerm.getGroupingKey().equals("npn: stockage de nergie"))
            System.out.println(sourceTerm);
        Preconditions.checkNotNull(sourceTerm);
        List<TranslationCandidate> mergedCandidates = Lists.newArrayList();
        List<List<Term>> sourceLemmaSets = AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm);
        for (List<Term> sourceLemmaSet : sourceLemmaSets) {
            Preconditions.checkState(sourceLemmaSet.size() == 1 || sourceLemmaSet.size() == 2,
                    ERR_MSG_BAD_SOURCE_LEMMA_SET_SIZE, sourceLemmaSet);
            if (sourceLemmaSet.size() == 1) {
                mergedCandidates.addAll(alignDicoThenDistributional(sourceLemmaSet.get(0), 3 * nbCandidates,
                        minCandidateFrequency));
            } else if (sourceLemmaSet.size() == 2) {
                List<TranslationCandidate> compositional = Lists.newArrayList();
                try {
                    compositional.addAll(alignCompositionalSize2(sourceLemmaSet.get(0), sourceLemmaSet.get(1),
                            nbCandidates, minCandidateFrequency));
                } catch (RequiresSize2Exception e) {
                    // Do nothing
                }
                mergedCandidates.addAll(compositional);
                if (mergedCandidates.isEmpty()) {
                    List<TranslationCandidate> semiDist = Lists.newArrayList();
                    try {
                        semiDist = alignSemiDistributionalSize2Syntagmatic(sourceLemmaSet.get(0),
                                sourceLemmaSet.get(1), nbCandidates, minCandidateFrequency);
                    } catch (RequiresSize2Exception e) {
                        // Do nothing
                    }
                    mergedCandidates.addAll(semiDist);
                }
            }
        }

        removeDuplicatesOnTerm(mergedCandidates);
        return sortTruncateNormalize(targetTermino, nbCandidates, mergedCandidates);
    }

    private List<TranslationCandidate> sortTruncateNormalize(TermIndex termIndex, int nbCandidates,
            Collection<TranslationCandidate> candidatesCandidates) {
        List<TranslationCandidate> list = Lists.newArrayList(candidatesCandidates);
        Collections.sort(list);
        // set rank
        for (int i = 0; i < list.size(); i++)
            list.get(i).setRank(i + 1);
        List<TranslationCandidate> finalCandidates = list.subList(0,
                Ints.min(nbCandidates, candidatesCandidates.size()));
        normalizeCandidateScores(finalCandidates);
        return finalCandidates;
    }

    /*
     * Filter candidates by specificity
     */
    private void applySpecificityBonus(TermIndex termIndex, List<TranslationCandidate> list) {
        Iterator<TranslationCandidate> it = list.iterator();
        TranslationCandidate c;
        while (it.hasNext()) {
            c = (TranslationCandidate) it.next();
            double wr = termIndex.getWRMeasure().getValue(c.getTerm());
            c.setScore(c.getScore() * getSpecificityBonusFactor(wr));
        }
    }

    private double getSpecificityBonusFactor(double wr) {
        if (wr <= 1)
            return 0.5;
        else if (wr <= 2)
            return 1;
        else if (wr <= 10)
            return 1.5;
        else if (wr <= 100)
            return 2;
        else
            return 5;
    }

    public List<TranslationCandidate> alignDico(Term sourceTerm, int nbCandidates) {
        List<TranslationCandidate> dicoCandidates = Lists.newArrayList();
        Collection<String> translations = dico.getTranslations(sourceTerm.getLemma());

        ContextVector translatedSourceVector = AlignerUtils.translateVector(sourceTerm.getContextVector(), dico,
                AlignerUtils.TRANSLATION_STRATEGY_MOST_SPECIFIC, targetTermino);

        for (String candidateLemma : translations) {
            List<Term> terms = targetTermino.getCustomIndex(TermIndexes.LEMMA_LOWER_CASE).getTerms(candidateLemma);
            for (Term candidateTerm : terms) {
                if (candidateTerm.isContextVectorComputed())
                    dicoCandidates.add(new TranslationCandidate(candidateTerm, AlignmentMethod.DICTIONARY,
                            distance.getValue(translatedSourceVector, candidateTerm.getContextVector()),
                            Explanation.emptyExplanation()));
            }
        }

        return dicoCandidates;
    }

    public boolean canAlignCompositional(Term sourceTerm) {
        return AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm).stream()
                .anyMatch(slTerms -> slTerms.size() == 2);
    }

    public List<TranslationCandidate> alignCompositional(Term sourceTerm, int nbCandidates,
            int minCandidateFrequency) {
        Preconditions.checkArgument(canAlignCompositional(sourceTerm),
                "Cannot align <%s> with compositional method", sourceTerm);

        List<List<Term>> singleLemmaTermSets = AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm);

        List<TranslationCandidate> candidates = Lists.newArrayList();

        for (List<Term> singleLemmaTerms : singleLemmaTermSets) {
            if (singleLemmaTerms.size() == 2) {
                candidates.addAll(alignCompositionalSize2(singleLemmaTerms.get(0), singleLemmaTerms.get(1),
                        nbCandidates, minCandidateFrequency));
            }
        }

        return sortTruncateNormalize(targetTermino, nbCandidates, candidates);
    }

    public boolean canAlignSemiDistributional(Term sourceTerm) {
        return AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm).stream()
                .anyMatch(slTerms -> slTerms.size() == 2);
    }

    public List<TranslationCandidate> alignSemiDistributional(Term sourceTerm, int nbCandidates,
            int minCandidateFrequency) {
        Preconditions.checkArgument(canAlignCompositional(sourceTerm),
                "Cannot align <%s> with compositional method", sourceTerm);

        List<List<Term>> singleLemmaTermSets = AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm);

        List<TranslationCandidate> candidates = Lists.newArrayList();

        for (List<Term> singleLemmaTerms : singleLemmaTermSets) {
            if (singleLemmaTerms.size() == 2) {
                candidates.addAll(alignSemiDistributionalSize2Syntagmatic(singleLemmaTerms.get(0),
                        singleLemmaTerms.get(1), nbCandidates, minCandidateFrequency));
            }
        }

        return sortTruncateNormalize(targetTermino, nbCandidates, candidates);

    }

    public List<TranslationCandidate> alignCompositionalSize2(Term lemmaTerm1, Term lemmaTerm2, int nbCandidates,
            int minCandidateFrequency) {
        List<TranslationCandidate> candidates = Lists.newArrayList();
        List<TranslationCandidate> dicoCandidates1 = alignDico(lemmaTerm1, Integer.MAX_VALUE);
        List<TranslationCandidate> dicoCandidates2 = alignDico(lemmaTerm2, Integer.MAX_VALUE);

        candidates.addAll(combineCandidates(dicoCandidates1, dicoCandidates2, AlignmentMethod.COMPOSITIONAL));
        return sortTruncateNormalize(targetTermino, nbCandidates, candidates);
    }

    public static class RequiresSize2Exception extends RuntimeException {
        private static final long serialVersionUID = 1L;
        private Term term;
        private List<Term> swtTerms;

        public RequiresSize2Exception(Term term, List<Term> swtTerms) {
            super();
            this.term = term;
            this.swtTerms = swtTerms;
        }

        @Override
        public String getMessage() {
            return String.format(MSG_REQUIRES_SIZE_2_LEMMAS, term,
                    Joiner.on(TermSuiteConstants.COMMA).join(swtTerms));
        }
    }

    /**
     * Join to lists of swt candidates and use the specificities (wrLog)
     * of the combine terms as the candidate scores.
     * 
     * FIXME Bad way of scoring candidates. They should be scored by similarity of context vectors with the source context vector
     * 
     * @param candidates1
     * @param candidates2
     * @return
     */
    private Collection<TranslationCandidate> combineCandidates(Collection<TranslationCandidate> candidates1,
            Collection<TranslationCandidate> candidates2, AlignmentMethod method) {
        Collection<TranslationCandidate> combination = Sets.newHashSet();
        TermMeasure wrLog = targetTermino.getWRLogMeasure();
        wrLog.compute();
        for (TranslationCandidate candidate1 : candidates1) {
            for (TranslationCandidate candidate2 : candidates2) {
                /*
                 * 1- create candidate combine terms
                 */
                CustomTermIndex index = targetTermino.getCustomIndex(TermIndexes.WORD_COUPLE_LEMMA_LEMMA);
                List<Term> candidateCombinedTerms = index
                        .getTerms(candidate1.getTerm().getLemma() + "+" + candidate2.getTerm().getLemma());
                candidateCombinedTerms.addAll(
                        index.getTerms(candidate2.getTerm().getLemma() + "+" + candidate1.getTerm().getLemma()));
                if (candidateCombinedTerms.isEmpty())
                    continue;

                /*
                 * 2- Avoids retrieving too long terms by keeping the ones that have 
                 * the lowest number of lemma+lemma keys.
                 */
                final Map<Term, Collection<String>> termLemmaLemmaKeys = Maps.newHashMap();
                for (Term t : candidateCombinedTerms)
                    termLemmaLemmaKeys.put(t,
                            TermValueProviders.WORD_LEMMA_LEMMA_PROVIDER.getClasses(targetTermino, t));
                Collections.sort(candidateCombinedTerms, new Comparator<Term>() {
                    @Override
                    public int compare(Term o1, Term o2) {
                        return Integer.compare(termLemmaLemmaKeys.get(o1).size(),
                                termLemmaLemmaKeys.get(o2).size());
                    }
                });
                List<Term> filteredTerms = Lists.newArrayList();
                int minimumNbClasses = termLemmaLemmaKeys.get(candidateCombinedTerms.get(0)).size();
                for (Term t : candidateCombinedTerms) {
                    if (termLemmaLemmaKeys.get(t).size() == minimumNbClasses)
                        filteredTerms.add(t);
                    else
                        break;
                }

                /*
                 * 3- Create candidates from filtered terms
                 */
                for (Term t : filteredTerms) {
                    combination.add(new TranslationCandidate(t, method, wrLog.getValue(t),
                            new TextExplanation(String.format("Spcificit: %.1f", wrLog.getValue(t)))));
                }
            }
        }
        return combination;
    }

    private void checkNotNull(Term sourceTerm) {
        Preconditions.checkNotNull(sourceTerm, MSG_TERM_NOT_NULL);
    }

    public List<TranslationCandidate> alignSemiDistributionalSize2Syntagmatic(Term lemmaTerm1, Term lemmaTerm2,
            int nbCandidates, int minCandidateFrequency) {
        List<TranslationCandidate> candidates = Lists.newArrayList();

        Collection<? extends TranslationCandidate> t1 = semiDistributional(lemmaTerm1, lemmaTerm2);
        candidates.addAll(t1);
        Collection<? extends TranslationCandidate> t2 = semiDistributional(lemmaTerm2, lemmaTerm1);
        candidates.addAll(t2);

        removeDuplicatesOnTerm(candidates);
        return sortTruncateNormalize(targetTermino, nbCandidates, candidates);
    }

    private void removeDuplicatesOnTerm(List<TranslationCandidate> candidates) {
        Set<Term> set = Sets.newHashSet();
        Iterator<TranslationCandidate> it = candidates.iterator();
        while (it.hasNext())
            if (!set.add(it.next().getTerm()))
                it.remove();
    }

    private Collection<? extends TranslationCandidate> semiDistributional(Term dicoTerm, Term vectorTerm) {
        List<TranslationCandidate> candidates = Lists.newArrayList();
        List<TranslationCandidate> dicoCandidates = alignDico(dicoTerm, Integer.MAX_VALUE);

        if (dicoCandidates.isEmpty())
            // Optimisation: no need to align since there is no possible combination
            return candidates;
        else {
            List<TranslationCandidate> vectorCandidates = alignDicoThenDistributional(vectorTerm, Integer.MAX_VALUE,
                    1);
            return combineCandidates(dicoCandidates, vectorCandidates, AlignmentMethod.SEMI_DISTRIBUTIONAL);
        }
    }

    private void normalizeCandidateScores(List<TranslationCandidate> candidates) {
        double sum = 0;
        for (TranslationCandidate cand : candidates)
            sum += cand.getScore();

        if (sum > 0d)
            for (TranslationCandidate cand : candidates)
                cand.setScore(cand.getScore() / sum);

    }

    public static enum AlignmentMethod {
        DICTIONARY("dico", "dictionary"), DISTRIBUTIONAL("dist", "distributional"), COMPOSITIONAL("comp",
                "compositional"), SEMI_DISTRIBUTIONAL("s-dist", "semi-distributional");

        private String shortName;
        private String longName;

        private AlignmentMethod(String shortName, String longName) {
            this.shortName = shortName;
            this.longName = longName;
        }

        public String getShortName() {
            return shortName;
        }

        public String getLongName() {
            return longName;
        }
    }

    public static class TranslationCandidate implements Comparable<TranslationCandidate> {
        private IExplanation explanation;
        private AlignmentMethod method;
        private Term term;
        private int rank = -1;
        private double score;

        //      private TranslationCandidate(Term term, AlignmentMethod method, double score) {
        //         this(term, method, score, Explanation.emptyExplanation());
        //      }

        public void setScore(double score) {
            this.score = score;
        }

        public void setRank(int rank) {
            this.rank = rank;
        }

        public int getRank() {
            return rank;
        }

        private TranslationCandidate(Term term, AlignmentMethod method, double score, IExplanation explanation) {
            super();
            this.term = term;
            this.score = score;
            this.method = method;
            this.explanation = explanation;
        }

        @Override
        public int compareTo(TranslationCandidate o) {
            return ComparisonChain.start().compare(o.score, score).compare(term, o.term).result();
        }

        public AlignmentMethod getMethod() {
            return method;
        }

        public double getScore() {
            return score;
        }

        public Term getTerm() {
            return term;
        }

        @Override
        public boolean equals(Object obj) {
            if (obj instanceof TranslationCandidate)
                return Objects.equal(((TranslationCandidate) obj).score, this.score)
                        && Objects.equal(((TranslationCandidate) obj).term, this.term);
            else
                return false;
        }

        public IExplanation getExplanation() {
            return explanation;
        }

        @Override
        public int hashCode() {
            return Objects.hashCode(term, score);
        }

        @Override
        public String toString() {
            return MoreObjects.toStringHelper(this).addValue(this.term.getGroupingKey())
                    .addValue(this.method.toString()).add("s", String.format("%.2f", this.score)).toString();
        }
    }

    public BilingualDictionary getDico() {
        return this.dico;
    }
}