com.cloudera.oryx.kmeans.common.pmml.KMeansPMML.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.kmeans.common.pmml.KMeansPMML.java

Source

/*
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.kmeans.common.pmml;

import com.cloudera.oryx.common.math.Vectors;
import com.cloudera.oryx.kmeans.common.Centers;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import org.apache.commons.math3.linear.RealVector;
import org.dmg.pmml.Array;
import org.dmg.pmml.Cluster;
import org.dmg.pmml.ClusteringModel;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.IOUtil;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;

import com.cloudera.oryx.common.io.IOUtils;

/**
 * Convenience methods to read/write a {@link ClusteringModel} description of a k-means clustering.
 */
public final class KMeansPMML {

    // PMML always delimits with space
    private static final Splitter SPACE = Splitter.on(' ');

    private KMeansPMML() {
    }

    public static PMML read(File f) throws IOException {
        InputStream in = IOUtils.openMaybeDecompressing(f);
        try {
            return read(in);
        } catch (JAXBException jaxbe) {
            throw new IOException(jaxbe);
        } catch (SAXException saxe) {
            throw new IOException(saxe);
        } finally {
            in.close();
        }
    }

    public static PMML read(InputStream in) throws JAXBException, SAXException {
        return IOUtil.unmarshal(in);
    }

    public static void write(File f, DataDictionary dictionary, List<? extends Model> models) throws IOException {
        OutputStream out = IOUtils.buildGZIPOutputStream(new FileOutputStream(f));
        try {
            write(out, dictionary, models);
        } catch (JAXBException jaxbe) {
            throw new IOException(jaxbe);
        } finally {
            out.close();
        }
    }

    public static Centers toCenters(ClusteringModel cm) {
        int dims = cm.getClusteringFields().size();
        boolean sparse = cm.getMiningSchema().getMiningFields().size() * 2 < dims;
        List<RealVector> vecs = Lists.newArrayListWithExpectedSize(cm.getClusters().size());
        for (Cluster c : cm.getClusters()) {
            vecs.add(createCenter(c.getArray(), sparse, dims));
        }
        return new Centers(vecs);
    }

    private static RealVector createCenter(Array array, boolean sparse, int fieldCount) {
        if (array.getN() != fieldCount) {
            return null;
        }
        RealVector v = sparse ? Vectors.sparse(fieldCount) : Vectors.dense(fieldCount);
        int i = 0;
        for (String token : SPACE.split(array.getValue())) {
            double t = Double.parseDouble(token);
            if (t != 0.0) {
                v.setEntry(i, t);
            }
            i++;
        }
        return v;
    }

    public static void write(OutputStream out, DataDictionary dictionary, List<? extends Model> models)
            throws JAXBException {
        PMML pmml = new PMML(null, dictionary, "4.1");
        pmml.getModels().addAll(models);
        IOUtil.marshal(pmml, out);
    }

}