KMeans.java :  » Graphic-Library » ImageJ-Plugins-1.4.1 » net » sf » ij_plugins » clustering » Java Open Source

Java Open Source » Graphic Library » ImageJ Plugins 1.4.1 
ImageJ Plugins 1.4.1 » net » sf » ij_plugins » clustering » KMeans.java
/***
 * Image/J Plugins
 * Copyright (C) 2002-2008 Jarek Sacha
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 * Latest release available at http://sourceforge.net/projects/ij-plugins/
 */
package net.sf.ij_plugins.clustering;

import ij.IJ;
import ij.ImageStack;
import ij.process.ByteProcessor;
import ij.process.FloatProcessor;
import net.sf.ij_plugins.multiband.VectorProcessor;

import java.awt.Rectangle;
import java.util.Random;

/**
 * Pixel-based multi-band image segmentation using k-means clustering algorithm.
 *
 * @author Jarek Sacha
 */
public final class KMeans {
    private Config config;

    private Rectangle roi;
    private ByteProcessor mask;

    private VectorProcessor vp;
    private float[][] clusterCenters;
    private ImageStack clusterAnimation;

    public KMeans() {
        this.config = new Config();
    }

    public KMeans(final Config config) {
        this.config = config.duplicate();
    }

    public void setRoi(final Rectangle roi) {
        this.roi = roi;
    }

    public void setMask(ByteProcessor mask) {
        this.mask = mask;
    }

    /**
     * Perform k-means clustering of the input <code>stack</code>. Elements of the
     * <code>stack</code> must be of type <code>FloatProcessor</code>.
     *
     * @param stack stack representing a multi-band image.
     * @return segmented image.
     */
    public ByteProcessor run(final ImageStack stack) {

        if (stack.getSize() < 1) {
            throw new IllegalArgumentException("Input stack cannot be empty");
        }

        vp = new VectorProcessor(stack);

        // TODO: add support for using ROI. ROI of the first slice is applied to all slices.
//    Rectangle roi = stack.getProcessor(1).getRoi();
//    int[] mask = stack.getProcessor(1).getMask();

        // TODO Verify that ROI and mask are consistent with the input image.

        // Run clustering
        cluster();

        return encodeSegmentedImage();
    }

    /**
     * Return location of cluster centers.
     *
     * @return array of cluster centers. First index refers to cluster number.
     */
    public float[][] getClusterCenters() {
        return clusterCenters;
    }

    /**
     * Return stack representing clustering optimization. This will return not <code>null</code>
     * value only when configuration parameters <code>clusterAnimationEnabled</code> is set to
     * true.
     *
     * @return stack representing cluster optimization, can return <code>null</code>.
     */
    public ImageStack getClusterAnimation() {
        return clusterAnimation;
    }

    /**
     * Returns stack where discovered clusters can be represented by replacing pixel values in a
     * cluster by the value of the centroid of that cluster.
     *
     * @return centroid value image
     */
    public ImageStack getCentroidValueImage() {
        if (clusterCenters == null) {
            throw new IllegalStateException("Need to perform clustering first.");
        }

        return encodeCentroidValueImage();
    }


    private ByteProcessor encodeSegmentedImage() {
        // Encode output image
        final ByteProcessor dest = new ByteProcessor(vp.getWidth(), vp.getHeight());
        final VectorProcessor.PixelIterator iterator = vp.pixelIterator();
        while (iterator.hasNext()) {
            final float[] v = iterator.next();
            final int c = closestCluster(v, clusterCenters);
            dest.putPixel(iterator.getX(), iterator.getY(), c);
        }
        return dest;
    }

    private ImageStack encodeCentroidValueImage() {
        final int width = vp.getWidth();
        final int height = vp.getHeight();
        final int numberOfValues = vp.getNumberOfValues();
        final ImageStack s = new ImageStack(width, height);
        for (int i = 0; i < numberOfValues; ++i) {
            // TODO: Band label should be the same as in the input stack
            s.addSlice("Band i", new FloatProcessor(width, height));
        }

        final VectorProcessor.PixelIterator iterator = vp.pixelIterator();
        final Object[] pixels = s.getImageArray();
        while (iterator.hasNext()) {
            final float[] v = iterator.next();
            final int c = closestCluster(v, clusterCenters);
            for (int j = 0; j < numberOfValues; ++j) {
                ((float[]) pixels[j])[iterator.getOffset()] = clusterCenters[c][j];
            }
        }

        return s;
    }


    private void printClusters(final String message) {
        IJ.write(message);
        for (final float[] clusterCenter : clusterCenters) {
            final StringBuffer buffer = new StringBuffer("  (");
            for (final float vv : clusterCenter) {
                buffer.append(" ").append(vv).append(" ");
            }
            buffer.append(")");
            IJ.write(buffer.toString());
        }
    }

    /**
     *
     */
    private void cluster() {

        // Select initial partitioning - initialize cluster centers
        clusterCenters = generateRandomClusterCenters();
        if (config.isPrintTraceEnabled()) {
            printClusters("Initial clusters");
        }

        if (config.clusterAnimationEnabled) {
            clusterAnimation = new ImageStack(vp.getWidth(), vp.getHeight());
            clusterAnimation.addSlice("Initial", encodeSegmentedImage());
        }

//        public int[] getHistogram() {
//            if (mask!=null)
//                return getHistogram(mask);
//            int[] histogram = new int[256];
//            for (int y=roiY; y<(roiY+roiHeight); y++) {
//                int i = y * width + roiX;
//                for (int x=roiX; x<(roiX+roiWidth); x++) {
//                    int v = pixels[i++] & 0xff;
//                    histogram[v]++;
//                }
//            }
//            return histogram;
//        }
//
//        public int[] getHistogram(ImageProcessor mask) {
//            if (mask.getWidth()!=roiWidth||mask.getHeight()!=roiHeight)
//                throw new IllegalArgumentException(maskSizeError(mask));
//            int v;
//            int[] histogram = new int[256];
//            byte[] mpixels = (byte[])mask.getPixels();
//            for (int y=roiY, my=0; y<(roiY+roiHeight); y++, my++) {
//                int i = y * width + roiX;
//                int mi = my * roiWidth;
//                for (int x=roiX; x<(roiX+roiWidth); x++) {
//                    if (mpixels[mi++]!=0) {
//                        v = pixels[i] & 0xff;
//                        histogram[v]++;
//                    }
//                    i++;
//                }
//            }
//            return histogram;
//        }

        // Optimize cluster centers
        boolean converged = false;
        long count = 0;
        while (!converged) {

            final MeanElement[] newClusterMeans = new MeanElement[config.getNumberOfClusters()];
            for (int i = 0; i < newClusterMeans.length; i++) {
                newClusterMeans[i] = new MeanElement(vp.getNumberOfValues());
            }

            // Generate a new partition by assigning each pattern to its closest cluster center
            // Compute new cluster centers as the centroids of the clusters
            VectorProcessor.PixelIterator iterator = vp.pixelIterator();
            while (iterator.hasNext()) {
                final float[] v = iterator.next();
                final int c = closestCluster(v, clusterCenters);
                newClusterMeans[c].add(v);
            }

            // Check for convergence
            float distanceSum = 0;
            for (int i = 0; i < clusterCenters.length; i++) {
                final float[] clusterCenter = clusterCenters[i];
                final float[] newClusterCenter = newClusterMeans[i].mean();
                distanceSum += distance(clusterCenter, newClusterCenter);
            }

            converged = distanceSum < config.getTolerance();

            for (int i = 0; i < clusterCenters.length; i++) {
                clusterCenters[i] = newClusterMeans[i].mean();
            }

            ++count;

            final String message = "k-means iteration " + count + ", cluster error: " + distanceSum;
            IJ.showStatus(message);
            if (config.isPrintTraceEnabled()) {
                printClusters(message);
            }

            if (config.clusterAnimationEnabled) {
                clusterAnimation.addSlice("Iteration " + count, encodeSegmentedImage());
            }
        }
    }

    /**
     * Return index of the closest cluster to point <code>x</code>.
     *
     * @param x              point coordinates.
     * @param clusterCenters cluster centers.
     * @return index of the closest cluster
     */
    private static int closestCluster(final float[] x, final float[][] clusterCenters) {
        double minDistance = Double.MAX_VALUE;
        int closestCluster = -1;
        for (int i = 0; i < clusterCenters.length; i++) {
            final float[] clusterCenter = clusterCenters[i];
            final double d = distance(clusterCenter, x);
            if (d < minDistance) {
                minDistance = d;
                closestCluster = i;
            }
        }

        return closestCluster;
    }


    /**
     * Distance between points <code>a</code> and <code>b</code>.
     *
     * @param a first point.
     * @param b second point.
     * @return distance.
     */
    private static double distance(final float[] a, final float[] b) {
        float sum = 0;
        for (int i = 0; i < a.length; i++) {
            final float d = a[i] - b[i];
            sum += d * d;
        }
        return java.lang.Math.sqrt(sum);
    }

    /**
     * @return cluster centers.
     */
    private float[][] generateRandomClusterCenters() {

        final Random random = config.isRandomizationSeedEnabled()
                ? new Random(config.getRandomizationSeed())
                : new Random();

        final float[][] centers = new float[config.getNumberOfClusters()][];
        for (int i = 0; i < centers.length; i++) {
            centers[i] = new float[vp.getNumberOfValues()];
            // Make sure that each center is unique
            boolean unique = false;
            int count = 0;
            while (!unique) {
                // Initialize center
                final int sampleX = random.nextInt(vp.getWidth());
                final int sampleY = random.nextInt(vp.getHeight());
                vp.get(sampleX, sampleY, centers[i]);

                // Test if it is not a repeat of already selected center.
                unique = true;
                for (int j = 0; j < i; ++j) {
                    final double d = distance(centers[j], centers[i]);
                    if (d < config.getTolerance()) {
                        unique = false;
                        break;
                    }
                }

                ++count;
                if (count > vp.getWidth() * vp.getHeight()) {
                    throw new RuntimeException("Unable to initialize " + centers.length +
                            " unique cluster centroids.\n" +
                            "Input image may not have enough unique pixel values.");
                }
            }
        }

        return centers;
    }


    /**
     *
     */
    private static final class MeanElement {
        final float[] sum;
        int count;

        public MeanElement(final int elementSize) {
            sum = new float[elementSize];
        }

        public void add(final float[] x) {
            if (x.length != sum.length) {
                throw new java.lang.IllegalArgumentException("Invalid element size, got " + x.length + ", expecting" + sum.length);
            }

            for (int i = 0; i < x.length; i++) {
                sum[i] += x[i];
            }
            ++count;
        }

        public float[] mean() {
            final float[] r = new float[sum.length];
            for (int i = 0; i < r.length; i++) {
                r[i] = sum[i] / count;
            }

            return r;
        }
    }

    /**
     * Configurable parameters of the k-means algorithm.
     */
    public static final class Config implements java.lang.Cloneable {
        /**
         * Seed used to initialize random number generator.
         */
        private int randomizationSeed = 48;
        private boolean randomizationSeedEnabled = true;
        private double tolerance = 0.0001;
        private int numberOfClusters = 4;
        private boolean clusterAnimationEnabled;
        private boolean printTraceEnabled;

        public int getRandomizationSeed() {
            return randomizationSeed;
        }

        public void setRandomizationSeed(final int randomizationSeed) {
            this.randomizationSeed = randomizationSeed;
        }

        /**
         * If <code>true</code>, random number generator will be initialized with a
         * <code>randomizationSeed</code>. If <code>false</code> random number generator will be
         * initialized using 'current' time.
         *
         * @return {@code true} when randomization seed is enabled.
         * @see #getRandomizationSeed()
         */
        public boolean isRandomizationSeedEnabled() {
            return randomizationSeedEnabled;
        }

        public void setRandomizationSeedEnabled(final boolean randomizationSeedEnabled) {
            this.randomizationSeedEnabled = randomizationSeedEnabled;
        }

        public int getNumberOfClusters() {
            return numberOfClusters;
        }

        public void setNumberOfClusters(final int numberOfClusters) {
            this.numberOfClusters = numberOfClusters;
        }

        /**
         * Return tolerance used to determine cluster centroid distance. This tolerance is used to
         * determine if a centroid changed location between iterations.
         *
         * @return cluster centroid location tolerance.
         */
        public double getTolerance() {
            return tolerance;
        }

        public void setTolerance(final float tolerance) {
            this.tolerance = tolerance;
        }

        /**
         * Return <code>true</code> if when an animation illustrating cluster optimization is
         * enabled.
         *
         * @return {@code true} when cluster animation is enabled.
         */
        public boolean isClusterAnimationEnabled() {
            return clusterAnimationEnabled;
        }

        public void setClusterAnimationEnabled(final boolean clusterAnimationEnabled) {
            this.clusterAnimationEnabled = clusterAnimationEnabled;
        }

        /**
         * Return <code>true</code> if a trace is printed to the ImageJ's Result window.
         *
         * @return {@code true} when printing of trace is enabled.
         */
        public boolean isPrintTraceEnabled() {
            return printTraceEnabled;
        }

        public void setPrintTraceEnabled(final boolean printTraceEnabled) {
            this.printTraceEnabled = printTraceEnabled;
        }

        /**
         * Make duplicate of this object. This a convenience wrapper for {@link #clone()} method.
         *
         * @return duplicate of this object.
         */
        public Config duplicate() {
            try {
                return (Config) this.clone();
            } catch (java.lang.CloneNotSupportedException e) {
                throw new java.lang.RuntimeException("Error cloning object of class " + getClass().getName() + ".", e);
            }
        }
    }

}
java2s.com  | Contact Us | Privacy Policy
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.