org.caleydo.core.util.impute.KNNImpute.java Source code

Java tutorial

Introduction

Here is the source code for org.caleydo.core.util.impute.KNNImpute.java

Source

/*******************************************************************************
 * Caleydo - Visualization for Molecular Biology - http://caleydo.org
 * Copyright (c) The Caleydo Team. All rights reserved.
 * Licensed under the new BSD license, available at http://caleydo.org/license
 *******************************************************************************/
package org.caleydo.core.util.impute;

import static java.lang.Double.isInfinite;
import static java.lang.Double.isNaN;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.TreeSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;

import org.apache.commons.lang.StringUtils;
import org.caleydo.core.io.KNNImputeDescription;
import org.caleydo.core.util.collection.Pair;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.ImmutableTable.Builder;
import com.google.common.collect.Table;
import com.google.common.io.CharStreams;

/**
 * @author Samuel Gratzl
 *
 */
public class KNNImpute extends RecursiveTask<Table<Integer, Integer, Float>> {
    private static final long serialVersionUID = 1605596542323204167L;
    private final KNNImputeDescription desc;
    private final Random r;

    private final int samples;
    private final ImmutableList<Gene> genes;
    private final LoadingCache<Integer, Sample> samples_ = CacheBuilder.newBuilder()
            .build(new CacheLoader<Integer, Sample>() {
                @Override
                public Sample load(Integer key) {
                    return computeSample(key.intValue());
                }
            });

    public KNNImpute(KNNImputeDescription desc, ImmutableList<Gene> genes) {
        this.desc = desc;
        this.genes = genes;
        this.samples = this.genes.get(0).size();
        this.r = new Random(desc.getRng_seed());
    }

    @Override
    protected Table<Integer, Integer, Float> compute() {
        //      data
        //      An expression matrix with genes in the rows, samples in the columns
        //      k
        //      Number of neighbors to be used in the imputation (default=10)
        //      rowmax
        //      The maximum percent missing data allowed in any row (default 50%). For any
        //      rows with more than rowmax% missing are imputed using the overall mean per
        //      sample.
        //      colmax
        //      The maximum percent missing data allowed in any column (default 80%). If
        //      any column has more than colmax% missing data, the program halts and reports
        //      an error.
        //      maxp
        //      The largest block of genes imputed using the knn algorithm inside impute.knn
        //      (default 1500); larger blocks are divided by two-means clustering (recursively)
        //      prior to imputation. If maxp=p, only knn imputation is done.
        //      rng.seed
        //      The seed used for the random number generator (default 362436069) for repro-
        //      ducibility.
        //      impute.knn uses k-nearest neighbors in the space of genes to impute missing expression values.
        //      For each gene with missing values, we find the k nearest neighbors using a Euclidean metric, con-
        //      fined to the columns for which that gene is NOT missing. Each candidate neighbor might be missing
        //      some of the coordinates used to calculate the distance. In this case we average the distance from
        //      the non-missing coordinates. Having found the k nearest neighbors for a gene, we impute the miss-
        //      ing elements by averaging those (non-missing) elements of its neighbors. This can fail if ALL the
        //      neighbors are missing in a particular element. In this case we use the overall column mean for that
        //      block of genes.
        //      Since nearest neighbor imputation costs O(plog(p)) operations per gene, where p is the number
        //      of rows, the computational time can be excessive for large p and a large number of missing rows.
        //      Our strategy is to break blocks with more than maxp genes into two smaller blocks using two-mean
        //      clustering. This is done recursively till all blocks have less than maxp genes. For each block, k-
        //      nearest neighbor imputation is done separately. We have set the default value of maxp to 1500.
        //      Depending on the speed of the machine, and number of samples, this number might be increased.
        //      Making it too small is counter-productive, because the number of two-mean clustering algorithms
        //      will increase.

        if (toomanyNaNsInAColumn())
            throw new IllegalStateException();

        final float rowMax = desc.getRowmax();
        final boolean validRowMax = !Float.isInfinite(rowMax) && !Float.isNaN(rowMax);
        final int max = validRowMax ? Math.round(desc.getRowmax() * samples) : 0;

        //list of possible
        List<Gene> neighborhood;
        int withMissing = 0;
        Collection<ForkJoinTask<Void>> tasks = new ArrayList<>();
        if (!validRowMax) {
            neighborhood = genes;// all genes
        } else {
            neighborhood = new ArrayList<>(genes.size());
            for (Gene gene : genes) {
                if (gene.getNaNs() == 0) {// nothing to impute
                    neighborhood.add(gene);
                } else if (validRowMax && gene.getNaNs() > max) { // too many nans use the sample mean
                    tasks.add(new ImputeSampleMean(gene));
                    //not a good neighbor
                } else {
                    // neighbor but something needs to be done
                    neighborhood.add(gene);
                    withMissing++;
                }
            }
        }

        if (withMissing > 0)
            tasks.add(new ImputeKNNMean(neighborhood));
        invokeAll(tasks);

        ImmutableTable.Builder<Integer, Integer, Float> b = ImmutableTable.builder();
        for (Gene gene : genes) {
            if (gene.isAnySet()) {
                gene.fillImpute(b);
            }
        }
        return b.build();
    }

    /**
     * input by KNN mean and split in two cluster if neighborhood is too large
     *
     * @param neighborhood
     */
    private class ImputeKNNMean extends RecursiveAction {
        private static final long serialVersionUID = 1L;
        private final List<Gene> neighborhood;

        public ImputeKNNMean(List<Gene> neighborhood) {
            this.neighborhood = neighborhood;
        }

        @Override
        protected void compute() {
            if (neighborhood.size() <= desc.getMaxp()) {
                Collection<ImputeKNNMeanImpl> tasks = new ArrayList<>();
                for (Gene gene : neighborhood) {
                    if (gene.getNaNs() > 0)
                        tasks.add(new ImputeKNNMeanImpl(neighborhood, gene));
                }
                invokeAll(tasks);
            } else {
                Pair<List<Gene>, List<Gene>> r = twoMeanClusterSplit(neighborhood);
                invokeAll(new ImputeKNNMean(r.getFirst()), new ImputeKNNMean(r.getSecond()));
            }
        }
    }

    /**
     * split the neighbor hood in two groups based on 2 k-means
     *
     * @param neighborhood
     * @return
     */
    private Pair<List<Gene>, List<Gene>> twoMeanClusterSplit(List<Gene> neighborhood) {
        final int n = neighborhood.size();

        final int maxit = desc.getMaxit();
        final double eps = desc.getEps();

        int a_start = r.nextInt(n);
        int b_start = r.nextInt(n);
        Gene a_center = new Gene(1, -1, Arrays.copyOf(neighborhood.get(a_start).data, samples));
        Gene b_center = new Gene(1, -1, Arrays.copyOf(neighborhood.get(b_start).data, samples));
        float[] a_center_pong = new float[samples];
        Arrays.fill(a_center_pong, Float.NaN);
        float[] b_center_pong = new float[samples];
        Arrays.fill(b_center_pong, Float.NaN);

        float[] tmp;
        BitSet partOf_a = new BitSet(n);

        double d_old = 0;
        for (int i = 0; i < maxit; ++i) {
            int j = 0;
            int changed = 0;
            double d_new = 0;
            for (Gene gene : neighborhood) {
                final double a_distance = distance(a_center, gene);
                final double b_distance = distance(b_center, gene);
                final boolean in_a = a_distance < b_distance;
                if (partOf_a.get(j) != in_a) {
                    changed++;
                    partOf_a.set(j, in_a);
                }
                d_new += in_a ? a_distance : b_distance;
                tmp = in_a ? a_center_pong : b_center_pong;
                // shift new center
                for (int k = 0; k < samples; ++k) {
                    if (!gene.isNaN(k)) {
                        if (Float.isNaN(tmp[k]))
                            tmp[k] = gene.get(k);
                        else
                            tmp[k] += gene.get(k);
                    }
                }
                j++;
            }
            if (changed == 0 || d_new == 0)
                break;
            final double ratio = Math.abs(d_new - d_old) / d_old;
            if (i > 0 && ratio < eps)
                break;
            d_old = d_new;
            int a_n = partOf_a.cardinality();
            int b_n = n - a_n;
            if (a_n == 0 || b_n == 0) {
                // FIXME
            }
            updateCenter(a_center, a_center_pong, a_n);
            updateCenter(b_center, b_center_pong, b_n);
        }

        return split(neighborhood, partOf_a);
    }

    private static Pair<List<Gene>, List<Gene>> split(List<Gene> neighborhood, BitSet partOf_a) {
        final int a_n = partOf_a.cardinality();
        final int b_n = neighborhood.size() - a_n;

        final List<Gene> a = new ArrayList<>(a_n);
        final List<Gene> b = new ArrayList<>(b_n);

        for (int i = 0; i < neighborhood.size(); ++i) {
            if (partOf_a.get(i))
                a.add(neighborhood.get(i));
            else
                b.add(neighborhood.get(i));
        }
        return Pair.make(a, b);
    }

    private void updateCenter(Gene center, float[] values, int n) {
        float[] data = center.data;
        for (int i = 0; i < samples; ++i)
            data[i] = values[i] / n;
        Arrays.fill(values, Float.NaN); // reset
    }

    private class ImputeKNNMeanImpl extends RecursiveAction {
        private static final long serialVersionUID = 1L;
        private final List<Gene> neighborhood;
        private final Gene gene;

        public ImputeKNNMeanImpl(List<Gene> neighborhood, Gene gene) {
            this.neighborhood = neighborhood;
            this.gene = gene;
        }

        @Override
        protected void compute() {
            TreeSet<NeighorGene> neighbors = new TreeSet<>();
            final int k = desc.getK();
            for (Gene other : neighborhood) {
                if (other == gene)
                    continue;
                double distance = distance(gene, other);
                if (isInfinite(distance))
                    continue;
                neighbors.add(new NeighorGene(other, distance));
                if (neighbors.size() > k)
                    neighbors.pollLast();
            }
            for (int sample = 0; sample < samples; ++sample) {
                if (gene.isNaN(sample)) {
                    double value = mean(neighbors, sample);
                    if (Double.isNaN(value)) // This can fail if ALL the neighbors are missing in a particular element.
                        // In
                        // this case we use the overall column mean for that block of genes.
                        value = getSample(sample).getMean();
                    gene.setNextNaN(value);
                }
            }
        }
    }

    private double distance(Gene target, Gene neighbor) {
        double acc = 0;
        int n = 0;
        for (int sample = 0; sample < samples; ++sample) {
            if (target.isNaN(sample) || neighbor.isNaN(sample)) // skip missing
                continue;
            double dx = target.get(sample) - neighbor.get(sample);
            acc += dx * dx;
            n++;
        }
        if (n > 0) {
            return acc / n;// FIXME according to the fortran code, this is not the eucledian distance
            // return Math.sqrt(acc);
        }
        return Double.POSITIVE_INFINITY;
    }

    private Sample computeSample(final int sample) {
        int nans = 0;
        double sum = 0;
        int n = 0;
        for (Gene gene : genes) {
            double v = gene.get(sample);
            if (isNaN(v))
                nans++;
            else {
                sum += v;
                n++;
            }
        }
        return new Sample(sum / n, nans);
    }

    private static double mean(Iterable<NeighorGene> neighbors, int sample) {
        double sum = 0;
        int n = 0;
        for (NeighorGene gene : neighbors) {
            double v = gene.get(sample);
            if (!isNaN(v)) {
                sum += v;
                n++;
            }
        }
        return n == 0 ? Double.NaN : sum / n;
    }

    private class ImputeSampleMean extends RecursiveAction {
        private static final long serialVersionUID = 1L;
        private final Gene gene;

        public ImputeSampleMean(Gene gene) {
            this.gene = gene;
        }

        @Override
        protected void compute() {
            for (int sample = 0; sample < samples; ++sample) {
                if (gene.isNaN(sample))
                    gene.setNextNaN(getSample(sample).getMean());
            }
        }
    }

    /**
     * @param sample
     * @return
     */
    private Sample getSample(int sample) {
        return samples_.getUnchecked(sample);
    }

    /**
     * @return
     */
    private boolean toomanyNaNsInAColumn() {
        float colmax = desc.getColmax();
        if (Float.isInfinite(colmax) || Float.isNaN(colmax))
            return false;
        int max = Math.round(colmax * genes.size());
        for (int i = 0; i < samples; ++i) {
            int nans = getSample(i).getNans();
            if (nans > max)
                return true;
        }
        return false;
    }

    public static final class Gene {
        protected final int gene;
        // as we are working from left to right, we just store the replacements for each NaN
        private final float[] nanReplacements;
        private final float[] data;
        private int nanSetCounter = 0;

        public Gene(int gene, int nans, float[] data) {
            this.gene = gene;
            this.nanReplacements = nans < 0 ? null : new float[nans];
            this.data = data;
        }

        /**
         * @param b
         */
        public void fillImpute(Builder<Integer, Integer, Float> b) {
            if (!isAnySet())
                return;
            assert nanSetCounter == nanReplacements.length;
            int j = 0;
            for (int i = 0; i < data.length; ++i)
                if (Float.isNaN(data[i]))
                    b.put(gene, i, nanReplacements[j++]);
        }

        public boolean isAnySet() {
            return nanSetCounter > 0;
        }

        /**
         * @return
         */
        public int size() {
            return data.length;
        }

        public int getNaNs() {
            return nanReplacements.length;
        }

        public void setNextNaN(double value) {
            nanReplacements[nanSetCounter++] = (float) value;
        }

        public boolean isNaN(int sample) {
            return Float.isNaN(data[sample]);
        }

        /**
         * @param sample
         * @return
         */
        public float get(int sample) {
            return data[sample];
        }

        @Override
        public String toString() {
            StringBuilder builder = new StringBuilder();
            builder.append("Gene [gene=");
            builder.append(gene);
            builder.append(", nanSetCounter=");
            builder.append(nanSetCounter);
            builder.append(", nanReplacements=");
            builder.append(Arrays.toString(nanReplacements));
            builder.append("]");
            return builder.toString();
        }
    }

    private static final class NeighorGene implements Comparable<NeighorGene> {
        private final Gene gene;
        protected final double distance;

        public NeighorGene(Gene gene, double distance) {
            this.gene = gene;
            this.distance = distance;
        }

        public double get(int sample) {
            return gene.get(sample);
        }

        @Override
        public int compareTo(NeighorGene o) {
            int r = Double.compare(distance, o.distance);
            if (r == 0)
                return o.gene.gene - o.gene.gene;
            return r;
        }

        @Override
        public String toString() {
            StringBuilder builder = new StringBuilder();
            builder.append("NeighorGene [gene=");
            builder.append(gene);
            builder.append(", distance=");
            builder.append(distance);
            builder.append("]");
            return builder.toString();
        }

    }

    private static final class Sample {
        private final double mean;
        private final int nans;

        public Sample(double mean, int nans) {
            this.mean = mean;
            this.nans = nans;
        }

        public double getMean() {
            return mean;
        }

        public int getNans() {
            return nans;
        }

        @Override
        public String toString() {
            StringBuilder builder = new StringBuilder();
            builder.append("Sample [mean=");
            builder.append(mean);
            builder.append(", nans=");
            builder.append(nans);
            builder.append("]");
            return builder.toString();
        }

    }

    public static void main(String[] args) throws IOException {
        ImmutableList.Builder<Gene> b = ImmutableList.builder();
        List<String> lines = CharStreams
                .readLines(new InputStreamReader(KNNImpute.class.getResourceAsStream("khan.csv")));
        lines = lines.subList(1, lines.size());
        int j = 0;
        for (String line : lines) {
            String[] l = line.split(";");
            float[] d = new float[l.length];
            int nans = 0;
            for (int i = 0; i < l.length; ++i) {
                if ("NA".equals(l[i])) {
                    nans++;
                    d[i] = Float.NaN;
                } else {
                    d[i] = Float.parseFloat(l[i]);
                }
            }
            b.add(new Gene(j++, nans, d));
        }
        final KNNImputeDescription desc2 = new KNNImputeDescription();
        desc2.setMaxp(100000);
        KNNImpute r = new KNNImpute(desc2, b.build());
        ForkJoinPool p = new ForkJoinPool();
        p.invoke(r);
        try (PrintWriter w = new PrintWriter("khan.imputed.csv")) {
            w.println(StringUtils.repeat("sample", ";", r.samples));
            for (Gene g : r.genes) {
                float[] d = g.data;
                int nan = 0;
                w.print(Float.isNaN(d[0]) ? g.nanReplacements[nan++] : d[0]);
                for (int i = 1; i < d.length; ++i)
                    w.append(';').append(String.valueOf(Float.isNaN(d[i]) ? g.nanReplacements[nan++] : d[i]));
                w.println();
            }
        }
    }
}