List of usage examples for weka.core Instances setClassIndex
public void setClassIndex(int classIndex)
From source file:edu.utexas.cs.tactex.utils.RegressionUtils.java
License:Open Source License
/** * adding y attributes without giving it values */// ww w . j ava 2 s . co m public static Instances addYforWeka(Instances xInsts) { // add another column for y int n = xInsts.numAttributes(); xInsts.insertAttributeAt(new Attribute(Integer.toString(n)), n); // last attribute is y value, the class 'label' xInsts.setClassIndex(n); return xInsts; }
From source file:edu.washington.cs.knowitall.summarization.RedundancyClassifier.java
License:Open Source License
public Instances setupInstances(StringReader testReader) { Instances instances = null; try {//from w w w . ja v a 2s . c om instances = new Instances(testReader); } catch (IOException e) { e.printStackTrace(); } instances.setClassIndex(instances.numAttributes() - 1); testReader.close(); return instances; }
From source file:edu.washington.cs.knowitall.utilities.Classifier.java
License:Open Source License
/** * Set up the instances from the reader// ww w . jav a 2 s . co m * @param instanceReader the source of the instances * @return the instances object */ public Instances setupInstances(Reader instanceReader) { Instances instances = null; try { instances = new Instances(instanceReader); } catch (IOException e) { e.printStackTrace(); } instances.setClassIndex(instances.numAttributes() - 1); try { instanceReader.close(); } catch (IOException e) { System.err.println("could not close reader"); e.printStackTrace(); System.exit(1); } return instances; }
From source file:elh.eus.absa.CLI.java
License:Open Source License
/** * Main access to the train-atc functionalities. * Train ATC using a single classifier (one vs. all) for E#A aspect categories. * // ww w. j a v a 2 s.com * @throws Exception */ public final void trainATC(final InputStream inputStream) throws IOException { // load training parameters file String paramFile = parsedArguments.getString("params"); String corpusFormat = parsedArguments.getString("corpusFormat"); //String validation = parsedArguments.getString("validation"); int foldNum = Integer.parseInt(parsedArguments.getString("foldNum")); String lang = parsedArguments.getString("language"); //boolean printPreds = parsedArguments.getBoolean("printPreds"); boolean nullSentenceOpinions = parsedArguments.getBoolean("nullSentences"); //double threshold = 0.2; //String modelsPath = "/home/inaki/Proiektuak/BOM/SEMEVAL2015/ovsaModels"; CorpusReader reader = new CorpusReader(inputStream, corpusFormat, nullSentenceOpinions, lang); Features atcTrain = new Features(reader, paramFile, "3"); Instances traindata = atcTrain.loadInstances(true, "atc"); //setting class attribute (entCat|attCat|entAttCat|polarityCat) //HashMap<String, Integer> opInst = atcTrain.getOpinInst(); WekaWrapper classifyEnts; WekaWrapper classifyAtts; //WekaWrapper onevsall; try { //train first classifier (entities) Instances traindataEnt = new Instances(traindata); // IMPORTANT: filter indexes are added 1 because weka remove function counts attributes from 1, traindataEnt.setClassIndex(traindataEnt.attribute("entCat").index()); classifyEnts = new WekaWrapper(traindataEnt, true); String filtRange = String.valueOf(traindata.attribute("attCat").index() + 1) + "," + String.valueOf(traindata.attribute("entAttCat").index() + 1); classifyEnts.filterAttribute(filtRange); System.out.println("trainATC: entity classifier results -> "); classifyEnts.crossValidate(foldNum); classifyEnts.saveModel("elixa-atc_ent-" + lang + ".model"); //Classifier entityCl = classify.getMLclass(); //train second classifier (attributes) Instances traindataAtt = new Instances(traindata); traindataAtt.setClassIndex(traindataAtt.attribute("attCat").index()); classifyAtts = new WekaWrapper(traindataAtt, true); filtRange = String.valueOf(traindataAtt.attribute("entAttCat").index() + 1); classifyAtts.filterAttribute(filtRange); System.out.println("trainATC: attribute classifier results -> "); classifyAtts.crossValidate(foldNum); classifyAtts.saveModel("elixa-atc_att-" + lang + ".model"); /* Instances traindataEntadded = classifyEnts.addClassification(classifyEnts.getMLclass(), traindataEnt); //train second classifier (entCat attributes will have the values of the entities always) traindataEntadded.setClassIndex(traindataEntadded.attribute("attCat").index()); WekaWrapper classify2 = new WekaWrapper(traindataEntadded, true); System.out.println("trainATC: enhanced attribute classifier results -> "); classify2.saveModel("elixa-atc_att_enhanced.model"); classify2.crossValidate(foldNum); */ //classify.printMultilabelPredictions(classify.multiLabelPrediction()); */ //reader.print2Semeval2015format(paramFile+"entAttCat.xml"); } catch (Exception e) { e.printStackTrace(); } //traindata.setClass(traindata.attribute("entAttCat")); System.err.println("DONE CLI train-atc"); }
From source file:elh.eus.absa.CLI.java
License:Open Source License
/** * Main access to the train-atc functionalities. Train ATC using a double one vs. all classifier * (E and A) for E#A aspect categories// w w w. j av a 2 s. co m * @throws Exception */ public final void trainATC2(final InputStream inputStream) throws IOException { // load training parameters file String paramFile = parsedArguments.getString("params"); String testFile = parsedArguments.getString("testset"); String paramFile2 = parsedArguments.getString("params2"); String corpusFormat = parsedArguments.getString("corpusFormat"); //String validation = parsedArguments.getString("validation"); String lang = parsedArguments.getString("language"); //int foldNum = Integer.parseInt(parsedArguments.getString("foldNum")); //boolean printPreds = parsedArguments.getBoolean("printPreds"); boolean nullSentenceOpinions = parsedArguments.getBoolean("nullSentences"); boolean onlyTest = parsedArguments.getBoolean("testOnly"); double threshold = 0.5; double threshold2 = 0.5; String modelsPath = "/home/inaki/elixa-atp/ovsaModels"; CorpusReader reader = new CorpusReader(inputStream, corpusFormat, nullSentenceOpinions, lang); Features atcTrain = new Features(reader, paramFile, "3"); Instances traindata = atcTrain.loadInstances(true, "atc"); if (onlyTest) { if (FileUtilsElh.checkFile(testFile)) { System.err.println("read from test file"); reader = new CorpusReader(new FileInputStream(new File(testFile)), corpusFormat, nullSentenceOpinions, lang); atcTrain.setCorpus(reader); traindata = atcTrain.loadInstances(true, "atc"); } } //setting class attribute (entCat|attCat|entAttCat|polarityCat) //HashMap<String, Integer> opInst = atcTrain.getOpinInst(); //WekaWrapper classifyAtts; WekaWrapper onevsall; try { //classify.printMultilabelPredictions(classify.multiLabelPrediction()); */ //onevsall Instances entdata = new Instances(traindata); entdata.deleteAttributeAt(entdata.attribute("attCat").index()); entdata.deleteAttributeAt(entdata.attribute("entAttCat").index()); entdata.setClassIndex(entdata.attribute("entCat").index()); onevsall = new WekaWrapper(entdata, true); if (!onlyTest) { onevsall.trainOneVsAll(modelsPath, paramFile + "entCat"); System.out.println("trainATC: one vs all models ready"); } onevsall.setTestdata(entdata); HashMap<Integer, HashMap<String, Double>> ovsaRes = onevsall.predictOneVsAll(modelsPath, paramFile + "entCat"); System.out.println("trainATC: one vs all predictions ready"); HashMap<Integer, String> instOps = new HashMap<Integer, String>(); for (String oId : atcTrain.getOpinInst().keySet()) { instOps.put(atcTrain.getOpinInst().get(oId), oId); } atcTrain = new Features(reader, paramFile2, "3"); entdata = atcTrain.loadInstances(true, "attTrain2_data"); entdata.deleteAttributeAt(entdata.attribute("entAttCat").index()); //entdata.setClassIndex(entdata.attribute("entCat").index()); Attribute insAtt = entdata.attribute("instanceId"); double maxInstId = entdata.kthSmallestValue(insAtt, entdata.numDistinctValues(insAtt) - 1); System.err.println("last instance has index: " + maxInstId); for (int ins = 0; ins < entdata.numInstances(); ins++) { System.err.println("ins" + ins); int i = (int) entdata.instance(ins).value(insAtt); Instance currentInst = entdata.instance(ins); //System.err.println("instance "+i+" oid "+kk.get(i+1)+"kk contains key i?"+kk.containsKey(i)); String sId = reader.getOpinion(instOps.get(i)).getsId(); String oId = instOps.get(i); reader.removeSentenceOpinions(sId); int oSubId = 0; for (String cl : ovsaRes.get(i).keySet()) { //System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl)); if (ovsaRes.get(i).get(cl) > threshold) { //System.err.println("one got through ! instance "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl)); // for the first one update the instances if (oSubId >= 1) { Instance newIns = new SparseInstance(currentInst); newIns.setDataset(entdata); entdata.add(newIns); newIns.setValue(insAtt, maxInstId + oSubId); newIns.setClassValue(cl); instOps.put((int) maxInstId + oSubId, oId); } // if the are more create new instances else { currentInst.setClassValue(cl); //create and add opinion to the structure // trgt, offsetFrom, offsetTo, polarity, cat, sId); //Opinion op = new Opinion(instOps.get(i)+"_"+oSubId, "", 0, 0, "", cl, sId); //reader.addOpinion(op); } oSubId++; } } //finished updating instances data } entdata.setClass(entdata.attribute("attCat")); onevsall = new WekaWrapper(entdata, true); /** * Bigarren sailkatzailea * * */ if (!onlyTest) { onevsall.trainOneVsAll(modelsPath, paramFile + "attCat"); System.out.println("trainATC: one vs all attcat models ready"); } ovsaRes = onevsall.predictOneVsAll(modelsPath, paramFile + "entAttCat"); insAtt = entdata.attribute("instanceId"); maxInstId = entdata.kthSmallestValue(insAtt, insAtt.numValues()); System.err.println("last instance has index: " + maxInstId); for (int ins = 0; ins < entdata.numInstances(); ins++) { System.err.println("ins: " + ins); int i = (int) entdata.instance(ins).value(insAtt); Instance currentInst = entdata.instance(ins); //System.err.println("instance "+i+" oid "+kk.get(i+1)+"kk contains key i?"+kk.containsKey(i)); String sId = reader.getOpinion(instOps.get(i)).getsId(); String oId = instOps.get(i); reader.removeSentenceOpinions(sId); int oSubId = 0; for (String cl : ovsaRes.get(i).keySet()) { //System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl)); if (ovsaRes.get(i).get(cl) > threshold2) { ///System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl)); if (ovsaRes.get(i).get(cl) > threshold) { //System.err.println("one got through ! instance "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl)); // for the first one update the instances if (oSubId >= 1) { String label = currentInst.stringValue(entdata.attribute("entAtt")) + "#" + cl; //create and add opinion to the structure // trgt, offsetFrom, offsetTo, polarity, cat, sId); Opinion op = new Opinion(oId + "_" + oSubId, "", 0, 0, "", label, sId); reader.addOpinion(op); } // if the are more create new instances else { String label = currentInst.stringValue(entdata.attribute("entAtt")) + "#" + cl; //create and add opinion to the structure // trgt, offsetFrom, offsetTo, polarity, cat, sId); reader.removeOpinion(oId); Opinion op = new Opinion(oId + "_" + oSubId, "", 0, 0, "", label, sId); reader.addOpinion(op); } oSubId++; } } //finished updating instances data } } reader.print2Semeval2015format(paramFile + "entAttCat.xml"); } catch (Exception e) { e.printStackTrace(); } //traindata.setClass(traindata.attribute("entAttCat")); System.err.println("DONE CLI train-atc2 (oneVsAll)"); }
From source file:elh.eus.absa.CLI.java
License:Open Source License
/** * train ATC using a single classifier (one vs. all) for E#A aspect categories. * //from ww w . ja va 2s.co m * @param inputStream * @throws IOException */ public final void trainATCsingleCategory(final InputStream inputStream) throws IOException { // load training parameters file String paramFile = parsedArguments.getString("params"); String testFile = parsedArguments.getString("testset"); String corpusFormat = parsedArguments.getString("corpusFormat"); //String validation = parsedArguments.getString("validation"); String lang = parsedArguments.getString("language"); //int foldNum = Integer.parseInt(parsedArguments.getString("foldNum")); //boolean printPreds = parsedArguments.getBoolean("printPreds"); boolean nullSentenceOpinions = parsedArguments.getBoolean("nullSentences"); boolean onlyTest = parsedArguments.getBoolean("testOnly"); double threshold = 0.5; String modelsPath = "/home/inaki/Proiektuak/BOM/SEMEVAL2015/ovsaModels"; CorpusReader reader = new CorpusReader(inputStream, corpusFormat, nullSentenceOpinions, lang); Features atcTrain = new Features(reader, paramFile, "3"); Instances traindata = atcTrain.loadInstances(true, "atc"); if (onlyTest) { if (FileUtilsElh.checkFile(testFile)) { System.err.println("read from test file"); reader = new CorpusReader(new FileInputStream(new File(testFile)), corpusFormat, nullSentenceOpinions, lang); atcTrain.setCorpus(reader); traindata = atcTrain.loadInstances(true, "atc"); } } //setting class attribute (entCat|attCat|entAttCat|polarityCat) //HashMap<String, Integer> opInst = atcTrain.getOpinInst(); //WekaWrapper classifyEnts; //WekaWrapper classifyAtts; WekaWrapper onevsall; try { //classify.printMultilabelPredictions(classify.multiLabelPrediction()); */ //onevsall //Instances entdata = new Instances(traindata); traindata.deleteAttributeAt(traindata.attribute("attCat").index()); traindata.deleteAttributeAt(traindata.attribute("entCat").index()); traindata.setClassIndex(traindata.attribute("entAttCat").index()); onevsall = new WekaWrapper(traindata, true); if (!onlyTest) { onevsall.trainOneVsAll(modelsPath, paramFile + "entAttCat"); System.out.println("trainATC: one vs all models ready"); } onevsall.setTestdata(traindata); HashMap<Integer, HashMap<String, Double>> ovsaRes = onevsall.predictOneVsAll(modelsPath, paramFile + "entAttCat"); System.out.println("trainATC: one vs all predictions ready"); HashMap<Integer, String> kk = new HashMap<Integer, String>(); for (String oId : atcTrain.getOpinInst().keySet()) { kk.put(atcTrain.getOpinInst().get(oId), oId); } Object[] ll = ovsaRes.get(1).keySet().toArray(); for (Object l : ll) { System.err.print((String) l + " - "); } System.err.print("\n"); for (int i : ovsaRes.keySet()) { //System.err.println("instance "+i+" oid "+kk.get(i+1)+"kk contains key i?"+kk.containsKey(i)); String sId = reader.getOpinion(kk.get(i)).getsId(); reader.removeSentenceOpinions(sId); int oSubId = 0; for (String cl : ovsaRes.get(i).keySet()) { //System.err.println("instance: "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl)); if (ovsaRes.get(i).get(cl) > threshold) { //System.err.println("one got through ! instance "+i+" class "+cl+" value: "+ovsaRes.get(i).get(cl)); oSubId++; //create and add opinion to the structure //trgt, offsetFrom, offsetTo, polarity, cat, sId); Opinion op = new Opinion(kk.get(i) + "_" + oSubId, "", 0, 0, "", cl, sId); reader.addOpinion(op); } } } reader.print2Semeval2015format(paramFile + "entAttCat.xml"); } catch (Exception e) { e.printStackTrace(); } //traindata.setClass(traindata.attribute("entAttCat")); System.err.println("DONE CLI train-atc2 (oneVsAll)"); }
From source file:elh.eus.absa.WekaWrapper.java
License:Open Source License
/** * Train one vs all models over the given training data. * /*from w w w. ja va2 s. c o m*/ * @param modelpath directory to store each model for the one vs. all method * @param prefix prefix the models should have (each model will have the name of its class appended * @throws Exception */ public void trainOneVsAll(String modelpath, String prefix) throws Exception { Instances orig = new Instances(traindata); Enumeration<Object> classValues = traindata.classAttribute().enumerateValues(); String classAtt = traindata.classAttribute().name(); while (classValues.hasMoreElements()) { String v = (String) classValues.nextElement(); System.err.println("trainer onevsall for class " + v + " classifier"); //needed because of weka's sparse data format problems THIS IS TROUBLE! ... if (v.equalsIgnoreCase("dummy")) { continue; } // copy instances and set the same class value Instances ovsa = new Instances(orig); //create a new class attribute // // Declare the class attribute along with its values ArrayList<String> classVal = new ArrayList<String>(); classVal.add("dummy"); //needed because of weka's sparse data format problems... classVal.add(v); classVal.add("UNKNOWN"); ovsa.insertAttributeAt(new Attribute(classAtt + "2", classVal), ovsa.numAttributes()); //change all instance labels that have not the current class value to "other" for (int i = 0; i < ovsa.numInstances(); i++) { Instance inst = ovsa.instance(i); String instClass = inst.stringValue(ovsa.attribute(classAtt).index()); if (instClass.equalsIgnoreCase(v)) { inst.setValue(ovsa.attribute(classAtt + "2").index(), v); } else { inst.setValue(ovsa.attribute(classAtt + "2").index(), "UNKNOWN"); } } //delete the old class attribute and set the new. ovsa.setClassIndex(ovsa.attribute(classAtt + "2").index()); ovsa.deleteAttributeAt(ovsa.attribute(classAtt).index()); ovsa.renameAttribute(ovsa.attribute(classAtt + "2").index(), classAtt); ovsa.setClassIndex(ovsa.attribute(classAtt).index()); //build the classifier, crossvalidate and store the model setTraindata(ovsa); saveModel(modelpath + File.separator + prefix + "_" + v + ".model"); setTestdata(ovsa); testModel(modelpath + File.separator + prefix + "_" + v + ".model"); System.err.println("trained onevsall " + v + " classifier"); } setTraindata(orig); }
From source file:entities.WekaBaselineBOWFeatureVector.java
public Instances fillInstanceSet(ArrayList<BaselineBOWFeatureVector> vList, ArrayList<BaselineBOWFeatureVector> vList2) throws IOException { ArrayList<Attribute> attributes = initializeWekaFeatureVector(); Instances isSet = new Instances(vList.get(0).getLabel(), attributes, vList.size()); isSet.setClassIndex(isSet.numAttributes() - 1); for (BaselineBOWFeatureVector BOWv : vList) { Instance i = fillFeatureVector(BOWv, isSet); isSet.add(i);/* w w w . ja v a 2 s . c o m*/ } for (BaselineBOWFeatureVector BOWv : vList2) { Instance i = fillFeatureVector(BOWv, isSet); isSet.add(i); } ArffSaver saver = new ArffSaver(); saver.setInstances(isSet); saver.setFile(new File("./data/test.arff")); saver.writeBatch(); return isSet; }
From source file:entities.WekaBOWFeatureVector.java
public Instances fillInstanceSet(ArrayList<BOWFeatureVector> vList, ArrayList<BOWFeatureVector> vList2) throws IOException { ArrayList<Attribute> attributes = initializeWekaFeatureVector(); Instances isSet = new Instances(vList.get(0).getLabel(), attributes, vList.size()); isSet.setClassIndex(isSet.numAttributes() - 1); for (BOWFeatureVector BOWv : vList) { Instance i = fillFeatureVector(BOWv, isSet); isSet.add(i);/* www. j ava2 s . c o m*/ } for (BOWFeatureVector BOWv : vList2) { Instance i = fillFeatureVector(BOWv, isSet); isSet.add(i); } ArffSaver saver = new ArffSaver(); saver.setInstances(isSet); saver.setFile(new File("./data/test.arff")); saver.writeBatch(); return isSet; }
From source file:entities.WekaHMMFeatureVector.java
public Instances fillInstanceSet(ArrayList<HMMFeatureVector> vList, ArrayList<HMMFeatureVector> vList2) throws IOException { //FastVector fvWekaAttributesHmm = new FastVector(3); ArrayList<Attribute> attributes = initializeWekaFeatureVector(); Instances isSet = new Instances("dataset", attributes, vList.size()); isSet.setClassIndex(isSet.numAttributes() - 1); for (HMMFeatureVector HMMv : vList) { Instance i = fillFeatureVector(HMMv, isSet); isSet.add(i);//from w w w . j a v a 2 s . c o m } for (HMMFeatureVector HMMv : vList2) { Instance i = fillFeatureVector(HMMv, isSet); isSet.add(i); } ArffSaver saver = new ArffSaver(); saver.setInstances(isSet); saver.setFile(new File("./data/test.arff")); saver.writeBatch(); return isSet; }