// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Collections.Generic; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; namespace EpicGames.Core { /// /// Manages a set of async tasks forming a pipeline, which can be cancelled and awaited as a complete unit. The first exception thrown by an individual task is captured and reported back to the main thread. /// public sealed class AsyncPipeline : IAsyncDisposable { readonly List _tasks; readonly CancellationToken _cancellationToken; CancellationTokenSource _cancellationSource; ExceptionDispatchInfo? _exceptionDispatchInfo; /// /// Tests whether the pipeline has failed /// public bool IsFaulted => _exceptionDispatchInfo != null; /// /// Constructor /// /// Cancellation token for the pipeline public AsyncPipeline(CancellationToken cancellationToken) { _tasks = []; _cancellationToken = cancellationToken; _cancellationSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); } /// public async ValueTask DisposeAsync() { if (_cancellationSource != null) { await _cancellationSource.CancelAsync(); } if (_tasks.Count > 0) { try { await Task.WhenAll(_tasks); } catch (OperationCanceledException) { // Ignore during dispose } _tasks.Clear(); } if(_cancellationSource != null) { _cancellationSource.Dispose(); _cancellationSource = null!; } } /// /// Adds a new task to the pipeline. /// /// Method to execute public Task AddTask(Func taskFunc) { Task task = Task.Run(() => RunGuardedAsync(taskFunc), _cancellationSource.Token); _tasks.Add(task); return task; } /// /// Adds several tasks to the pipeline. /// /// Number of tasks to run /// Method to execute public Task[] AddTasks(int count, Func taskFunc) { Task[] tasks = new Task[count]; for (int idx = 0; idx < count; idx++) { tasks[idx] = AddTask(taskFunc); } return tasks; } async Task RunGuardedAsync(Func taskFunc) { try { await taskFunc(_cancellationSource.Token); } catch (OperationCanceledException) { // Ignore } catch (Exception ex) { if (_exceptionDispatchInfo == null) { ExceptionDispatchInfo dispatchInfo = ExceptionDispatchInfo.Capture(ex); Interlocked.CompareExchange(ref _exceptionDispatchInfo, dispatchInfo, null); } await _cancellationSource.CancelAsync(); } } /// /// Waits for all tasks to complete, and throws any exceptions /// public async Task WaitForCompletionAsync() { await Task.WhenAll(_tasks); _cancellationToken.ThrowIfCancellationRequested(); _exceptionDispatchInfo?.Throw(); } } /// /// Extension methods for async pipelines /// public static class AsyncPipelineExtensions { /// /// Adds a worker to process items from a channel /// /// Item type /// Pipeline to add the worker to /// Source for the items /// Action to execute for each item public static Task AddTask(this AsyncPipeline pipeline, ChannelReader reader, Func taskFunc) => pipeline.AddTask(ctx => ProcessItemsAsync(reader, taskFunc, ctx)); /// /// Adds a worker to process items from a channel /// /// Input item type /// Output item type /// Pipeline to add the worker to /// Reader for input items /// Writer for output items /// Action to execute for each item public static Task AddTask(this AsyncPipeline pipeline, ChannelReader reader, ChannelWriter writer, Func> taskFunc) => pipeline.AddTask(ctx => ProcessItemsAsync(reader, writer, taskFunc, ctx)); /// /// Adds a worker to process items from a channel /// /// Input item type /// Pipeline to add the worker to /// Number of workers to add /// Reader for input items /// Action to execute for each item public static Task[] AddTasks(this AsyncPipeline pipeline, int count, ChannelReader reader, Func taskFunc) { Task[] tasks = new Task[count]; for (int idx = 0; idx < count; idx++) { tasks[idx] = AddTask(pipeline, reader, taskFunc); } return tasks; } /// /// Adds a worker to process items from a channel /// /// Input item type /// Output item type /// Pipeline to add the worker to /// Number of workers to add /// Reader for input items /// Writer for output items /// Action to execute for each item public static Task[] AddTasks(this AsyncPipeline pipeline, int count, ChannelReader reader, ChannelWriter writer, Func> taskFunc) { Task[] tasks = new Task[count]; for (int idx = 0; idx < count; idx++) { tasks[idx] = AddTask(pipeline, reader, writer, taskFunc); } return tasks; } static async Task ProcessItemsAsync(ChannelReader reader, Func taskFunc, CancellationToken cancellationToken) { while (await reader.WaitToReadAsync(cancellationToken)) { T? item; if (reader.TryRead(out item)) { await taskFunc(item, cancellationToken); } } } static async Task ProcessItemsAsync(ChannelReader reader, ChannelWriter writer, Func> taskFunc, CancellationToken cancellationToken) { while (await reader.WaitToReadAsync(cancellationToken)) { TInput? input; if (reader.TryRead(out input)) { TOutput output = await taskFunc(input, cancellationToken); await writer.WriteAsync(output, cancellationToken); } } } } }