com.joliciel.talismane.machineLearning.linearsvm.LinearSVMOneVsRestModel.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.talismane.machineLearning.linearsvm.LinearSVMOneVsRestModel.java

Source

///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see <http://www.gnu.org/licenses/>.
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.linearsvm;

import gnu.trove.map.TObjectIntMap;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.io.UnsupportedEncodingException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.joliciel.talismane.machineLearning.AbstractClassificationModel;
import com.joliciel.talismane.machineLearning.ClassificationObserver;
import com.joliciel.talismane.machineLearning.DecisionFactory;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.MachineLearningAlgorithm;
import com.joliciel.talismane.machineLearning.Outcome;
import com.joliciel.talismane.utils.JolicielException;
import com.joliciel.talismane.utils.LogUtils;
import com.joliciel.talismane.utils.io.UnclosableWriter;

import de.bwaldvogel.liblinear.Model;

class LinearSVMOneVsRestModel<T extends Outcome> extends AbstractClassificationModel<T> {
    private static final Log LOG = LogFactory.getLog(LinearSVMOneVsRestModel.class);
    private List<Model> models = new ArrayList<Model>();
    private TObjectIntMap<String> featureIndexMap = null;
    private List<String> outcomes = null;
    private transient Set<String> outcomeNames = null;

    /**
     * Default constructor for factory.
     */
    LinearSVMOneVsRestModel() {
    }

    /**
     * Construct from a newly trained model including the feature descriptors.
     * @param model
     * @param featureDescriptors
     */
    LinearSVMOneVsRestModel(Map<String, List<String>> descriptors, DecisionFactory<T> decisionFactory,
            Map<String, Object> trainingParameters) {
        super();
        this.setDescriptors(descriptors);
        this.setDecisionFactory(decisionFactory);
        this.setTrainingParameters(trainingParameters);
    }

    public void addModel(Model model) {
        this.models.add(model);
    }

    @Override
    public DecisionMaker<T> getDecisionMaker() {
        LinearSVMOneVsRestDecisionMaker<T> decisionMaker = new LinearSVMOneVsRestDecisionMaker<T>(models,
                this.featureIndexMap, this.outcomes);
        decisionMaker.setDecisionFactory(this.getDecisionFactory());
        return decisionMaker;
    }

    @Override
    public ClassificationObserver<T> getDetailedAnalysisObserver(File file) {
        throw new JolicielException("No detailed analysis observer currently available for linear SVM.");
    }

    @Override
    public void writeModelToStream(OutputStream outputStream) {
        try {
            ZipOutputStream zos = new ZipOutputStream(outputStream);
            zos.setLevel(ZipOutputStream.STORED);
            int i = 0;
            for (Model model : models) {
                LOG.debug("Writing model " + i + " for outcome " + outcomes.get(i));
                ZipEntry zipEntry = new ZipEntry("model" + i);
                i++;
                zos.putNextEntry(zipEntry);
                Writer writer = new OutputStreamWriter(zos, "UTF-8");
                Writer unclosableWriter = new UnclosableWriter(writer);
                model.save(unclosableWriter);
                zos.closeEntry();
                zos.flush();
            }
        } catch (UnsupportedEncodingException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        } catch (IOException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        }
    }

    @Override
    public void loadModelFromStream(InputStream inputStream) {
        // load model or use it directly
        try {
            models = new ArrayList<Model>();
            ZipInputStream zis = new ZipInputStream(inputStream);
            ZipEntry zipEntry = null;
            while ((zipEntry = zis.getNextEntry()) != null) {
                LOG.debug("Reading " + zipEntry.getName());
                Reader reader = new InputStreamReader(zis, "UTF-8");
                Model model = Model.load(reader);
                models.add(model);
            }
        } catch (UnsupportedEncodingException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        } catch (IOException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        }
    }

    @Override
    public MachineLearningAlgorithm getAlgorithm() {
        return MachineLearningAlgorithm.LinearSVMOneVsRest;
    }

    /**
     * A map of feature names to unique indexes.
     * @return
     */
    public TObjectIntMap<String> getFeatureIndexMap() {
        return featureIndexMap;
    }

    public void setFeatureIndexMap(TObjectIntMap<String> featureIndexMap) {
        this.featureIndexMap = featureIndexMap;
    }

    /**
     * A list of outcomes, where the indexes are the ones used by the binary model.
     * @return
     */
    public List<String> getOutcomes() {
        return outcomes;
    }

    public void setOutcomes(List<String> outcomes) {
        this.outcomes = outcomes;
    }

    @SuppressWarnings("unchecked")
    @Override
    protected boolean loadDataFromStream(InputStream inputStream, ZipEntry zipEntry) {
        try {
            boolean loaded = true;
            if (zipEntry.getName().equals("featureIndexMap.obj")) {
                ObjectInputStream in = new ObjectInputStream(inputStream);
                featureIndexMap = (TObjectIntMap<String>) in.readObject();
            } else if (zipEntry.getName().equals("outcomes.obj")) {
                ObjectInputStream in = new ObjectInputStream(inputStream);
                outcomes = (List<String>) in.readObject();
            } else {
                loaded = false;
            }
            return loaded;
        } catch (ClassNotFoundException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        } catch (IOException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        }
    }

    @Override
    public void writeDataToStream(ZipOutputStream zos) {
        try {
            zos.putNextEntry(new ZipEntry("featureIndexMap.obj"));
            ObjectOutputStream out = new ObjectOutputStream(zos);

            try {
                out.writeObject(featureIndexMap);
            } finally {
                out.flush();
            }

            zos.flush();

            zos.putNextEntry(new ZipEntry("outcomes.obj"));
            out = new ObjectOutputStream(zos);
            try {
                out.writeObject(outcomes);
            } finally {
                out.flush();
            }

            zos.flush();
        } catch (IOException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        }

    }

    @Override
    public Set<String> getOutcomeNames() {
        if (this.outcomeNames == null) {
            this.outcomeNames = new TreeSet<String>(this.outcomes);
        }
        return this.outcomeNames;
    }

    public List<Model> getModels() {
        return models;
    }

}