// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Collections.Generic;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace EpicGames.Core
{
///
/// Manages a set of async tasks forming a pipeline, which can be cancelled and awaited as a complete unit. The first exception thrown by an individual task is captured and reported back to the main thread.
///
public sealed class AsyncPipeline : IAsyncDisposable
{
readonly List _tasks;
readonly CancellationToken _cancellationToken;
CancellationTokenSource _cancellationSource;
ExceptionDispatchInfo? _exceptionDispatchInfo;
///
/// Tests whether the pipeline has failed
///
public bool IsFaulted => _exceptionDispatchInfo != null;
///
/// Constructor
///
/// Cancellation token for the pipeline
public AsyncPipeline(CancellationToken cancellationToken)
{
_tasks = [];
_cancellationToken = cancellationToken;
_cancellationSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
}
///
public async ValueTask DisposeAsync()
{
if (_cancellationSource != null)
{
await _cancellationSource.CancelAsync();
}
if (_tasks.Count > 0)
{
try
{
await Task.WhenAll(_tasks);
}
catch (OperationCanceledException)
{
// Ignore during dispose
}
_tasks.Clear();
}
if(_cancellationSource != null)
{
_cancellationSource.Dispose();
_cancellationSource = null!;
}
}
///
/// Adds a new task to the pipeline.
///
/// Method to execute
public Task AddTask(Func taskFunc)
{
Task task = Task.Run(() => RunGuardedAsync(taskFunc), _cancellationSource.Token);
_tasks.Add(task);
return task;
}
///
/// Adds several tasks to the pipeline.
///
/// Number of tasks to run
/// Method to execute
public Task[] AddTasks(int count, Func taskFunc)
{
Task[] tasks = new Task[count];
for (int idx = 0; idx < count; idx++)
{
tasks[idx] = AddTask(taskFunc);
}
return tasks;
}
async Task RunGuardedAsync(Func taskFunc)
{
try
{
await taskFunc(_cancellationSource.Token);
}
catch (OperationCanceledException)
{
// Ignore
}
catch (Exception ex)
{
if (_exceptionDispatchInfo == null)
{
ExceptionDispatchInfo dispatchInfo = ExceptionDispatchInfo.Capture(ex);
Interlocked.CompareExchange(ref _exceptionDispatchInfo, dispatchInfo, null);
}
await _cancellationSource.CancelAsync();
}
}
///
/// Waits for all tasks to complete, and throws any exceptions
///
public async Task WaitForCompletionAsync()
{
await Task.WhenAll(_tasks);
_cancellationToken.ThrowIfCancellationRequested();
_exceptionDispatchInfo?.Throw();
}
}
///
/// Extension methods for async pipelines
///
public static class AsyncPipelineExtensions
{
///
/// Adds a worker to process items from a channel
///
/// Item type
/// Pipeline to add the worker to
/// Source for the items
/// Action to execute for each item
public static Task AddTask(this AsyncPipeline pipeline, ChannelReader reader, Func taskFunc)
=> pipeline.AddTask(ctx => ProcessItemsAsync(reader, taskFunc, ctx));
///
/// Adds a worker to process items from a channel
///
/// Input item type
/// Output item type
/// Pipeline to add the worker to
/// Reader for input items
/// Writer for output items
/// Action to execute for each item
public static Task AddTask(this AsyncPipeline pipeline, ChannelReader reader, ChannelWriter writer, Func> taskFunc)
=> pipeline.AddTask(ctx => ProcessItemsAsync(reader, writer, taskFunc, ctx));
///
/// Adds a worker to process items from a channel
///
/// Input item type
/// Pipeline to add the worker to
/// Number of workers to add
/// Reader for input items
/// Action to execute for each item
public static Task[] AddTasks(this AsyncPipeline pipeline, int count, ChannelReader reader, Func taskFunc)
{
Task[] tasks = new Task[count];
for (int idx = 0; idx < count; idx++)
{
tasks[idx] = AddTask(pipeline, reader, taskFunc);
}
return tasks;
}
///
/// Adds a worker to process items from a channel
///
/// Input item type
/// Output item type
/// Pipeline to add the worker to
/// Number of workers to add
/// Reader for input items
/// Writer for output items
/// Action to execute for each item
public static Task[] AddTasks(this AsyncPipeline pipeline, int count, ChannelReader reader, ChannelWriter writer, Func> taskFunc)
{
Task[] tasks = new Task[count];
for (int idx = 0; idx < count; idx++)
{
tasks[idx] = AddTask(pipeline, reader, writer, taskFunc);
}
return tasks;
}
static async Task ProcessItemsAsync(ChannelReader reader, Func taskFunc, CancellationToken cancellationToken)
{
while (await reader.WaitToReadAsync(cancellationToken))
{
T? item;
if (reader.TryRead(out item))
{
await taskFunc(item, cancellationToken);
}
}
}
static async Task ProcessItemsAsync(ChannelReader reader, ChannelWriter writer, Func> taskFunc, CancellationToken cancellationToken)
{
while (await reader.WaitToReadAsync(cancellationToken))
{
TInput? input;
if (reader.TryRead(out input))
{
TOutput output = await taskFunc(input, cancellationToken);
await writer.WriteAsync(output, cancellationToken);
}
}
}
}
}