Example usage for edu.stanford.nlp.classify SVMLightClassifier setPlatt

List of usage examples for edu.stanford.nlp.classify SVMLightClassifier setPlatt

Introduction

In this page you can find the example usage for edu.stanford.nlp.classify SVMLightClassifier setPlatt.

Prototype

public void setPlatt(LinearClassifier<L, L> platt) 

Source Link

Usage

From source file:gr.aueb.cs.nlp.wordtagger.classifier.SVMWindows64Factory.java

License:Open Source License

public SVMLightClassifier<L, F> trainClassifierBasic(GeneralDataset<L, F> dataset) {
    Index<L> labelIndex = dataset.labelIndex();
    Index<F> featureIndex = dataset.featureIndex;
    boolean multiclass = (dataset.numClasses() > 2);
    try {//from ww w  . j  a v  a2 s .  c o m

        // this is the file that the model will be saved to
        File modelFile = File.createTempFile("svm-", ".model");
        if (deleteTempFilesOnExit) {
            modelFile.deleteOnExit();
        }

        // this is the file that the svm light formated dataset
        // will be printed to
        File dataFile = File.createTempFile("svm-", ".data");
        if (deleteTempFilesOnExit) {
            dataFile.deleteOnExit();
        }

        // print the dataset
        PrintWriter pw = new PrintWriter(new FileWriter(dataFile));
        dataset.printSVMLightFormat(pw);
        pw.close();

        // -v 0 makes it not verbose
        // -m 400 gives it a larger cache, for faster training
        String cmd = (multiclass ? svmStructLearn : (useSVMPerf ? svmPerfLearn : svmLightLearn)) + " -v "
                + svmLightVerbosity + " -m 5000 -w 3 -t 0 -g 7 ";

        // set the value of C if we have one specified
        if (C > 0.0)
            cmd = cmd + " -c " + C + " "; // C value
        else if (useSVMPerf)
            cmd = cmd + " -c " + 0.01 + " "; //It's required to specify this parameter for SVM perf

        // Alpha File
        if (useAlphaFile) {
            File newAlphaFile = File.createTempFile("svm-", ".alphas");
            if (deleteTempFilesOnExit) {
                newAlphaFile.deleteOnExit();
            }
            cmd = cmd + " -a " + newAlphaFile.getAbsolutePath();
            if (alphaFile != null) {
                cmd = cmd + " -y " + alphaFile.getAbsolutePath();
            }
            alphaFile = newAlphaFile;
        }

        // File and Model Data
        cmd = cmd + " " + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath();

        if (verbose)
            System.err.println("<< " + cmd + " >>");

        /*Process p = Runtime.getRuntime().exec(cmd);
                
        p.waitFor();
                
        if (p.exitValue() != 0) throw new RuntimeException("Error Training SVM Light exit value: " + p.exitValue());
        p.destroy();   */
        SystemUtils.run(new ProcessBuilder(whitespacePattern.split(cmd)), new PrintWriter(System.err),
                new PrintWriter(System.err));

        if (doEval) {
            File predictFile = File.createTempFile("svm-", ".pred");
            if (deleteTempFilesOnExit) {
                predictFile.deleteOnExit();
            }
            String evalCmd = (multiclass ? svmStructClassify
                    : (useSVMPerf ? svmPerfClassify : svmLightClassify)) + " " + dataFile.getAbsolutePath()
                    + " " + modelFile.getAbsolutePath() + " " + predictFile.getAbsolutePath();
            if (verbose)
                System.err.println("<< " + evalCmd + " >>");
            SystemUtils.run(new ProcessBuilder(whitespacePattern.split(evalCmd)), new PrintWriter(System.err),
                    new PrintWriter(System.err));
        }
        // read in the model file
        Pair<Double, ClassicCounter<Integer>> weightsAndThresh = readModel(modelFile, multiclass);
        double threshold = weightsAndThresh.first();
        ClassicCounter<Pair<F, L>> weights = convertWeights(weightsAndThresh.second(), featureIndex, labelIndex,
                multiclass);
        ClassicCounter<L> thresholds = new ClassicCounter<L>();
        if (!multiclass) {
            thresholds.setCount(labelIndex.get(0), -threshold);
            thresholds.setCount(labelIndex.get(1), threshold);
        }
        SVMLightClassifier<L, F> classifier = new SVMLightClassifier<L, F>(weights, thresholds);
        if (doEval) {
            File predictFile = File.createTempFile("svm-", ".pred2");
            if (deleteTempFilesOnExit) {
                predictFile.deleteOnExit();
            }
            PrintWriter pw2 = new PrintWriter(predictFile);
            NumberFormat nf = NumberFormat.getNumberInstance();
            nf.setMaximumFractionDigits(5);
            for (Datum<L, F> datum : dataset) {
                Counter<L> scores = classifier.scoresOf(datum);
                pw2.println(Counters.toString(scores, nf));
            }
            pw2.close();
        }

        if (useSigmoid) {
            if (verbose)
                System.out.print("fitting sigmoid...");
            classifier.setPlatt(fitSigmoid(classifier, dataset));
            if (verbose)
                System.out.println("done");
        }

        return classifier;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}