sentinets.SentiNets.java Source code

Java tutorial

Introduction

Here is the source code for sentinets.SentiNets.java

Source

/*******************************************************************************
 * Copyright (c) 2015 University of Illinois Board of Trustees, All rights reserved.
 * Developed at GSLIS/ the iSchool, by Dr. Jana Diesner, Shubhanshu Mishra, Liang Tao, and Chieh-Li Chin.    
 * This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or any later version.
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * You should have received a copy of the GNU General Public License along with this program; if not, see <http://www.gnu.org/licenses>.
 *******************************************************************************/
/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package sentinets;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;

import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instances;
import weka.core.converters.CSVLoader;
import weka.core.converters.CSVSaver;
import weka.filters.unsupervised.attribute.Remove;

/**
 * @author Shubhanshu
 * 
 */
public class SentiNets {

    /**
     * @param args
     */
    Instances original, labled, unlabled, ep_instances, sns_instances;
    String inputFile, outputDir;
    int classifierType;
    int clsIndex;
    String filePrefix;
    FilteredClassifier cls;
    static int E_P = 0;
    static int S_NS = 1;
    static int BOTH = 2;
    private static String EPFilePrefix = "E_P_ALL";
    private static String SNSFilePrefix = "S_NS_ALL";
    private static String BOTHFilePrefix = "ALL";
    private static String EPModelFile = "./data/models/E_P_ALL.model";
    private static String SNSModelFile = "./data/models/S_NS_ALL.model";
    ArrayList<Double[][]> classDist;

    public static void main(String[] args) {
        String inputFile = "./data/output/cyber_bull.tsv";
        String outputDir = "./data/output";
        SentiNets obj = new SentiNets(inputFile, outputDir, SentiNets.BOTH);
        obj.doPredictions();
    }

    public SentiNets(String inputFile, String outputDir, int classifierType) {
        this.inputFile = inputFile;
        this.outputDir = outputDir;
        this.classifierType = classifierType;
        this.setInstances(this.inputFile);
    }

    public void doPredictions() {
        classDist = new ArrayList<Double[][]>();
        for (int i = 0; i < unlabled.numInstances(); i++) {
            /*
             * Distribution stored as:
             * {{E, P}, {S, NS}}
             */
            //classDist.add(TweetCorpusStatistics.getNullDist()); 
        }
        if (classifierType == BOTH) {
            Remove r = new Remove();
            try {
                r.setAttributeIndices("9");
                r.setInputFormat(unlabled);
                ep_instances = Remove.useFilter(unlabled, r);

                r.setAttributeIndices("8");
                r.setInputFormat(unlabled);
                sns_instances = Remove.useFilter(unlabled, r);
            } catch (Exception e) {
                e.printStackTrace();
            }

            System.out.println("Classifying E_P classes");
            setClassifier(E_P);
            unlabled = ep_instances;
            unlabled.setClassIndex(unlabled.numAttributes() - 1);
            //System.out.println(unlabled.toSummaryString());
            performClassification();
            labled.renameAttribute(labled.numAttributes() - 1, "e/p");
            writePredictions(labled, EPFilePrefix);

            System.out.println("Classifying S_NS classes");
            setClassifier(S_NS);
            unlabled = sns_instances;
            //System.out.println(unlabled.toSummaryString());
            performClassification();
            writePredictions(labled, SNSFilePrefix);

            try {
                original.renameAttribute(original.numAttributes() - 2, "e/p");
                original.renameAttribute(original.numAttributes() - 1, "s/ns");
                original.renameAttributeValue(original.numAttributes() - 1, 1, "ns");
                r = new Remove();
                r.setAttributeIndices("1-3,5,7-10,13-14");
                //r.setAttributeIndices("1-2,5,7-10,13");
                r.setInputFormat(original);
                writePredictions(Remove.useFilter(original, r), BOTHFilePrefix);
                writeStats(original);
            } catch (Exception e) {
                e.printStackTrace();
            }

        } else {
            setClassifier(classifierType);
            performClassification();
            writePredictions(labled, filePrefix);
            try {
                original.renameAttribute(original.numAttributes() - 2, "e/p");
                original.renameAttribute(original.numAttributes() - 1, "s/ns");
                original.renameAttributeValue(original.numAttributes() - 1, 1, "ns");
                Remove r = new Remove();
                r.setAttributeIndices("1-3,5,7-10,13-14");
                r.setInputFormat(original);
                writePredictions(Remove.useFilter(original, r), BOTHFilePrefix);
                writeStats(original);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    public String printDist(Double[][] dist) {
        String distStr = "";
        if (dist != null)
            distStr = dist[0][0] + "\t" + dist[0][1] + "\t" + dist[1][0] + "\t" + dist[1][1];
        return distStr;
    }

    public void writeStats(Instances tweetInstances) {
        //TweetCorpusStatistics stats = new TweetCorpusStatistics();
        System.out.println("Stats Instances: \n" + tweetInstances.toSummaryString());
        for (int i = 0; i < tweetInstances.size(); i++) {
            String user = tweetInstances.get(i).stringValue(11 - 1);
            String mentions = tweetInstances.get(i).stringValue(3 - 1);
            String hashtags = tweetInstances.get(i).stringValue(14 - 1);
            String epClass = tweetInstances.get(i).stringValue(15 - 1);
            String snsClass = tweetInstances.get(i).stringValue(16 - 1);
            System.out.println("Tweet Details:\t" + user + "\t" + mentions + "\t" + hashtags + "\t"
                    + printDist(classDist.get(i)));
            //stats.updateStatistics(user, mentions, hashtags, epClass+","+snsClass, classDist.get(i));
        }
    }

    public void performClassification() {
        labled = new Instances(unlabled);
        labled.setClassIndex(labled.numAttributes() - 1);
        for (int i = 0; i < unlabled.numInstances(); i++) {
            double clsLabel = 0;
            Double[][] instanceDist = classDist.get(i);
            double[] dist = { 0.0, 0.0 };
            int distIndex = 0;
            if (classifierType == S_NS)
                distIndex = 1;
            //System.out.println(unlabled.instance(i).toString(4));
            try {
                dist = cls.distributionForInstance(unlabled.instance(i));
                //System.out.println("Distributions for "+filePrefix+" :\t"+dist[0]+","+dist[1]);
                instanceDist[distIndex][0] = dist[0];
                instanceDist[distIndex][1] = dist[1];
                clsLabel = cls.classifyInstance(unlabled.instance(i));
                //System.out.println("CLSLABEL: " + clsLabel);
            } catch (Exception e) {
                e.printStackTrace();
            }
            labled.instance(i).setClassValue(clsLabel);
            original.instance(i).setValue(clsIndex, clsLabel);
            classDist.set(i, instanceDist);
        }
    }

    public void setClassifier(int classifierType) {
        String modelFile = "";
        if (classifierType == E_P) {
            modelFile = EPModelFile;
            clsIndex = 15 - 1;
            filePrefix = EPFilePrefix;
        } else if (classifierType == S_NS) {
            modelFile = SNSModelFile;
            clsIndex = 16 - 1;
            filePrefix = SNSFilePrefix;
        } else {
            System.out.println("Wrong Classifier type");
            return;
        }
        cls = null;
        try {
            cls = (FilteredClassifier) weka.core.SerializationHelper.read(modelFile);
            //System.out.println(cls.toString());
        } catch (Exception e1) {
            e1.printStackTrace();
        }
    }

    public void setInstances(String inputFile) {
        String[] nominalVals = { "15:e,p", "16:s,na_ns" };
        original = null;
        try {
            CSVLoader csvSource = new CSVLoader();
            csvSource.setSource(new File(inputFile));
            csvSource.setFieldSeparator("\t");
            csvSource.setNominalAttributes("15-16");
            csvSource.setStringAttributes("3,4,6,8,10-12,14");
            csvSource.setNominalLabelSpecs(nominalVals);
            original = csvSource.getDataSet();
            unlabled = original;
            //System.out.println(unlabled.toSummaryString());
            Remove r = new Remove();
            r.setAttributeIndices("3-4,6,10-12,14");
            r.setInputFormat(unlabled);
            unlabled = Remove.useFilter(unlabled, r);
            //System.out.println(unlabled.toSummaryString());
            r = new Remove();
            if (classifierType == E_P) {
                System.out.println("Filtering instances for E_P");
                r.setAttributeIndices("9");
                r.setInputFormat(unlabled);
                unlabled = Remove.useFilter(unlabled, r);
            } else if (classifierType == S_NS) {
                System.out.println("Filtering instances for S_NS");
                r.setAttributeIndices("8");
                r.setInputFormat(unlabled);
                unlabled = Remove.useFilter(unlabled, r);
            }
            //System.out.println(unlabled.toSummaryString());

        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (Exception e) {
            e.printStackTrace();
        }
        int cIdx = unlabled.numAttributes() - 1;
        unlabled.setClassIndex(cIdx);
    }

    public void writePredictions(Instances ins, String filePrefix) {
        try {
            BufferedWriter writer = new BufferedWriter(new FileWriter(outputDir + "/" + filePrefix + ".arff"));
            writer.write(ins.toString());
            writer.newLine();
            writer.flush();
            writer.close();
            CSVSaver s = new CSVSaver();

            s.setFile(new File(outputDir + "/" + filePrefix + ".tsv"));
            s.setInstances(ins);
            s.setFieldSeparator("\t");
            s.writeBatch();

        } catch (IOException e) {
            e.printStackTrace();
        }
    }

}