Pull to refresh

Simplify working with parallel tasks in C# (updated)

Reading time 7 min
Views 22K

image


No doubts that async/await pattern has significantly simplified working with asynchronous operations in C#. However, this simplification relates only to the situation when asynchronous operations are executed consequently. If we need to execute several asynchronous operations simultaneously (e.g. we need to call several micro-services) then we do not have many built-in capabilities and most probably Task.WhenAll will be used:


Task<SomeType1> someAsyncOp1 = SomeAsyncOperation1();
Task<SomeType2> someAsyncOp2 = SomeAsyncOperation2();
Task<SomeType3> someAsyncOp3 = SomeAsyncOperation3();
Task<SomeType4> someAsyncOp4 = SomeAsyncOperation4();
await Task.WhenAll(someAsyncOp1, someAsyncOp2, someAsyncOp4);
var result = new SomeContainer(
     someAsyncOp1.Result,someAsyncOp2.Result,someAsyncOp3.Result, someAsyncOp4.Result);

This is a working solution, but it is quite verbose and not very reliable (you can forget to add a new task to “WhenAll”). I would prefer something like that instead:


var result =  await 
    from r1 in SomeAsyncOperation1()
    from r2 in SomeAsyncOperation2()
    from r3 in SomeAsyncOperation3()
    from r4 in SomeAsyncOperation4()
    select new SomeContainer(r1, r2, r3, r4);

Further I will tell you what is necessary for this construction to work...


First, we should remember that C# query syntax is just a syntax sugar over method call chains and C# preprocessor will convert the previous statement into a cain of SelectMany calls:


SomeAsyncOperation1()/*Task<T1>*/
  .SelectMany(
      (r1/*T1*/) => SomeAsyncOperation2()/*Task<T2>*/,
      (r1/*T1*/, r2/*T2*/) => new {r1, r2}/*Anon type 1*/)
  .SelectMany(
      (t/*Anon type 1*/) => SomeAsyncOperation3()/*Task<T3>*/,
      (t/*Anon type 1*/, r3/*T3*/) => new {t, r3}/*Anon type 2*/)
  .SelectMany(
      (t/*Anon type 2*/) => SomeAsyncOperation4()/*Task<T4>*/, 
      (t/*Anon type 2*/, r4/*T4*/) => new SomeContainer(t.t.r1, t.t.r2, t.r3, r4));

Скрытый текст

By default C# Compiler will complain about this code since SelectMany extension method is defined only for IEnumerable interface, but nothing prevents us to create our own overloads of SelectMany to make the code compilable.


Each SelectMany function in the chain has two arguments*:


  • the first argument is a link to a function which returns a next asynchronous operation

(t/*Anon type 1*/) => SomeAsyncOperation3(),/*Task<T2>*/


  • the second argument is a link to a function that combines results of previous asynchronous operations with a result of the operation returned by the function which is passed as the first argument.

(t/*Anon type 1*/, r3/*T3*/) => new {t, r3}/*Anon type 2*/)


We can call the first functions to get a list of tasks which will be used in Task.WhenAll and then call the second mapping functions to build a result.


To get the task list and the list of mapping functions our SelectMany overload needs to return some objects which will contain links to a task and a mapping function. In addition to that, SelectMany receives a links to a previous object as this argument (in case if SelectMany is an extension method) — we also need this link to build a linked list which will
contain all required data.


static class TaskAllExtensions
{
    public static ITaskAccumulator<TRes> SelectMany<TCur, TNext, TRes>(
        this ITaskAccumulator<TCur> source, 
        Func<TCur, Task<TNext>> getNextTaskFunc, 
        Func<TCur, TNext, TRes> mapperFunc) 
    => 
        new TaskAccumulator<TCur, TNext, TRes>(
            prev: source, 
            currentTask: getNextTaskFunc(default(TCur)), 
            mapper: mapperFunc);
}

class TaskAccumulator<TPrev, TCur, TRes> : ITaskAccumulator<TRes>
{
    public readonly ITaskAccumulator<TPrev> Prev;

    public readonly Task<TCur> CurrentTask;

    public readonly Func<TPrev, TCur, TRes> Mapper;
    ...
}


The initial accumulator

The initial accumulator differs from the subsequent ones since It cannot have a link to a previous accumulator, but it should have a link to the first task:


static class TaskAllExtensions
{
    ...
    public static ITaskAccumulator<TRes> SelectMany<TCur, TNext, TRes>(
        this Task<TCur> source, 
        Func<TCur, Task<TNext>> getNextTaskFunc, 
        Func<TCur, TNext, TRes> mapperFunc) 
    => 
        new TaskAccumulatorInitial<TCur, TNext, TRes>(
            task1: source, 
            task2: getNextTaskFunc(default(TCur)), 
            mapper: mapperFunc);
    ...
}

class TaskAccumulatorInitial<TPrev, TCur, TRes> : ITaskAccumulator<TRes>
{
    public readonly Task<TPrev> Task1;

    public readonly Task<TCur> Task2;

    public readonly Func<TPrev, TCur, TRes> Mapper;
    ...
}

Now we can get the result by adding these methods:


class TaskAccumulator<TPrev, TCur, TRes> : ITaskAccumulator<TRes>
{
    public async Task<TRes> Result()
    {
        await Task.WhenAll(this.Tasks);
        return this.ResultSync();
    }

    internal IEnumerable<Task> Tasks 
        => new Task[] { this.CurrentTask }.Concat(this.Prev.Tasks);

    internal TRes ResultSync() 
        => this.Mapper(this.Prev.ResultSync(), this.CurrentTask.Result);
    ...
    public readonly ITaskAccumulator<TPrev> Prev;
    public readonly Task<TCur> CurrentTask;
    public readonly Func<TPrev, TCur, TRes> Mapper;        
}

  • Tasks property returns all tasks from the entire linked list.
  • ResultSync() recursively applies mapper functions to the task results (all the tasks are supposed to be already resolved).
  • Result() resolves all tasks (through await Task.WhenAll(Tasks)) and returns result of ResultSync()

Also, we can add a simple extension method to make await work with ITaskAccumulator:


static class TaskAllExtensions
{
    ...
    public static TaskAwaiter<T> GetAwaiter<T>(this ITaskAccumulator<T> source)
        => source.Result().GetAwaiter();
}

Now the code is working:


var result =  await 
    from r1 in SomeAsyncOperation1()
    from r2 in SomeAsyncOperation2()
    from r3 in SomeAsyncOperation3()
    from r4 in SomeAsyncOperation4()
    select new SomeContainer(r1, r2, r3, r4);

However, there is an issue here — C# allows using an intermediate result as an argument for further operations. For example:


    from r2 in SomeAsyncOperation2()
    from r3 in SomeAsyncOperation3(r2)

Such code will lead to "Null Reference Exception" since r2 is not yet resolved at the moment when SomeAsyncOperation3 is called:


... 
task2:getNextTaskFunc(default(TCur)),
...

since all the tasks are run in parallel.


Unfortunately, I do not see a solution for that problem in the current state of C# language, but we can mitigate it by dividing tasks in two groups:


  1. Tasks that are executed in parallel
  2. Tasks that are executed consequently (which can use all previous results).

To do that let's introduce the two simple wrappers over a task:


public struct ParallelTaskWrapper<T>
{
    public readonly Task<T> Task;

    internal ParallelTaskWrapper(Task<T> task) => this.Task = task;
}

public struct SequentialTaskWrapper<T>
{
    public readonly Task<T> Task;

    public SequentialTaskWrapper(Task<T> task) => this.Task = task;
}

Helpers
public static ParallelTaskWrapper<T> AsParallel<T>(this Task<T> task)
{
    return new ParallelTaskWrapper<T>(task);
}

public static SequentialTaskWrapper<T> AsSequential<T>(this Task<T> task)
{
    return new SequentialTaskWrapper<T>(task);
}

The only purpose of the tasks is to specify what '''SelectMany''' overloads should be used:


public static ITaskAccumulator<TRes> SelectMany<TCur, TNext, TRes>(
    this ITaskAccumulator<TCur> source, 
    Func<TCur, ParallelTaskWrapper<TNext>> exec, 
    Func<TCur, TNext, TRes> mapper)
...    

and (it is a new overload):


public static ITaskAccumulator<TRes> SelectMany<TCur, TNext, TRes>(
    this ITaskAccumulator<TCur> source, 
    Func<TCur, SequentialTaskWrapper<TNext>> exec, 
    Func<TCur, TNext, TRes> mapper)
{
    return new SingleTask<TRes>(BuildTask());

    async Task<TRes> BuildTask()
    {
        var arg1 = await source.Result();
        var arg2 = await exec(arg1).Task;
        return mapper(arg1, arg2);
    }
}

SingleTask
internal class SingleTask<T> : ITaskAccumulator<T>
{
    private readonly Task<T> _task;
    private readonly Task[] _tasks;

    public SingleTask(Task<T> task)
    {
        this._task = task;
        this._tasks = new Task[] { task };
    }

    public Task<T> Result() => this._task;
    public IEnumerable<Task> Tasks => this._tasks;
    public T ResultSync() => this._task.Result;
}

As you see all previous tasks are resolved trough var arg1/*Anon Type X*/ = await source.Result();, so they can be used to retrieve a next task and the code bellow will work properly:


var result =  await 
    from r1 in SomeAsyncOperation1().AsParallel()
    from r2 in SomeAsyncOperation2().AsParallel()
    from r3 in SomeAsyncOperation3().AsParallel()
    from r4 in SomeAsyncOperation4(r1, r2, r3).AsSequential()
    from r5 in SomeAsyncOperation5().AsParallel()
    select new SomeContainer(r1, r2, r3, r4, r5);

Update (Getting rid of Task.WhenAll)


We introduced the task accumulator to get a list of tasks be able to call Task.WhenAll over them. But do we really need it? Actually, we do not!
The thing is that once we received a link to the task it is already started execution and all the task below are running in parallel (the code from the beginning):


Task<SomeType1> someAsyncOp1 = SomeAsyncOperation1();
Task<SomeType2> someAsyncOp2 = SomeAsyncOperation2();
Task<SomeType3> someAsyncOp3 = SomeAsyncOperation3();
Task<SomeType4> someAsyncOp4 = SomeAsyncOperation4();

But instead of Task.WhenAll we can use several await-s:


SomeType1 op1Result = await someAsyncOp1;
SomeType2 op2Result = await someAsyncOp2;
SomeType3 op3Result = await someAsyncOp3;
SomeType4 op4Result = await someAsyncOp4;

await immediately returns a result if a task is already resolved or waits till an asynchronous operation is completed, so the code will take the same amount of time as if Task.WhenAll was used.


That fact allows us to significantly simplify the code and get rid of the task accumulator:


static class TaskAllExtensions
{
    public static ParallelTaskWrapper<TRes> SelectMany<TCur, TNext, TRes>(
        this ParallelTaskWrapper<TCur> source, 
        Func<TCur, ParallelTaskWrapper<TNext>> exec, 
        Func<TCur, TNext, TRes> mapper)
    {
        async Task<TRes> GetResult()
        {
            var nextTask = exec(default(TCur));//<--Important!
            return mapper(await source.Task, await nextTask);
        }
        return new ParallelTaskWrapper<TRes>(GetResult());
    }

    public static ParallelTaskWrapper<TRes> SelectMany<TCur, TNext, TRes>(
        this ParallelTaskWrapper<TCur> source, 
        Func<TCur, SequentialTaskWrapper<TNext>> exec, 
        Func<TCur, TNext, TRes> mapper)
    {
        async Task<TRes> GetResult()
        {
            return mapper(await source, await exec(await source).Task);
        }
        return new ParallelTaskWrapper<TRes>(GetResult());
    }

    public static TaskAwaiter<T> GetAwaiter<T>(
        this ParallelTaskWrapper<T> source)
        => 
        source.Task.GetAwaiter();
}

That is it.


All the code can be found on GitHub


...


Developers who familiar with functional programing languages might notice that the approach described above resembles “Monad” design pattern. It is no surprise since C# query notation is a kind of equivalent of “do” notation in Haskell which, in turn, is a “syntax sugar” for working with monads. If you are not familiar what that design pattern yet then, I hope, this demonstration will encourage you to get familiar with monads and functional programming.

Tags:
Hubs:
+11
Comments 4
Comments Comments 4

Articles