Example usage for edu.stanford.nlp.classify SVMLightClassifier SVMLightClassifier

List of usage examples for edu.stanford.nlp.classify SVMLightClassifier SVMLightClassifier

Introduction

In this page you can find the example usage for edu.stanford.nlp.classify SVMLightClassifier SVMLightClassifier.

Prototype

public SVMLightClassifier(ClassicCounter<Pair<F, L>> weightCounter, ClassicCounter<L> thresholds) 

Source Link

Usage

From source file:gr.aueb.cs.nlp.wordtagger.classifier.SVMWindows64Factory.java

License:Open Source License

public SVMLightClassifier<L, F> trainClassifierBasic(GeneralDataset<L, F> dataset) {
    Index<L> labelIndex = dataset.labelIndex();
    Index<F> featureIndex = dataset.featureIndex;
    boolean multiclass = (dataset.numClasses() > 2);
    try {// w  w  w .j  a v  a  2s  .  c om

        // this is the file that the model will be saved to
        File modelFile = File.createTempFile("svm-", ".model");
        if (deleteTempFilesOnExit) {
            modelFile.deleteOnExit();
        }

        // this is the file that the svm light formated dataset
        // will be printed to
        File dataFile = File.createTempFile("svm-", ".data");
        if (deleteTempFilesOnExit) {
            dataFile.deleteOnExit();
        }

        // print the dataset
        PrintWriter pw = new PrintWriter(new FileWriter(dataFile));
        dataset.printSVMLightFormat(pw);
        pw.close();

        // -v 0 makes it not verbose
        // -m 400 gives it a larger cache, for faster training
        String cmd = (multiclass ? svmStructLearn : (useSVMPerf ? svmPerfLearn : svmLightLearn)) + " -v "
                + svmLightVerbosity + " -m 5000 -w 3 -t 0 -g 7 ";

        // set the value of C if we have one specified
        if (C > 0.0)
            cmd = cmd + " -c " + C + " "; // C value
        else if (useSVMPerf)
            cmd = cmd + " -c " + 0.01 + " "; //It's required to specify this parameter for SVM perf

        // Alpha File
        if (useAlphaFile) {
            File newAlphaFile = File.createTempFile("svm-", ".alphas");
            if (deleteTempFilesOnExit) {
                newAlphaFile.deleteOnExit();
            }
            cmd = cmd + " -a " + newAlphaFile.getAbsolutePath();
            if (alphaFile != null) {
                cmd = cmd + " -y " + alphaFile.getAbsolutePath();
            }
            alphaFile = newAlphaFile;
        }

        // File and Model Data
        cmd = cmd + " " + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath();

        if (verbose)
            System.err.println("<< " + cmd + " >>");

        /*Process p = Runtime.getRuntime().exec(cmd);
                
        p.waitFor();
                
        if (p.exitValue() != 0) throw new RuntimeException("Error Training SVM Light exit value: " + p.exitValue());
        p.destroy();   */
        SystemUtils.run(new ProcessBuilder(whitespacePattern.split(cmd)), new PrintWriter(System.err),
                new PrintWriter(System.err));

        if (doEval) {
            File predictFile = File.createTempFile("svm-", ".pred");
            if (deleteTempFilesOnExit) {
                predictFile.deleteOnExit();
            }
            String evalCmd = (multiclass ? svmStructClassify
                    : (useSVMPerf ? svmPerfClassify : svmLightClassify)) + " " + dataFile.getAbsolutePath()
                    + " " + modelFile.getAbsolutePath() + " " + predictFile.getAbsolutePath();
            if (verbose)
                System.err.println("<< " + evalCmd + " >>");
            SystemUtils.run(new ProcessBuilder(whitespacePattern.split(evalCmd)), new PrintWriter(System.err),
                    new PrintWriter(System.err));
        }
        // read in the model file
        Pair<Double, ClassicCounter<Integer>> weightsAndThresh = readModel(modelFile, multiclass);
        double threshold = weightsAndThresh.first();
        ClassicCounter<Pair<F, L>> weights = convertWeights(weightsAndThresh.second(), featureIndex, labelIndex,
                multiclass);
        ClassicCounter<L> thresholds = new ClassicCounter<L>();
        if (!multiclass) {
            thresholds.setCount(labelIndex.get(0), -threshold);
            thresholds.setCount(labelIndex.get(1), threshold);
        }
        SVMLightClassifier<L, F> classifier = new SVMLightClassifier<L, F>(weights, thresholds);
        if (doEval) {
            File predictFile = File.createTempFile("svm-", ".pred2");
            if (deleteTempFilesOnExit) {
                predictFile.deleteOnExit();
            }
            PrintWriter pw2 = new PrintWriter(predictFile);
            NumberFormat nf = NumberFormat.getNumberInstance();
            nf.setMaximumFractionDigits(5);
            for (Datum<L, F> datum : dataset) {
                Counter<L> scores = classifier.scoresOf(datum);
                pw2.println(Counters.toString(scores, nf));
            }
            pw2.close();
        }

        if (useSigmoid) {
            if (verbose)
                System.out.print("fitting sigmoid...");
            classifier.setPlatt(fitSigmoid(classifier, dataset));
            if (verbose)
                System.out.println("done");
        }

        return classifier;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}