How can I prevent synchronous continuations on a Task?
I have some library (socket networking) code that provides a Task
-based API for pending responses to requests, based on TaskCompletionSource<T>
. However, there's an annoyance in the TPL in that it seems to be impossible to prevent synchronous continuations. What I would like to be able to do is either:
- tell a
TaskCompletionSource<T>
that is should not allow callers to attach withTaskContinuationOptions.ExecuteSynchronously
, or - set the result (
SetResult
/TrySetResult
) in a way that specifies thatTaskContinuationOptions.ExecuteSynchronously
should be ignored, using the pool instead
Specifically, the issue I have is that the incoming data is being processed by a dedicated reader, and if a caller can attach with TaskContinuationOptions.ExecuteSynchronously
they can stall the reader (which affects more than just them). Previously, I have worked around this by some hackery that detects whether any continuations are present, and if they are it pushes the completion onto the ThreadPool
, however this has significant impact if the caller has saturated their work queue, as the completion will not get processed in a timely fashion. If they are using Task.Wait()
(or similar), they will then essentially deadlock themselves. Likewise, this is why the reader is on a dedicated thread rather than using workers.
So; before I try and nag the TPL team: am I missing an option?
Key points:
- I don't want external callers to be able to hijack my thread
- I can't use the
ThreadPool
as an implementation, as it needs to work when the pool is saturated
The example below produces output (ordering may vary based on timing):
Continuation on: Main thread
Press [return]
Continuation on: Thread pool
The problem is the fact that a random caller managed to get a continuation on "Main thread". In the real code, this would be interrupting the primary reader; bad things!
Code:
using System;
using System.Threading;
using System.Threading.Tasks;
static class Program
{
static void Identify()
{
var thread = Thread.CurrentThread;
string name = thread.IsThreadPoolThread
? "Thread pool" : thread.Name;
if (string.IsNullOrEmpty(name))
name = "#" + thread.ManagedThreadId;
Console.WriteLine("Continuation on: " + name);
}
static void Main()
{
Thread.CurrentThread.Name = "Main thread";
var source = new TaskCompletionSource<int>();
var task = source.Task;
task.ContinueWith(delegate {
Identify();
});
task.ContinueWith(delegate {
Identify();
}, TaskContinuationOptions.ExecuteSynchronously);
source.TrySetResult(123);
Console.WriteLine("Press [return]");
Console.ReadLine();
}
}
Solution 1:
New in .NET 4.6:
.NET 4.6 contains a new TaskCreationOptions
: RunContinuationsAsynchronously
.
Since you're willing to use Reflection to access private fields...
You can mark the TCS's Task with the TASK_STATE_THREAD_WAS_ABORTED
flag, which would cause all continuations not to be inlined.
const int TASK_STATE_THREAD_WAS_ABORTED = 134217728;
var stateField = typeof(Task).GetField("m_stateFlags", BindingFlags.NonPublic | BindingFlags.Instance);
stateField.SetValue(task, (int) stateField.GetValue(task) | TASK_STATE_THREAD_WAS_ABORTED);
Edit:
Instead of using Reflection emit, I suggest you use expressions. This is much more readable and has the advantage of being PCL-compatible:
var taskParameter = Expression.Parameter(typeof (Task));
const string stateFlagsFieldName = "m_stateFlags";
var setter =
Expression.Lambda<Action<Task>>(
Expression.Assign(Expression.Field(taskParameter, stateFlagsFieldName),
Expression.Or(Expression.Field(taskParameter, stateFlagsFieldName),
Expression.Constant(TASK_STATE_THREAD_WAS_ABORTED))), taskParameter).Compile();
Without using Reflection:
If anyone's interested, I've figured out a way to do this without Reflection, but it is a bit "dirty" as well, and of course carries a non-negligible perf penalty:
try
{
Thread.CurrentThread.Abort();
}
catch (ThreadAbortException)
{
source.TrySetResult(123);
Thread.ResetAbort();
}
Solution 2:
I don't think there's anything in TPL which would provides explicit API control over TaskCompletionSource.SetResult
continuations. I decided to keep my initial answer for controlling this behavior for async/await
scenarios.
Here is another solution which imposes asynchronous upon ContinueWith
, if the tcs.SetResult
-triggered continuation takes place on the same thread the SetResult
was called on:
public static class TaskExt
{
static readonly ConcurrentDictionary<Task, Thread> s_tcsTasks =
new ConcurrentDictionary<Task, Thread>();
// SetResultAsync
static public void SetResultAsync<TResult>(
this TaskCompletionSource<TResult> @this,
TResult result)
{
s_tcsTasks.TryAdd(@this.Task, Thread.CurrentThread);
try
{
@this.SetResult(result);
}
finally
{
Thread thread;
s_tcsTasks.TryRemove(@this.Task, out thread);
}
}
// ContinueWithAsync, TODO: more overrides
static public Task ContinueWithAsync<TResult>(
this Task<TResult> @this,
Action<Task<TResult>> action,
TaskContinuationOptions continuationOptions = TaskContinuationOptions.None)
{
return @this.ContinueWith((Func<Task<TResult>, Task>)(t =>
{
Thread thread = null;
s_tcsTasks.TryGetValue(t, out thread);
if (Thread.CurrentThread == thread)
{
// same thread which called SetResultAsync, avoid potential deadlocks
// using thread pool
return Task.Run(() => action(t));
// not using thread pool (TaskCreationOptions.LongRunning creates a normal thread)
// return Task.Factory.StartNew(() => action(t), TaskCreationOptions.LongRunning);
}
else
{
// continue on the same thread
var task = new Task(() => action(t));
task.RunSynchronously();
return Task.FromResult(task);
}
}), continuationOptions).Unwrap();
}
}
Updated to address the comment:
I don't control the caller - I can't get them to use a specific continue-with variant: if I could, the problem would not exist in the first place
I wasn't aware you don't control the caller. Nevertheless, if you don't control it, you're probably not passing the TaskCompletionSource
object directly to the caller, either. Logically, you'd be passing the token part of it, i.e. tcs.Task
. In which case, the solution might be even easier, by adding another extension method to the above:
// ImposeAsync, TODO: more overrides
static public Task<TResult> ImposeAsync<TResult>(this Task<TResult> @this)
{
return @this.ContinueWith(new Func<Task<TResult>, Task<TResult>>(antecedent =>
{
Thread thread = null;
s_tcsTasks.TryGetValue(antecedent, out thread);
if (Thread.CurrentThread == thread)
{
// continue on a pool thread
return antecedent.ContinueWith(t => t,
TaskContinuationOptions.None).Unwrap();
}
else
{
return antecedent;
}
}), TaskContinuationOptions.ExecuteSynchronously).Unwrap();
}
Use:
// library code
var source = new TaskCompletionSource<int>();
var task = source.Task.ImposeAsync();
// ...
// client code
task.ContinueWith(delegate
{
Identify();
}, TaskContinuationOptions.ExecuteSynchronously);
// ...
// library code
source.SetResultAsync(123);
This actually works for both await
and ContinueWith
(fiddle) and is free of reflection hacks.
Solution 3:
What about instead of doing
var task = source.Task;
you do this instead
var task = source.Task.ContinueWith<Int32>( x => x.Result );
Thus you are always adding one continuation which will be executed asynchronously and then it doesn't matter if the subscribers want a continuation in the same context. It's sort of currying the task, isn't it?
Solution 4:
The simulate abort approach looked really good, but led to the TPL hijacking threads in some scenarios.
I then had an implementation that was similar to checking the continuation object, but just checking for any continuation since there are actually too many scenarios for the given code to work well, but that meant that even things like Task.Wait
resulted in a thread-pool lookup.
Ultimately, after inspecting lots and lots of IL, the only safe and useful scenario is the SetOnInvokeMres
scenario (manual-reset-event-slim continuation). There are lots of other scenarios:
- some aren't safe, and lead to thread hijacking
- the rest aren't useful, as they ultimately lead to the thread-pool
So in the end, I opted to check for a non-null continuation-object; if it is null, fine (no continuations); if it is non-null, special-case check for SetOnInvokeMres
- if it is that: fine (safe to invoke); otherwise, let the thread-pool perform the TrySetComplete
, without telling the task to do anything special like spoofing abort. Task.Wait
uses the SetOnInvokeMres
approach, which is the specific scenario we want to try really hard not to deadlock.
Type taskType = typeof(Task);
FieldInfo continuationField = taskType.GetField("m_continuationObject", BindingFlags.Instance | BindingFlags.NonPublic);
Type safeScenario = taskType.GetNestedType("SetOnInvokeMres", BindingFlags.NonPublic);
if (continuationField != null && continuationField.FieldType == typeof(object) && safeScenario != null)
{
var method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
var il = method.GetILGenerator();
var hasContinuation = il.DefineLabel();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
Label nonNull = il.DefineLabel(), goodReturn = il.DefineLabel();
// check if null
il.Emit(OpCodes.Brtrue_S, nonNull);
il.MarkLabel(goodReturn);
il.Emit(OpCodes.Ldc_I4_1);
il.Emit(OpCodes.Ret);
// check if is a SetOnInvokeMres - if so, we're OK
il.MarkLabel(nonNull);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
il.Emit(OpCodes.Isinst, safeScenario);
il.Emit(OpCodes.Brtrue_S, goodReturn);
il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Ret);
IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));