// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Buffers.Binary;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using EpicGames.Core;
using EpicGames.Horde.Compute.Buffers;
using Microsoft.Extensions.Logging;
namespace EpicGames.Horde.Compute
{
///
/// Implementation of a compute channel
///
public sealed class AgentMessageChannel : IDisposable
{
// Length of a message header. Consists of a 1 byte type field, followed by 4 byte length field.
const int HeaderLength = 5;
///
/// Allows creating new messages in rented memory
///
class MessageBuilder : IAgentMessageBuilder
{
readonly AgentMessageChannel _channel;
readonly ComputeBufferWriter _sendBufferWriter;
readonly AgentMessageType _type;
int _length;
///
public int Length => _length;
public MessageBuilder(AgentMessageChannel channel, ComputeBufferWriter sendBufferWriter, AgentMessageType type)
{
_channel = channel;
_sendBufferWriter = sendBufferWriter;
_type = type;
_length = 0;
}
public void Dispose()
{
if (_channel._currentBuilder == this)
{
_channel._currentBuilder = null;
}
}
public void Send()
{
Span header = _sendBufferWriter.GetWriteBuffer().Span;
header[0] = (byte)_type;
BinaryPrimitives.WriteInt32LittleEndian(header.Slice(1, 4), _length);
// if (_channel._logger.IsEnabled(LogLevel.Trace))
// {
// _channel.LogMessageInfo("SEND", _type, _sendBufferWriter.GetWriteBuffer().Slice(HeaderLength, _length).Span);
// }
_sendBufferWriter.AdvanceWritePosition(HeaderLength + _length);
_length = 0;
}
///
public void Advance(int count) => _length += count;
///
public Memory GetMemory(int sizeHint = 0) => _sendBufferWriter.GetWriteBuffer().Slice(HeaderLength + _length);
///
public Span GetSpan(int sizeHint = 0) => GetMemory(sizeHint).Span;
}
readonly int _channelId;
readonly ComputeProtocol _protocol;
readonly ComputeBufferReader _recvBufferReader;
readonly ComputeBufferWriter _sendBufferWriter;
// Can lock chunked memory writer to acquire pointer
readonly ILogger _logger;
MessageBuilder? _currentBuilder;
///
/// The negotiated compute protocol version number
///
public ComputeProtocol Protocol => _protocol;
///
/// Constructor
///
///
/// Protocol version number
///
///
/// Logger for diagnostic output
public AgentMessageChannel(int channelId, ComputeProtocol protocol, ComputeBufferReader recvBufferReader, ComputeBufferWriter sendBufferWriter, ILogger logger)
{
_channelId = channelId;
_protocol = protocol;
_recvBufferReader = recvBufferReader.AddRef();
_sendBufferWriter = sendBufferWriter.AddRef();
_logger = logger;
}
///
/// Constructor
///
///
///
///
///
/// Logger for diagnostic output
public AgentMessageChannel(ComputeSocket socket, int channelId, ComputeBuffer recvBuffer, ComputeBuffer sendBuffer, ILogger logger)
{
socket.AttachRecvBuffer(channelId, recvBuffer);
socket.AttachSendBuffer(channelId, sendBuffer);
_channelId = channelId;
_protocol = socket.Protocol;
_recvBufferReader = recvBuffer.CreateReader();
_sendBufferWriter = sendBuffer.CreateWriter();
_logger = logger;
}
///
/// Overridable dispose method
///
public void Dispose()
{
_currentBuilder?.Dispose();
_sendBufferWriter.Dispose();
_recvBufferReader.Dispose();
}
///
/// Mark the send buffer as complete
///
public void MarkComplete()
{
_sendBufferWriter.MarkComplete();
}
///
public async ValueTask ReceiveAsync(CancellationToken cancellationToken)
{
while (!_recvBufferReader.IsComplete)
{
ReadOnlyMemory memory = _recvBufferReader.GetReadBuffer();
if (memory.Length < HeaderLength)
{
await _recvBufferReader.WaitToReadAsync(HeaderLength, cancellationToken);
continue;
}
int messageLength = BinaryPrimitives.ReadInt32LittleEndian(memory.Span.Slice(1, 4));
if (memory.Length < HeaderLength + messageLength)
{
await _recvBufferReader.WaitToReadAsync(HeaderLength + messageLength, cancellationToken);
continue;
}
AgentMessageType type = (AgentMessageType)memory.Span[0];
AgentMessage message = new AgentMessage(type, memory.Slice(HeaderLength, messageLength));
// if (_logger.IsEnabled(LogLevel.Trace))
// {
// LogMessageInfo("RECV", message.Type, message.Data.Span);
// }
_recvBufferReader.AdvanceReadPosition(HeaderLength + messageLength);
return message;
}
return new AgentMessage(AgentMessageType.None, ReadOnlyMemory.Empty);
}
[SuppressMessage("CodeQuality", "IDE0051:Remove unused private members", Justification = "Log calls disabled")]
void LogMessageInfo(string verb, AgentMessageType type, ReadOnlySpan data)
{
StringBuilder bytes = new StringBuilder();
for (int offset = 0; offset < 16 && offset < data.Length; offset++)
{
bytes.Append($" {data[offset]:X2}");
}
if (data.Length > 16)
{
bytes.Append("..");
}
_logger.LogTrace("{Verb} {ChannelId} {Type,-18} [{Length,10:n0}] = {Bytes}", verb, _channelId, type, data.Length, bytes.ToString());
}
///
public async ValueTask CreateMessageAsync(AgentMessageType type, int maxSize, CancellationToken cancellationToken)
{
if (_currentBuilder != null)
{
throw new InvalidOperationException("Only one writer can be active at a time. Dispose of the previous writer first.");
}
await _sendBufferWriter.WaitToWriteAsync(maxSize, cancellationToken);
_currentBuilder = new MessageBuilder(this, _sendBufferWriter, type);
return _currentBuilder;
}
}
///
/// Extension methods to allow creating channels from leases
///
public static class AgentMessageChannelExtensions
{
///
/// Creates a message channel with the given identifier
///
/// Socket to create a channel for
/// Identifier for the channel
public static AgentMessageChannel CreateAgentMessageChannel(this ComputeSocket socket, int channelId)
=> socket.CreateAgentMessageChannel(channelId, 65536);
///
/// Creates a message channel with the given identifier
///
/// Socket to create a channel for
/// Identifier for the channel
/// Size of the send and receive buffer
public static AgentMessageChannel CreateAgentMessageChannel(this ComputeSocket socket, int channelId, int bufferSize)
=> socket.CreateAgentMessageChannel(channelId, bufferSize, bufferSize);
///
/// Creates a message channel with the given identifier
///
/// Socket to create a channel for
/// Identifier for the channel
/// Size of the send buffer
/// Size of the recieve buffer
public static AgentMessageChannel CreateAgentMessageChannel(this ComputeSocket socket, int channelId, int sendBufferSize, int recvBufferSize)
{
using ComputeBuffer sendBuffer = new PooledBuffer(sendBufferSize);
using ComputeBuffer recvBuffer = new PooledBuffer(recvBufferSize);
return new AgentMessageChannel(socket, channelId, sendBuffer, recvBuffer, socket.Logger);
}
///
/// Reads a message from the channel
///
/// Channel to receive on
/// Expected type of the message
/// Cancellation token for the operation
/// Data for a message that was read. Must be disposed.
public static async ValueTask ReceiveAsync(this AgentMessageChannel channel, AgentMessageType type, CancellationToken cancellationToken = default)
{
AgentMessage message = await channel.ReceiveAsync(cancellationToken);
message.ThrowIfUnexpectedType(type);
return message;
}
///
/// Creates a new builder for a message
///
/// Channel to send on
/// Type of the message
/// Cancellation token for the operation
/// New builder for messages
public static ValueTask CreateMessageAsync(this AgentMessageChannel channel, AgentMessageType type, CancellationToken cancellationToken)
{
return channel.CreateMessageAsync(type, 1024, cancellationToken);
}
///
/// Forwards an existing message across a channel
///
/// Channel to send on
/// The message to be sent
/// Cancellation token for the operation
public static async ValueTask SendAsync(this AgentMessageChannel channel, AgentMessage message, CancellationToken cancellationToken)
{
using (IAgentMessageBuilder builder = await channel.CreateMessageAsync(message.Type, message.Data.Length, cancellationToken))
{
builder.WriteFixedLengthBytes(message.Data.Span);
builder.Send();
}
}
}
}