org.apache.mahout.classifier.chi_rw.data.Dataset.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.chi_rw.data.Dataset.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.chi_rw.data;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import com.google.common.primitives.Doubles;

import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.mahout.classifier.chi_rw.Chi_RWUtils;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Contains informations about the attributes.
 */
public class Dataset implements Writable {

    /**
     * Attributes type
     */
    public enum Attribute {
        IGNORED, NUMERICAL, CATEGORICAL, LABEL;

        public boolean isNumerical() {
            return this == NUMERICAL;
        }

        public boolean isCategorical() {
            return this == CATEGORICAL;
        }

        public boolean isLabel() {
            return this == LABEL;
        }

        public boolean isIgnored() {
            return this == IGNORED;
        }
    }

    private Attribute[] attributes;

    /**
     * list of ignored attributes
     */
    private int[] ignored;

    /**
     * distinct values (CATEGORIAL attributes only)
     */
    private String[][] values;

    private double[][] nvalues; //NUMERICAL attributes only

    private double[][] minmaxvalues;

    /**
     * index of the label attribute in the loaded data (without ignored attributed)
     */
    private int labelId;

    /**
     * number of instances in the dataset
     */
    private int nbInstances;

    private Dataset() {
    }

    /**
     * Should only be called by a DataLoader
     *
     * @param attrs  attributes description
     * @param values distinct values for all CATEGORICAL attributes
     */
    Dataset(Attribute[] attrs, List<String>[] values, ArrayList<Double>[] nvalues, int nbInstances,
            boolean regression) {
        validateValues(attrs, values, nvalues);

        int nbattrs = countAttributes(attrs);

        // the label values are set apart
        attributes = new Attribute[nbattrs];
        this.values = new String[nbattrs][];
        this.nvalues = new double[nbattrs][];
        this.minmaxvalues = new double[nbattrs][2];
        ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs

        labelId = -1;
        int ignoredId = 0;
        int ind = 0;
        for (int attr = 0; attr < attrs.length; attr++) {
            if (attrs[attr].isIgnored()) {
                ignored[ignoredId++] = attr;
                continue;
            }

            if (attrs[attr].isLabel()) {
                if (labelId != -1) {
                    throw new IllegalStateException("Label found more than once");
                }
                labelId = ind;
                if (regression) {
                    attrs[attr] = Attribute.NUMERICAL;
                } else {
                    attrs[attr] = Attribute.CATEGORICAL;
                }
            }

            if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) {
                this.values[ind] = new String[values[attr].size()];
                values[attr].toArray(this.values[ind]);
                this.minmaxvalues[ind][0] = 0;
                this.minmaxvalues[ind][1] = values[attr].size() - 1;
            }

            if (attrs[attr].isNumerical()) {
                this.nvalues[ind] = new double[nvalues[attr].size()];
                this.nvalues[ind] = Doubles.toArray(nvalues[attr]);
                this.minmaxvalues[ind][0] = getMinAttribute(this.nvalues[ind]);
                this.minmaxvalues[ind][1] = getMaxAttribute(this.nvalues[ind]);
            }

            attributes[ind++] = attrs[attr];
        }

        if (labelId == -1) {
            throw new IllegalStateException("Label not found");
        }

        this.nbInstances = nbInstances;
    }

    public double getMinAttribute(double[] values) {
        double min = values[0];
        for (int i = 1; i < values.length; i++) {
            if (values[i] < min) {
                min = values[i];
            }
        }
        return min;
    }

    public double getMaxAttribute(double[] values) {
        double max = values[0];
        for (int i = 1; i < values.length; i++) {
            if (values[i] > max) {
                max = values[i];
            }
        }
        return max;
    }

    public double[][] getRanges() {
        return minmaxvalues;
    }

    public int nbValues(int attr) {
        return values[attr].length;
    }

    public String[] getValues(int attr) {
        return values[attr];
    }

    public double[] getNValues(int attr) {
        return nvalues[attr];
    }

    public String[] labels() {
        return Arrays.copyOf(values[labelId], nblabels());
    }

    public int nblabels() {
        return values[labelId].length;
    }

    public int getLabelId() {
        return labelId;
    }

    public double getLabel(Instance instance) {
        return instance.get(getLabelId());
    }

    public int nbInstances() {
        return nbInstances;
    }

    /**
     * Returns the code used to represent the label value in the data
     *
     * @param label label's value to code
     * @return label's code
     */
    public int labelCode(String label) {
        return ArrayUtils.indexOf(values[labelId], label);
    }

    /**
     * Returns the label value in the data
     * This method can be used when the criterion variable is the categorical attribute.
     *
     * @param code label's code
     * @return label's value
     */
    public String getLabelString(double code) {
        // handle the case (prediction is NaN)
        if (Double.isNaN(code)) {
            return "unknown";
        }
        return values[labelId][(int) code];
    }

    /**
     * Converts a token to its corresponding int code for a given attribute
     *
     * @param attr attribute's index
     */
    public int valueOf(int attr, String token) {
        Preconditions.checkArgument(!isNumerical(attr), "Only for CATEGORICAL attributes");
        Preconditions.checkArgument(values != null, "Values not found");
        return ArrayUtils.indexOf(values[attr], token);
    }

    public int[] getIgnored() {
        return ignored;
    }

    /**
     * @return number of attributes that are not IGNORED
     */
    private static int countAttributes(Attribute[] attrs) {
        int nbattrs = 0;

        for (Attribute attr : attrs) {
            if (!attr.isIgnored()) {
                nbattrs++;
            }
        }

        return nbattrs;
    }

    private static void validateValues(Attribute[] attrs, List<String>[] values, ArrayList<Double>[] nvalues) {
        Preconditions.checkArgument(attrs.length == values.length, "attrs.length != values.length");
        for (int attr = 0; attr < attrs.length; attr++) {
            Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null,
                    "values not found for attribute " + attr);
        }
    }

    /**
     * @return number of attributes
     */
    public int nbAttributes() {
        return attributes.length;
    }

    /**
     * Is this a numerical attribute ?
     *
     * @param attr index of the attribute to check
     * @return true if the attribute is numerical
     */
    public boolean isNumerical(int attr) {
        return attributes[attr].isNumerical();
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof Dataset)) {
            return false;
        }

        Dataset dataset = (Dataset) obj;

        if (!Arrays.equals(attributes, dataset.attributes)) {
            return false;
        }

        for (int attr = 0; attr < nbAttributes(); attr++) {
            if (!Arrays.equals(values[attr], dataset.values[attr])) {
                return false;
            }
        }

        return labelId == dataset.labelId && nbInstances == dataset.nbInstances;
    }

    @Override
    public int hashCode() {
        int hashCode = labelId + 31 * nbInstances;
        for (Attribute attr : attributes) {
            hashCode = 31 * hashCode + attr.hashCode();
        }
        for (String[] valueRow : values) {
            if (valueRow == null) {
                continue;
            }
            for (String value : valueRow) {
                hashCode = 31 * hashCode + value.hashCode();
            }
        }
        return hashCode;
    }

    /**
     * Loads the dataset from a file
     *
     * @throws java.io.IOException
     */
    public static Dataset load(Configuration conf, Path path) throws IOException {
        FileSystem fs = path.getFileSystem(conf);
        FSDataInputStream input = fs.open(path);
        try {
            return read(input);
        } finally {
            Closeables.closeQuietly(input);
        }
    }

    public static Dataset read(DataInput in) throws IOException {
        Dataset dataset = new Dataset();

        dataset.readFields(in);
        return dataset;
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        int nbAttributes = in.readInt();
        attributes = new Attribute[nbAttributes];
        for (int attr = 0; attr < nbAttributes; attr++) {
            String name = WritableUtils.readString(in);
            attributes[attr] = Attribute.valueOf(name);
        }

        ignored = Chi_RWUtils.readIntArray(in);

        // only CATEGORICAL attributes have values
        values = new String[nbAttributes][];
        for (int attr = 0; attr < nbAttributes; attr++) {
            if (attributes[attr].isCategorical()) {
                values[attr] = WritableUtils.readStringArray(in);
            }
        }

        // only NUMERICAL attributes have values
        nvalues = new double[nbAttributes][];
        for (int attr = 0; attr < nbAttributes; attr++) {
            if (attributes[attr].isNumerical()) {
                nvalues[attr] = Chi_RWUtils.readDoubleArray(in);
            }
        }

        minmaxvalues = new double[nbAttributes][];
        for (int attr = 0; attr < nbAttributes; attr++) {
            minmaxvalues[attr] = Chi_RWUtils.readDoubleArray(in);
        }

        labelId = in.readInt();
        nbInstances = in.readInt();
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeInt(attributes.length); // nb attributes
        for (Attribute attr : attributes) {
            WritableUtils.writeString(out, attr.name());
        }

        Chi_RWUtils.writeArray(out, ignored);

        // only CATEGORICAL attributes have values
        for (String[] vals : values) {
            if (vals != null) {
                WritableUtils.writeStringArray(out, vals);
            }
        }

        // only NUMERICAL attributes have values
        for (double[] vals : nvalues) {
            if (vals != null) {
                Chi_RWUtils.writeArray(out, vals);
            }
        }

        for (double[] vals : minmaxvalues) {
            if (vals != null) {
                Chi_RWUtils.writeArray(out, vals);
            }
        }

        out.writeInt(labelId);
        out.writeInt(nbInstances);
    }

}