Consume all messages in a System.Threading.Channels.Channel

Suppose I have a many producers, 1 consumer unbound Channel, with a consumer:

await foreach (var message in channel.Reader.ReadAllAsync(cts.Token))
{
    await consume(message);
}

The problem is that the consume function does some IO access and potentially some network access too, thus before 1 message is consumed many more may be produced. But since the IO resources can't be accessed concurently, I can't have many consumers, nor can I throw the consume function into a Task and forget it.

The consume function is such that it can be easily modified to take multiple messages and handle them all in a batch. So my question is if there's a way to make the consumer take all messages in the channel queue whenever it tries to access it, something like this:

while (true) {
    Message[] messages = await channel.Reader.TakeAll();
    await consumeAll(messages);
}

Edit: 1 option that I can come up with, is:

List<Message> messages = new();
await foreach (var message in channel.Reader.ReadAllAsync(cts.Token))
{
    await consume(message);
    Message msg;
    while (channel.Reader.TryRead(out msg))
        messages.Add(msg);
    if (messages.Count > 0)
    {
        await consumeAll(messages);
        messages.Clear();
    }
}

But I feel like thare should be a better way to do this.


After reading Stephen Toub's primer on channels, I had a stab at writing an extension method that should do what you need (It's been a while since I did any C#, so this was fun).

public static class ChannelReaderEx
{
    public static async IAsyncEnumerable<IEnumerable<T>> ReadBatchesAsync<T>(
        this ChannelReader<T> reader, 
        [EnumeratorCancellation] CancellationToken cancellationToken = default
    )
    {
        while (await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
        {
            yield return reader.Flush().ToList();
        }
    }

    public static IEnumerable<T> Flush<T>(this ChannelReader<T> reader)
    {
        while (reader.TryRead(out T item))
        {
            yield return item;
        }
    }
}

which can be used like this:

await foreach (var batch in channel.Reader.ReadBatchesAsync())
{
    await ConsumeBatch(batch);
}

Solving this problem on the ChannelReader<T> level, like in the excellent spender's answer, is practical and sufficient, but solving it on the IAsyncEnumerable<T> level might be a solution with a broader range of applications. Below is an extension method BufferImmediate for asynchronous sequences, that yields non-empty buffers with all the elements that are immediately available at the time the sequence is pulled:

/// <summary>
/// Splits the elements of a sequence into chunks that contain all the elements
/// that are immediately available.
/// </summary>
public static IAsyncEnumerable<IList<TSource>> BufferImmediate<TSource>(
    this IAsyncEnumerable<TSource> source, int maxSize = Int32.MaxValue)
{
    if (source == null) throw new ArgumentNullException(nameof(source));
    if (maxSize < 1) throw new ArgumentOutOfRangeException(nameof(maxSize));
    return Implementation();

    async IAsyncEnumerable<IList<TSource>> Implementation(
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        ValueTask<bool> moveNext = default;
        var enumerator = source.GetAsyncEnumerator(cancellationToken);
        try
        {
            moveNext = enumerator.MoveNextAsync();
            var buffer = new List<TSource>();
            ExceptionDispatchInfo error = null;
            while (true)
            {
                if ((!moveNext.IsCompleted && buffer.Count > 0)
                    || buffer.Count >= maxSize)
                {
                    yield return buffer.ToArray();
                    buffer.Clear();
                }
                else
                {
                    // Await a copy, to prevent a second await on finally.
                    var moveNextCopy = moveNext;
                    moveNext = default;
                    bool moved;
                    try { moved = await moveNextCopy.ConfigureAwait(false); }
                    catch (Exception ex)
                    {
                        error = ExceptionDispatchInfo.Capture(ex); break;
                    }
                    if (!moved) break;
                    buffer.Add(enumerator.Current);
                    try { moveNext = enumerator.MoveNextAsync(); }
                    catch (Exception ex)
                    {
                        error = ExceptionDispatchInfo.Capture(ex); break;
                    }
                }
            }
            if (buffer.Count > 0) yield return buffer.ToArray();
            error?.Throw();
        }
        finally
        {
            // The finally runs when an enumerator created by this method is disposed.
            // Prevent fire-and-forget, otherwise the DisposeAsync() might throw.
            // Swallow MoveNextAsync errors, but propagate DisposeAsync errors.
            try { await moveNext.ConfigureAwait(false); } catch { }
            await enumerator.DisposeAsync().ConfigureAwait(false);
        }
    }
}

Usage example:

await foreach (var batch in channel.Reader.ReadAllAsync().BufferImmediate())
{
    await ConsumeBatch(batch);
}

The above implementation is non-destructive, meaning that no elements that have been consumed from the source sequence can be lost. In case the source sequence fails or the enumeration is canceled, any buffered elements will be emitted before the propagation of the error.