de.isabeldrostfromm.sof.Trainer.java Source code

Java tutorial

Introduction

Here is the source code for de.isabeldrostfromm.sof.Trainer.java

Source

/**
 * Copyright (C) 2013 Isabel Drost-Fromm
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package de.isabeldrostfromm.sof;

import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.LuceneTextValueEncoder;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;

import de.isabeldrostfromm.sof.naive.Vectoriser;

/**
 * Implements training an {@link OnlineLogisticRegression} model based on
 * a document turning all text into vector fractions via the {@link LuceneTextValueEncoder}
 * provided by Mahout
 * 
 * TODO fix logging
 * TODO fix documentation
 * */
public class Trainer implements ModelTrainer {

    @Override
    public OnlineLogisticRegression train(ExampleProvider provider) {
        OnlineLogisticRegression logReg = new OnlineLogisticRegression(ModelTargets.STATEVALUES.length,
                Vectoriser.getCardinality(), new L1());

        Multiset<String> set = HashMultiset.create();
        for (Example instance : provider) {
            set.add(instance.getState());
            logReg.train(ModelTargets.STATES.get(instance.getState()), instance.getVector());
        }

        return logReg;
    }

    @Override
    public List<String> apply(OnlineLogisticRegression model, ExampleProvider provider) {
        List<String> result = new ArrayList<String>();
        for (Example instance : provider) {
            Vector labeled = model.classify(instance.getVector());
            double max = -1;
            int maxIndex = -1;
            for (Vector.Element element : labeled) {
                if (element.get() > max) {
                    max = element.get();
                    maxIndex = element.index();
                }
            }
            result.add(ModelTargets.INDECES.get(maxIndex));
        }
        return result;
    }

    @Override
    public void store(OnlineLogisticRegression model) throws IOException {
        File outFile = new File("/tmp/softrainer.model");
        DataOutputStream stream = new DataOutputStream(new FileOutputStream(outFile));
        DataOutput output = stream;
        model.write(output);
        stream.flush();
        stream.close();
    }

}