com.cloudera.science.ml.client.cmd.KMeansCommand.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.science.ml.client.cmd.KMeansCommand.java

Source

/**
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */
package com.cloudera.science.ml.client.cmd;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;

import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.math.Vector;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;
import com.beust.jcommander.ParametersDelegate;
import com.beust.jcommander.converters.CommaParameterSplitter;
import com.beust.jcommander.converters.IntegerConverter;
import com.cloudera.science.ml.avro.MLWeightedCenters;
import com.cloudera.science.ml.client.params.RandomParameters;
import com.cloudera.science.ml.client.util.AvroIO;
import com.cloudera.science.ml.core.vectors.Centers;
import com.cloudera.science.ml.core.vectors.VectorConvert;
import com.cloudera.science.ml.core.vectors.Weighted;
import com.cloudera.science.ml.kmeans.core.KMeans;
import com.cloudera.science.ml.kmeans.core.KMeansInitStrategy;
import com.cloudera.science.ml.kmeans.core.KMeansEvaluation;
import com.cloudera.science.ml.kmeans.core.KMeansUpdateStrategy;
import com.cloudera.science.ml.kmeans.core.LloydsUpdateStrategy;
import com.cloudera.science.ml.kmeans.core.MiniBatchUpdateStrategy;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;

@Parameters(commandDescription = "Executes k-means++ on Avro vectors stored on the local filesystem")
public class KMeansCommand implements Command {

    @Parameter(names = "--input-file", required = true, description = "The local Avro file that contains the sketches computed by the ksketch command")
    private String sketchFile;

    @Parameter(names = "--clusters", required = true, description = "A CSV containing the number of clusters to create from the sample", splitter = CommaParameterSplitter.class, converter = IntegerConverter.class)
    private List<Integer> clusters = Lists.newArrayList();

    @Parameter(names = "--best-of", description = "Run this many iterations of k-means for each value of K")
    private int bestOf = 5;

    @Parameter(names = "--init-strategy", description = "The k-means initialization strategy (PLUS_PLUS or RANDOM)")
    private String initStrategyName = KMeansInitStrategy.PLUS_PLUS.name();

    @Parameter(names = "--max-iterations", description = "The maximum number of k-means iterations to run (either Lloyd's or mini-batch)")
    private int maxIterations = 100;

    @Parameter(names = "--mini-batch-size", description = "The number of points to include in each mini-batch update (enables mini-batch k-means)")
    private int miniBatchSize = 0;

    @Parameter(names = "--centers-file", required = true, description = "A local file to store the centers that were created into")
    private String centersOutputFile;

    @Parameter(names = "--num-threads", description = "The number of execution threads to use for running the (computationally intensive) k-means algorithm")
    private int numThreads = 1;

    @Parameter(names = "--eval-details-file", description = "Write detailed stats on the cluster stability information to this file")
    private String detailsFileName;

    @Parameter(names = "--eval-stats-file", description = "Write the high-level stats on the cluster stability to this file")
    private String statsFileName = "kmeans_stats.csv";

    @ParametersDelegate
    private RandomParameters randomParams = new RandomParameters();

    @Override
    public String getDescription() {
        return "Executes k-means++ on Avro vectors stored on the local filesystem";
    }

    @Override
    public int execute(Configuration conf) throws IOException {
        KMeansInitStrategy initStrategy = KMeansInitStrategy.valueOf(initStrategyName);
        KMeans kmeans = new KMeans(initStrategy, getUpdateStrategy());

        ListeningExecutorService exec;
        if (numThreads <= 1) {
            exec = MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor());
        } else {
            exec = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(numThreads));
        }

        List<MLWeightedCenters> mlwc = AvroIO.read(MLWeightedCenters.class, new File(sketchFile));
        List<List<Weighted<Vector>>> sketches = toSketches(mlwc);
        List<Weighted<Vector>> allPoints = Lists.newArrayList();
        for (List<Weighted<Vector>> sketch : sketches) {
            allPoints.addAll(sketch);
        }
        List<Centers> centers = getClusters(exec, allPoints, kmeans);
        AvroIO.write(Lists.transform(centers, VectorConvert.FROM_CENTERS), new File(centersOutputFile));

        if (sketches.size() > 1) {
            // Perform the prediction strength calculations on the folds
            List<Weighted<Vector>> train = Lists.newArrayList();
            for (int i = 0; i < sketches.size() - 1; i++) {
                train.addAll(sketches.get(i));
            }
            List<Weighted<Vector>> test = sketches.get(sketches.size() - 1);
            List<Centers> trainCenters = getClusters(exec, train, kmeans);
            List<Centers> testCenters = getClusters(exec, test, kmeans);
            KMeansEvaluation eval = new KMeansEvaluation(testCenters, test, trainCenters, detailsFileName);
            eval.writeStatsToFile(new File(statsFileName));
            eval.writeStats(System.out);
        }

        return 0;
    }

    private List<Centers> getClusters(ListeningExecutorService exec, List<Weighted<Vector>> sketch, KMeans kmeans) {
        List<ListenableFuture<Centers>> futures = Lists.newArrayList();
        for (Integer nc : clusters) {
            int loops = nc == 1 ? 1 : bestOf;
            for (int i = 0; i < loops; i++) {
                Random r = randomParams.getRandom(nc + i);
                futures.add(exec.submit(new Clustering(kmeans, sketch, nc, r)));
            }
        }
        try {
            return Futures.allAsList(futures).get();
        } catch (Exception e) {
            throw new CommandException("Error in clustering", e);
        }
    }

    private static List<List<Weighted<Vector>>> toSketches(List<MLWeightedCenters> mlwc) {
        List<List<Weighted<Vector>>> base = Lists.newArrayList();
        for (MLWeightedCenters wc : mlwc) {
            base.add(Lists.transform(wc.getCenters(), VectorConvert.TO_WEIGHTED_VEC));
        }
        return base;
    }

    private KMeansUpdateStrategy getUpdateStrategy() {
        if (miniBatchSize > 0) {
            return new MiniBatchUpdateStrategy(maxIterations, miniBatchSize, randomParams.getRandom());
        } else {
            return new LloydsUpdateStrategy(maxIterations);
        }
    }

    private static class Clustering implements Callable<Centers> {

        private final KMeans kmeans;
        private final List<Weighted<Vector>> sketch;
        private final int numClusters;
        private final Random r;

        Clustering(KMeans kmeans, List<Weighted<Vector>> sketch, int numClusters, Random r) {
            this.kmeans = kmeans;
            this.sketch = sketch;
            this.numClusters = numClusters;
            this.r = r;
        }

        @Override
        public Centers call() throws Exception {
            return kmeans.compute(sketch, numClusters, r);
        }
    }
}