// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Buffers.Binary; using System.Diagnostics.CodeAnalysis; using System.Text; using System.Threading; using System.Threading.Tasks; using EpicGames.Core; using EpicGames.Horde.Compute.Buffers; using Microsoft.Extensions.Logging; namespace EpicGames.Horde.Compute { /// /// Implementation of a compute channel /// public sealed class AgentMessageChannel : IDisposable { // Length of a message header. Consists of a 1 byte type field, followed by 4 byte length field. const int HeaderLength = 5; /// /// Allows creating new messages in rented memory /// class MessageBuilder : IAgentMessageBuilder { readonly AgentMessageChannel _channel; readonly ComputeBufferWriter _sendBufferWriter; readonly AgentMessageType _type; int _length; /// public int Length => _length; public MessageBuilder(AgentMessageChannel channel, ComputeBufferWriter sendBufferWriter, AgentMessageType type) { _channel = channel; _sendBufferWriter = sendBufferWriter; _type = type; _length = 0; } public void Dispose() { if (_channel._currentBuilder == this) { _channel._currentBuilder = null; } } public void Send() { Span header = _sendBufferWriter.GetWriteBuffer().Span; header[0] = (byte)_type; BinaryPrimitives.WriteInt32LittleEndian(header.Slice(1, 4), _length); // if (_channel._logger.IsEnabled(LogLevel.Trace)) // { // _channel.LogMessageInfo("SEND", _type, _sendBufferWriter.GetWriteBuffer().Slice(HeaderLength, _length).Span); // } _sendBufferWriter.AdvanceWritePosition(HeaderLength + _length); _length = 0; } /// public void Advance(int count) => _length += count; /// public Memory GetMemory(int sizeHint = 0) => _sendBufferWriter.GetWriteBuffer().Slice(HeaderLength + _length); /// public Span GetSpan(int sizeHint = 0) => GetMemory(sizeHint).Span; } readonly int _channelId; readonly ComputeProtocol _protocol; readonly ComputeBufferReader _recvBufferReader; readonly ComputeBufferWriter _sendBufferWriter; // Can lock chunked memory writer to acquire pointer readonly ILogger _logger; MessageBuilder? _currentBuilder; /// /// The negotiated compute protocol version number /// public ComputeProtocol Protocol => _protocol; /// /// Constructor /// /// /// Protocol version number /// /// /// Logger for diagnostic output public AgentMessageChannel(int channelId, ComputeProtocol protocol, ComputeBufferReader recvBufferReader, ComputeBufferWriter sendBufferWriter, ILogger logger) { _channelId = channelId; _protocol = protocol; _recvBufferReader = recvBufferReader.AddRef(); _sendBufferWriter = sendBufferWriter.AddRef(); _logger = logger; } /// /// Constructor /// /// /// /// /// /// Logger for diagnostic output public AgentMessageChannel(ComputeSocket socket, int channelId, ComputeBuffer recvBuffer, ComputeBuffer sendBuffer, ILogger logger) { socket.AttachRecvBuffer(channelId, recvBuffer); socket.AttachSendBuffer(channelId, sendBuffer); _channelId = channelId; _protocol = socket.Protocol; _recvBufferReader = recvBuffer.CreateReader(); _sendBufferWriter = sendBuffer.CreateWriter(); _logger = logger; } /// /// Overridable dispose method /// public void Dispose() { _currentBuilder?.Dispose(); _sendBufferWriter.Dispose(); _recvBufferReader.Dispose(); } /// /// Mark the send buffer as complete /// public void MarkComplete() { _sendBufferWriter.MarkComplete(); } /// public async ValueTask ReceiveAsync(CancellationToken cancellationToken) { while (!_recvBufferReader.IsComplete) { ReadOnlyMemory memory = _recvBufferReader.GetReadBuffer(); if (memory.Length < HeaderLength) { await _recvBufferReader.WaitToReadAsync(HeaderLength, cancellationToken); continue; } int messageLength = BinaryPrimitives.ReadInt32LittleEndian(memory.Span.Slice(1, 4)); if (memory.Length < HeaderLength + messageLength) { await _recvBufferReader.WaitToReadAsync(HeaderLength + messageLength, cancellationToken); continue; } AgentMessageType type = (AgentMessageType)memory.Span[0]; AgentMessage message = new AgentMessage(type, memory.Slice(HeaderLength, messageLength)); // if (_logger.IsEnabled(LogLevel.Trace)) // { // LogMessageInfo("RECV", message.Type, message.Data.Span); // } _recvBufferReader.AdvanceReadPosition(HeaderLength + messageLength); return message; } return new AgentMessage(AgentMessageType.None, ReadOnlyMemory.Empty); } [SuppressMessage("CodeQuality", "IDE0051:Remove unused private members", Justification = "Log calls disabled")] void LogMessageInfo(string verb, AgentMessageType type, ReadOnlySpan data) { StringBuilder bytes = new StringBuilder(); for (int offset = 0; offset < 16 && offset < data.Length; offset++) { bytes.Append($" {data[offset]:X2}"); } if (data.Length > 16) { bytes.Append(".."); } _logger.LogTrace("{Verb} {ChannelId} {Type,-18} [{Length,10:n0}] = {Bytes}", verb, _channelId, type, data.Length, bytes.ToString()); } /// public async ValueTask CreateMessageAsync(AgentMessageType type, int maxSize, CancellationToken cancellationToken) { if (_currentBuilder != null) { throw new InvalidOperationException("Only one writer can be active at a time. Dispose of the previous writer first."); } await _sendBufferWriter.WaitToWriteAsync(maxSize, cancellationToken); _currentBuilder = new MessageBuilder(this, _sendBufferWriter, type); return _currentBuilder; } } /// /// Extension methods to allow creating channels from leases /// public static class AgentMessageChannelExtensions { /// /// Creates a message channel with the given identifier /// /// Socket to create a channel for /// Identifier for the channel public static AgentMessageChannel CreateAgentMessageChannel(this ComputeSocket socket, int channelId) => socket.CreateAgentMessageChannel(channelId, 65536); /// /// Creates a message channel with the given identifier /// /// Socket to create a channel for /// Identifier for the channel /// Size of the send and receive buffer public static AgentMessageChannel CreateAgentMessageChannel(this ComputeSocket socket, int channelId, int bufferSize) => socket.CreateAgentMessageChannel(channelId, bufferSize, bufferSize); /// /// Creates a message channel with the given identifier /// /// Socket to create a channel for /// Identifier for the channel /// Size of the send buffer /// Size of the recieve buffer public static AgentMessageChannel CreateAgentMessageChannel(this ComputeSocket socket, int channelId, int sendBufferSize, int recvBufferSize) { using ComputeBuffer sendBuffer = new PooledBuffer(sendBufferSize); using ComputeBuffer recvBuffer = new PooledBuffer(recvBufferSize); return new AgentMessageChannel(socket, channelId, sendBuffer, recvBuffer, socket.Logger); } /// /// Reads a message from the channel /// /// Channel to receive on /// Expected type of the message /// Cancellation token for the operation /// Data for a message that was read. Must be disposed. public static async ValueTask ReceiveAsync(this AgentMessageChannel channel, AgentMessageType type, CancellationToken cancellationToken = default) { AgentMessage message = await channel.ReceiveAsync(cancellationToken); message.ThrowIfUnexpectedType(type); return message; } /// /// Creates a new builder for a message /// /// Channel to send on /// Type of the message /// Cancellation token for the operation /// New builder for messages public static ValueTask CreateMessageAsync(this AgentMessageChannel channel, AgentMessageType type, CancellationToken cancellationToken) { return channel.CreateMessageAsync(type, 1024, cancellationToken); } /// /// Forwards an existing message across a channel /// /// Channel to send on /// The message to be sent /// Cancellation token for the operation public static async ValueTask SendAsync(this AgentMessageChannel channel, AgentMessage message, CancellationToken cancellationToken) { using (IAgentMessageBuilder builder = await channel.CreateMessageAsync(message.Type, message.Data.Length, cancellationToken)) { builder.WriteFixedLengthBytes(message.Data.Span); builder.Send(); } } } }