Gathering a sequence of unknown length in dask

I would like to implement a parallel optimization algorithm using dask. My goals are:

  • The main optimization loop should run on a worker.
  • The number of optimization steps is not known in advance, and the optimization step must spawn other tasks.
  • The intermediate results should be auditable, so that I can monitor what is happening.

An example code that satisfies all the above criteria is:

from time import sleep
from distributed import Client, get_client


def f(x):
    sleep(0.5)
    return (x - 1)**3


def derivative(x):
    sleep(1)
    return 3 * (x - 1)**2


def newton_optimization(x, fval, dfdx):
    if abs(fval) < 1e-10:
        return x, None
    x = x - fval / dfdx
    client = get_client()
    fval = client.submit(f, x)
    dfdx = client.submit(derivative, x)
    next_step = client.submit(newton_optimization, x, fval, dfdx)
    return x, next_step


client = Client()
task = client.submit(newton_optimization, 0, 1, 3)

while task is not None:
    i, task = task.result()
    print(i)

client.shutdown()

However it doesn't feel elegant for example because in order to check the current state of the optimization, I need to follow the result chain all the way from the start. Is there a better way?


Solution 1:

Maybe you can use Queue or some other coordination primitive to keep track of intermediates.

Also, based on these Dask docs, using the worker_client() context manager here will be more stable.

So, it'll be something like:

from distributed import worker_client, Queue

def newton_optimization(x, fval, dfdx):
    if abs(fval) < 1e-10:
        return x
    with worker_client() as client:
        x = x - fval / dfdx
        fval = client.submit(f, x)
        dfdx = client.submit(derivative, x)
        next_step = client.submit(newton_optimization, x, fval, dfdx)
        queue.put(next_step)
    return x

queue = Queue()

while True:
    future = queue.get()
    print(future.result())

Note that it's usually best to avoid launching tasks from tasks because it may cause reliability issues. That said, it looks like your workflow needs this, so I'll just share that this feature is well supported in Dask (even though the docs say it's experimental -- the docs need to be updated) but do expect some nuances here. :)

Thank you for asking this question, it helped start a good discussion here: https://github.com/dask/distributed/issues/5671