Java tutorial
/* * Copyright (C) 2015 ikonstas * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package pltag.parser.semantics.discriminative; import pltag.parser.semantics.discriminative.incrAnalyses.IncrementalAnalyses; import com.esotericsoftware.wildcard.Paths; import com.googlecode.concurrentlinkedhashmap.ConcurrentLinkedHashMap; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.Label; import edu.stanford.nlp.trees.Tree; import fig.basic.IOUtils; import static fig.basic.IOUtils.openObjIn; import fig.basic.LogInfo; import fig.basic.Pair; import java.io.File; import java.io.ObjectInputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.BitSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import pltag.parser.BuildBlock; import pltag.parser.Chart; import pltag.parser.ChartEntry; import pltag.parser.Fringe; import pltag.parser.Node; import pltag.parser.Options; import pltag.parser.ParsingTask; import pltag.parser.StringTreeAnalysis; import pltag.parser.semantics.DepTreeState; import pltag.parser.semantics.Dependencies; import pltag.parser.semantics.conll.Argument; import pltag.parser.semantics.conll.Predicate; import pltag.parser.semantics.conll.Proposition; import pltag.parser.semantics.discriminative.incrAnalyses.IncrementalAnalysis; import pltag.parser.semantics.discriminative.incrAnalyses.NbestIncrementalAnalyses; import pltag.util.Utils; /** * * @author sinantie */ public class ExtractFeatures implements Serializable { private static final long serialVersionUID = -1L; public static final int FEAT_BASELINE_SCORE = 0, FEAT_BASELINE_WORD_SCORE = 1, FEAT_ELEM_TREE = 2, FEAT_PREV_ELEM_TREE = 3, FEAT_ELEM_TREE_BIGRAM = 4, FEAT_ELEM_TREE_UNLEX = 5, FEAT_PREV_ELEM_TREE_UNLEX = 6, FEAT_ELEM_TREE_UNLEX_BIGRAM = 7, FEAT_INTEGRATION_POINT = 8, FEAT_FRINGE = 9, FEAT_NUM_FRINGE_NODES = 10, FEAT_NUM_FRINGE_PREDICT_NODES = 11, FEAT_RIGHT_BRANCH_REST = 12, FEAT_RIGHT_BRANCH_SPINE = 13, FEAT_HEAVY = 14, FEAT_NEIGHBOURS_L1 = 15, FEAT_NEIGHBOURS_L2 = 16, FEAT_CO_PAR = 17, FEAT_CO_LEN_PAR = 18, FEAT_IP_ELEM_TREE = 19, FEAT_IP_ELEM_TREE_UNLEX = 20, FEAT_WORD_L2 = 21, FEAT_WORD_L3 = 22, FEAT_SEMANTIC_SCORE = 23, FEAT_SEMANTIC_VEC_SCORES = 24, FEAT_SRL_TRIPLES = 25, FEAT_SRL_TRIPLES_POS = 26, FEAT_SRL_INCOMPLETE_TRIPLES = 27, FEAT_SRL_INCOMPLETE_TRIPLES_POS = 28, FEAT_SRL_ROLE = 29, FEAT_SRL_DEPENDENCY = 30, FEAT_SRL_DEPENDENCY_POS = 31, FEAT_SRL_ROLE_PRED = 32, FEAT_SRL_ROLE_PRED_POS = 33, FEAT_SRL_ROLE_ARG = 34, FEAT_SRL_ROLE_ARG_POS = 35, FEAT_SRL_PRED = 36, FEAT_SRL_PRED_POS = 37, FEAT_SRL_ARG = 38, FEAT_SRL_ARG_POS = 39, FEAT_SRL_FRAME = 40, FEAT_SRL_FRAME_POS = 41, NUM_SYNTACTIC_FEATURES = 12, NUM_GLOBAL_SYNTACTIC_FEATURES = 11, NUM_SRL_FEATURES = 17, NUM_SEMANTIC_FEATURES = 2, NUM_SEMANTIC_VEC_SIZE = 36; // private final CollinsHeadFinder headFinder; private transient final Options opts; @Deprecated private Map<String, Chart> incrAnalyses; private final Map<String, IncrementalAnalyses> incrAnalysesFeatures, incrAnalysesFeaturesSentLevel; private DiscriminativeFeatureIndexers featureIndexers, featureIndexersSentLevel; private final ConcurrentLinkedHashMap<String, String> unlexTreeCache; public ExtractFeatures() { this.opts = null; // unlexTreeCache = new LRUMap<String, String>(1000); incrAnalysesFeatures = new HashMap<String, IncrementalAnalyses>(); incrAnalysesFeaturesSentLevel = new HashMap<String, IncrementalAnalyses>(); unlexTreeCache = new ConcurrentLinkedHashMap.Builder<String, String>().maximumWeightedCapacity(1000) .build(); // headFinder = new CollinsHeadFinder(); } public ExtractFeatures(Options opts) { this.opts = opts; featureIndexers = new DiscriminativeFeatureIndexers(); featureIndexersSentLevel = new DiscriminativeFeatureIndexers(); incrAnalysesFeatures = new HashMap<String, IncrementalAnalyses>(); incrAnalysesFeaturesSentLevel = new HashMap<String, IncrementalAnalyses>(); // unlexTreeCache = new LRUMap<String, String>(1000); unlexTreeCache = new ConcurrentLinkedHashMap.Builder<String, String>().maximumWeightedCapacity(1000) .build(); // headFinder = new CollinsHeadFinder(); } public void execute() { boolean isOracle = opts.semanticsModel == Options.SemanticsModelType.oracle || opts.semanticsModel == Options.SemanticsModelType.oracleAllRoles; if (isOracle) { incrAnalyses = readCharts(opts.inputPaths.get(0) + ".oracle.obj.gz", true); processCharts(incrAnalyses, true); } else { File f = new File(opts.inputPaths.get(0)); Paths paths = new Paths(f.getParent(), f.getName() + "-*"); for (String path : paths.getPaths()) { incrAnalyses = readCharts(path, false); processCharts(incrAnalyses, false); } } } /** * Load feature indexers from serialized object file * * @param saveIncrAnalysesFeatures */ public void loadFeatureIndexers(boolean saveIncrAnalysesFeatures, boolean isDirectory) { if (isDirectory) // featureIndexers = DiscriminativeFeatureIndexers.loadFeatureIndexers(opts.discriminativeFeatureIndexers); featureIndexers = new DiscriminativeFeatureIndexers(opts.discriminativeFeatureIndexers); else // featureIndexers = (DiscriminativeFeatureIndexers) IOUtils.readObjFileHard(opts.discriminativeFeatureIndexers + ".obj.gz"); featureIndexers = new DiscriminativeFeatureIndexers((DiscriminativeFeatureIndexers) IOUtils .readObjFileHard(opts.discriminativeFeatureIndexers + ".obj.gz")); if (saveIncrAnalysesFeatures) { if (isDirectory) // featureIndexersSentLevel = DiscriminativeFeatureIndexers.loadFeatureIndexers(opts.discriminativeFeatureIndexers); featureIndexersSentLevel = new DiscriminativeFeatureIndexers(opts.discriminativeFeatureIndexers); else // featureIndexersSentLevel = (DiscriminativeFeatureIndexers) IOUtils.readObjFileHard(opts.discriminativeFeatureIndexers + ".obj.gz");// + "_sentenceLevel.obj.gz"); featureIndexersSentLevel = new DiscriminativeFeatureIndexers((DiscriminativeFeatureIndexers) IOUtils .readObjFileHard(opts.discriminativeFeatureIndexers + ".obj.gz"));// + "_sentenceLevel.obj.gz"); } } public void setFeatureIndexers(DiscriminativeFeatureIndexers featureIndexers) { this.featureIndexers = featureIndexers; } /** * Save feature indexers to file * * @param extend * @return */ public void saveFeatureIndexers(boolean extend) { // Commented out saving indexers for full chart features, as it is a very lengthy process // featureIndexers.saveFlattenedFeatureIndexers(opts.discriminativeFeatureIndexers + (extend ? "_extended" : "")); featureIndexersSentLevel.saveFlattenedFeatureIndexers( opts.discriminativeFeatureIndexers + "_sentenceLevel" + (extend ? "_extended" : "")); try { // Commented out saving indexers for full chart features, as it is a very lengthy process // IOUtils.writeObjFileHard(opts.discriminativeFeatureIndexers + (extend ? "_extended" : "") + ".obj.gz", featureIndexers); IOUtils.writeObjFileHard( opts.discriminativeFeatureIndexers + "_sentenceLevel" + (extend ? "_extended" : "") + ".obj.gz", featureIndexersSentLevel); //&& } catch (Exception e) { LogInfo.error(e); } // IOUtils.writeObjFileEasy(opts.incrAnalysesFeaturesPath, incrAnalysesFeatures); } // public boolean saveFeatureIndexersFlat(boolean extend) // { // String indexerFilename = opts.discriminativeFeatureIndexers + (extend ? "_extended" : "") + ".obj.gz"; // String indexerSentLevelFilename = opts.discriminativeFeatureIndexers + "_sentenceLevel" + (extend ? "_extended" : "") + ".obj.gz"; // // } public void loadFeatures(String path) { // incrAnalysesFeatures = (Map<String, IncrementalAnalyses>) IOUtils.readObjFileHard(path); try { ObjectInputStream in = openObjIn(path); Map<String, IncrementalAnalyses> map = (Map<String, IncrementalAnalyses>) in.readObject(); // for(String key : map.keySet()) // System.out.println(key); incrAnalysesFeatures.putAll(map); } catch (Exception e) { e.printStackTrace(); } } public boolean saveFeatures(String path, String pathSentLevel) { // Commented out saving of full incremental features as it is a very length process. // return IOUtils.writeObjFileEasy(path, incrAnalysesFeatures) && // IOUtils.writeObjFileEasy(pathSentLevel, incrAnalysesFeaturesSentLevel); return IOUtils.writeObjFileEasy(pathSentLevel, incrAnalysesFeaturesSentLevel); } @Deprecated private Map<String, Chart> readCharts(String inputPath, boolean isOracle) { Utils.logs( "Reading " + (isOracle ? "oracle" : "") + " incremental analyses from file " + inputPath + " ..."); Map<String, Chart> map = (Map<String, Chart>) IOUtils.readObjFileEasy(inputPath); if (map != null) { System.out.println("Done"); System.out.println(map.size()); // remove } else { LogInfo.error("Error loading " + (isOracle ? "oracle" : "") + " incremental analyses object file"); map = new HashMap(); } return map; } @Deprecated private void processCharts(Map<String, Chart> incrAnalyses, boolean isOracle) { for (Entry<String, Chart> entry : incrAnalyses.entrySet()) { incrAnalysesFeatures.put(entry.getKey(), extractFeaturesFromChart(entry.getValue(), isOracle)); } } /** * Process chart entries and extract features for prediction model, outputting features from arg,max analysis at word and sentence * level * @param name * @param chart * @param analyses */ public void processChart(String name, Chart chart, List<StringTreeAnalysis> analyses) { // word-level feature extraction: arg,max // IncrementalAnalyses incr = extractFeaturesFromChart(chart, false); // synchronized(incrAnalysesFeatures) // { // incrAnalysesFeatures.put(name, incr); // } // sentence-level feature extraction: arg,max IncrementalAnalyses incrSentence = extractFeaturesFromChartSentenceLevel(chart, analyses); synchronized (incrAnalysesFeaturesSentLevel) { incrAnalysesFeaturesSentLevel.put(name, incrSentence); } } /** * * Process chart entries and extract features for oracle model, outputting features from gold-standard (word-Level) and oracle,max F1 * (sentenceLevel analysis with max F1 score) * @param name * @param chart * @param analyses * @param maxF1Analysis */ public void processChartOracle(String name, Chart chart, List<StringTreeAnalysis> analyses, StringTreeAnalysis maxF1Analysis) { IncrementalAnalyses incr = extractFeaturesFromChart(chart, true); synchronized (incrAnalysesFeatures) { incrAnalysesFeatures.put(name, incr); } IncrementalAnalyses incrMaxF1 = extractFeaturesFromChartOracle(chart, maxF1Analysis); synchronized (incrAnalysesFeaturesSentLevel) { incrAnalysesFeaturesSentLevel.put(name, incrMaxF1); } } public IncrementalAnalyses extractFeaturesFromChart(Chart chart, boolean isOracle) { // int count = 0; int words = chart.length(); IncrementalAnalyses analyses = new IncrementalAnalyses(words); for (int i = 0; i < words; i++) { ArrayList<ChartEntry> nbest = chart.get(i); int nbestPos = 0; NbestIncrementalAnalyses nbestAnalyses; if (isOracle) { nbestAnalyses = new NbestIncrementalAnalyses(1); nbestAnalyses.addIncrementalAnalysis(extractFeaturesFromChartEntry(chart, nbest.get(0), nbest.get(0).getBestProbability(), i == words - 1, i, featureIndexers, true), 0); } else { nbestAnalyses = new NbestIncrementalAnalyses(nbest.size()); for (ChartEntry entry : nbest) { // IncrementalAnalysis incrAnalysis = extractFeaturesFromChartEntry(chart, entry, entry.getBestProbability(), i, true); // if(incrAnalysis == null) // count++; // nbestAnalyses.addIncrementalAnalysis(incrAnalysis, nbestPos++); nbestAnalyses.addIncrementalAnalysis(extractFeaturesFromChartEntry(chart, entry, entry.getBestProbability(), i == words - 1, i, featureIndexers, true), nbestPos++); } } analyses.addNbestAnalyses(nbestAnalyses, i); } // System.out.println("Null incr analyses: " + count); return analyses; } /** * * Extract features from chart at a sentence level. The method extract the backpointer traces * at the last word, and extracts features from the chart entries for each trace. The result is * an <code>IncrementalAnalyses</code> object with a fixed number of n-best entries for each word. * Each entry corresponds to the n-th part of the trace of the n-th sentence, where n is the number * of analyses at the end of the sentence. * @param chart * @param treeAnalyses * @return */ public IncrementalAnalyses extractFeaturesFromChartSentenceLevel(Chart chart, List<StringTreeAnalysis> treeAnalyses) { int words = chart.length(); int numOfsentences = treeAnalyses.size(); IncrementalAnalyses analyses = new IncrementalAnalyses(words); for (int i = 0; i < words; i++) { analyses.addNbestAnalyses(new NbestIncrementalAnalyses(numOfsentences), i); } // get built traces from analyses (the format is [wordId_1, chartSlice_1.entryId][wordId_2, chartSlize_2.entryId]... int col = 0; for (StringTreeAnalysis analysis : treeAnalyses) { Map<Integer, Trace> traceMap = reFormatTrace(analysis.getBuiltTrace()); // System.out.println("Complete analysis: " + analysis.getProbability() + " " + ParsingTask.convertStringTree(analysis)); // System.out.println(analysis.getBuiltTrace()); for (Entry<Integer, Trace> entry : traceMap.entrySet()) { Integer timestamp = entry.getKey(); Trace trace = entry.getValue(); ArrayList<ChartEntry> nbest = chart.get(timestamp); // analyses.getNbestAnalyses()[timestamp].addIncrementalAnalysis(extractFeaturesFromChartEntry(nbest.get(trace.entryPos), trace.bbPos, trace.analysisPos, timestamp, true), col); analyses.getNbestAnalyses()[timestamp].addIncrementalAnalysis( extractFeaturesFromChartEntry(chart, nbest.get(trace.entryPos), trace.analysisProb, timestamp == words - 1, timestamp, featureIndexersSentLevel, true), col); } col++; } return analyses; } /** * * Extracts features for analysis with highest evalb F1 score at the end of the sentence. * @param chart * @param analysis * @return */ public IncrementalAnalyses extractFeaturesFromChartOracle(Chart chart, StringTreeAnalysis analysis) { return extractFeaturesFromChartSentenceLevel(chart, Arrays.asList(new StringTreeAnalysis[] { analysis })); } /** * * Reformat built traces from analyses. The input format is [wordId_1, chartSlice_1.entryId][wordId_2, chartSlice_2.entryId]... * and output format is a map with the word_id as the key and chartSlice.entryId as the value. * @param trace * @return */ private Map<Integer, Trace> reFormatTrace(String trace) { Map<Integer, Trace> map = new HashMap(); for (String token : trace.split("]")) { String[] pair = token.split(","); Integer wordId = Integer.valueOf(pair[0].substring(1)); String[] ceIndexBbPos = pair[1].substring(1).split(" "); Integer entryId = Integer.valueOf(ceIndexBbPos[0]); Integer bbPos = Integer.valueOf(ceIndexBbPos[1]); // Integer analysisPos = Integer.valueOf(ceIndexBbPos[2]); Double analysisProb = Double.valueOf(ceIndexBbPos[2]); // map.put(wordId, new Trace(entryId, bbPos, analysisProb)); map.put(wordId, new Trace(entryId, bbPos, analysisProb)); } return map; } public IncrementalAnalysis extractFeaturesFromChartEntry(Chart chart, ChartEntry entry, double analysisProb, boolean endOfSent, int timestamp, boolean train) { return extractFeaturesFromChartEntry(chart, entry, analysisProb, endOfSent, timestamp, featureIndexers, train); } public IncrementalAnalysis extractFeaturesFromChartEntry(Chart chart, ChartEntry entry, double analysisProb, boolean endOfSent, int timestamp, DiscriminativeFeatureIndexers featureIndexers, boolean train) { IncrementalAnalysis analysis = new IncrementalAnalysis(); // analysis.setBaselineScore(entry.getBestProbability()); // analysis.setBaselineWordScore(entry.getBuildBlocks().get(0).getProbability()); if (!extractSyntacticFeatures(chart, entry, analysisProb, endOfSent, analysis, timestamp, featureIndexers, train)) return null; DepTreeState treeState = ((DepTreeState) entry.getTreeState()); extractSrlFeatures(treeState.getDependencies(), analysis, timestamp, featureIndexers, train); // extractSemanticFeatures(analysis); // TODO: implement only if we need semantics on the fly // double partialSyntacticScore = computeSyntacticScore(entry); // double partialSrlScore = computeSemanticScore(treeState.getDependencies()); return analysis; } private boolean extractSyntacticFeatures(Chart chart, ChartEntry entry, double analysisProb, boolean endOfSent, IncrementalAnalysis analysis, int timestamp, DiscriminativeFeatureIndexers featureIndexers, boolean train) { // contains information on the current tree and pointer to the previous tree List<BuildBlock> bbList = entry.getBuildBlocks(); if (!bbList.isEmpty()) { Pair<StringTreeAnalysis, BuildBlock> currentAnalysisPair = getCurrentAnalysis(bbList, analysisProb); if (currentAnalysisPair == null) { ArrayList<ArrayList<BuildBlock>>[] emptyBBlist = ParsingTask.getEmptyBBList(chart, entry); // create backpointers on buildblocks to be exploited in the feature extraction process chart.setTime(System.currentTimeMillis()); chart.getAnalyses(entry, emptyBBlist, "", true, analysisProb, false); // LogInfo.error("Current buildblock not found"); currentAnalysisPair = getCurrentAnalysis(bbList, analysisProb); if (currentAnalysisPair == null) return false; } BuildBlock curBb = currentAnalysisPair.getSecond(); StringTreeAnalysis sta = currentAnalysisPair.getFirst(); // BuildBlock curBb = bbList.get(buildBlockPos); // StringTreeAnalysis sta = curBb.getAnalysis(analysisProb); // System.out.println("\t" + Utils.normaliseTree(ParsingTask.convertStringTree(sta))); analysis.setBaselineScore(sta.getProbability()); analysis.setBaselineWordScore(curBb.getProbability()); extractTreeFeatures(curBb, analysis, false, featureIndexers, train); extractFringeFeatures(entry, analysis, featureIndexers, train); // get previous entry (if any) if (timestamp == 0) { analysis.setPrevElemTree("S", featureIndexers, train); analysis.setPrevElemTreeUnlex("S", featureIndexers, train); } else { List<BuildBlock> prevBbList = curBb.getPrevChartEntry().getBuildBlocks(); if (!prevBbList.isEmpty()) { Pair<StringTreeAnalysis, BuildBlock> prevAnalysisPair = getCurrentAnalysis(prevBbList, analysis.getBaselineScore() - analysis.getBaselineWordScore()); if (prevAnalysisPair == null) { // LogInfo.error("Current prev buildblock not found"); return false; } // System.out.println("\t\t" + Utils.normaliseTree(ParsingTask.convertStringTree(prevAnalysisPair.getFirst()))); extractTreeFeatures(prevAnalysisPair.getSecond(), analysis, true, featureIndexers, train); } } // get bigrams analysis.setElemTreeBigram(analysis.getPrevElemTree() + " " + analysis.getElemTree(), featureIndexers, train); analysis.setElemTreeUnlexBigram(analysis.getPrevElemTreeUnlex() + " " + analysis.getElemTreeUnlex(), featureIndexers, train); // get integration point analysis.setIntegrationPoint(curBb.getIpi().getCategory(), featureIndexers, train); // global syntactic features if (opts.useGlobalSyntacticFeatures) extractStringTreeAnalysisFeatures(sta, analysis, timestamp + 1, endOfSent, featureIndexers, train); } return true; } private Pair<StringTreeAnalysis, BuildBlock> getCurrentAnalysis(List<BuildBlock> bbList, double probability) { // Set<Pair<StringTreeAnalysis, BuildBlock>> set = new TreeSet(new Pair.FirstComparator()); // List<Pair<StringTreeAnalysis, BuildBlock>> list = new ArrayList(); for (BuildBlock bb : bbList) { for (Pair<StringTreeAnalysis, BuildBlock> pair : bb.getAnalyses()) { if (Math.abs(pair.getFirst().getProbability() - probability) < 0.00000001) { // System.out.println("\t" + Utils.normaliseTree(ParsingTask.convertStringTree(pair.getFirst()))); return pair; } } } return null; } private void extractTreeFeatures(BuildBlock bb, IncrementalAnalysis analysis, boolean prev, DiscriminativeFeatureIndexers featureIndexers, boolean train) { // get elementary tree // String elemTree = bb.getElemTreeString(); String elemTree = bb.getElemTreeString(); if (!prev) analysis.setElemTree(elemTree, featureIndexers, train); else analysis.setPrevElemTree(elemTree, featureIndexers, train); // get unlex elementary tree // String elemTreeUnlex = bb.getElemTree().toStringUnlex(false); // TODO: Use LRU cache String elemTreeUnlex = getUnlexTree(elemTree, bb); if (!prev) analysis.setElemTreeUnlex(elemTreeUnlex, featureIndexers, train); else analysis.setPrevElemTreeUnlex(elemTreeUnlex, featureIndexers, train); // } private void extractFringeFeatures(ChartEntry entry, IncrementalAnalysis analysis, DiscriminativeFeatureIndexers featureIndexers, boolean train) { // get raw fringe Fringe fringe = entry.getTreeState().getFringe(); // create indexed open-left and open-right fringe lists, substitution node and up/down indices (to indicate prediction nodes) byte[] fringeUpIndex = new byte[fringe.size()]; byte[] fringeDownIndex = new byte[fringe.size()]; int[] openRight = new int[fringe.getAdjNodesOpenRight().size()]; int[] openLeft = new int[fringe.getAdjNodesOpenLeft().size()]; int i = 0; for (Node node : fringe.getAdjNodesOpenRight()) { openRight[i] = featureIndexers.getCategoryIndex(node.getCategory(), train); fringeUpIndex[i] = node.getUpIndex(); fringeDownIndex[i] = node.getDownIndex(); i++; } int j = 0; for (Node node : fringe.getAdjNodesOpenLeft()) { openLeft[j] = featureIndexers.getCategoryIndex(node.getCategory(), train); fringeUpIndex[i] = node.getUpIndex(); fringeDownIndex[i] = node.getDownIndex(); i++; j++; } Node substNode = fringe.getSubstNode(); if (substNode != null) { fringeUpIndex[i] = substNode.getUpIndex(); fringeDownIndex[i] = substNode.getDownIndex(); } analysis.setFringeOpenRight(openRight); analysis.setFringeOpenLeft(openLeft); analysis.setFringeSubstNode( featureIndexers.getCategoryIndex(substNode == null ? "U" : substNode.getCategory(), train)); analysis.setFringeUpIndex(fringeUpIndex); analysis.setFringeDownIndex(fringeDownIndex); analysis.setNumFringeNodes(fringe.size()); analysis.setNumPredictFringeNodes(fringe.numOfPredictNodes()); } private void extractStringTreeAnalysisFeatures(StringTreeAnalysis sta, IncrementalAnalysis analysis, int numOfWords, boolean endOfSent, DiscriminativeFeatureIndexers featureIndexers, boolean train) { String stringAnalysisTree = ParsingTask.convertStringTree(sta); String inputTree = Utils.removeSubtreesAfterWord(Utils.normaliseTree(stringAnalysisTree), numOfWords); extractStringTreeAnalysisFeatures(inputTree, stringAnalysisTree, endOfSent, analysis, featureIndexers, train); } private void extractStringTreeAnalysisFeatures(String inputTree, String stringAnalysisTree, boolean endOfSent, IncrementalAnalysis analysis, DiscriminativeFeatureIndexers featureIndexers, boolean train) { // System.out.println(inputTree); Tree tree = Tree.valueOf(inputTree); List<Tree> leaves = tree.getLeaves(); Tree currentWord = leaves.get(leaves.size() - 1); int currentWordIndex = featureIndexers.getWordIndex(currentWord.nodeString(), train); // right branch (2): 1. Count number of nodes from root to rightmost non-terminal, 2. Count rest nodes // compute domination path from root to rightmost leaf. Subtract 2 from size to exclude root and pre-terminal int pathSize = tree.dominationPath(currentWord.parent(tree)).size(); analysis.setRightBranchSpine(pathSize > 2 ? pathSize - 2 : 0); int rest = tree.size() - analysis.getRightBranchSpine() - leaves.size() * 2 - 1; analysis.setRightBranchRest(rest > 0 ? rest : 0); // Subtract the number of terminals, pre-terminals (leaves.size()*2) and root symbol // get list of rightmost complete non-terminals. We don't compute the whole list of complete non-terminals, but only the ones that have been just completed, // hence lie at the rightmost position of the tree. Since we compute the features piecemeal, by the end of the sentence we will have computed all complete // non-terminals, depending on the training strategy. Used for features: heavy, neighbours, and edges Tree analysisTree = Tree.valueOf(stringAnalysisTree); analysisTree.indexLeaves(); List<Tree> complete = getListOfRightMostCompleteNonTerminals(analysisTree); String[] heavyStr = new String[complete.size()]; String[] neighboursL1Str = new String[complete.size()]; String[] neighboursL2Str = new String[complete.size()]; int i = 0; for (Tree subTree : complete) { // heavy feature int categoryIndex = featureIndexers.getCategoryIndex(subTree.nodeString(), train); List<Label> yield = subTree.yield(); String yieldSize = yield.size() > 5 ? "5+" : String.valueOf(yield.size()); heavyStr[i] = String.format("%s %s %s", categoryIndex, yieldSize, endOfSent ? "y" : "n"); // neighbours l1, l2 features int leftmostLeafId = ((CoreLabel) yield.get(0)).index(); if (leftmostLeafId > 1) { int l1CategoryId = featureIndexers .getCategoryIndex(leaves.get(leftmostLeafId - 2).parent(tree).nodeString(), train); if (leftmostLeafId > 2) { neighboursL1Str[i] = String.format("%s %s %s", categoryIndex, yieldSize, l1CategoryId); int l2CategoryId = featureIndexers .getCategoryIndex(leaves.get(leftmostLeafId - 3).parent(tree).nodeString(), train); neighboursL2Str[i] = String.format("%s %s %s %s", categoryIndex, yieldSize, l2CategoryId, l1CategoryId); } else { neighboursL2Str[i] = String.format("%s %s SoS %s", categoryIndex, yieldSize, l1CategoryId); } } else // leftmost leaf is at the beginning of the sentence { neighboursL1Str[i] = String.format("%s %s SoS", categoryIndex, yieldSize); neighboursL2Str[i] = String.format("%s %s SoS SoS", categoryIndex, yieldSize); } // coPar and coLenPar features Tree[] children = subTree.children(); if (children.length > 2) { // found structure: (X (A ...) (CC and/or) (B ...)) if (children.length == 3 && children[1].nodeString().startsWith("CC")) { analysis.setCoPar(getCoParString(children[0], children[2]), featureIndexers, train); analysis.setCoLenPar(getCoLenParString(children[0], children[2], endOfSent), featureIndexers, train); } // found structure ((CC either) (A ...) (CC or) (B...)) else if (children.length == 4 && children[0].nodeString().startsWith("CC") && children[2].nodeString().startsWith("CC")) { analysis.setCoPar(getCoParString(children[1], children[3]), featureIndexers, train); analysis.setCoLenPar(getCoLenParString(children[1], children[3], endOfSent), featureIndexers, train); } // found structure ((A ...) (, ,) (CC but) (B...)) else if (children.length == 4 && children[1].nodeString().equals(",") && children[2].nodeString().startsWith("CC")) { analysis.setCoPar(getCoParString(children[0], children[3]), featureIndexers, train); analysis.setCoLenPar(getCoLenParString(children[0], children[3], endOfSent), featureIndexers, train); } } i++; } analysis.setHeavy(heavyStr, featureIndexers, train); analysis.setNeighboursL1(neighboursL1Str, featureIndexers, train); analysis.setNeighboursL2(neighboursL2Str, featureIndexers, train); // compute word + L=2 ancestor nodes, L=3 ancestor nodes Tree preTerminal = currentWord.parent(tree); Tree wordL2 = preTerminal.parent(tree); if (wordL2 != null) { int preTerminalIndex = featureIndexers.getCategoryIndex(preTerminal.nodeString(), train); int wordL2Index = featureIndexers.getCategoryIndex(wordL2.nodeString(), train); analysis.setWordL2(String.format("%s %s %s", currentWordIndex, preTerminalIndex, wordL2Index), featureIndexers, train); Tree wordL3 = wordL2.parent(tree); if (wordL3 != null) { analysis.setWordL3(String.format("%s %s %s %s", currentWordIndex, preTerminalIndex, wordL2Index, featureIndexers.getCategoryIndex(wordL3.nodeString())), featureIndexers, train); } } // get integration point + elem tree (Parent-emulation feature) analysis.setIpElemTree(String.format("%s,%s", analysis.getIntegrationPoint(), analysis.getElemTree()), featureIndexers, train); analysis.setIpElemTreeUnlex( String.format("%s,%s", analysis.getIntegrationPoint(), analysis.getElemTreeUnlex()), featureIndexers, train); } /** * Identify the list of rightmost non-terminals that span a complete subtree, i.e., one that * a) the leaf of its' rightmost child is a word, OR * b) the index of the leaf of its' rightmost is a word AND is the last in the yield (AND this leaf is the last word - optional, as this condition breeches incrementality). * @param analysisTree * @return */ private List<Tree> getListOfRightMostCompleteNonTerminals(Tree tree) { List<Tree> list = new ArrayList(); List<Tree> leaves = tree.getLeaves(); // check if the last leaf is a word. Tree currentWord = leaves.get(leaves.size() - 1); if (currentWord.nodeString().endsWith("<>")) { Tree parent = currentWord.parent(tree); while (parent != tree) { if (parent.isPhrasal()) { list.add(parent); } parent = parent.parent(tree); } list.add(tree); } return list; } private String getCoParString(Tree left, Tree right) { if (sameLabelAtLevel(left, right, 0)) { if (sameLabelAtLevel(left, right, 1)) { if (sameLabelAtLevel(left, right, 2)) { if (sameLabelAtLevel(left, right, 3)) return "1111"; else return "1110"; } else { return "1100"; } } // if else { return "1000"; } } else { return "0000"; } } private String getCoLenParString(Tree left, Tree right, boolean endOfSent) { int leftSize = left.yield().size(); int rightSize = right.yield().size(); int diff = Math.abs(rightSize - leftSize); return String.format("%s %s", diff > 5 ? "5+" : String.valueOf(diff), endOfSent ? "y" : "n"); } private boolean sameLabelAtLevel(Tree left, Tree right, int level) { if (left.isLeaf() || right.isLeaf()) return false; if (level == 0) return left.nodeString().equals(right.nodeString()); Tree[] leftChildren = left.children(); Tree[] rightChildren = right.children(); if (leftChildren.length != rightChildren.length) return false; for (int i = 0; i < leftChildren.length; i++) { if (!sameLabelAtLevel(leftChildren[i], rightChildren[i], level - 1)) return false; } return true; } private void extractSrlFeatures(Dependencies dependencies, IncrementalAnalysis analysis, int timestamp, DiscriminativeFeatureIndexers featureIndexers, boolean train) { // get the set of all dependencies for more future use with more complex semantic-aware features Pair<Proposition[], Proposition[]> completeIncompleteProps = dependencies.toPropositions(featureIndexers, train); // analysis.setPropositions(propositions); List<String> triples = new ArrayList(); List<String> triplesPos = new ArrayList(); BitSet triplesArgComplete = new BitSet(); // Get collection of propositions at a specific timestamp. The list includes all semantic triples // with either their argument or predicate at the given timestamp int i = 0; for (Proposition prop : completeIncompleteProps.getFirst()) { Predicate pred = prop.getPredicate(); int predTimestamp = pred.getTimestamp(); if (predTimestamp == timestamp) // current word is a predicate. Add all triples in the proposition and exit { for (Argument arg : prop.getArguments()) { triples.add(String.format("%s,%s,%s", arg.getRoleInt(), arg.getFormInt(), pred.getLemmaInt())); triplesPos.add( String.format("%s,%s,%s", arg.getRoleInt(), arg.getPosTagInt(), pred.getPosTagInt())); i++; } break; } else // search all arguments { for (Argument arg : prop.getArguments()) { // current word is an argument. Add triple and go to the next proposition in case the same argument // is in an another proposition as well if (arg.getTimestamp() == timestamp) { triples.add( String.format("%s,%s,%s", arg.getRoleInt(), arg.getFormInt(), pred.getLemmaInt())); triplesPos.add(String.format("%s,%s,%s", arg.getRoleInt(), arg.getPosTagInt(), pred.getPosTagInt())); triplesArgComplete.set(i++); break; } } } // else } // for all propositions analysis.setSrlTriples(triples.toArray(new String[0]), featureIndexers, train); analysis.setSrlTriplesPos(triplesPos.toArray(new String[0]), featureIndexers, train); analysis.setArgComplete(triplesArgComplete); // deal with incomplete propositions List<String> incompleteTriples = new ArrayList(); List<String> incompleteTriplesPos = new ArrayList(); for (Proposition prop : completeIncompleteProps.getSecond()) { Predicate pred = prop.getPredicate(); int predTimestamp = pred.getTimestamp(); if (predTimestamp == -1) // predicate incomplete { Argument arg = prop.getArguments().get(0); if (arg.getTimestamp() == timestamp && arg.getFormInt() != -1) { incompleteTriples.add(String.format("%s,-", arg.getFormInt())); incompleteTriplesPos.add(String.format("%s,-", arg.getPosTagInt())); } } else if (predTimestamp == timestamp && pred.getLemmaInt() != -1) { incompleteTriples.add(String.format("-,%s", pred.getLemmaInt())); incompleteTriplesPos.add(String.format("-,%s", pred.getPosTagInt())); } } analysis.setSrlIncompleteTriples(incompleteTriples.toArray(new String[0]), featureIndexers, train); analysis.setSrlIncompletePosTriples(incompleteTriplesPos.toArray(new String[0]), featureIndexers, train); } private String getUnlexTree(String elemTree, BuildBlock bb) { String unlexFromCache; // synchronized(unlexTreeCache) // { unlexFromCache = unlexTreeCache.get(elemTree); // } if (unlexFromCache != null) return unlexFromCache; String unlexFromStringTree = bb.getElemTree().toStringUnlex(false); // synchronized(unlexTreeCache) // { unlexTreeCache.put(elemTree, unlexFromStringTree); // } return unlexFromStringTree; } public boolean isEmpty() { return incrAnalysesFeatures.isEmpty(); } public boolean isEmptySentenceLevel() { return incrAnalysesFeaturesSentLevel.isEmpty(); } public int size() { return incrAnalysesFeatures.size(); } public int sizeSentenceLevel() { return incrAnalysesFeaturesSentLevel.size(); } public void clearFeatures() { incrAnalysesFeatures.clear(); incrAnalysesFeaturesSentLevel.clear(); } public Set<Entry<String, IncrementalAnalyses>> getIncrAnalysesFeaturesSet() { return incrAnalysesFeatures.entrySet(); } public IncrementalAnalyses getIncrAnalysesFeatures(String name) { return incrAnalysesFeatures.get(name); } public void addAll(ExtractFeatures featuresIn) { incrAnalysesFeatures.putAll(featuresIn.incrAnalysesFeatures); } public DiscriminativeFeatureIndexers getFeatureIndexers() { return featureIndexers; } public Map<String, IncrementalAnalyses> getIncrAnalysesFeaturesSentLevel() { return incrAnalysesFeaturesSentLevel; } public void extractStringTreeAnalysisFeaturesTest(String inputTree, String stringAnalysisTree, boolean endOfSent, DiscriminativeFeatureIndexers featureIndexers) { extractStringTreeAnalysisFeatures(inputTree, stringAnalysisTree, endOfSent, new IncrementalAnalysis(), featureIndexers, true); } class Trace { int entryPos, bbPos, analysisPos; double analysisProb; public Trace(int entryPos, int bbPos, int analysisPos, double analysisProb) { this.entryPos = entryPos; this.bbPos = bbPos; this.analysisPos = analysisPos; this.analysisProb = analysisProb; } public Trace(int entryPos, int bbPos, int analysisPos) { this(entryPos, bbPos, analysisPos, -1); } public Trace(int entryPos, int bbPos, double analysisProb) { this(entryPos, bbPos, -1, analysisProb); } @Override public String toString() { return String.format("[%s %s %s]", entryPos, bbPos, analysisPos != -1 ? analysisPos : analysisProb); } } }