info.debatty.spark.knngraphs.builder.Online.java Source code

Java tutorial

Introduction

Here is the source code for info.debatty.spark.knngraphs.builder.Online.java

Source

/*
 * The MIT License
 *
 * Copyright 2016 Thibault Debatty.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
package info.debatty.spark.knngraphs.builder;

import info.debatty.java.graphs.Graph;
import info.debatty.java.graphs.Neighbor;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import info.debatty.java.graphs.SimilarityInterface;
import info.debatty.spark.knngraphs.ApproximateSearch;
import info.debatty.spark.knngraphs.BalancedKMedoidsPartitioner;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;

/**
 *
 * @author Thibault Debatty
 * @param <T>
 */
public class Online<T> {

    /**
     * Key used to store the sequence number of the nodes. Used by the window
     * algorithm to remove the nodes.
     */
    public static final String NODE_SEQUENCE_KEY = "ONLINE_SEQ_KEY";

    private static final int PARTITIONING_ITERATIONS = 5;
    private static final int DEFAULT_SEARCH_SPEEDUP = 4;
    private static final double DEFAULT_MEDOID_UPDATE_RATIO = 0.1;

    // Number of nodes to add before performing a checkpoint
    // (to strip RDD DAG)
    private static final int ITERATIONS_BETWEEN_CHECKPOINTS = 100;

    // the search algorithm also contains a reference to the current graph
    private final ApproximateSearch<T> searcher;
    private final int k;
    private final SimilarityInterface<T> similarity;
    // Number of nodes to add before recomputing centroids
    private double medoid_update_ratio = DEFAULT_MEDOID_UPDATE_RATIO;

    private final long[] partitions_size;
    private final LinkedList<JavaRDD<Graph<T>>> previous_rdds;

    private int search_speedup = DEFAULT_SEARCH_SPEEDUP;
    private long nodes_added_or_removed;
    private long nodes_before_update_medoids;
    private int window_size = 0;

    /**
     *
     * @param k number of edges per node
     * @param similarity similarity to use for computing edges
     * @param sc spark context
     * @param initial_graph initial graph
     * @param partitioning_medoids number of partitions
     */
    public Online(final int k, final SimilarityInterface<T> similarity, final JavaSparkContext sc,
            final JavaPairRDD<Node<T>, NeighborList> initial_graph, final int partitioning_medoids) {

        this.nodes_added_or_removed = 0;
        this.similarity = similarity;
        this.k = k;

        // Use the distributed search algorithm to partition the graph
        this.searcher = new ApproximateSearch<T>(initial_graph, PARTITIONING_ITERATIONS, partitioning_medoids,
                similarity);

        sc.setCheckpointDir("/tmp/checkpoints");

        this.partitions_size = getPartitionsSize(searcher.getGraph());
        this.previous_rdds = new LinkedList<JavaRDD<Graph<T>>>();
        this.nodes_before_update_medoids = computeNodesBeforeUpdate();
    }

    /**
     * Get the size of the window.
     * @return
     */
    public final int getWindowSize() {
        return window_size;
    }

    /**
     * Set the size of the window (for removing a point when a new point is
     * added to graph).
     * @param window_size
     */
    public final void setWindowSize(final int window_size) {
        this.window_size = window_size;
    }

    /**
     * Get the total number of nodes in the online graph.
     * @return total number of nodes in the graph
     */
    public final long getSize() {
        long agg = 0;
        for (long value : partitions_size) {
            agg += value;
        }
        return agg;
    }

    /**
     * Set the speedup of the search step to add a node (default: 4).
     * @param search_speedup speedup
     */
    public final void setSearchSpeedup(final int search_speedup) {
        this.search_speedup = search_speedup;
    }

    /**
     * Set the ratio of nodes to add to the graph before recomputing the
     * medoids (default: 0.1).
     * @param update_ratio [0.0 ...] (0 = disable medoid update)
     */
    public final void setMedoidUpdateRatio(final double update_ratio) {
        if (update_ratio < 0) {
            throw new InvalidParameterException("Update ratio must be >= 0!");
        }
        this.medoid_update_ratio = update_ratio;
        this.nodes_before_update_medoids = computeNodesBeforeUpdate();
    }

    /**
     * Add a node to the graph using fast distributed algorithm.
     * @param node to add to the graph
     */
    public final void fastAdd(final Node<T> node) {

        if (window_size != 0) {
            fastRemove((Integer) node.getAttribute(NODE_SEQUENCE_KEY) - window_size);
        }

        // Find the neighbors of this node
        NeighborList neighborlist = searcher.search(node, k, search_speedup);

        // Assign the node to a partition (most similar medoid, with partition
        // size constraint)
        searcher.assign(node, partitions_size);

        // bookkeeping: update the counts
        partitions_size[(Integer) node.getAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY)]++;

        // update the existing graph edges
        JavaRDD<Graph<T>> updated_graph = searcher.getGraph()
                .map(new UpdateFunction<T>(node, neighborlist, similarity));

        // Add the new <Node, NeighborList> to the distributed graph
        updated_graph = updated_graph.map(new AddNode(node, neighborlist));

        // From now on, use the new graph...
        searcher.setGraph(updated_graph);

        //  truncate RDD DAG (would cause a stack overflow, even with caching)
        if ((nodes_added_or_removed % ITERATIONS_BETWEEN_CHECKPOINTS) == 0) {
            updated_graph.checkpoint();
        }

        // Keep a track of updated RDD to unpersist after two iterations
        previous_rdds.add(updated_graph);
        if (nodes_added_or_removed > 2) {
            previous_rdds.pop().unpersist();
        }

        nodes_added_or_removed++;

        nodes_before_update_medoids--;
        if (nodes_before_update_medoids == 0) {
            searcher.getPartitioner().computeNewMedoids(updated_graph);
            nodes_before_update_medoids = computeNodesBeforeUpdate();
        }
    }

    /**
     * Remove a node using fast approximate algorithm.
     * @param node_to_remove
     */
    public final void fastRemove(final Node<T> node_to_remove) {
        // find the list of nodes to update
        List<Node<T>> nodes_to_update = searcher.getGraph().flatMap(new FindNodesToUpdate(node_to_remove))
                .collect();

        // build the list of candidates
        LinkedList<Node<T>> initial_candidates = new LinkedList<Node<T>>();
        initial_candidates.add(node_to_remove);
        initial_candidates.addAll(nodes_to_update);

        // In spark 1.6.0 the list returned by collect causes an
        // UnsupportedOperationException when you try to remove :(
        LinkedList<Node<T>> candidates = new LinkedList<Node<T>>(
                searcher.getGraph().flatMap(new SearchNeighbors(initial_candidates)).collect());

        // Find the partition corresponding to node_to_remove
        // The balanced kmedoids partitioner wrote this information in the
        // attributes of the node, in the distributed graph
        // hence not necessarily in node_to_remove provided as parameter...
        // This is dirty :(
        int partition_of_node_to_remove = 0;
        for (Node<T> node : candidates) {
            if (node.equals(node_to_remove)) {
                if (null != node.getAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY)) {
                    partition_of_node_to_remove = (Integer) node
                            .getAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY);
                    break;
                }
            }
        }

        while (candidates.contains(node_to_remove)) {
            candidates.remove(node_to_remove);
        }
        // update the graph and remove the node
        JavaRDD<Graph<T>> updated_graph = searcher.getGraph()
                .map(new RemoveUpdate(node_to_remove, nodes_to_update, candidates));
        searcher.setGraph(updated_graph);

        // bookkeeping: update the counts
        partitions_size[partition_of_node_to_remove]--;

        //  truncate RDD DAG (would cause a stack overflow, even with caching)
        if ((nodes_added_or_removed % ITERATIONS_BETWEEN_CHECKPOINTS) == 0) {
            updated_graph.checkpoint();
        }

        // Keep a track of updated RDD to unpersist after two iterations
        previous_rdds.add(updated_graph);
        if (nodes_added_or_removed > 2) {
            previous_rdds.pop().unpersist();
        }

        nodes_added_or_removed++;
    }

    /**
     * Remove a node using the node sequence number instead of the node itself.
     * Used by the sliding window algorithm.
     * @param node_sequence
     */
    private void fastRemove(final long node_sequence) {
        // This is not really efficient :(
        List<Node<T>> nodes = searcher.getGraph().flatMap(new FindNode(node_sequence)).collect();

        if (nodes.isEmpty()) {
            System.out.println("Node sequence not found: " + node_sequence);
            return;
        }

        fastRemove(nodes.get(0));
    }

    /**
     * Get the current graph, represented as a RDD of Graph.
     * @return the current graph
     */
    public final JavaRDD<Graph<T>> getDistributedGraph() {
        return searcher.getGraph();
    }

    /**
     * Get the current graph, converted to a RDD of Tuples (Node, NeighborList).
     * @return
     */
    public final JavaPairRDD<Node<T>, NeighborList> getGraph() {
        return searcher.getGraph().flatMapToPair(new MergeGraphs());
    }

    private long[] getPartitionsSize(final JavaRDD<Graph<T>> graph) {

        List<Long> counts_list = graph.map(new SubgraphSizeFunction()).collect();

        long[] result = new long[counts_list.size()];
        for (int i = 0; i < result.length; i++) {
            result[i] = counts_list.get(i);
        }
        return result;
    }

    /**
     * Compute the number of nodes that can be added before we recompute the
     * medoids (depends on current size and medoid_update_ratio).
     */
    private long computeNodesBeforeUpdate() {
        if (medoid_update_ratio == 0.0) {
            return Long.MAX_VALUE;
        }

        return (long) (getSize() * medoid_update_ratio);

    }
}

/**
 * Used to count the number of nodes in each partition, when we initialize the
 * distributed online graph. Returns the size of each subgraph.
 * @author Thibault Debatty
 * @param <T>
 */
class SubgraphSizeFunction<T> implements Function<Graph<T>, Long> {

    public Long call(final Graph<T> subgraph) {
        return new Long(subgraph.size());
    }
}

/**
 *
 * @author Thibault Debatty
 * @param <T>
 */
class AddNode<T> implements Function<Graph<T>, Graph<T>> {
    private final Node<T> node;
    private final NeighborList neighborlist;

    AddNode(final Node<T> node, final NeighborList neighborlist) {
        this.node = node;
        this.neighborlist = neighborlist;
    }

    public Graph<T> call(final Graph<T> graph) {
        Node<T> one_node = graph.getNodes().iterator().next();

        if (node.getAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY)
                .equals(one_node.getAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY))) {
            graph.put(node, neighborlist);
        }

        return graph;
    }
}

/**
 * Used to find the node corresponding to a given sequence number.
 * @author Thibault Debatty
 * @param <T>
 */
class FindNode<T> implements FlatMapFunction<Graph<T>, Node<T>> {
    private final long sequence_number;

    FindNode(final long sequence_number) {
        this.sequence_number = sequence_number;
    }

    public Iterable<Node<T>> call(final Graph<T> subgraph) {

        LinkedList<Node<T>> result = new LinkedList<Node<T>>();
        for (Node<T> node : subgraph.getNodes()) {
            Integer node_sequence = (Integer) node.getAttribute(Online.NODE_SEQUENCE_KEY);
            //System.out.println(node_sequence);
            if (node_sequence == sequence_number) {
                result.add(node);
                return result;
            }
        }
        return result;
    }
}

/**
 * Update the graph when adding a node.
 * @author Thibault Debatty
 * @param <T>
 */
class UpdateFunction<T> implements Function<Graph<T>, Graph<T>> {

    private static final int UPDATE_DEPTH = 2;

    private final NeighborList neighborlist;
    private final SimilarityInterface<T> similarity;
    private final Node<T> node;

    UpdateFunction(final Node<T> node, final NeighborList neighborlist, final SimilarityInterface<T> similarity) {

        this.node = node;
        this.neighborlist = neighborlist;
        this.similarity = similarity;
    }

    public Graph<T> call(final Graph<T> local_graph) {

        // Nodes to analyze at this iteration
        LinkedList<Node<T>> analyze = new LinkedList<Node<T>>();

        // Nodes to analyze at next iteration
        LinkedList<Node<T>> next_analyze = new LinkedList<Node<T>>();

        // List of already analyzed nodes
        HashMap<Node<T>, Boolean> visited = new HashMap<Node<T>, Boolean>();

        // Fill the list of nodes to analyze
        for (Neighbor neighbor : neighborlist) {
            analyze.add(neighbor.node);
        }

        for (int depth = 0; depth < UPDATE_DEPTH; depth++) {
            while (!analyze.isEmpty()) {
                Node other = analyze.pop();
                NeighborList other_neighborlist = local_graph.get(other);

                // This part of the graph is in another partition :-(
                if (other_neighborlist == null) {
                    continue;
                }

                // Add neighbors to the list of nodes to analyze
                // at next iteration
                for (Neighbor other_neighbor : other_neighborlist) {
                    if (!visited.containsKey(other_neighbor.node)) {
                        next_analyze.add(other_neighbor.node);
                    }
                }

                // Try to add the new node (if sufficiently similar)
                other_neighborlist.add(new Neighbor(node, similarity.similarity(node.value, (T) other.value)));

                visited.put(other, Boolean.TRUE);
            }

            analyze = next_analyze;
            next_analyze = new LinkedList<Node<T>>();
        }

        return local_graph;
    }
}

/**
 * In this Spark implementation, the distributed graph is stored as a RDD of
 * subgraphs, this function collects the subgraphs and returns a single graph,
 * represented as an RDD of tuples (Node, NeighborList).
 * This function is used by the method Online.getGraph().
 * @author Thibault Debatty
 * @param <T>
 */
class MergeGraphs<T> implements PairFlatMapFunction<Graph<T>, Node<T>, NeighborList> {

    public Iterable<Tuple2<Node<T>, NeighborList>> call(final Graph<T> graph) {

        ArrayList<Tuple2<Node<T>, NeighborList>> list = new ArrayList<Tuple2<Node<T>, NeighborList>>(graph.size());

        for (Map.Entry<Node<T>, NeighborList> entry : graph.entrySet()) {
            list.add(new Tuple2<Node<T>, NeighborList>(entry.getKey(), entry.getValue()));
        }

        return list;
    }
}

/**
 * Used by fastRemove to find the nodes that should be updated.
 * @author Thibault Debatty
 * @param <T>
 */
class FindNodesToUpdate<T> implements FlatMapFunction<Graph<T>, Node<T>> {
    private final Node<T> node_to_remove;

    FindNodesToUpdate(final Node<T> node_to_remove) {
        this.node_to_remove = node_to_remove;
    }

    public Iterable<Node<T>> call(final Graph<T> subgraph) {
        LinkedList<Node<T>> nodes_to_update = new LinkedList<Node<T>>();
        for (Node<T> node : subgraph.getNodes()) {
            if (subgraph.get(node).containsNode(node_to_remove)) {
                nodes_to_update.add(node);
            }
        }

        return nodes_to_update;
    }
}

/**
 * Search neighbors from a list of starting points, up to a fixed depth.
 * Used in Online.fastRemove(node) to search the candidates.
 * @author Thibault Debatty
 * @param <T>
 */
class SearchNeighbors<T> implements FlatMapFunction<Graph<T>, Node<T>> {
    private static final int SEARCH_DEPTH = 3;

    private final LinkedList<Node<T>> starting_points;

    SearchNeighbors(final LinkedList<Node<T>> initial_candidates) {
        this.starting_points = initial_candidates;
    }

    public Iterable<Node<T>> call(final Graph<T> subgraph) {
        return subgraph.findNeighbors(starting_points, SEARCH_DEPTH);
    }
}

/**
 * When removing a node, update the subgraphs: remove the node, and assign
 * a new neighbor to nodes that had this node as neighbor.
 * @author Thibault Debatty
 * @param <T>
 */
class RemoveUpdate<T> implements Function<Graph<T>, Graph<T>> {
    private final Node<T> node_to_remove;
    private final List<Node<T>> nodes_to_update;
    private final List<Node<T>> candidates;

    RemoveUpdate(final Node<T> node_to_remove, final List<Node<T>> nodes_to_update,
            final List<Node<T>> candidates) {

        this.node_to_remove = node_to_remove;
        this.nodes_to_update = nodes_to_update;
        this.candidates = candidates;

    }

    public Graph<T> call(final Graph<T> subgraph) {

        // Remove the node (if present in this subgraph)
        subgraph.getHashMap().remove(node_to_remove);

        for (Node<T> node_to_update : nodes_to_update) {
            if (!subgraph.containsKey(node_to_update)) {
                // This node belongs to another subgraph => skip
                continue;
            }

            NeighborList nl_to_update = subgraph.get(node_to_update);

            // Remove the old node
            nl_to_update.removeNode(node_to_remove);

            // Replace the old node by the best candidate
            for (Node<T> candidate : candidates) {
                if (candidate.equals(node_to_update)) {
                    continue;
                }

                double similarity = subgraph.getSimilarity().similarity(node_to_update.value, candidate.value);

                nl_to_update.add(new Neighbor(candidate, similarity));
            }
        }

        return subgraph;
    }
}