I'm trying to change Stephen Toub's ForEachAsync<T> extension method into an extension which returns a result...
Stephen's extension:
public static Task ForEachAsync<T>(this IEnumerable<T> source, int dop, Func<T, Task> body)
{
return Task.WhenAll(
from partition in Partitioner.Create(source).GetPartitions(dop)
select Task.Run(async delegate {
using (partition)
while (partition.MoveNext())
await body(partition.Current);
}));
}
My approach (not working; tasks get executed but result is wrong)
public static Task<TResult[]> ForEachAsync<T, TResult>(this IList<T> source,
int degreeOfParallelism, Func<T, Task<TResult>> body)
{
return Task.WhenAll<TResult>(
from partition in Partitioner.Create(source).GetPartitions(degreeOfParallelism)
select Task.Run<TResult>(async () =
{
using (partition)
while (partition.MoveNext())
await body(partition.Current); // When I "return await",
// I get good results but only one per partition
return default(TResult);
}));
}
I know I somehow have to return (WhenAll?) the results from the last part but I didn't yet figure out how to do it...
Update: The result I get is just degreeOfParallelism times null (I guess because of default(TResult)) even though all the tasks get executed. I also tried to return await body(...) and then the result was fine, but only degreeOfParallelism number of tasks got executed.
Now that the Parallel.ForEachAsync API has become part of the standard libraries (.NET 6), it makes sense to implement a variant that returns a Task<TResult[]>, based on this API. Here is an implementation that targets .NET 8:
/// <summary>
/// Executes a foreach loop on an enumerable sequence, in which iterations may run
/// in parallel, and returns the results of all iterations in the original order.
/// </summary>
public static Task<TResult[]> ForEachAsync<TSource, TResult>(
IEnumerable<TSource> source,
ParallelOptions parallelOptions,
Func<TSource, CancellationToken, ValueTask<TResult>> body)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentNullException.ThrowIfNull(parallelOptions);
ArgumentNullException.ThrowIfNull(body);
List<TResult> results = new();
if (source.TryGetNonEnumeratedCount(out int count)) results.Capacity = count;
IEnumerable<(TSource, int)> withIndexes = source.Select((x, i) => (x, i));
return Parallel.ForEachAsync(withIndexes, parallelOptions, async (entry, ct) =>
{
(TSource item, int index) = entry;
TResult result = await body(item, ct).ConfigureAwait(false);
lock (results)
{
if (index >= results.Count)
CollectionsMarshal.SetCount(results, index + 1);
results[index] = result;
}
}).ContinueWith(t =>
{
TaskCompletionSource<TResult[]> tcs = new();
switch (t.Status)
{
case TaskStatus.RanToCompletion:
lock (results) tcs.SetResult(results.ToArray()); break;
case TaskStatus.Faulted:
tcs.SetException(t.Exception.InnerExceptions); break;
case TaskStatus.Canceled:
tcs.SetCanceled(new TaskCanceledException(t).CancellationToken); break;
default: throw new UnreachableException();
}
Debug.Assert(tcs.Task.IsCompleted);
return tcs.Task;
}, default, TaskContinuationOptions.DenyChildAttach |
TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default).Unwrap();
}
This implementation supports all the options and the functionality of the Parallel.ForEachAsync overload that has an IEnumerable<T> as source. Its behavior in case of errors and cancellation is identical. The results are arranged in the same order as the associated elements in the source sequence.
The CollectionsMarshal.SetCount is an advanced API that was introduced in .NET 8. It alters the Count of a List<T>, exposing uninitialized data when increased. For a less modern (and slightly less performant) approach that runs on .NET 6, see the 5th revision of this answer.
Your LINQ query can only ever have the same number of results as the number of partitions - you're just projecting each partition into a single result.
If you don't care about the order, you just need to assemble the results of each partition into a list, then flatten them afterwards.
public static async Task<TResult[]> ExecuteInParallel<T, TResult>(this IList<T> source, int degreeOfParalleslism, Func<T, Task<TResult>> body)
{
var lists = await Task.WhenAll<List<TResult>>(
Partitioner.Create(source).GetPartitions(degreeOfParalleslism)
.Select(partition => Task.Run<List<TResult>>(async () =>
{
var list = new List<TResult>();
using (partition)
{
while (partition.MoveNext())
{
list.Add(await body(partition.Current));
}
}
return list;
})));
return lists.SelectMany(list => list).ToArray();
}
(I've renamed this from ForEachAsync, as ForEach sounds imperative (suitable for the Func<T, Task> in the original) whereas this is fetching results. A foreach loop doesn't have a result - this does.)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With