com.davidbracewell.ml.classification.linear.LibLinearModel.java Source code

Java tutorial

Introduction

Here is the source code for com.davidbracewell.ml.classification.linear.LibLinearModel.java

Source

/*
 * (c) 2005 David B. Bracewell
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.davidbracewell.ml.classification.linear;

import com.davidbracewell.math.linear.VectorMath;
import com.davidbracewell.ml.Instance;
import com.davidbracewell.ml.classification.ClassificationResult;
import com.google.common.base.Preconditions;
import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;

import java.text.DecimalFormat;

/**
 * @author David B. Bracewell
 */
public class LibLinearModel extends LinearModel {

    private static final long serialVersionUID = -8604575597819794701L;
    Model model;

    public static Feature[] toFeature(Instance vector) {
        int[] keys = VectorMath.nonZeroIndexes(Preconditions.checkNotNull(vector));
        Feature[] feature = new Feature[keys.length];
        for (int i = 0; i < keys.length; i++) {
            int fi = keys[i];
            feature[i] = new FeatureNode(fi + 1, vector.get(fi));
        }
        return feature;
    }

    @Override
    protected ClassificationResult classifyImpl(Instance instance) {
        double[] p = new double[getTargetFeature().alphabetSize()];
        if (model.isProbabilityModel()) {
            Linear.predictProbability(model, toFeature(instance), p);
        } else {
            Linear.predictValues(model, toFeature(instance), p);
        }

        //re-arrange the probabilities to match the target feature
        double[] prime = new double[getTargetFeature().alphabetSize()];
        int[] labels = model.getLabels();
        for (int i = 0; i < labels.length; i++) {
            prime[labels[i]] = p[i];
        }

        return new ClassificationResult(getTargetFeature(), prime);
    }

    @Override
    public boolean isTrained() {
        return model != null;
    }

    public void printParams() {
        double[] weights = model.getFeatureWeights();
        int nrClass = model.getNrClass() - 1;
        DecimalFormat format = new DecimalFormat("+0.00000;-#");
        com.davidbracewell.ml.Feature target = getFeatures().getTargetFeature();
        for (int i = 0; i < nrClass; i++) {
            System.out.print("\t" + target.valueAtIndex(model.getLabels()[i]));
        }
        System.out.println();

        for (int i = 0, index = 0; i < model.getNrFeature(); i++, index += nrClass) {
            com.davidbracewell.ml.Feature f = getFeatures().get(i);
            System.out.print(f.getName());
            for (int j = 0; j < nrClass; j++) {
                System.out.print("\t" + format.format(weights[j + index]));
            }
            System.out.println();
        }
        if (weights.length > model.getNrFeature()) {
            System.out.print("BIAS");
            int index = nrClass * model.getNrFeature();
            for (int j = 0; j < nrClass; j++) {
                System.out.print("\t" + format.format(weights[j + index]));
            }
            System.out.println();
        }

    }

}//END OF LibLinearModel