// AbstractSequenceClassifier -- a framework for probabilistic sequence models.
// Copyright (c) 2002-2008 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// java-nlp-support@lists.stanford.edu
// http://nlp.stanford.edu/downloads/crf-classifier.shtml
package edu.stanford.nlp.ie;
import edu.stanford.nlp.ling.*;
import edu.stanford.nlp.ling.CoreAnnotations.AnswerAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.BeginPositionAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.EndPositionAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.GoldAnswerAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.PositionAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.UnknownAnnotation;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.objectbank.ResettableReaderIteratorFactory;
import edu.stanford.nlp.util.Function;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.sequences.FeatureFactory;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.sequences.LatticeWriter;
import edu.stanford.nlp.sequences.PlainTextDocumentReaderAndWriter;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.sequences.ObjectBankWrapper;
import edu.stanford.nlp.sequences.TrueCasingDocumentReaderAndWriter;
import edu.stanford.nlp.sequences.KBestSequenceFinder;
import edu.stanford.nlp.sequences.ViterbiSearchGraphBuilder;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.sequences.SequenceSampler;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Sampler;
import edu.stanford.nlp.io.RegExFileFilter;
import edu.stanford.nlp.fsm.DFSA;
import java.io.*;
import java.util.*;
import java.util.regex.Pattern;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.zip.GZIPInputStream;
/** This class provides common functionality for (probabilistic) sequence
* models. It is a superclass of our CMM and CRF sequence classifiers,
* and is even used in the (deterministic) NumberSequenceClassifier.
* See implementing classes for more information.
*
* @author Jenny Finkel
* @author Dan Klein
* @author Christopher Manning
* @author Dan Cer
*/
public abstract class AbstractSequenceClassifier implements Function<String, String> {
public static final String JAR_CLASSIFIER_PATH = "/classifiers/";
public SeqClassifierFlags flags;
public Index<String> classIndex; // = null;
protected DocumentReaderAndWriter readerAndWriter; // = null;
public FeatureFactory featureFactory;
protected CoreLabel pad;
public int windowSize;
protected Set<String> knownLCWords = new HashSet<String>();
/** This does nothing. An implementing class should call
* init() in its constructor.
*/
public AbstractSequenceClassifier() {
}
protected void init(Properties props) {
SeqClassifierFlags newFlags = new SeqClassifierFlags();
newFlags.setProperties(props);
init(newFlags);
}
protected void init(SeqClassifierFlags flags) {
this.flags = flags;
pad = new CoreLabel();
windowSize = flags.maxLeft + 1;
try {
featureFactory = (FeatureFactory) Class.forName(flags.featureFactory).newInstance();
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
reinit();
}
/** This method should be called after there have been changes to the
* flags (SeqClassifierFlags) variable, such as after deserializing
* a classifier. It is called inside the loadClassifier methods.
* It assumes that the flags variable and the pad
* variable exist, but reinitializes things like the pad variable,
* featureFactory and readerAndWriter based on the flags.
* <p>
* <i>Implementation note:</i> At the moment this variable doesn't
* set windowSize or featureFactory, since they are being serialized
* separately in the
* file, but we should probably stop serializing them and just
* reinitialize them from the flags?
*/
protected void reinit() {
pad.set(AnswerAnnotation.class, flags.backgroundSymbol);
pad.set(GoldAnswerAnnotation.class, flags.backgroundSymbol);
try {
readerAndWriter = (DocumentReaderAndWriter) Class.forName(flags.readerAndWriter).newInstance();
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage(), e);
}
readerAndWriter.init(flags);
featureFactory.init(flags);
}
public String backgroundSymbol() {
return flags.backgroundSymbol;
}
public Set<String> labels() {
return new HashSet<String>(classIndex.objectsList());
}
/**
* Classify a {@link Sentence}.
*
* @param sentence The {@link Sentence} to be classified.
* @return The classified {@link Sentence}, where the classifier output for
* each token is stored in its "answer" field.
*/
public List<CoreLabel> testSentence(List<? extends HasWord> sentence) {
List<CoreLabel> document = new ArrayList<CoreLabel>();
int i = 0;
for (HasWord word : sentence) {
CoreLabel wi = new CoreLabel();
wi.setWord(word.word());
wi.set(PositionAnnotation.class, Integer.toString(i));
wi.set(AnswerAnnotation.class, backgroundSymbol());
document.add(wi);
i++;
}
ObjectBankWrapper wrapper = new ObjectBankWrapper(flags, null, knownLCWords);
wrapper.processDocument(document);
test(document);
return document;
}
public SequenceModel getSequenceModel(List<? extends CoreLabel> doc) {
throw new UnsupportedOperationException();
}
public Sampler<List<CoreLabel>> getSampler(final List<? extends CoreLabel> input) {
return new Sampler<List<CoreLabel>>() {
SequenceModel model = getSequenceModel(input);
SequenceSampler sampler = new SequenceSampler();
public List<CoreLabel> drawSample() {
int[] sampleArray = sampler.bestSequence(model);
List<CoreLabel> sample = new ArrayList<CoreLabel>();
int i=0;
for (CoreLabel word : input) {
CoreLabel newWord = new CoreLabel(word);
newWord.set(AnswerAnnotation.class, classIndex.get(sampleArray[i++]));
sample.add(newWord);
}
return sample;
}
};
}
public Counter<List<CoreLabel>> testKBest(List<CoreLabel> doc, Class<? extends CoreAnnotation<String>> answerField, int k) {
if (doc.isEmpty()) {
return new ClassicCounter<List<CoreLabel>>();
}
// i'm sorry that this is so hideous - JRF
ObjectBankWrapper obw = new ObjectBankWrapper(flags, null, knownLCWords);
doc = obw.processDocument(doc);
SequenceModel model = getSequenceModel(doc);
KBestSequenceFinder tagInference = new KBestSequenceFinder();
Counter<int[]> bestSequences = tagInference.kBestSequences(model,k);
Counter<List<CoreLabel>> kBest = new ClassicCounter<List<CoreLabel>>();
for (int[] seq : bestSequences.keySet()) {
List<CoreLabel> kth = new ArrayList<CoreLabel>();
int pos = model.leftWindow();
for (CoreLabel fi : doc) {
CoreLabel newFL = new CoreLabel(fi);
String guess = classIndex.get(seq[pos]);
fi.remove(AnswerAnnotation.class); // because fake answers will get added during testing
newFL.set(answerField, guess);
pos++;
kth.add(newFL);
}
kBest.setCount(kth, bestSequences.getCount(seq));
}
return kBest;
}
public DFSA getViterbiSearchGraph(List<CoreLabel> doc, Class<? extends CoreAnnotation<String>> answerField) {
if (doc.isEmpty()) {
return new DFSA(null);
}
ObjectBankWrapper obw = new ObjectBankWrapper(flags, null, knownLCWords);
doc = obw.processDocument(doc);
SequenceModel model = getSequenceModel(doc);
return ViterbiSearchGraphBuilder.getGraph(model, classIndex);
}
/**
* Classify a List of CoreLabels using a TrueCasingDocumentReader.
*
* @param sentence a list of CoreLabels to be classifierd
* @return The classified list}.
*/
public List<CoreLabel> testSentenceWithCasing(List<CoreLabel> sentence) {
List<CoreLabel> document = new ArrayList<CoreLabel>();
int i = 0;
for (CoreLabel word : sentence) {
CoreLabel wi = new CoreLabel();
if (readerAndWriter instanceof TrueCasingDocumentReaderAndWriter) {
wi.setWord(word.word().toLowerCase());
if (flags.useUnknown) {
wi.set(UnknownAnnotation.class, (TrueCasingDocumentReaderAndWriter.known(wi.word()) ? "false" : "true"));
//System.err.println(wi.word()+" : "+wi.get("unknown"));
}
} else {
wi.setWord(word.word());
}
wi.set(PositionAnnotation.class, Integer.toString(i));
wi.set(AnswerAnnotation.class, backgroundSymbol());
document.add(wi);
i++;
}
test(document);
i = 0;
for (CoreLabel wi : document) {
CoreLabel word = sentence.get(i);
if (flags.readerAndWriter.equalsIgnoreCase("edu.stanford.nlp.sequences.TrueCasingDocumentReader")) {
String w = word.word();
if (wi.get(AnswerAnnotation.class).equals("INIT_UPPER") || wi.get(PositionAnnotation.class).equals("0")) {
w = w.substring(0,1).toUpperCase()+w.substring(1).toLowerCase();
} else if (wi.get(AnswerAnnotation.class).equals("LOWER")) {
w = w.toLowerCase();
} else if (wi.get(AnswerAnnotation.class).equals("UPPER")) {
w = w.toUpperCase();
}
word.setWord(w);
} else {
word.setNER(wi.get(AnswerAnnotation.class));
}
i++;
}
return sentence;
}
/**
* Classify a {@link Sentence}.
*
* @param sentences The sentence(s) to be classified.
* @return {@link List} of classified {@link Sentence}s.
*/
public List<List<CoreLabel>> testSentences(String sentences) {
DocumentReaderAndWriter oldRW = readerAndWriter;
readerAndWriter = new PlainTextDocumentReaderAndWriter();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(sentences, true);
List<List<CoreLabel>> result = new ArrayList<List<CoreLabel>>();
for (List<CoreLabel> document : documents) {
test(document);
List<CoreLabel> sentence = new ArrayList<CoreLabel>();
for (CoreLabel wi : document) {
// TaggedWord word = new TaggedWord(wi.word(), wi.answer());
// sentence.add(word);
sentence.add(wi);
}
result.add(sentence);
}
readerAndWriter = oldRW;
return result;
}
/**
* Classify a {@link Sentence}.
*
* @param filename Contains the sentence(s) to be classified.
* @return {@link List} of classified {@link Sentence}s.
*/
public List<List<CoreLabel>> testFile(String filename) {
DocumentReaderAndWriter oldRW = readerAndWriter;
readerAndWriter = new PlainTextDocumentReaderAndWriter();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(filename, true);
List<List<CoreLabel>> result = new ArrayList<List<CoreLabel>>();
for (List<CoreLabel> document : documents) {
System.err.println(document);
test(document);
List<CoreLabel> sentence = new ArrayList<CoreLabel>();
for (CoreLabel wi : document) {
sentence.add(wi);
System.err.println(wi);
}
result.add(sentence);
}
readerAndWriter = oldRW;
return result;
}
/**
* Maps a String input to an XML-formatted rendition of applying NER to
* the String. Implements the Function interface. Calls
* testStringInlineXML(Stringa) [q.v.].
*/
public String apply(String in) {
return testStringInlineXML((String) in);
}
/**
* Classify the contents of a {@link String}. Plain text or XML is
* expected and the {@link PlainTextDocumentReaderAndWriter} is used. Output
* is in inline XML format (e.g. <PERSON>Bill Smith</PERSON>
* went to <LOCATION>Paris</LOCATION> .)
*
* @param sentences The string to be classified
* @return A {@link String} with annotated with classification
* information.
*/
public String testStringInlineXML(String sentences) {
DocumentReaderAndWriter tmp = readerAndWriter;
readerAndWriter = new PlainTextDocumentReaderAndWriter();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(sentences, true);
readerAndWriter = tmp;
StringBuilder sb = new StringBuilder();
for (List<CoreLabel> doc : documents) {
test(doc);
sb.append(PlainTextDocumentReaderAndWriter.getAnswersInlineXML(doc));
}
return sb.toString();
}
/**
* Classify the contents of a {@link String}. Plain text or XML is
* expected and the {@link PlainTextDocumentReaderAndWriter} is used. Output
* is in XML format.
*
* @param sentences The string to be classified
* @return A {@link String} with annotated with classification
* information.
*/
public String testStringXML(String sentences) {
DocumentReaderAndWriter tmp = readerAndWriter;
readerAndWriter = new PlainTextDocumentReaderAndWriter();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(sentences, true);
readerAndWriter = tmp;
StringBuilder sb = new StringBuilder();
for (List<CoreLabel> doc : documents) {
test(doc);
sb.append(PlainTextDocumentReaderAndWriter.getAnswersXML(doc));
}
return sb.toString();
}
/**
* Classify the contents of a {@link String}. Plain text or XML is
* expected and the {@link PlainTextDocumentReaderAndWriter} is used. Output
* looks like: My/O name/O is/O Bill/PERSON Smith/PERSON ./O
*
* @param sentences The string to be classified
* @return A {@link String} with annotated with classification
* information.
*/
public String testString(String sentences) {
DocumentReaderAndWriter tmp = readerAndWriter;
readerAndWriter = new PlainTextDocumentReaderAndWriter();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(sentences, true);
readerAndWriter = tmp;
StringBuilder sb = new StringBuilder();
for (List<CoreLabel> doc : documents) {
test(doc);
sb.append(PlainTextDocumentReaderAndWriter.getAnswers(doc));
}
return sb.toString();
}
/**
* Classify the contents of a {@link String}. Plain text or XML is
* expected and the {@link PlainTextDocumentReaderAndWriter} is used. Output
* looks like: My/O name/O is/O Bill/PERSON Smith/PERSON ./O
*
* @param sentences The string to be classified
* @return A {@link String} with annotated with classification
* information.
*/
public List<Triple<String,Integer,Integer>> testStringAndGetCharacterOffsets(String sentences) {
DocumentReaderAndWriter tmp = readerAndWriter;
readerAndWriter = new PlainTextDocumentReaderAndWriter();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(sentences, true);
readerAndWriter = tmp;
List<Triple<String,Integer,Integer>> entities = new ArrayList<Triple<String,Integer,Integer>>();
for (List<CoreLabel> doc : documents) {
String prevEntityType = "O";
Triple<String,Integer,Integer> prevEntity = null;
test(doc);
for (CoreLabel fl : doc) {
String guessedAnswer = fl.get(AnswerAnnotation.class);
if (guessedAnswer.equals("O")) {
if (prevEntity != null) {
entities.add(prevEntity);
prevEntity = null;
}
} else {
if (!guessedAnswer.equals(prevEntityType)) {
if (prevEntity != null) {
entities.add(prevEntity);
}
prevEntity = new Triple<String,Integer,Integer>(guessedAnswer, (Integer)fl.get(BeginPositionAnnotation.class),
(Integer)fl.get(EndPositionAnnotation.class));
} else {
prevEntity.setThird((Integer)fl.get(EndPositionAnnotation.class));
}
}
prevEntityType = guessedAnswer;
}
if (prevEntity != null) {
entities.add(prevEntity);
prevEntity = null;
}
}
return entities;
}
/**
* ONLY USE IF LOADED A CHINESE WORD SEGMENTER!!!!!
*
* @param sentence The string to be classified
* @return List of words
*/
public List<String> segmentString(String sentence) {
ObjectBank<List<CoreLabel>> docs =
makeObjectBank(sentence);
// @ cer - previously, there was the following todo here:
//
// TODO: use printAnswers(List<CoreLabel> doc, PrintWriter pw)
// instead
//
// I went ahead and did the TODO. However, given that the TODO
// was incredibly easy to do, I'm wondering if it was left
// as a todo for a reason. For example, I'm concerned that something
// else bizarrely breaks if this method calls printAnswers, as the method
// arguably should, instead of manually building up the output string,
// as was being done before.
//
// In any case, by doing the TODO, I was able to improve the online
// parser/segmenter since all of the wonderful post processing
// stuff is now being done to the segmented strings.
//
// However, if anything I'm not aware of broke, please just shot me
// an e-mail (cerd@cs.colorado.edu) and I will look into and fix
// the problem asap.
// Also...
//
// Using a temporary file for flags.testFile is not elegant
// However, I think all more elegant solutions would require
// touching more source files. Touching more source files
// risks incurring the wrath of whoever regularly works-with
// and/or 'owns' this part of the codebase.
//
// (...the testFile stuff is necessary for segmentation whitespace
// normalization)
String oldTestFile = flags.testFile;
try {
File tempFile = File.createTempFile("segmentString", ".txt");
tempFile.deleteOnExit();
flags.testFile = tempFile.getPath();
FileWriter tempWriter = new FileWriter(tempFile);
tempWriter.write(sentence);
tempWriter.close();
} catch (IOException e) {
System.err.println("Warning(segmentString): " +
"couldn't create temporary file for flags.testFile");
flags.testFile = "";
}
StringWriter stringWriter = new StringWriter();
PrintWriter stringPrintWriter = new PrintWriter(stringWriter);
for (List<CoreLabel> doc : docs) { test(doc);
readerAndWriter.printAnswers(doc, stringPrintWriter);
stringPrintWriter.println();
}
stringPrintWriter.close();
String segmented = stringWriter.toString();
flags.testFile = oldTestFile;
return Arrays.asList(segmented.split("\\s"));
}
/**
* Classify the contents of {@link SeqClassifierFlags scf.testFile}.
* The file should be in the format
* expected based on {@link SeqClassifierFlags scf.documentReader}.
*
* @return A {@link List} of {@link List}s of classified
* {@link CoreLabel}s where each
* {@link List} refers to a document/sentence.
*/
// public ObjectBank<List<CoreLabel>> test() {
// return test(flags.testFile);
// }
/**
* Classify the contents of a file. The file should be in the format
* expected based on {@link SeqClassifierFlags scf.documentReader} if the
* file is specified in {@link SeqClassifierFlags scf.testFile}. If the
* file being read is from {@link SeqClassifierFlags scf.textFile} then
* the {@link PlainTextDocumentReaderAndWriter} is used.
*
* @param filename The path to the specified file
* @return A {@link List} of {@link List}s of classified {@link CoreLabel}s where each
* {@link List} refers to a document/sentence.
*/
// public ObjectBank<List<CoreLabel>> test(String filename) {
// // only for the OCR data does this matter
// flags.ocrTrain = false;
// ObjectBank<List<CoreLabel>> docs = makeObjectBank(filename);
// return testDocuments(docs);
// }
/**
* Classify a {@link List} of {@link CoreLabel}s.
*
* @param document A {@link List} of {@link CoreLabel}s.
* @return the same {@link List}, but with the elements annotated
* with their answers (with <code>setAnswer()</code>).
*/
public abstract List<CoreLabel> test(List<CoreLabel> document);
public void train() {
if (flags.trainFiles != null) {
train(flags.baseTrainDir, flags.trainFiles);
} else if (flags.trainFileList != null) {
String[] files = flags.trainFileList.split(",");
train(files);
} else {
train(flags.trainFile);
}
}
public void train(String filename) {
// only for the OCR data does this matter
flags.ocrTrain = true;
train(makeObjectBank(filename));
}
public void train(String baseTrainDir, String trainFiles) {
// only for the OCR data does this matter
flags.ocrTrain = true;
train(makeObjectBank(baseTrainDir, trainFiles, true));
}
public void train(String[] trainFileList) {
// only for the OCR data does this matter
flags.ocrTrain = true;
train(makeObjectBank(trainFileList, true));
}
public abstract void train(ObjectBank<List<CoreLabel>> docs);
public ObjectBank<List<CoreLabel>> makeObjectBank(String filenameOrString) {
return makeObjectBank(filenameOrString, false);
}
public ObjectBank<List<CoreLabel>> makeObjectBank(String filenameOrString, boolean quietly) {
if (!quietly) {
System.err.print("Reading data using ");
System.err.println(flags.readerAndWriter);
}
if (flags.inputEncoding == null) {
System.err.println("Getting data from " + filenameOrString + " (default encoding)");
} else {
System.err.println("Getting data from " + filenameOrString + " (" + flags.inputEncoding + " encoding)");
}
return new ObjectBankWrapper(flags, new ObjectBank<List<CoreLabel>>(new ResettableReaderIteratorFactory(filenameOrString), readerAndWriter), knownLCWords);
}
public ObjectBank<List<CoreLabel>> makeObjectBank(String[] trainFileList, boolean quitely) {
//try{
Collection files = new ArrayList();
for (String trainFile : trainFileList) {
File f = new File(trainFile);
files.add(f);
}
System.err.printf("trainFileList contains %d files.\n", files.size());
return new ObjectBankWrapper(flags, new ObjectBank<List<CoreLabel>>(new ResettableReaderIteratorFactory(files), readerAndWriter), knownLCWords);
//} catch (IOException e) {
//throw new RuntimeException(e);
//}
}
public ObjectBank<List<CoreLabel>> makeObjectBank(String baseDir, String filePattern, boolean quietly) {
try {
File path = new File(baseDir);
FileFilter filter = new RegExFileFilter(Pattern.compile(filePattern));
File[] origFiles = path.listFiles(filter);
Collection files = new ArrayList();
for (File file : origFiles) {
if (file.isFile()) {
if (flags.inputEncoding == null) {
System.err.println("Getting data from " + file + " (default encoding)");
files.add(file);
} else {
System.err.println("Getting data from " + file + " (" + flags.inputEncoding + " encoding)");
files.add(new BufferedReader(new InputStreamReader(new FileInputStream(file), flags.inputEncoding)));
}
}
}
if (files.isEmpty()) {
System.err.println("no matching files: "+baseDir+"\t"+filePattern);
throw new RuntimeException();
}
return new ObjectBankWrapper(flags, new ObjectBank<List<CoreLabel>>(new ResettableReaderIteratorFactory(files), readerAndWriter), knownLCWords);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public ObjectBank<List<CoreLabel>> makeObjectBank(Collection<File> files){
if (files.isEmpty()) {
System.err.println("Attempt to make ObjectBank with empty file list");
throw new RuntimeException();
}
return new ObjectBankWrapper(flags, new ObjectBank<List<CoreLabel>>(new ResettableReaderIteratorFactory(files), readerAndWriter), knownLCWords);
}
/** Set up an ObjectBank that will allow one to iterate over a
* collection of documents obtained from the passed in Reader.
* Each document will be represented as a list of CoreLabel.
* If the ObjectBank iterator() is called until hasNext() returns false,
* then the Reader will be read till end of file, but no
* reading is done at the time of this call. Reading is done using the
* reading method specified in <code>flags.documentReader</code>,
* and for some reader choices, the column mapping given in
* <code>flags.map</code>.
*
* @param in Input data
* addNEWLCWords do we add new lowercase words from this data to the word shape classifier
* @param quietly Print less messages if this is true (use when calling
* it repeatedly on small bits of text)
* @return The list of documents
*/
protected ObjectBank<List<CoreLabel>> makeObjectBank(BufferedReader in, boolean quietly) {
if (!quietly) {
System.err.print("Reading data using ");
System.err.println(flags.readerAndWriter);
}
return new ObjectBankWrapper(flags, new ObjectBank<List<CoreLabel>>(new ResettableReaderIteratorFactory(in), readerAndWriter), knownLCWords);
}
public ObjectBank<List<CoreLabel>> makeObjectBank(BufferedReader in) {
return makeObjectBank(in, false);
}
/**
* Takes the file, reads it in, and prints out the likelihood of
* each possible label at each point.
*
* @param filename The path to the specified file
*/
public void printProbs(String filename) {
// only for the OCR data does this matter
flags.ocrTrain = false;
ObjectBank<List<CoreLabel>> docs = makeObjectBank(filename);
printProbsDocuments(docs);
}
/**
* Takes a {@link List} of documents and prints the likelihood
* of each possible label at each point.
*
* @param documents A {@link List} of {@link List} of {@link CoreLabel}s.
*/
public void printProbsDocuments(ObjectBank<List<CoreLabel>> documents) {
for (List<CoreLabel> doc : documents) {
printProbsDocument(doc);
System.out.println();
}
}
public abstract void printProbsDocument(List<CoreLabel> document);
/** Load a test file, run the classifier on it, and then print the answers
* to stdout (with timing to stderr). This uses the value of
* flags.documentReader to determine testFile format.
*
* @param testFile The file to test on.
*/
public void testAndWriteAnswers(String testFile) throws Exception {
ObjectBank<List<CoreLabel>> documents = makeObjectBank(testFile);
testAndWriteAnswers(documents);
}
public void testAndWriteAnswers(String baseDir, String filePattern) throws Exception {
ObjectBank<List<CoreLabel>> documents = makeObjectBank(baseDir, filePattern, true);
testAndWriteAnswers(documents);
}
public void testAndWriteAnswers(Collection<File> testFiles) throws Exception{
ObjectBank<List<CoreLabel>> documents = makeObjectBank(testFiles);
testAndWriteAnswers(documents);
}
private void testAndWriteAnswers(ObjectBank<List<CoreLabel>> documents) throws Exception {
Timing timer = new Timing();
int numWords = 0;
int numDocs = 0;
for (List<CoreLabel> doc : documents) {
test(doc);
numWords += doc.size();
writeAnswers(doc);
numDocs++;
}
long millis = timer.stop();
double wordspersec = numWords / (((double) millis) / 1000);
NumberFormat nf = new DecimalFormat("0.00"); // easier way!
System.err.println(StringUtils.getShortClassName(this) +
" tagged " + numWords + " words in " + numDocs +
" documents at " + nf.format(wordspersec) +
" words per second.");
}
/** Load a test file, run the classifier on it, and then print the answers
* to stdout (with timing to stderr). This uses the value of
* flags.documentReader to determine testFile format.
*
* @param testFile The file to test on.
*/
public void testAndWriteAnswersKBest(String testFile, int k) throws Exception {
Timing timer = new Timing();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(testFile);
int numWords = 0;
int numSentences = 0;
for (List<CoreLabel> doc : documents) {
Counter<List<CoreLabel>> kBest = testKBest(doc, AnswerAnnotation.class, k);
numWords += doc.size();
List<List<CoreLabel>> sorted = Counters.toSortedList(kBest);
int n = 1;
for (List<CoreLabel> l : sorted) {
System.out.println("<sentence id="+numSentences+" k="+n+" logProb="+kBest.getCount(l)+" prob="+Math.exp(kBest.getCount(l))+">");
writeAnswers(l);
System.out.println("</sentence>");
n++;
}
numSentences++;
}
long millis = timer.stop();
double wordspersec = numWords / (((double) millis) / 1000);
NumberFormat nf = new DecimalFormat("0.00"); // easier way!
System.err.println(this.getClass().getName()+" tagged " + numWords + " words in " + numSentences +
" documents at " + nf.format(wordspersec) +
" words per second.");
}
/** Load a test file, run the classifier on it, and then write a Viterbi search graph for
* each sequence.
*
* @param testFile The file to test on.
*/
public void testAndWriteViterbiSearchGraph(String testFile, String searchGraphPrefix)
throws Exception {
Timing timer = new Timing();
ObjectBank<List<CoreLabel>> documents = makeObjectBank(testFile);
int numWords = 0;
int numSentences = 0;
for (List<CoreLabel> doc : documents) {
DFSA tagLattice = getViterbiSearchGraph(doc, AnswerAnnotation.class);
numWords += doc.size();
PrintWriter latticeWriter = new PrintWriter(new FileOutputStream(searchGraphPrefix+"."+numSentences+".wlattice"));
PrintWriter vsgWriter = new PrintWriter(new FileOutputStream(searchGraphPrefix+"."+numSentences+".lattice"));
if(readerAndWriter instanceof LatticeWriter)
((LatticeWriter)readerAndWriter).printLattice(tagLattice, doc, latticeWriter);
tagLattice.printAttFsmFormat(vsgWriter);
latticeWriter.close();
vsgWriter.close();
numSentences++;
}
long millis = timer.stop();
double wordspersec = numWords / (((double) millis) / 1000);
NumberFormat nf = new DecimalFormat("0.00"); // easier way!
System.err.println(this.getClass().getName()+" tagged " + numWords + " words in " + numSentences +
" documents at " + nf.format(wordspersec) +
" words per second.");
}
/** Write the classifications of the Sequence classifier out
* to stdout in a format
* determined by the DocumentReaderAndWriter used.
* If the flag <code>outputEncoding</code> is defined, the output
* is written in that character encoding, otherwise in the system default
* character encoding.
*
* @param doc Documents to write out
* @throws Exception If an IO problem
*/
public void writeAnswers(List<CoreLabel> doc) throws Exception {
if (flags.lowerNewgeneThreshold) {
return;
}
if (flags.numRuns <= 1) {
PrintWriter out;
if (flags.outputEncoding == null) {
out = new PrintWriter(System.out, true);
} else {
out = new PrintWriter(new OutputStreamWriter(System.out, flags.outputEncoding), true);
}
readerAndWriter.printAnswers(doc, out);
out.println();
out.flush();
}
}
public abstract void serializeClassifier(String serializePath);
/**
* Loads a classifier from the given input stream.
*/
public void loadClassifierNoExceptions(BufferedInputStream in) {
// load the classifier
try {
loadClassifier(in);
} catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
}
public void loadClassifier(InputStream in) throws IOException, ClassCastException, ClassNotFoundException {
loadClassifier(in, null);
}
/** Load a classsifier from the specified input stream.
* The classifier is reinitialized from the flags serialized in the
* classifier.
*
* @param in The InputStream to load the serialized classifier from
* @param props This Properties object will be used to update the SeqClassifierFlags which
* are read from the serialized classifier
*
* @throws IOException
* @throws ClassCastException
* @throws ClassNotFoundException
*/
public abstract void loadClassifier(InputStream in, Properties props) throws IOException, ClassCastException, ClassNotFoundException;
/**
* Loads a classifier from the file specified by loadPath. If loadPath
* ends in .gz, uses a GZIPInputStream, else uses a regular FileInputStream.
*/
public void loadClassifier(String loadPath) throws ClassCastException, IOException, ClassNotFoundException {
loadClassifier(new File(loadPath));
}
public void loadClassifierNoExceptions(String loadPath) {
loadClassifierNoExceptions(new File(loadPath));
}
public void loadClassifierNoExceptions(String loadPath, Properties props) {
loadClassifierNoExceptions(new File(loadPath), props);
}
public void loadClassifier(File file) throws ClassCastException, IOException, ClassNotFoundException {
loadClassifier(file, null);
}
/**
* Loads a classifier from the file specified by loadPath. If loadPath
* ends in .gz, uses a GZIPInputStream, else uses a regular FileInputStream.
*/
public void loadClassifier(File file, Properties props) throws ClassCastException, IOException, ClassNotFoundException {
Timing.startDoing("Loading classifier from " + file.getAbsolutePath());
BufferedInputStream bis;
if (file.getName().endsWith(".gz")) {
bis = new BufferedInputStream(new GZIPInputStream(new FileInputStream(file)));
} else {
bis = new BufferedInputStream(new FileInputStream(file));
}
loadClassifier(bis, props);
bis.close();
Timing.endDoing();
}
public void loadClassifierNoExceptions(File file) {
loadClassifierNoExceptions(file, null);
}
public void loadClassifierNoExceptions(File file, Properties props) {
try {
loadClassifier(file, props);
} catch (Exception e) {
System.err.println("Error deserializing " + file.getAbsolutePath());
e.printStackTrace();
System.exit(1);
}
}
/**
* This function will load a classifier that is stored inside a jar file
* (if it is so stored). The classifier should be specified as its full
* filename, but the path in the jar file (<code>/classifiers/</code>) is
* coded in this class. If the classifier is not stored in the jar file
* or this is not run from inside a jar file, then this function will
* throw a RuntimeException.
*
* @param modelName The name of the model file. Iff it ends in .gz, then
* it is assumed to be gzip compressed.
* @param props A Properties object which can override certain properties
* in the serialized file, such as the DocumentReaderAndWriter.
* You can pass in <code>null</code> to override nothing.
*/
public void loadJarClassifier(String modelName, Properties props) {
Timing.startDoing("Loading JAR-internal classifier " + modelName);
try {
InputStream is;
is = this.getClass().getResourceAsStream(JAR_CLASSIFIER_PATH + modelName);
if (modelName.endsWith(".gz")) {
is = new GZIPInputStream(is);
}
is = new BufferedInputStream(is);
loadClassifier(is, props);
is.close();
Timing.endDoing();
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
}
}
|