de.tudarmstadt.ukp.experiments.argumentation.sequence.annotator.OnlyFilesMatchingPredictionsReader.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.sequence.annotator.OnlyFilesMatchingPredictionsReader.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.sequence.annotator;

import de.tudarmstadt.ukp.dkpro.argumentation.misc.uima.JCasUtil2;
import de.tudarmstadt.ukp.dkpro.core.api.resources.CompressionUtils;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiReader;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.CASException;
import org.apache.uima.cas.impl.XmiCasDeserializer;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.factory.TypeSystemDescriptionFactory;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.CasCreationUtils;
import org.xml.sax.SAXException;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;

/**
 * @author Ivan Habernal
 */
public class OnlyFilesMatchingPredictionsReader extends XmiReader {
    public static final String PARAM_TOKEN_LEVEL_PREDICTIONS_CSV_FILE = "tokenLevelPredictionsCsvFile";
    @ConfigurationParameter(name = PARAM_TOKEN_LEVEL_PREDICTIONS_CSV_FILE, mandatory = true)
    File tokenLevelPredictionsCsvFile;

    List<Sequence> sequences = new ArrayList<>();

    // cache -- sequences in a map by their first token to narrow down brute-force search later
    Map<String, List<Sequence>> firstTokenMap = new HashMap<>();

    @Override
    public void initialize(UimaContext context) throws ResourceInitializationException {
        try {
            sequences = OnlyFilesMatchingPredictionsReader.extractSequences(tokenLevelPredictionsCsvFile);
            firstTokenMap = OnlyFilesMatchingPredictionsReader.updateFirstTokenCacheFromSequences(sequences);
        } catch (IOException e) {
            throw new ResourceInitializationException(e);
        }

        // must be called at the end!
        super.initialize(context);
    }

    @Override
    protected Collection<Resource> scan(String aBase, Collection<String> aIncludes, Collection<String> aExcludes)
            throws IOException {
        Collection<Resource> resources = super.scan(aBase, aIncludes, aExcludes);

        // filter the resources according to metadata
        Iterator<Resource> resourceIterator = resources.iterator();
        while (resourceIterator.hasNext()) {
            Resource res = resourceIterator.next();

            CAS cas;
            try {
                cas = CasCreationUtils.createCas(TypeSystemDescriptionFactory.createTypeSystemDescription(), null,
                        null);
            } catch (ResourceInitializationException e) {
                throw new IOException(e);
            }

            initCas(cas, res);

            InputStream is = null;

            try {
                is = CompressionUtils.getInputStream(res.getLocation(), res.getInputStream());
                XmiCasDeserializer.deserialize(is, cas, (Boolean) this.getConfigParameterValue(PARAM_LENIENT));

                // remove it?
                Sequence sequence = findSequence(cas.getJCas(), firstTokenMap);
                if (sequence == null) {
                    resourceIterator.remove();
                }
            } catch (SAXException | CASException e) {
                throw new IOException(e);
            } finally {
                IOUtils.closeQuietly(is);
            }
        }

        System.err.println("Returning " + resources.size() + " resources");

        return resources;
    }

    /**
     * Finds the labeled sequence given exact match with tokens in the jcas.
     *
     * @param jCas  jcas
     * @param cache cache
     * @return sequence sequence
     * @throws java.util.NoSuchElementException if no such sequence exists
     */
    public static Sequence findSequence(JCas jCas, Map<String, List<Sequence>> cache)
            throws NoSuchElementException {
        List<Token> tokens = JCasUtil2.selectTokensAsList(jCas);
        String firstToken = tokens.get(0).getCoveredText();

        //        Sequence result = null;
        if (firstToken == null) {
            //            throw new IllegalStateException("First token is null");
            return null;
        }

        List<Sequence> candidateSequence = cache.get(firstToken);

        if (candidateSequence == null) {
            //            throw new IllegalStateException(
            //                    "Cannot find sentence starting with token " + firstToken);
            return null;
        }

        List<Sequence> resultCandidates = new ArrayList<>(candidateSequence);
        int tokenIndex = 1;

        //        while (resultCandidates.size() > 1) {
        while (tokenIndex < tokens.size() && resultCandidates.size() > 0) {
            String jCasToken = tokens.get(tokenIndex).getCoveredText();

            // iterate over all candidates
            Iterator<Sequence> iterator = resultCandidates.iterator();
            while (iterator.hasNext()) {
                Sequence sequence = iterator.next();

                String token = null;

                if (tokenIndex < sequence.getTokens().size()) {
                    TokenEntry tokenEntry = sequence.getTokens().get(tokenIndex);
                    token = tokenEntry.getToken();
                }

                if (token == null || !token.equals(jCasToken)) {
                    iterator.remove();
                }
            }

            tokenIndex++;
        }

        if (resultCandidates.isEmpty()) {
            return null;
        }

        Sequence result = resultCandidates.get(0);

        if (result.getTokens().size() != tokens.size()) {
            System.err.println(result);
            System.err.println(jCas.getDocumentText().substring(tokens.get(0).getBegin(),
                    tokens.get(tokens.size() - 1).getEnd()));
            throw new IllegalStateException("Number of tokens in sentence (" + tokens.size()
                    + ") differs from number of tokens in sequence (" + result.getTokens().size() + ")");
        }

        return result;
    }

    public static List<Sequence> extractSequences(File tokenLevelPredictionsCsvFile1) throws IOException {
        List<Sequence> result = new ArrayList<>();
        // load the CSV
        CSVParser csvParser = new CSVParser(new FileReader(tokenLevelPredictionsCsvFile1),
                CSVFormat.DEFAULT.withCommentMarker('#'));

        String prevSeqId = null;
        Sequence currentSequence = new Sequence();

        int tokenCounter = 0;

        for (CSVRecord csvRecord : csvParser) {
            // row for particular instance (token)
            String predictedTag = csvRecord.get(1);
            String token = csvRecord.get(2);
            String seqId = csvRecord.get(3);

            TokenEntry tokenEntry = new TokenEntry(token, predictedTag);

            // if the token belongs to the previous seqId, add it to the sequence
            if (prevSeqId == null || seqId.equals(prevSeqId)) {
                currentSequence.getTokens().add(tokenEntry);
            } else {
                // otherwise start a new sequence
                result.add(currentSequence);

                currentSequence = new Sequence();
                currentSequence.getTokens().add(tokenEntry);
            }

            prevSeqId = seqId;
            tokenCounter++;
        }

        // don't forget to add the last sequence
        result.add(currentSequence);

        System.out.println("Loaded " + result.size() + " sequences with total " + tokenCounter + " tokens.");

        return result;
    }

    public static Map<String, List<Sequence>> updateFirstTokenCacheFromSequences(List<Sequence> sequences) {
        Map<String, List<Sequence>> result = new HashMap<>();

        // update the cache
        for (Sequence sequence : sequences) {
            String firstToken = sequence.getTokens().get(0).getToken();

            if (!result.containsKey(firstToken)) {
                result.put(firstToken, new ArrayList<Sequence>());
            }

            result.get(firstToken).add(sequence);
        }

        return result;
    }
}