edu.cmu.lti.oaqa.framework.eval.passage.PassageMAPEvalAggregator.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.lti.oaqa.framework.eval.passage.PassageMAPEvalAggregator.java

Source

/*
 *  Copyright 2012 Carnegie Mellon University
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package edu.cmu.lti.oaqa.framework.eval.passage;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CASException;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceSpecifier;
import org.apache.uima.resource.Resource_ImplBase;
import org.oaqa.model.Passage;

import com.google.common.base.Function;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.Sets;
import com.google.common.collect.TreeRangeSet;

import edu.cmu.lti.oaqa.ecd.BaseExperimentBuilder;
import edu.cmu.lti.oaqa.framework.eval.Key;
import edu.cmu.lti.oaqa.framework.eval.retrieval.EvaluationAggregator;
import edu.cmu.lti.oaqa.framework.eval.retrieval.EvaluationHelper;

public class PassageMAPEvalAggregator extends Resource_ImplBase implements EvaluationAggregator<Passage> {

    private PassageMAPEvalPersistenceProvider persistence;

    @Override
    public boolean initialize(ResourceSpecifier aSpecifier, Map<String, Object> tuples)
            throws ResourceInitializationException {
        String pp = (String) tuples.get("persistence-provider");
        if (pp == null) {
            throw new ResourceInitializationException(
                    new IllegalArgumentException("Must provide a parameter of type <persistence-provider>"));
        }
        this.persistence = BaseExperimentBuilder.loadProvider(pp, PassageMAPEvalPersistenceProvider.class);
        return true;
    }

    @Override
    public void update(Key key, String sequenceId, List<Passage> docs, List<Passage> gs, Ordering<Passage> ordering,
            Function<Passage, String> toIdString) throws AnalysisEngineProcessException {
        PassageMAPCounts cnt = count(docs, gs, ordering, toIdString);
        try {
            persistence.deletePassageAggrEval(key, sequenceId);
            persistence.insertPartialCounts(key, sequenceId, cnt);
        } catch (SQLException e) {
            throw new AnalysisEngineProcessException(e);
        }
    }

    private PassageMAPCounts count(List<Passage> docs, List<Passage> gs, Ordering<Passage> ordering,
            Function<Passage, String> toIdString) {
        Set<String> gsSet = EvaluationHelper.getStringSet(gs, toIdString);
        List<Passage> legalDocs = checkLegalSpan(docs);
        List<String> docsArray = EvaluationHelper.getUniqeDocIdList(legalDocs, ordering, toIdString);
        float docavep = EvaluationHelper.getAvgMAP(docsArray, gsSet);
        float psgavep = getAvgPsgMAP(legalDocs, gs);
        float psg2avep = getAvgPsg2MAP(legalDocs, gs);
        float aspavep = getAvgAspMAP(legalDocs, gs);
        return new PassageMAPCounts(docavep, psgavep, psg2avep, aspavep, 1);
    }

    private List<Passage> checkLegalSpan(List<Passage> docs) {
        List<Passage> legalDocs = new ArrayList<Passage>();
        for (Passage doc : docs) {
            try {
                legalDocs.add(isSpanLegal(doc) ? doc : new Passage(doc.getCAS().getJCas()));
            } catch (CASException e) {
                e.printStackTrace();
            }
        }
        return legalDocs;
    }

    private boolean isSpanLegal(Passage doc) {
        // TODO: Legal span check
        return true;
    }

    private float getAvgPsg2MAP(List<Passage> docs, List<Passage> gs) {
        if (gs.size() == 0) {
            return 0;
        }
        Map<String, RangeSet<Integer>> gsId2Spans = Maps.newHashMap();
        Map<String, RangeSet<Integer>> trackGsId2Spans = Maps.newHashMap();
        for (Passage g : gs) {
            String id = g.getUri();
            if (!gsId2Spans.containsKey(id)) {
                gsId2Spans.put(id, TreeRangeSet.<Integer>create());
                trackGsId2Spans.put(id, TreeRangeSet.<Integer>create());
            }
            gsId2Spans.get(g.getUri()).add(Range.closedOpen(g.getBegin(), g.getEnd()));
            trackGsId2Spans.get(g.getUri()).add(Range.closedOpen(g.getBegin(), g.getEnd()));
        }
        int totalChars = 0;
        int overlapLength = 0;
        float sumPrecision = 0;
        for (Passage doc : docs) {
            Range<Integer> docRange = Range.closedOpen(doc.getBegin(), doc.getEnd());
            String docId = doc.getUri();
            if (!gsId2Spans.containsKey(docId) || gsId2Spans.get(docId).encloses(docRange)) {
                totalChars += docRange.upperEndpoint() - docRange.lowerEndpoint();
                continue;
            }
            for (int offset = docRange.lowerEndpoint(); offset < docRange.upperEndpoint(); offset++) {
                if (gsId2Spans.containsKey(docId) && gsId2Spans.get(docId).contains(offset)) {
                    if (trackGsId2Spans.get(docId).contains(offset)) {
                        trackGsId2Spans.get(docId).remove(Range.singleton(offset));
                        // +1
                        totalChars++;
                        overlapLength++;
                        sumPrecision += (float) overlapLength / (float) totalChars;
                    }
                } else {
                    // -1
                    totalChars++;
                }
            }
        }
        int count = 0;
        for (RangeSet<Integer> spans : gsId2Spans.values()) {
            for (Range<Integer> span : spans.asRanges()) {
                count += span.upperEndpoint() - span.lowerEndpoint();
            }
        }
        return (float) sumPrecision / (float) count;
    }

    private float getAvgPsgMAP(List<Passage> docs, List<Passage> gs) {
        if (gs.size() == 0) {
            return 0;
        }
        int totalChars = 0;
        int overlapLength = 0;
        float sumPrecision = 0;
        int count = 0;
        Set<Passage> foundGoldTriplets = Sets.newHashSet();
        for (Passage doc : docs) {
            Range<Integer> docRange = Range.closedOpen(doc.getBegin(), doc.getEnd());
            totalChars += docRange.upperEndpoint() - docRange.lowerEndpoint();
            for (Passage g : gs) {
                if (!g.getUri().equals(doc.getUri()))
                    continue;
                Range<Integer> gRange = Range.closedOpen(g.getBegin(), g.getEnd());
                if (!docRange.isConnected(gRange)) {
                    continue;
                }
                Range<Integer> overlap = docRange.intersection(gRange);
                if (overlap.isEmpty()) {
                    continue;
                }
                overlapLength += overlap.upperEndpoint() - overlap.lowerEndpoint();
                sumPrecision += (float) overlapLength / (float) totalChars;
                count++;
                foundGoldTriplets.add(g);
                break;
            }
        }
        int numZeros = Sets.difference(Sets.newHashSet(gs), foundGoldTriplets).size();
        return (float) sumPrecision / (float) (count + numZeros);
    }

    private float getAvgAspMAP(List<Passage> docs, List<Passage> gs) {
        if (gs.size() == 0) {
            return 0;
        }
        float avep = 0;
        HashSet<String> aspectsFound = new HashSet<String>();
        HashSet<String> uniqueAspectsByTopic = new HashSet<String>();
        for (int j = 0; j < gs.size(); ++j) {
            String[] temp = gs.get(j).getAspects().split("\\|");
            for (int k = 0; k < temp.length; ++k)
                if (!uniqueAspectsByTopic.contains(temp[k]))
                    uniqueAspectsByTopic.add(temp[k]);
        }
        int numerator = 0;
        int denominator = 0;
        float sumPrecision = 0;

        for (int i = 0; i < docs.size(); ++i) {
            String Aspect = "";
            int releFlag = 0;
            String docid = docs.get(i).getUri();
            for (int j = 0; j < gs.size(); ++j) {
                if (!gs.get(j).getUri().equals(docid))
                    continue;
                if ((docs.get(i).getBegin() >= gs.get(j).getBegin())
                        && (docs.get(i).getEnd() <= gs.get(j).getEnd())) {
                    releFlag = 1;
                    Aspect = gs.get(j).getAspects();
                    break;
                } else if ((docs.get(i).getBegin() < gs.get(j).getBegin())
                        && (docs.get(i).getEnd() <= gs.get(j).getEnd())
                        && (docs.get(i).getEnd() >= gs.get(j).getBegin())) {
                    releFlag = 1;
                    Aspect = gs.get(j).getAspects();
                    break;
                } else if ((docs.get(i).getBegin() >= gs.get(j).getBegin())
                        && (docs.get(i).getBegin() <= gs.get(j).getEnd())
                        && (docs.get(i).getEnd() > gs.get(j).getEnd())) {
                    releFlag = 1;
                    Aspect = gs.get(j).getAspects();
                    break;
                } else if ((docs.get(i).getBegin() < gs.get(j).getBegin())
                        && (docs.get(i).getEnd() > gs.get(j).getEnd())) {
                    releFlag = 1;
                    Aspect = gs.get(j).getAspects();
                    break;
                }
            }
            if (releFlag == 1) {
                ArrayList<String> DocsaspectList = new ArrayList<String>();
                String[] split = Aspect.split("\\|");
                for (int j = 0; j < split.length; ++j) {
                    if (!DocsaspectList.contains(split[j])) {
                        DocsaspectList.add(split[j]);
                    }
                }
                int numNewAspects = 0;
                for (int j = 0; j < DocsaspectList.size(); ++j) {
                    if (!aspectsFound.contains(DocsaspectList.get(j)))
                        numNewAspects++;
                }
                if (numNewAspects > 0) {
                    numerator++;
                    denominator++;
                    sumPrecision += numNewAspects * (float) numerator / (float) denominator;
                }

                for (int j = 0; j < split.length; ++j) {
                    if (!aspectsFound.contains(split[j]))
                        aspectsFound.add(split[j]);
                }
            }
            if (releFlag == 0)
                denominator++;
        }
        avep = (float) sumPrecision / (float) uniqueAspectsByTopic.size();
        return avep;
    }
}