// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Buffers; using System.Buffers.Binary; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using EpicGames.Core; using EpicGames.Horde.Compute.Buffers; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace EpicGames.Horde.Compute { /// /// Socket for sending and reciving data using a "push" model. The application can attach multiple writers to accept received data. /// public abstract class ComputeSocket { /// /// The current protocol number /// public abstract ComputeProtocol Protocol { get; } /// /// Logger for diagnostic messages /// public abstract ILogger Logger { get; } /// /// Attaches a buffer to receive data. /// /// Channel to receive data on /// Writer for the buffer to store received data public abstract void AttachRecvBuffer(int channelId, ComputeBuffer recvBuffer); /// /// Attaches a buffer to send data. /// /// Channel to receive data on /// Reader for the buffer to send data from public abstract void AttachSendBuffer(int channelId, ComputeBuffer sendBuffer); /// /// Creates a channel using a socket and receive buffer /// /// Channel id to send and receive data public ComputeChannel CreateChannel(int channelId) { using SharedMemoryBuffer recvBuffer = SharedMemoryBuffer.CreateNew(null, 65536); using SharedMemoryBuffer sendBuffer = SharedMemoryBuffer.CreateNew(null, 65536); return CreateChannel(channelId, recvBuffer, sendBuffer); } /// /// Creates a channel using a socket and receive buffer /// /// Channel id to send and receive data /// Buffer for receiving data /// Buffer for sending data public ComputeChannel CreateChannel(int channelId, ComputeBuffer recvBuffer, ComputeBuffer sendBuffer) { AttachRecvBuffer(channelId, recvBuffer); AttachSendBuffer(channelId, sendBuffer); using ComputeBufferReader recvBufferReader = recvBuffer.CreateReader(); using ComputeBufferWriter sendBufferWriter = sendBuffer.CreateWriter(); return new ComputeChannel(recvBufferReader, sendBufferWriter); } } internal enum IpcMessage { AttachRecvBuffer = 0, AttachSendBuffer = 1, } /// /// Provides functionality for attaching buffers for compute workers /// public sealed class WorkerComputeSocket : ComputeSocket, IDisposable { /// /// Name of the environment variable for passing the name of the compute channel /// public const string IpcEnvVar = "UE_HORDE_COMPUTE_IPC"; readonly ComputeBufferWriter _commandBufferWriter; readonly List _buffers = new List(); readonly ILogger _logger; /// public override ComputeProtocol Protocol => ComputeProtocol.Unknown; /// public override ILogger Logger => _logger; /// /// Creates a socket for a worker /// private WorkerComputeSocket(ComputeBufferWriter commandBufferWriter, ILogger logger) { _commandBufferWriter = commandBufferWriter; _logger = logger; } /// /// Opens a socket which allows a worker to communicate with the Horde Agent /// public static WorkerComputeSocket Open() => Open(NullLogger.Instance); /// /// Opens a socket which allows a worker to communicate with the Horde Agent /// /// Logger for diagnostic messages public static WorkerComputeSocket Open(ILogger logger) { string? baseName = Environment.GetEnvironmentVariable(IpcEnvVar); if (baseName == null) { throw new InvalidOperationException($"Environment variable {IpcEnvVar} is not defined; cannot connect as worker."); } return Open(baseName, logger); } /// /// Opens a socket which allows a worker to communicate with the Horde Agent /// /// Name of the command buffer /// Logger for diagnostic messages public static WorkerComputeSocket Open(string commandBufferName, ILogger logger) { using SharedMemoryBuffer commandBuffer = SharedMemoryBuffer.OpenExisting(commandBufferName); return new WorkerComputeSocket(commandBuffer.CreateWriter(), logger); } /// public void Dispose() { _commandBufferWriter.Dispose(); } /// public override void AttachRecvBuffer(int channelId, ComputeBuffer buffer) { _buffers.Add(buffer.AddRef()); string bufferName = ((SharedMemoryBufferDetail)buffer._detail).Name; AttachBuffer(IpcMessage.AttachRecvBuffer, channelId, bufferName); } /// public override void AttachSendBuffer(int channelId, ComputeBuffer buffer) { _buffers.Add(buffer.AddRef()); string bufferName = ((SharedMemoryBufferDetail)buffer._detail).Name; AttachBuffer(IpcMessage.AttachSendBuffer, channelId, bufferName); } void AttachBuffer(IpcMessage message, int channelId, string bufferName) { MemoryWriter writer = new MemoryWriter(_commandBufferWriter.GetWriteBuffer()); writer.WriteUnsignedVarInt((int)message); writer.WriteUnsignedVarInt(channelId); writer.WriteString(bufferName); _commandBufferWriter.AdvanceWritePosition(writer.Length); } } /// /// Operates a server that a child process can open a to. /// public sealed class WorkerComputeSocketBridge : IAsyncDisposable { readonly SharedMemoryBuffer _ipcBuffer; readonly ComputeBufferReader _ipcBufferReader; readonly BackgroundTask _backgroundTask; readonly ILogger _logger; /// /// Name of the buffer to pass via /// public string BufferName => _ipcBuffer.Name; /// /// Constructor /// private WorkerComputeSocketBridge(SharedMemoryBuffer ipcBuffer, ComputeBufferReader ipcBufferReader, BackgroundTask backgroundTask, ILogger logger) { _ipcBuffer = ipcBuffer; _ipcBufferReader = ipcBufferReader; _backgroundTask = backgroundTask; _logger = logger; } /// /// Creates a new server for /// /// Socket to connect to /// Logger for errors /// New server instance public static async Task CreateAsync(ComputeSocket socket, ILogger logger) { SharedMemoryBuffer? ipcBuffer = null; ComputeBufferReader? ipcBufferReader = null; BackgroundTask? backgroundTask = null; try { ipcBuffer = SharedMemoryBuffer.CreateNew(null, 1, 64 * 1024); ipcBufferReader = ipcBuffer.CreateReader(); backgroundTask = BackgroundTask.StartNew(ctx => ProcessIpcMessagesAsync(socket, ipcBufferReader, logger, ctx)); return new WorkerComputeSocketBridge(ipcBuffer, ipcBufferReader, backgroundTask, logger); } catch { if (backgroundTask != null) { await backgroundTask.DisposeAsync(); } ipcBufferReader?.Dispose(); ipcBuffer?.Dispose(); throw; } } /// public async ValueTask DisposeAsync() { await _backgroundTask.DisposeAsync(); _logger.LogDebug("Ipc message loop is complete"); _ipcBufferReader.Dispose(); _logger.LogDebug("Closed IPC reader"); _ipcBuffer.Dispose(); _logger.LogDebug("Closed IPC buffer"); } static async Task ProcessIpcMessagesAsync(ComputeSocket socket, ComputeBufferReader ipcReader, ILogger logger, CancellationToken cancellationToken) { List buffers = new(); try { List<(int, ComputeBufferWriter)> writers = new List<(int, ComputeBufferWriter)>(); while (await ipcReader.WaitToReadAsync(1, cancellationToken)) { ReadOnlyMemory memory = ipcReader.GetReadBuffer(); MemoryReader reader = new MemoryReader(memory); IpcMessage message = (IpcMessage)reader.ReadUnsignedVarInt(); try { switch (message) { case IpcMessage.AttachSendBuffer: { int channelId = (int)reader.ReadUnsignedVarInt(); string name = reader.ReadString(); logger.LogDebug("Attaching send buffer for channel {ChannelId} to {Name}", channelId, name); SharedMemoryBuffer buffer = SharedMemoryBuffer.OpenExisting(name); buffers.Add(buffer); socket.AttachSendBuffer(channelId, buffer); } break; case IpcMessage.AttachRecvBuffer: { int channelId = (int)reader.ReadUnsignedVarInt(); string name = reader.ReadString(); logger.LogDebug("Attaching recv buffer for channel {ChannelId} to {Name}", channelId, name); SharedMemoryBuffer buffer = SharedMemoryBuffer.OpenExisting(name); buffers.Add(buffer); socket.AttachRecvBuffer(channelId, buffer); } break; default: throw new InvalidOperationException($"Invalid IPC message: {message}"); } } catch (Exception ex) { logger.LogError(ex, "Exception while processing messages from child process: {Message}", ex.Message); } ipcReader.AdvanceReadPosition(memory.Length - reader.RemainingMemory.Length); } } catch (OperationCanceledException) { logger.LogDebug("Ipc message loop cancelled"); } finally { foreach (SharedMemoryBuffer buffer in buffers) { buffer.Dispose(); } } } } /// /// Manages a set of readers and writers to buffers across a transport layer /// public class RemoteComputeSocket : ComputeSocket, IAsyncDisposable { enum ControlMessageType { Detach = -2, KeepAlive = -3, } class RecvBuffer : IDisposable { public ComputeBufferWriter? _writer; public readonly SemaphoreSlim Semaphore = new SemaphoreSlim(1); public int _refCount = 1; public RecvBuffer(ComputeBufferWriter writer) => _writer = writer; public void AddRef() => Interlocked.Increment(ref _refCount); public void Release() { if (Interlocked.Decrement(ref _refCount) == 0) { Dispose(); } } public void Dispose() { _writer?.Dispose(); Semaphore.Dispose(); } } class SendBuffer : IDisposable { public ComputeBufferReader? _reader; public readonly SemaphoreSlim Semaphore = new SemaphoreSlim(1); [SuppressMessage("Usage", "CA2213:Disposable fields should be disposed")] public readonly BackgroundTask Task; int _refCount = 1; public SendBuffer(ComputeBufferReader reader, Func func) { _reader = reader; Task = BackgroundTask.StartNew(ctx => func(this, ctx)); } public void AddRef() => Interlocked.Increment(ref _refCount); public void Release() { if (Interlocked.Decrement(ref _refCount) == 0) { Dispose(); } } public void Dispose() { _reader?.Dispose(); Semaphore.Dispose(); } } readonly object _lockObject = new object(); bool _complete; bool _isClosing; readonly ComputeTransport _transport; readonly ComputeProtocol _protocol; readonly ILogger _logger; BackgroundTask? _recvTask; readonly Dictionary _recvBuffers = new Dictionary(); readonly SemaphoreSlim _sendSemaphore = new SemaphoreSlim(1, 1); readonly Dictionary _sendBuffers = new Dictionary(); /// public override ComputeProtocol Protocol => _protocol; /// public override ILogger Logger => _logger; /// /// Constructor /// /// Transport to communicate with the remote /// The protocol version number /// Logger for trace output public RemoteComputeSocket(ComputeTransport transport, ComputeProtocol protocol, ILogger logger) { _transport = transport; _protocol = protocol; _logger = logger; _recvTask = new BackgroundTask(ctx => RunRecvTaskAsync(_transport, ctx)); } /// /// Attempt to gracefully close the current connection and shutdown both ends of the transport /// public async ValueTask CloseAsync(CancellationToken cancellationToken) { lock (_lockObject) { if (_isClosing) { return; } _isClosing = true; } try { // Close the transport layer, freeing the remote end to shutdown. await _transport.MarkCompleteAsync(cancellationToken); // Close all the buffers await DetachAllBuffersAsync(true, true, cancellationToken); // Wait for the reader to stop if (_recvTask != null) { await _recvTask.DisposeAsync(); _recvTask = null; } } catch (SocketException se) when (se.SocketErrorCode is SocketError.ConnectionReset or SocketError.ConnectionAborted) { _logger.LogInformation("Socket already closed. Ignoring exception"); } } /// public async ValueTask DisposeAsync() { await CloseAsync(CancellationToken.None); _sendSemaphore.Dispose(); GC.SuppressFinalize(this); } /// /// Sends a keep alive message to the remote machine. Does not wait for a response. Designed to keep a connection open when the remote is eagerly trying to close it. /// /// Cancellation token for the operation public async Task SendKeepAliveMessageAsync(CancellationToken cancellationToken) { _logger.LogDebug("Sending ping message"); await SendInternalAsync(0, (int)ControlMessageType.KeepAlive, ReadOnlyMemory.Empty, cancellationToken); } async Task RunRecvTaskAsync(ComputeTransport transport, CancellationToken cancellationToken) { try { _logger.LogDebug("Started socket reader"); List detachTasks = new List(); byte[] header = new byte[8]; try { Memory last = Memory.Empty; // Process messages from the remote for (; ; ) { detachTasks.RemoveCompleteTasks(); // Read the next packet header if (!await transport.RecvOptionalAsync(header, cancellationToken)) { _logger.LogDebug("End of socket"); break; } // Parse the target buffer and packet size int id = BinaryPrimitives.ReadInt32LittleEndian(header); int size = BinaryPrimitives.ReadInt32LittleEndian(header.AsSpan(4)); // Dispatch it to the correct place if (size >= 0) { await ReadPacketAsync(transport, id, size, cancellationToken); } else if (size == (int)ControlMessageType.Detach) { detachTasks.Add(DetachRecvBufferAsync(id, cancellationToken)); } else if (size == (int)ControlMessageType.KeepAlive) { _logger.LogDebug("Received ping message"); } else { _logger.LogDebug("Unrecognized control message: {Message}", size); } } } catch (OperationCanceledException) { } // Mark all buffers as complete lock (_lockObject) { _complete = true; foreach (int channelIdx in _recvBuffers.Keys) { detachTasks.Add(DetachRecvBufferAsync(channelIdx, cancellationToken)); } } // Wait for all the detach tasks to finish if (detachTasks.Count > 0) { _logger.LogDebug("Waiting for detach tasks to complete..."); await Task.WhenAll(detachTasks).WaitAsync(cancellationToken); } _logger.LogDebug("Closing reader"); } catch (Exception e) { _logger.LogInformation(e, "Exception in background receive"); throw; } } async Task ReadPacketAsync(ComputeTransport transport, int id, int size, CancellationToken cancellationToken) { if (!await TryReadPacketAsync(transport, id, size, cancellationToken)) { _logger.LogInformation("Discarding {Size} bytes received on compute channel {Id}; no buffer attached?", size, id); int bufferSize = Math.Min(size, 65536); using (IMemoryOwner buffer = MemoryPool.Shared.Rent(bufferSize)) { for (int remaining = size; remaining > 0;) { int chunkSize = Math.Min(bufferSize, remaining); await transport.RecvAsync(buffer.Memory.Slice(0, chunkSize), cancellationToken); remaining -= chunkSize; } } } } async ValueTask TryReadPacketAsync(ComputeTransport transport, int id, int size, CancellationToken cancellationToken) { // Try to get the receive buffer for this channel RecvBuffer? recvBuffer; lock (_lockObject) { if (_recvBuffers.TryGetValue(id, out recvBuffer)) { recvBuffer.AddRef(); } else { return false; } } // Lock it and read a packet bool result = false; try { await recvBuffer.Semaphore.WaitAsync(cancellationToken); try { ComputeBufferWriter? writer = recvBuffer._writer; if (writer != null) { Memory memory = writer.GetWriteBuffer(); while (memory.Length < size) { _logger.LogDebug("No space in buffer {Id}, flushing", id); await writer.WaitToWriteAsync(size, cancellationToken); memory = writer.GetWriteBuffer(); } for (int offset = 0; offset < size;) { int read = await transport.RecvAsync(memory.Slice(offset, size - offset), cancellationToken); if (read == 0) { // Return true to avoid logic in ReadPacketAsync() from trying to read the whole message from a closed stream _logger.LogDebug("Unexpected end of stream while parsing message for channel {Id}; discarding message.", id); return true; } offset += read; } writer.AdvanceWritePosition(size); result = true; } } finally { recvBuffer.Semaphore.Release(); } } finally { recvBuffer.Release(); } return result; } class SendSegment : ReadOnlySequenceSegment { public void Set(ReadOnlyMemory memory, ReadOnlySequenceSegment? next, long runningIndex) { Memory = memory; Next = next; RunningIndex = runningIndex; } } readonly byte[] _header = new byte[8]; readonly SendSegment _headerSegment = new SendSegment(); readonly SendSegment _bodySegment = new SendSegment(); /// async ValueTask SendAsync(int id, ReadOnlyMemory memory, CancellationToken cancellationToken = default) { if (memory.Length > 0) { await SendInternalAsync(id, memory.Length, memory, cancellationToken); } } /// public ValueTask MarkCompleteAsync(int id, CancellationToken cancellationToken = default) => SendInternalAsync(id, (int)ControlMessageType.Detach, ReadOnlyMemory.Empty, cancellationToken); async ValueTask SendInternalAsync(int id, int size, ReadOnlyMemory memory, CancellationToken cancellationToken) { await _sendSemaphore.WaitAsync(cancellationToken); try { BinaryPrimitives.WriteInt32LittleEndian(_header, id); BinaryPrimitives.WriteInt32LittleEndian(_header.AsSpan(4), size); _headerSegment.Set(_header, _bodySegment, 0); _bodySegment.Set(memory, null, _header.Length); ReadOnlySequence sequence = new ReadOnlySequence(_headerSegment, 0, _bodySegment, memory.Length); await _transport.SendAsync(sequence, cancellationToken); } finally { _sendSemaphore.Release(); } } /// public override void AttachRecvBuffer(int channelId, ComputeBuffer recvBuffer) { _logger.LogDebug("Attaching recv buffer {Id}", channelId); lock (_lockObject) { if (_complete) { throw new InvalidOperationException($"Cannot attach new buffer to channel {channelId} after socket is closed"); } if (_recvBuffers.ContainsKey(channelId)) { throw new InvalidOperationException($"Buffer is already attached to channel {channelId}"); } _recvBuffers.Add(channelId, new RecvBuffer(recvBuffer.CreateWriter())); // Only start the receive task once we have a buffer to receive data, otherwise we discard data from the remote if (_recvTask is { Task: null }) { _recvTask.Start(); } } } [SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope")] async Task DetachRecvBufferAsync(int id, CancellationToken cancellationToken) { _logger.LogDebug("Detaching recv buffer {Id}", id); // Get the current receive buffer RecvBuffer? recvBuffer; lock (_lockObject) { if (!_recvBuffers.TryGetValue(id, out recvBuffer)) { _logger.LogDebug("Buffer {Id} has already been detached", id); return; } recvBuffer.AddRef(); // Note: adding extra ref here } // Release the writer await recvBuffer.Semaphore.WaitAsync(cancellationToken); try { recvBuffer._writer?.Dispose(); recvBuffer._writer = null; } finally { recvBuffer.Semaphore.Release(); recvBuffer.Release(); // Ref added above } // Remove the buffer lock (_lockObject) { if (_recvBuffers.Remove(id, out recvBuffer)) { recvBuffer.Release(); // Original ref from _recvBuffers } } } async Task DetachAllBuffersAsync(bool recvBuffers, bool sendBuffers, CancellationToken cancellationToken) { int[] recvChannelIds; int[] sendChannelIds; lock (_lockObject) { _complete = true; recvChannelIds = _recvBuffers.Keys.ToArray(); sendChannelIds = _sendBuffers.Keys.ToArray(); } List tasks = new List(); if (recvBuffers) { foreach (int recvChannelId in recvChannelIds) { tasks.Add(DetachRecvBufferAsync(recvChannelId, cancellationToken)); } } if (sendBuffers) { foreach (int sendChannelId in sendChannelIds) { tasks.Add(DetachSendBufferAsync(sendChannelId, cancellationToken)); } } await Task.WhenAll(tasks).WaitAsync(cancellationToken); } /// public override void AttachSendBuffer(int channelId, ComputeBuffer sendBuffer) { _logger.LogDebug("Attaching send buffer {Id}", channelId); lock (_lockObject) { if (_sendBuffers.ContainsKey(channelId)) { throw new InvalidOperationException($"Buffer is already attached to channel {channelId}"); } ComputeBufferReader sendBufferReader = sendBuffer.CreateReader(); SendBuffer sendBufferInfo = new SendBuffer(sendBufferReader, (buffer, ctx) => SendFromBufferAsync(channelId, buffer, ctx)); _sendBuffers.Add(channelId, sendBufferInfo); } } [SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope")] async Task DetachSendBufferAsync(int channelId, CancellationToken cancellationToken) { SendBuffer? sendBuffer = null; try { // Get the current send buffer state lock (_lockObject) { if (!_sendBuffers.TryGetValue(channelId, out sendBuffer)) { _logger.LogWarning("No buffer is attached to channel {ChannelId}", channelId); return; } sendBuffer.AddRef(); } // Try to stop the send task TimeSpan timeout = TimeSpan.FromSeconds(10); try { await sendBuffer.Task.StopAsync(cancellationToken).WaitAsync(timeout, cancellationToken); } catch (OperationCanceledException) { throw; } catch (TimeoutException ex) { _logger.LogInformation(ex, "Send task did not terminate gracefully in {Timeout} ms. Tearing down socket buffers...", timeout.TotalMilliseconds); } catch (Exception ex) { _logger.LogError(ex, "Exception while stopping send buffer task: {Message}", ex.Message); } // Release the reader await sendBuffer.Semaphore.WaitAsync(cancellationToken); try { sendBuffer._reader?.Dispose(); sendBuffer._reader = null; } finally { sendBuffer.Semaphore.Release(); } } finally { sendBuffer?.Release(); // Added above } // Force the send task to complete try { await sendBuffer.Task.DisposeAsync(); } catch (Exception ex) { _logger.LogError(ex, "Exception disposing send buffer task: {Message}", ex.Message); } // Remove the buffer from the dictionary lock (_lockObject) { if (_sendBuffers.Remove(channelId, out sendBuffer)) { sendBuffer.Release(); // For _sendBuffers } } } async Task SendFromBufferAsync(int channelId, SendBuffer sendBuffer, CancellationToken cancellationToken) { try { ComputeBufferReader reader = sendBuffer._reader!; while (!cancellationToken.IsCancellationRequested) { ReadOnlyMemory memory = reader.GetReadBuffer(); if (memory.Length > 0 && !_isClosing) { await SendAsync(channelId, memory, cancellationToken); reader.AdvanceReadPosition(memory.Length); } if (reader.IsComplete) { await MarkCompleteAsync(channelId, cancellationToken); break; } await reader.WaitToReadAsync(1, cancellationToken); } } catch (OperationCanceledException ex) { _logger.LogDebug(ex, "Background send task cancelled for channel {ChannelId}", channelId); } catch (Exception ex) { if (!_isClosing) { _logger.LogInformation(ex, "Exception in background send: {Message}", ex.Message); } } } } }