it.acubelab.smaph.learn.GenerateModel.java Source code

Java tutorial

Introduction

Here is the source code for it.acubelab.smaph.learn.GenerateModel.java

Source

/**
 *  Copyright 2014 Marco Cornolti
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package it.acubelab.smaph.learn;

import it.unipi.di.acube.batframework.metrics.MetricsResultSet;
import it.unipi.di.acube.batframework.systemPlugins.WATAnnotator;
import it.unipi.di.acube.batframework.utils.FreebaseApi;
import it.unipi.di.acube.batframework.utils.WikipediaApiInterface;
import it.acubelab.smaph.SmaphAnnotator;
import it.acubelab.smaph.SmaphAnnotatorDebugger;
import it.acubelab.smaph.SmaphConfig;
import it.cnr.isti.hpc.erd.WikipediaToFreebase;

import java.util.*;

import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_problem;

import org.apache.commons.lang3.tuple.Triple;

public class GenerateModel {

    public static void main(String[] args) throws Exception {
        Locale.setDefault(Locale.US);
        String freebKey = "";

        SmaphConfig.setConfigFile("smaph-config.xml");
        String bingKey = SmaphConfig.getDefaultBingKey();

        WikipediaApiInterface wikiApi = new WikipediaApiInterface("wid.cache", "redirect.cache");
        FreebaseApi freebApi = new FreebaseApi(freebKey, "freeb.cache");
        double[][] paramsToTest = new double[][] {
                /*
                 * {0.035, 0.5 }, {0.035, 1 }, {0.035, 4 }, {0.035, 8 }, {0.035, 10 },
                 * {0.035, 16 }, {0.714, .5 }, {0.714, 1 }, {0.714, 4 }, {0.714, 8 },
                 * {0.714, 10 }, {0.714, 16 }, {0.9, .5 }, {0.9, 1 }, {0.9, 4 }, {0.9, 8
                 * }, {0.9, 10 }, {0.9, 16 },
                 * 
                 * { 1.0/15.0, 1 }, { 1.0/27.0, 1 },
                 */

                /*
                 * {0.01, 1}, {0.01, 5}, {0.01, 10}, {0.03, 1}, {0.03, 5}, {0.03, 10},
                 * {0.044, 1}, {0.044, 5}, {0.044, 10}, {0.06, 1}, {0.06, 5}, {0.06,
                 * 10},
                 */
                { 0.03, 5 }, };
        double[][] weightsToTest = new double[][] {

                /*
                 * { 3, 4 }
                 */
                { 3.8, 3 }, { 3.8, 4 }, { 3.8, 5 }, { 3.8, 6 }, { 3.8, 7 }, { 3.8, 8 }, { 3.8, 9 }, { 3.8, 10 },

        };
        Integer[][] featuresSetsToTest = new Integer[][] {
                //{ 1, 2, 3, 6, 7, 8,   9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 },
                { 1, 2, 3, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 },
                /*
                 * { 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
                 */

        }; // < -------------------------------------- MIND THIS
        int wikiSearckTopK = 10; // <---------------------------
        String filePrefix = "_ANW";// <---------------------------

        WikipediaToFreebase wikiToFreebase = new WikipediaToFreebase("mapdb");
        List<ModelConfigurationResult> mcrs = new Vector<>();
        for (double editDistanceThr = 0.7; editDistanceThr <= 0.7; editDistanceThr += 0.7) {
            SmaphAnnotator bingAnnotator = GenerateTrainingAndTest.getDefaultBingAnnotator(wikiApi, wikiToFreebase,
                    editDistanceThr, wikiSearckTopK, bingKey);
            WATAnnotator.setCache("wikisense.cache");
            SmaphAnnotator.setCache(SmaphConfig.getDefaultBingCache());

            BinaryExampleGatherer trainEntityFilterGatherer = new BinaryExampleGatherer();
            BinaryExampleGatherer testEntityFilterGatherer = new BinaryExampleGatherer();
            GenerateTrainingAndTest.gatherExamplesTrainingAndDevel(bingAnnotator, trainEntityFilterGatherer,
                    testEntityFilterGatherer, wikiApi, wikiToFreebase, freebApi);

            SmaphAnnotator.unSetCache();
            BinaryExampleGatherer trainGatherer = trainEntityFilterGatherer; // //////////////
            // <----------------------
            BinaryExampleGatherer testGatherer = testEntityFilterGatherer; // //////////////
            // <----------------------

            int count = 0;
            for (Integer[] ftrToTestArray : featuresSetsToTest) {
                // double gamma = 1.0 / ftrToTestArray.length; //
                // <--------------------- MIND THIS
                // double C = 1;// < -------------------------------------- MIND
                // THIS
                for (double[] paramsToTestArray : paramsToTest) {
                    double gamma = paramsToTestArray[0];
                    double C = paramsToTestArray[1];
                    for (double[] weightsPosNeg : weightsToTest) {
                        double wPos = weightsPosNeg[0], wNeg = weightsPosNeg[1];
                        Vector<Integer> features = new Vector<>(Arrays.asList(ftrToTestArray));
                        Triple<svm_problem, double[], double[]> ftrsMinsMaxs = TuneModel
                                .getScaledTrainProblem(features, trainGatherer);
                        svm_problem trainProblem = ftrsMinsMaxs.getLeft();

                        String fileBase = getModelFileNameBaseEF(features.toArray(new Integer[0]), wPos, wNeg,
                                editDistanceThr, gamma, C) + filePrefix;
                        /*
                         * String fileBase = getModelFileNameBaseEQF(
                         * features.toArray(new Integer[0]), wPos, wNeg);
                         */// < -------------------------
                        LibSvmUtils.dumpRanges(ftrsMinsMaxs.getMiddle(), ftrsMinsMaxs.getRight(),
                                fileBase + ".range");
                        svm_model model = TuneModel.trainModel(wPos, wNeg, features, trainProblem, gamma, C);
                        svm.svm_save_model(fileBase + ".model", model);

                        MetricsResultSet metrics = TuneModel.ParameterTester.computeMetrics(model,
                                TuneModel.getScaledTestProblems(features, testGatherer, ftrsMinsMaxs.getMiddle(),
                                        ftrsMinsMaxs.getRight()));

                        int tp = metrics.getGlobalTp();
                        int fp = metrics.getGlobalFp();
                        int fn = metrics.getGlobalFn();
                        float microF1 = metrics.getMicroF1();
                        float macroF1 = metrics.getMacroF1();
                        float macroRec = metrics.getMacroRecall();
                        float macroPrec = metrics.getMacroPrecision();
                        int totVects = testGatherer.getExamplesCount();
                        mcrs.add(new ModelConfigurationResult(features, wPos, wNeg, editDistanceThr, tp, fp, fn,
                                totVects - tp - fp - fn, microF1, macroF1, macroRec, macroPrec));

                        System.err.printf("Trained %d/%d models.%n", ++count,
                                weightsToTest.length * featuresSetsToTest.length * paramsToTest.length);
                    }
                }
            }
        }
        for (ModelConfigurationResult mcr : mcrs)
            System.out.printf("%.5f%%\t%.5f%%\t%.5f%%%n", mcr.getMacroPrecision() * 100, mcr.getMacroRecall() * 100,
                    mcr.getMacroF1() * 100);
        for (double[] weightPosNeg : weightsToTest)
            System.out.printf("%.5f\t%.5f%n", weightPosNeg[0], weightPosNeg[1]);
        for (ModelConfigurationResult mcr : mcrs)
            System.out.println(mcr.getReadable());
        for (double[] paramGammaC : paramsToTest)
            System.out.printf("%.5f\t%.5f%n", paramGammaC[0], paramGammaC[1]);
        WATAnnotator.flush();
    }

    public static String getModelFileNameBaseEF(Integer[] ftrs, double wPos, double wNeg, double editDistance,
            double gamma, double C) {
        Vector<Integer> features = new Vector<Integer>(Arrays.asList(ftrs));
        Collections.sort(features);
        String filename = "models/model_";
        for (int f : features)
            filename += f + (f == features.get(features.size() - 1) ? "" : ",");
        filename += String.format("_%.5f_%.5f_%.3f_%.8f_%.8f", wPos, wNeg, editDistance, gamma, C);
        return filename;

    }

    public static String getModelFileNameBaseEQF(Integer[] ftrs, double wPos, double wNeg) {
        Vector<Integer> features = new Vector<Integer>(Arrays.asList(ftrs));
        Collections.sort(features);
        String filename = "models/EQ_model_";
        for (int f : features)
            filename += f + (f == features.get(features.size() - 1) ? "" : ",");
        filename += String.format("_%.5f_%.5f", wPos, wNeg);
        return filename;

    }
}