Monday, January 22, 2018

Writing a Scheduler for Tasks with DAG Dependencies using CompletableFuture and Async I/O

I had to write a task scheduler in Java recently. The tasks had dependencies with each other and the dependencies formed a Directed Acyclic Graph (DAG) of nodes when viewed as a graph. Each task must forward it's results to the node that it is connected to, otherwise the next node can't start executing. For example, if we have a bunch of tasks with dependencies like so:

B -->A
C --> B
D  --> B, C
E -->  A, C, D
F --> A, E

Here, B needs A's results to start, C needs B's results to start, D needs both B's and C's results to start, E needs A, C, and D's results to start, and finally F needs A and E's results to start, after which the whole graph completes execution when F finishes. The final results from the graph is whatever F produces.

So, how do we design a scheduler (and executor) for such a scenario?

Well, here are some of the things I did/used to write the scheduler:

1. Create a graph of nodes where each node represent a task.

2. Make sure the graph is a DAG (Directed Acyclic Graph): i) There is no cycles in the graph, ii) there is at least one node which does not depend on any other node. We call all such nodes the root node(s). Execution starts with the root nodes. Save the indegrees of all nodes. The root nodes are nodes with an indegree of 0. We also save all the nodes whose outdegree is 0. We call these nodes leaf nodes. Our task scheduler will stop when all leaf nodes have been scheduled.

3. I used Java 8's CompletableFuture async API's to maximize parallel running performance and ease of development. If you haven't used them already, Java 8's CompletableFuture API's are full of awesomeness! If you have many tasks at hand to run in parallel they come packed with a rich feature set that allow you to do almost anything you need for a parallel/concurrent programming environment!

4. The key algorithm I used for maintaining the task dependencies is Topological Sorting. Topological sorting is a nifty algorithm that can order a set of nodes with dependencies among them, as long as the nodes do not form any cycle. By following a strategy similar to topological sorting, I was able to start executing tasks at the root node(s), and gradually advance the execution to the dependent nodes who were waiting for the results of the nodes they depended on. To traverse the DAG, I utilized callback functions in a clever way.

Let's look at the algorithm for executing tasks in a topologically sorted graph of tasks:

i) Create a blocking queue of tasks. At the start add all tasks with an indegree of 0 to the queue. We will use this queue to add tasks when all their dependencies are met. Note that we are using a blocking queue so that the task scheduler will wait when there is no tasks available to execute. We will quit when all leaf nodes have been scheduled for execution.

ii) Enter an infinite loop for executing tasks. This loop looks somewhat like this:


while (true) {
    // Waiting for a task to be available.
    // Blocking call to the task queue. Tasks are added to the queue only when they
    // have an indegree of 0.    TaskNode currentTask = taskQueue.take();

    CompletableFuture<T> taskFuture = scheduleTask(currentTask, taskQueue,
                                                  taskResults,
                                                  nodeIndegrees, functionProvider);
    futures.put(currentTask.getName(), taskFuture);

    if (this.leafNodes.contains(currentTask.getName())) {
        leafNodeScheduled++;
        if (leafNodeScheduled == this.leafNodes.size()) {
            // We've scheduled all leaf nodes in the task set. Done!
            break;
        }
    }
}

T is the type of the results that the tasks will produce. The scheduleTask() method actually schedules the tasks asynchronously:

private CompletableFuture<T> scheduleTask(final TaskNode task,
                                          final BlockingQueue<TaskNode> taskQueue,
                                          final Map<String, T> taskResults,
                                          final Map<String, Integer> nodeIndegrees,
                                          final FunctionProvider<T> functionProvider) {
    String taskName = task.getName();

    // Create a dummy future to start the task with desired inputs from all
    // dependencies (tasks that this task depends on)

    List<T> dependenciesRecords = new ArrayList<>();
    for (String dependencyName : task.getDependencies()) {
        dependenciesRecords.addAll(taskResults.getOrDefault(dependencyName, new ArrayList<>()));
    }
    CompletableFuture<T> startCf = CompletableFuture.supplyAsync(() -> dependenciesRecords);
    PostCompletionTask postCompletionTask = new PostCompletionTask(task, taskQueue, taskResults, nodeIndegrees);
    Function<T, T> function = functionProvider.newFunction(taskName);
    CompletableFuture<T> taskFuture = startCf
                                       .thenApplyAsync(function, this.dagExecutor)
                                       .thenApplyAsync(postCompletionTask::whenDone, this.dagExecutor);

    return taskFuture;
} 

We do a bunch of things to make sure the task starts with the results from all the tasks it depends on. Whenever a task completes, we save it's results in a Map, so that we can feed those results to the next task. The simplest async job using Java 8's CompletableFuture is:

CompletableFuture<String> future  = CompletableFuture.supplyAsync(() -> "Hello");

Here, we have a completable future that simply returns the string "Hello". We create a simple async task that simply returns the results of all the previous jobs that the current job depends on:

CompletableFuture<T> startCf = CompletableFuture.supplyAsync(() -> dependenciesRecords);

Then we start the current job which is described by some Function. We chain the current job with a post-completion callback task so that we can decrement the indegree of the next task and start the next task when the it's indegree reaches 0.

PostCompletionTask postCompletionTask = new PostCompletionTask(task, taskQueue, taskResults, nodeIndegrees);

Here is how the PostCompletionTask looks like:

private class PostCompletionTask {
    private final TaskNode taskNode;
    private final BlockingQueue<TaskNode> taskQueue;
    private final Map<String, T> taskResults;
    private final Map<String, Integer> nodeIndegrees;

    public PostCompletionTask(final TaskNode taskNode,
                              final BlockingQueue<TaskNode> taskQueue,
                              final Map<String, T> taskResults,
                              final Map<String, Integer> nodeIndegrees) {
        this.taskNode = taskNode;
        this.taskQueue = taskQueue;
        this.taskResults = taskResults;
        this.nodeIndegrees = nodeIndegrees;
    }

    public T whenDone(T t) {
        Collection<TaskNode> outNodes = edges.get(this.taskNode.getName());

        for (TaskNode outNode : outNodes) {
            synchronized (this.nodeIndegrees) {
                // We need to synchronize the indegree map copy as multiple threads
                // can execute the following block                // and may simultaneously see the value going to 0!

                int currCount = this.nodeIndegrees.getOrDefault(outNode.getName(), -1);
                this.nodeIndegrees.put(outNode.getName(), currCount - 1);
                if ((currCount - 1) == 0) {
                    this.taskQueue.offer(outNode);
                }
            }
        }

        this.taskResults.put(this.taskNode.getName(), t);
        // Return the original results of the task whose completion we're handling.
        return t;
    }
}

The main job of the completion callback (whenDone()) is to decrement the indegree of all the nodes that depends on the current node (whose completion callback we are executing). In doing so, if we find a node whose indegree became 0, we add that node to the blocking task queue to be picked up by our scheduler code for scheduling and execution. This step is one of the core steps in doing topological sorting. When the indegree of a node becomes 0, that node is up next for action!

Note that we also save the task's results in a global map so that we can look it up when we schedule the task which depends on this task. We optionally return the results from the callback in case we want to chain this post completion task to another async tasks.

Overall, it was a great way for me to learn about CompletableFuture async API's in Java 8. Applying graph algorithms on top made it even more enjoyable.

Unfortunately, this is code I wrote for work, so I can't really write the entire code of the scheduler. My goal here is to highlight the algorithms and tools used in the process without revealing any proprietary code.

Hope I was able to illustrate how one can write a task scheduler using Java's CompletableFuture API's. Please let me know if you have any questions.

Adios!