A collection of interconnected futures in async Rust w/ tokio

Problem statement

I'd like to implement a directed acyclic computation graph framework in async Rust, i.e. an interconnected graph of computation "nodes", each of which takes inputs from predecessor nodes and produces outputs for successor nodes. I was planning to implement this by spawning a collection of Futures, one for each computation node, while allowing dependencies among futures. However, in implementing this framework using async I've become hopelessly lost in compiler errors.

Minimal example

Here's an attempt at a minimal example of what I want to do. There's a single input list of floats values, and the task is to make a new list output where output[i] = values[i] + output[i - 2]. This is what I've tried:

use std::sync;

fn some_complicated_expensive_fn(val1: f32, val2: f32) -> f32 {
    val1 + val2
}

fn example_async(values: &Vec<f32>) -> Vec<f32> {
    let runtime = tokio::runtime::Runtime::new().unwrap();

    let join_handles = sync::Arc::new(sync::Mutex::new(Vec::<tokio::task::JoinHandle<f32>>::new()));
    for (i, value) in values.iter().enumerate() {
        let future = {
            let join_handles = join_handles.clone();
            async move {
                if i < 2 {
                    *value
                } else {
                    let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
                    some_complicated_expensive_fn(*value, prev_value)
                }
            }
        };
        join_handles.lock().unwrap().push(runtime.spawn(future));
    }
    join_handles
        .lock()
        .unwrap()
        .iter_mut()
        .map(|join_handle| runtime.block_on(join_handle).unwrap())
        .collect()
}

#[cfg(test)]
mod tests {
    #[test]
    fn test_example() {
        let values = vec![1., 2., 3., 4., 5., 6.];
        println!("{:?}", super::example_async(&values));
    }
}

I get errors about the unlocked Mutex not being Send:

error: future cannot be sent between threads safely
  --> sim/src/compsim/runtime.rs:23:51
   |
23 |         join_handles.lock().unwrap().push(runtime.spawn(future));
   |                                                   ^^^^^ future created by async block is not `Send`
   |
   = help: within `impl Future`, the trait `Send` is not implemented for `std::sync::MutexGuard<'_, Vec<tokio::task::JoinHandle<f32>>>`
note: future is not `Send` as this value is used across an await
  --> sim/src/compsim/runtime.rs:18:38
   |
18 |                     let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
   |                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ first, await occurs here, with `join_handles.lock().unwrap()` maybe used later...
note: `join_handles.lock().unwrap()` is later dropped here
  --> sim/src/compsim/runtime.rs:18:88
   |
18 |                     let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
   |                                      ----------------------------                      ^
   |                                      |
   |                                      has type `std::sync::MutexGuard<'_, Vec<tokio::task::JoinHandle<f32>>>` which is not `Send`
help: consider moving this into a `let` binding to create a shorter lived borrow
  --> sim/src/compsim/runtime.rs:18:38
   |
18 |                     let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
   |                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This makes sense, and I see in the Tokio docs that you can use a tokio::task::Mutex instead, but a) I'm not sure how, and b) I'm wondering if there's a better overall approach that I'm missing. Help greatly appreciated! Thanks.


Solution 1:

The compiler is complaining that you can't cross an await point with join_handle being locked, which is because the task might be picked up by a different thread after .await, and a lock must be locked and unlocked in the same thread. You could resolve this by making the lock shorter-lived, e.g. by keeping each handle in an Option, taking it before await. But then you run into the issue that awaiting a JoinHandle consumes it - you receive the value that was returned by the task, and you lose the handle, so you can't return it to the vector. (This is a consequence of Rust values having a single owner, so once the handle passes the value to you, it no longer has it and has become useless.)

A handle basically works like a one-shot channel for the result of the spawned task. Since you need the results in one more place, you can separately create a vector of one-shot channels that keeps another copy of results, which can be awaited by tasks that need them.

pub fn example_async(values: &[f32]) -> Vec<f32> {
    let runtime = tokio::runtime::Runtime::new().unwrap();

    let (txs, rxs): (Vec<_>, Vec<_>) = (0..values.len())
        .map(|_| {
            let (tx, rx) = tokio::sync::oneshot::channel();
            (Mutex::new(Some(tx)), Mutex::new(Some(rx)))
        })
        .unzip();
    let txs = Arc::new(txs);
    let rxs = Arc::new(rxs);

    let mut join_handles = vec![];
    for (i, value) in values.iter().copied().enumerate() {
        let txs = Arc::clone(&txs);
        let rxs = Arc::clone(&rxs);
        let future = async move {
            let result = if i < 2 {
                value
            } else {
                let prev_rx = rxs[i - 2].lock().unwrap().take().unwrap();
                let prev_value = prev_rx.await.unwrap();
                some_complicated_expensive_fn(value, prev_value)
            };
            let tx = txs[i].lock().unwrap().take().unwrap();
            tx.send(result).unwrap(); // here you'd use result.clone() for non-Copy result
            result
        };
        join_handles.push(runtime.spawn(future));
    }
    join_handles
        .into_iter()
        .map(|handle| runtime.block_on(handle).unwrap())
        .collect()
}

Playground