edu.cmu.lti.oaqa.baseqa.answer.collective_score.scorers.ShapeDistanceCollectiveAnswerScorer.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.lti.oaqa.baseqa.answer.collective_score.scorers.ShapeDistanceCollectiveAnswerScorer.java

Source

/*
 * Open Advancement Question Answering (OAQA) Project Copyright 2016 Carnegie Mellon University
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations
 * under the License.
 */

package edu.cmu.lti.oaqa.baseqa.answer.collective_score.scorers;

import com.google.common.collect.*;
import com.google.common.primitives.Doubles;
import edu.cmu.lti.oaqa.baseqa.learning_base.AbstractScorer;
import edu.cmu.lti.oaqa.type.answer.Answer;
import edu.cmu.lti.oaqa.util.TypeUtil;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceSpecifier;

import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.stream.DoubleStream;

/**
 * A collective answer scorer that calculates the edit distance between each pair of answer texts,
 * after the original texts are transformed into their shape forms, by placing [A-Z] with "A", [a-z]
 * with "a", [0-9] with "0", and [^A-Za-z0-9] with "_".
 *
 * @see EditDistanceCollectiveAnswerScorer
 *
 * @author <a href="mailto:ziy@cs.cmu.edu">Zi Yang</a> created on 5/15/15
 */
public class ShapeDistanceCollectiveAnswerScorer extends AbstractScorer<Answer> {

    private Iterable<Integer> topLimits;

    private List<Answer> answers;

    private Table<Answer, Answer, Double> distances;

    private Table<Answer, Answer, Double> bdistances;

    @SuppressWarnings("unchecked")
    @Override
    public boolean initialize(ResourceSpecifier aSpecifier, Map<String, Object> aAdditionalParams)
            throws ResourceInitializationException {
        boolean ret = super.initialize(aSpecifier, aAdditionalParams);
        topLimits = (Iterable<Integer>) getParameterValue("top-limit");
        return ret;
    }

    @SuppressWarnings("unchecked")
    @Override
    public void prepare(JCas jcas) {
        answers = TypeUtil.getRankedAnswers(jcas);
        distances = HashBasedTable.create();
        bdistances = HashBasedTable.create();
        ImmutableSet<Answer> answerSet = ImmutableSet.copyOf(answers);
        SetMultimap<Answer, String> answer2shapes = HashMultimap.create();
        answers.forEach(answer -> TypeUtil.getCandidateAnswerVariantNames(answer).stream()
                .map(ShapeDistanceCollectiveAnswerScorer::shape)
                .forEach(shape -> answer2shapes.put(answer, shape)));
        for (List<Answer> pair : Sets.cartesianProduct(answerSet, answerSet)) {
            Answer answer1 = pair.get(0);
            Answer answer2 = pair.get(1);
            if (answer1.equals(answer2)) {
                distances.put(answer1, answer2, 1.0);
                bdistances.put(answer1, answer2, 1.0);
            } else {
                OptionalDouble distance = Sets
                        .cartesianProduct(answer2shapes.get(answer1), answer2shapes.get(answer2)).stream()
                        .mapToDouble(shapepair -> getDistance(shapepair.get(0), shapepair.get(1))).min();
                if (distance.isPresent()) {
                    distances.put(answer1, answer2, 1.0 - distance.getAsDouble());
                    bdistances.put(answer1, answer2, distance.getAsDouble() == 0.0 ? 1.0 : 0.0);
                }
            }
        }
    }

    @Override
    public Map<String, Double> score(JCas jcas, Answer answer) {
        ImmutableMap.Builder<String, Double> builder = ImmutableMap.builder();
        Map<Answer, Double> neighbor2distance = distances.row(answer);
        Map<Answer, Double> neighbor2bdistance = bdistances.row(answer);
        for (int topLimit : topLimits) {
            double[] distances = answers.subList(0, Math.min(answers.size(), topLimit)).stream()
                    .mapToDouble(neighbor -> neighbor2distance.getOrDefault(neighbor, 0.0)).toArray();
            builder.put("shape-distance-min-" + topLimit, Doubles.min(distances));
            builder.put("shape-distance-max-" + topLimit, Doubles.max(distances));
            builder.put("shape-distance-avg-" + topLimit, DoubleStream.of(distances).average().orElse(0.0));
            double[] bdistances = answers.subList(0, Math.min(answers.size(), topLimit)).stream()
                    .mapToDouble(neighbor -> neighbor2bdistance.getOrDefault(neighbor, 0.0)).toArray();
            builder.put("shape-bdistance-min-" + topLimit, Doubles.min(bdistances));
            builder.put("shape-bdistance-max-" + topLimit, Doubles.max(bdistances));
            builder.put("shape-bdistance-avg-" + topLimit, DoubleStream.of(bdistances).average().orElse(0.0));
        }
        return builder.build();
    }

    private double getDistance(String text1, String text2) {
        int distance = StringUtils.getLevenshteinDistance(text1, text2);
        return (double) distance / Math.max(text1.length(), text2.length());
    }

    private static String shape(String text) {
        return text.replaceAll("[A-Z]", "A").replaceAll("[a-z]", "a").replaceAll("[0-9]", "0")
                .replaceAll("[^A-Za-z0-9]", "_");
    }

}