Java Thread Tutorial - Java Fork/Join Framework








The fork/join framework solves problems by taking advantage of the multiple processors or multiple cores on a machine.

The framework helps solve the problems that involve parallelism.

The fork/join framework creates a pool of threads to execute the subtasks.

When a thread is waiting on a subtask to finish, the framework uses that thread to execute other pending subtasks of other threads.

The following four classes in the java.util.concurrent package are central to learning the fork/join framework:

  • ForkJoinPool
  • ForkJoinTask
  • RecursiveAction
  • RecursiveTask

An instance of the ForkJoinPool class represents a thread pool. An instance of the ForkJoinTask class represents a task.

The ForkJoinTask class is an abstract class. It has two concrete subclasses: RecursiveAction and RecursiveTask.

Java 8 added an abstract subclass of the ForkJoinTask class that is called CountedCompleter.

The framework supports two types of tasks: a task that does not yield a result and a task that yields a result.

An instance of the RecursiveAction class represents a task that does not yield a result. An instance of the RecursiveTask class represents a task that yields a result.

A CountedCompleter task may or may not yield a result.

Both classes, RecursiveAction and RecursiveTask, provide an abstract compute() method.

We should inherit from one of these classes and provide an implementation for the compute() method.





Example

The following two methods of the ForkJoinTask class provide two important features during a task execution:

The fork() method launches a new subtask from a task for an asynchronous execution. The join() method lets a task wait for another task to complete.

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
//from   ww w.  ja  v  a  2  s  .c  o m
public class Main {
  public static void main(String[] args) {
    ForkJoinPool pool = new ForkJoinPool();
    IntSum task = new IntSum(3);
    long sum = pool.invoke(task);
    System.out.println("Sum is " + sum);
  }
}

class IntSum extends RecursiveTask<Long> {
  private int count;
  public IntSum(int count) {
    this.count = count;
  }

  @Override
  protected Long compute() {
    long result = 0;

    if (this.count <= 0) {
      return 0L; 
    }else if (this.count == 1) {
      return (long) this.getRandomInteger();
    }
    List<RecursiveTask<Long>> forks = new ArrayList<>();
    for (int i = 0; i < this.count; i++) {
      IntSum subTask = new IntSum(1);
      subTask.fork(); // Launch the subtask
      forks.add(subTask);
    }
    // all subtasks finish and combine the result
    for (RecursiveTask<Long> subTask : forks) {
      result = result + subTask.join();
    }
    return result;
  }

  public int getRandomInteger() {
    return 2;
  }
}

The code above generates the following result.