acromusashi.stream.ml.clustering.kmeans.KmeansCalculator.java Source code

Java tutorial

Introduction

Here is the source code for acromusashi.stream.ml.clustering.kmeans.KmeansCalculator.java

Source

/**
* Copyright (c) Acroquest Technology Co, Ltd. All Rights Reserved.
* Please read the associated COPYRIGHTS file for more details.
*
* THE SOFTWARE IS PROVIDED BY Acroquest Technolog Co., Ltd.,
* WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDER BE LIABLE FOR ANY
* CLAIM, DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING
* OR DISTRIBUTING THIS SOFTWARE OR ITS DERIVATIVES.
*/
package acromusashi.stream.ml.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;

import org.apache.commons.collections.ComparatorUtils;
import org.apache.commons.math.util.MathUtils;

import acromusashi.stream.ml.clustering.kmeans.entity.CentroidMapping;
import acromusashi.stream.ml.clustering.kmeans.entity.CentroidsComparator;
import acromusashi.stream.ml.clustering.kmeans.entity.KmeansDataSet;
import acromusashi.stream.ml.clustering.kmeans.entity.KmeansPoint;
import acromusashi.stream.ml.clustering.kmeans.entity.KmeansResult;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

/**
 * KMeans???
 * 
 * @author kimura
 */
public class KmeansCalculator {
    /** BinarySearch???????--1???????? */
    private static final int COMPENSATE_INDEX = -2;

    /**
     * ????
     */
    private KmeansCalculator() {
    }

    /**
     * ??????<br>
     * 
     * @param pointList ?
     * @param clusterNum 
     * @param maxIteration 
     * @param convergenceThres ??????
     * @return ???
     */
    public static KmeansDataSet createDataModel(List<KmeansPoint> pointList, int clusterNum, int maxIteration,
            double convergenceThres) {
        // ???????null????
        if (pointList.size() < clusterNum) {
            return null;
        }

        // ????
        List<KmeansPoint> centroids = createInitialCentroids(pointList, clusterNum);
        long[] clusteredNum = new long[clusterNum];

        // ?????
        for (int exeIndex = 0; exeIndex < maxIteration; exeIndex++) {
            Map<Integer, List<KmeansPoint>> assignments = Maps.newHashMap();
            for (int centroidIndex = 0; centroidIndex < clusterNum; centroidIndex++) {
                assignments.put(centroidIndex, Lists.<KmeansPoint>newArrayList());
            }

            for (KmeansPoint targetPoint : pointList) {
                KmeansResult result = nearestCentroid(targetPoint, centroids);
                assignments.get(result.getCentroidIndex()).add(targetPoint);
            }

            List<KmeansPoint> newCentroids = Lists.newArrayList();
            for (Map.Entry<Integer, List<KmeansPoint>> entry : assignments.entrySet()) {
                if (entry.getValue().isEmpty()) {
                    newCentroids.add(centroids.get(entry.getKey()));
                } else {
                    newCentroids.add(calculateCentroid(entry.getValue()));
                }

                clusteredNum[entry.getKey()] = entry.getValue().size();
            }

            boolean isConvergenced = isConvergenced(centroids, newCentroids, convergenceThres);
            centroids = newCentroids;

            if (isConvergenced == true) {

                break;
            }
        }

        double[][] centroidPoints = new double[clusterNum][];

        for (int centroidIndex = 0; centroidIndex < clusterNum; centroidIndex++) {
            centroidPoints[centroidIndex] = centroids.get(centroidIndex).getDataPoint();
        }

        KmeansDataSet createdModel = new KmeansDataSet();
        createdModel.setCentroids(centroidPoints);
        createdModel.setClusteredNum(clusteredNum);
        return createdModel;
    }

    /**
     * ?????
     * 
     * @param basePoints 
     * @return 
     */
    public static KmeansPoint calculateCentroid(List<KmeansPoint> basePoints) {
        double[] firstDataPoint = basePoints.get(0).getDataPoint();
        double[] centroidSum = Arrays.copyOf(firstDataPoint, firstDataPoint.length);

        for (int pointIndex = 1; pointIndex < basePoints.size(); pointIndex++) {
            for (int coordinateIndex = 0; coordinateIndex < centroidSum.length; coordinateIndex++) {
                centroidSum[coordinateIndex] = centroidSum[coordinateIndex]
                        + basePoints.get(pointIndex).getDataPoint()[coordinateIndex];
            }
        }

        double[] centroidPoints = sub(centroidSum, basePoints.size());
        KmeansPoint centroid = new KmeansPoint();
        centroid.setDataPoint(centroidPoints);
        return centroid;
    }

    /**
     * ???????
     * 
     * @param basePoints ??
     * @param newPoints ?
     * @param convergenceThres ??
     * @return ?????true?????????false
     */
    public static boolean isConvergenced(List<KmeansPoint> basePoints, List<KmeansPoint> newPoints,
            double convergenceThres) {
        boolean result = true;

        for (int pointIndex = 0; pointIndex < basePoints.size(); pointIndex++) {
            double distance = MathUtils.distance(basePoints.get(pointIndex).getDataPoint(),
                    newPoints.get(pointIndex).getDataPoint());
            if (distance > convergenceThres) {
                result = false;
                break;
            }
        }

        return result;
    }

    /**
     * ????????????
     * 
     * @param targetPoint ?
     * @param centroids ?
     * @return Kmeans?
     */
    public static KmeansResult nearestCentroid(double[] targetPoint, double[][] centroids) {
        int nearestCentroidIndex = 0;
        Double minDistance = Double.MAX_VALUE;
        double[] currentCentroid = null;
        Double currentDistance;
        for (int index = 0; index < centroids.length; index++) {
            currentCentroid = centroids[index];
            if (currentCentroid != null) {
                currentDistance = MathUtils.distance(targetPoint, currentCentroid);
                if (currentDistance < minDistance) {
                    minDistance = currentDistance;
                    nearestCentroidIndex = index;
                }
            }
        }

        currentCentroid = centroids[nearestCentroidIndex];

        KmeansResult result = new KmeansResult();
        result.setDataPoint(targetPoint);
        result.setCentroidIndex(nearestCentroidIndex);
        result.setCentroid(currentCentroid);
        result.setDistance(minDistance);

        return result;
    }

    /**
     * ????????????
     * 
     * @param targetPoint ?
     * @param centroids 
     * @return Kmeans?
     */
    public static KmeansResult nearestCentroid(KmeansPoint targetPoint, List<KmeansPoint> centroids) {
        int nearestCentroidIndex = 0;
        Double minDistance = Double.MAX_VALUE;
        KmeansPoint currentCentroid = null;
        Double currentDistance;
        for (int index = 0; index < centroids.size(); index++) {
            currentCentroid = centroids.get(index);
            if (currentCentroid != null && currentCentroid.getDataPoint() != null) {
                currentDistance = MathUtils.distance(targetPoint.getDataPoint(), currentCentroid.getDataPoint());
                if (currentDistance < minDistance) {
                    minDistance = currentDistance;
                    nearestCentroidIndex = index;
                }
            }
        }

        currentCentroid = centroids.get(nearestCentroidIndex);

        KmeansResult result = new KmeansResult();
        result.setDataPoint(targetPoint.getDataPoint());
        result.setCentroidIndex(nearestCentroidIndex);
        result.setCentroid(currentCentroid.getDataPoint());
        result.setDistance(minDistance);

        return result;
    }

    /**
     * ??????????
     * 
     * @param targetPoint ?
     * @param dataSet 
     * @return ??????
     */
    public static KmeansResult classify(KmeansPoint targetPoint, KmeansDataSet dataSet) {
        // KMean?
        int nearestCentroidIndex = 0;
        Double minDistance = Double.MAX_VALUE;
        double[] currentCentroid = null;
        Double currentDistance;
        for (int index = 0; index < dataSet.getCentroids().length; index++) {
            currentCentroid = dataSet.getCentroids()[index];
            if (currentCentroid != null) {
                currentDistance = MathUtils.distance(targetPoint.getDataPoint(), currentCentroid);
                if (currentDistance < minDistance) {
                    minDistance = currentDistance;
                    nearestCentroidIndex = index;
                }
            }
        }

        currentCentroid = dataSet.getCentroids()[nearestCentroidIndex];

        KmeansResult result = new KmeansResult();
        result.setDataPoint(targetPoint.getDataPoint());
        result.setCentroidIndex(nearestCentroidIndex);
        result.setCentroid(currentCentroid);
        result.setDistance(minDistance);

        return result;
    }

    /**
     * KMeans++?????
     * 
     * @param basePoints ??
     * @param clusterNum 
     * @return 
     */
    public static List<KmeansPoint> createInitialCentroids(List<KmeansPoint> basePoints, int clusterNum) {
        Random random = new Random();
        List<KmeansPoint> resultList = Lists.newArrayList();
        // ??????????
        List<KmeansPoint> pointList = Lists.newArrayList(basePoints);
        KmeansPoint firstCentroid = pointList.remove(random.nextInt(pointList.size()));
        resultList.add(firstCentroid);

        double[] dxs;
        // KMeans++??????
        // ??1????????1????
        for (int centroidIndex = 1; centroidIndex < clusterNum; centroidIndex++) {
            // ?????????????
            dxs = computeDxs(pointList, resultList);

            // ??????????
            double r = random.nextDouble() * dxs[dxs.length - 1];
            int next = Arrays.binarySearch(dxs, r);
            int index = 0;
            if (next > 0) {
                index = next - 1;
            } else if (next < 0) {
                index = COMPENSATE_INDEX - next;
            }

            while (index > 0 && resultList.contains(pointList.get(index))) {
                index = index - 1;
            }

            resultList.add(pointList.get(index));
        }

        return resultList;
    }

    /**
     * ????????
     * 
     * @param basePoints ??
     * @param centroids ??
     * @return ?????
     */
    public static double[] computeDxs(List<KmeansPoint> basePoints, List<KmeansPoint> centroids) {
        double[] dxs = new double[basePoints.size()];

        double sum = 0.0d;
        double[] nearestCentroid;
        for (int pointIndex = 0; pointIndex < basePoints.size(); pointIndex++) {
            // ??????(dx)????????
            KmeansPoint targetPoint = basePoints.get(pointIndex);
            KmeansResult kmeanResult = KmeansCalculator.nearestCentroid(targetPoint, centroids);
            nearestCentroid = kmeanResult.getCentroid();
            double dx = MathUtils.distance(targetPoint.getDataPoint(), nearestCentroid);
            double probabilityDist = Math.pow(dx, 2);
            sum += probabilityDist;
            dxs[pointIndex] = sum;
        }

        return dxs;
    }

    /**
     * Kmeans??<br>
     * ???<br>
    * <ol>
    * <li>????????(?n????n?????)</li>
    * <li>n?????????????????????????????</li>
    * <li>???????</li>
    * </ol>
     * 
     * @param baseKmeans Kmeans
     * @param targetKmeans Kmeans
     * @return ?
     */
    public static final KmeansDataSet mergeKmeans(KmeansDataSet baseKmeans, KmeansDataSet targetKmeans) {
        KmeansDataSet merged = new KmeansDataSet();
        int centroidNum = (int) ComparatorUtils.min(baseKmeans.getCentroids().length,
                targetKmeans.getCentroids().length, ComparatorUtils.NATURAL_COMPARATOR);

        // ???????
        List<CentroidMapping> allDistance = calculateDistances(baseKmeans.getCentroids(),
                targetKmeans.getCentroids(), centroidNum);

        // n?????????????????
        Collections.sort(allDistance, new CentroidsComparator());
        Map<Integer, Integer> resultMapping = createCentroidMappings(centroidNum, allDistance);

        // ??
        double[][] mergedCentroids = mergeCentroids(baseKmeans.getCentroids(), targetKmeans.getCentroids(),
                resultMapping);
        merged.setCentroids(mergedCentroids);

        return merged;
    }

    /**
     * ?Counts??
     * 
     * @param baseCounts Counts
     * @param targetCounts Counts
     * @param resultMapping ??
     * @return ?Counts
     */
    protected static List<Long> mergeCounts(List<Long> baseCounts, List<Long> targetCounts,
            Map<Integer, Integer> resultMapping) {
        int countNum = resultMapping.size();
        List<Long> mergedCounts = new ArrayList<>(countNum);
        for (int count = 0; count < countNum; count++) {
            mergedCounts.add(0L);
        }

        for (Entry<Integer, Integer> resultEntry : resultMapping.entrySet()) {
            mergedCounts.set(resultEntry.getKey(),
                    baseCounts.get(resultEntry.getKey()) + targetCounts.get(resultEntry.getValue()));
        }

        return mergedCounts;
    }

    /**
     * ?????
     * 
     * @param basePoints ??
     * @param targetPoints ??
     * @return ?Counts
     */
    protected static List<double[]> mergeInitPoints(List<double[]> basePoints, List<double[]> targetPoints) {
        List<double[]> mergedFeatures = new ArrayList<>();
        mergedFeatures.addAll(basePoints);
        mergedFeatures.addAll(targetPoints);

        return mergedFeatures;
    }

    /**
     * ????????????<br>
     * ??????????<br>
     * 
     * @param baseCentroids ?
     * @param targetCentroids ? 
     * @param resultMapping ??
     * @return ??
     */
    public static double[][] mergeCentroids(double[][] baseCentroids, double[][] targetCentroids,
            Map<Integer, Integer> resultMapping) {
        // ??????
        double[][] mergedCentroids = new double[resultMapping.size()][];

        for (Map.Entry<Integer, Integer> targetEntry : resultMapping.entrySet()) {
            double[] baseCentroid = baseCentroids[targetEntry.getKey()];
            double[] targetCentroid = targetCentroids[targetEntry.getValue()];
            mergedCentroids[targetEntry.getKey()] = average(baseCentroid, targetCentroid);
        }

        return mergedCentroids;
    }

    /**
     * ???
     * 
     * @param centroidNum 
     * @param allDistance ?
     * @return 
     */
    protected static Map<Integer, Integer> createCentroidMappings(int centroidNum,
            List<CentroidMapping> allDistance) {
        Set<Integer> baseSet = new HashSet<>();
        Set<Integer> targetSet = new HashSet<>();
        Map<Integer, Integer> resultMapping = new TreeMap<>();
        int mappingNum = 0;

        // ?????
        for (CentroidMapping targetDistance : allDistance) {
            // ?????????
            if (baseSet.contains(targetDistance.getBaseIndex())
                    || targetSet.contains(targetDistance.getTargetIndex())) {
                continue;
            }

            baseSet.add(targetDistance.getBaseIndex());
            targetSet.add(targetDistance.getTargetIndex());
            resultMapping.put(targetDistance.getBaseIndex(), targetDistance.getTargetIndex());
            mappingNum++;

            // ????????
            if (mappingNum >= centroidNum) {
                break;
            }
        }

        return resultMapping;
    }

    /**
     * ????????
     * 
     * @param baseCentroids ?
     * @param targetCentroids ? 
     * @param centroidNum 
     * @return ?
     */
    protected static List<CentroidMapping> calculateDistances(double[][] baseCentroids, double[][] targetCentroids,
            int centroidNum) {
        // ???????
        List<CentroidMapping> allDistance = new ArrayList<>();

        for (int baseIndex = 0; baseIndex < centroidNum; baseIndex++) {
            for (int targetIndex = 0; targetIndex < centroidNum; targetIndex++) {
                CentroidMapping centroidMapping = new CentroidMapping();
                centroidMapping.setBaseIndex(baseIndex);
                centroidMapping.setTargetIndex(targetIndex);
                double distance = MathUtils.distance(baseCentroids[baseIndex], targetCentroids[targetIndex]);
                centroidMapping.setEuclideanDistance(distance);
                allDistance.add(centroidMapping);
            }
        }
        return allDistance;
    }

    /**
     * ????
     * 
     * @param base ?
     * @param target ??
     * @return ??
     */
    protected static double[] average(double[] base, double[] target) {
        int dataNum = base.length;
        double[] average = new double[dataNum];

        for (int index = 0; index < dataNum; index++) {
            average[index] = (base[index] + target[index]) / 2.0;
        }

        return average;
    }

    /**
     * double????
     * 
     * @param base ?
     * @param subNumber ?
     * @return ?
     */
    protected static double[] sub(double[] base, double subNumber) {
        int dataNum = base.length;
        double[] result = new double[dataNum];

        for (int index = 0; index < dataNum; index++) {
            result[index] = base[index] / subNumber;
        }

        return result;
    }
}