// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Buffers; using System.Buffers.Binary; using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; using EpicGames.Core; using EpicGames.Horde.Storage; #pragma warning disable CA1054 // URI-like parameters should not be strings #pragma warning disable CA1056 // Change string to URI namespace EpicGames.Horde.Compute { /// /// Type of a compute message /// public enum AgentMessageType { /// /// No message was received (end of stream) /// None = 0x00, /// /// No-op message sent to keep the connection alive. Remote should reply with the same message. /// Ping = 0x01, /// /// Sent in place of a regular response if an error occurs on the remote /// Exception = 0x02, /// /// Fork the message loop into a new channel /// Fork = 0x03, /// /// Sent as the first message on a channel to notify the remote that the remote end is attached /// Attach = 0x04, #region Process Management /// /// Extract files on the remote machine (Initiator -> Remote) /// WriteFiles = 0x10, /// /// Notification that files have been extracted (Remote -> Initiator) /// WriteFilesResponse = 0x11, /// /// Deletes files on the remote machine (Initiator -> Remote) /// DeleteFiles = 0x12, /// /// Execute a process in a sandbox (Initiator -> Remote) /// ExecuteV1 = 0x16, /// /// Execute a process in a sandbox (Initiator -> Remote) /// ExecuteV2 = 0x22, /// /// Execute a process in a sandbox (Initiator -> Remote) /// ExecuteV3 = 0x23, /// /// Returns output from the child process to the caller (Remote -> Initiator) /// ExecuteOutput = 0x17, /// /// Returns the process exit code (Remote -> Initiator) /// ExecuteResult = 0x18, #endregion #region Storage /// /// Reads a blob from storage /// ReadBlob = 0x20, /// /// Response to a request. /// ReadBlobResponse = 0x21, #endregion #region Test Requests /// /// Xor a block of data with a value /// XorRequest = 0xf0, /// /// Result from an request. /// XorResponse = 0xf1, #endregion } /// /// Flags describing how to execute a compute task process on the agent /// [Flags] public enum ExecuteProcessFlags { /// /// No execute flags set /// None = 0, /// /// Request execution to be wrapped under Wine when running on Linux. /// Agent still reserves the right to refuse it (e.g no Wine executable configured, mismatching OS etc) /// UseWine = 1, /// /// Use compute process executable as entrypoint for container /// If not set, path to the executable is passed as the first parameter to the container invocation /// ReplaceContainerEntrypoint = 2, } /// /// Standard implementation of a message /// public sealed class AgentMessage : IMemoryReader, IDisposable { /// /// Type of the message /// public AgentMessageType Type { get; } /// /// Data that was read /// public ReadOnlyMemory Data { get; } readonly IMemoryOwner _memoryOwner; int _position; /// /// Constructor /// public AgentMessage(AgentMessageType type, ReadOnlyMemory data) { _memoryOwner = MemoryPool.Shared.Rent(data.Length); data.CopyTo(_memoryOwner.Memory); Type = type; Data = _memoryOwner.Memory.Slice(0, data.Length); } /// public void Dispose() { _memoryOwner.Dispose(); } /// public ReadOnlyMemory GetMemory(int minSize = 1) => Data.Slice(_position); /// public void Advance(int length) => _position += length; } /// /// Exception thrown when an invalid message is received /// public sealed class InvalidAgentMessageException : ComputeException { /// /// Constructor /// public InvalidAgentMessageException(AgentMessage actualMessage, AgentMessageType? expectedType, ComputeRemoteException? remoteException) : base($"Unexpected message {actualMessage.Type}" + (expectedType != null ? $". Wanted {expectedType}" : ""), remoteException) { } } /// /// Exception thrown when a compute execution is cancelled /// public sealed class ComputeExecutionCancelledException : ComputeException { private const string Text = "Compute execution cancelled"; /// /// Constructor /// public ComputeExecutionCancelledException() : base(Text) { } private ComputeExecutionCancelledException(Exception innerException) : base(Text, innerException) { } /// /// Try constructing and throwing if the exception message matches a cancellation exception /// /// Deserialized exception message /// If message matches public static void TryThrow(ExceptionMessage em) { if (em.Message == Text) { throw new ComputeExecutionCancelledException(new ComputeRemoteException(em)); } } } /// /// Writer for compute messages /// public interface IAgentMessageBuilder : IMemoryWriter, IDisposable { /// /// Sends the current message /// void Send(); } /// /// Message for reporting an error /// public readonly record struct ExceptionMessage(string Message, string Description); /// /// Message requesting that the message loop be forked /// /// New channel to communicate on /// Size of the buffer public readonly record struct ForkMessage(int ChannelId, int BufferSize); /// /// Extract files from a bundle to a path in the remote sandbox /// /// Path to extract the files to /// Locator for the tree to extract public readonly record struct UploadFilesMessage(string Name, BlobLocator Locator); /// /// Deletes files or directories in the remote /// /// Filter for files to delete public readonly record struct DeleteFilesMessage(IReadOnlyList Filter); /// /// Message to execute a new child process /// /// Executable path /// Arguments for the executable /// Working directory to execute in /// Environment variables for the child process. Null values unset variables. /// Additional execution flags /// URL to container image. If specified, process will be executed inside this container public readonly record struct ExecuteProcessMessage(string Executable, IReadOnlyList Arguments, string? WorkingDir, IReadOnlyDictionary EnvVars, ExecuteProcessFlags Flags, string? ContainerImageUrl); /// /// Response from executing a child process /// /// Exit code for the process public readonly record struct ExecuteProcessResponseMessage(int ExitCode); /// /// Creates a blob read request /// public readonly record struct ReadBlobMessage(BlobLocator Locator, int Offset, int Length); /// /// Message for running an XOR command /// /// Data to xor /// Value to XOR with public readonly record struct XorRequestMessage(ReadOnlyMemory Data, byte Value); /// /// Wraps various requests across compute channels /// public static class AgentMessageExtensions { /// /// Closes the remote message loop /// public static async ValueTask CloseAsync(this AgentMessageChannel channel, CancellationToken cancellationToken = default) { using IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.None, cancellationToken); message.Send(); } /// /// Sends a ping message to the remote /// public static async ValueTask PingAsync(this AgentMessageChannel channel, CancellationToken cancellationToken = default) { using IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.Ping, cancellationToken); message.Send(); } /// /// Sends an exception response to the remote /// public static ValueTask SendExceptionAsync(this AgentMessageChannel channel, Exception ex, CancellationToken cancellationToken = default) => SendExceptionAsync(channel, ex.Message, ex.ToString(), cancellationToken); /// /// Sends an exception response to the remote /// public static async ValueTask SendExceptionAsync(this AgentMessageChannel channel, string description, string trace, CancellationToken cancellationToken = default) { using IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.Exception, cancellationToken); message.WriteString(description); message.WriteString(trace); message.Send(); } /// /// Parses a message as an /// public static ExceptionMessage ParseExceptionMessage(this AgentMessage message) { string msg = message.ReadString(); string description = message.ReadString(); return new ExceptionMessage(msg, description); } /// /// Requests that the remote message loop be forked /// public static async ValueTask ForkAsync(this AgentMessageChannel channel, int channelId, int bufferSize, CancellationToken cancellationToken = default) { using IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.Fork, cancellationToken); message.WriteInt32(channelId); message.WriteInt32(bufferSize); message.Send(); } /// /// Parses a fork request message /// public static ForkMessage ParseForkMessage(this AgentMessage message) { int channelId = message.ReadInt32(); int bufferSize = message.ReadInt32(); return new ForkMessage(channelId, bufferSize); } /// /// Notifies the remote that a buffer has been attached /// public static async ValueTask AttachAsync(this AgentMessageChannel channel, CancellationToken cancellationToken = default) { using IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.Attach, cancellationToken); message.Send(); } /// /// Waits until an attached notification is received along the channel /// /// /// /// public static async ValueTask WaitForAttachAsync(this AgentMessageChannel channel, CancellationToken cancellationToken = default) { using AgentMessage message = await channel.ReceiveAsync(cancellationToken); message.ThrowIfUnexpectedType(AgentMessageType.Attach); } /// /// Throw an exception if message is not of expected type /// /// Agent message to extend /// Optional type to expect. If not specified, assume type was unwanted no matter what public static void ThrowIfUnexpectedType(this AgentMessage message, AgentMessageType? expectedType = null) { if (message.Type == expectedType) { return; } ComputeRemoteException? cre = message.Type == AgentMessageType.Exception ? new ComputeRemoteException(message.ParseExceptionMessage()) : null; throw new InvalidAgentMessageException(message, expectedType, cre); } #region Process static async Task RunStorageServerAsync(this AgentMessageChannel channel, IStorageBackend storage, CancellationToken cancellationToken = default) { for (; ; ) { AgentMessage message = await channel.ReceiveAsync(cancellationToken); if (message.Type != AgentMessageType.ReadBlob) { return message; } try { ReadBlobMessage readBlob = message.ParseReadBlobRequest(); await SendBlobDataAsync(channel, readBlob, storage, cancellationToken); } finally { message.Dispose(); } } } /// /// Creates a sandbox on the remote machine /// public static async Task UploadFilesAsync(this AgentMessageChannel channel, string path, BlobLocator locator, IStorageBackend storage, CancellationToken cancellationToken = default) { using (IAgentMessageBuilder request = await channel.CreateMessageAsync(AgentMessageType.WriteFiles, cancellationToken)) { request.WriteString(path); request.WriteString($"{IoHash.Zero}@{locator}"); // HACK: Currently deployed agents have a hash check in BundleNodeLocator.Parse() which does not check length before checking for the '@' character separating the hash from locator. request.Send(); } using AgentMessage response = await RunStorageServerAsync(channel, storage, cancellationToken); response.ThrowIfUnexpectedType(AgentMessageType.WriteFilesResponse); } /// /// Parses a message as a /// public static UploadFilesMessage ParseUploadFilesMessage(this AgentMessage message) { string name = message.ReadString(); string path = message.ReadString(); int atIdx = path.IndexOf('@', StringComparison.Ordinal); if (atIdx != -1) { path = path.Substring(atIdx + 1); } BlobLocator locator = new BlobLocator(path); return new UploadFilesMessage(name, locator); } /// /// Destroys a sandbox on the remote machine /// /// Current channel /// Paths of files or directories to delete /// Cancellation token for the operation public static async ValueTask DeleteFilesAsync(this AgentMessageChannel channel, IReadOnlyList paths, CancellationToken cancellationToken) { using IAgentMessageBuilder request = await channel.CreateMessageAsync(AgentMessageType.DeleteFiles, cancellationToken); request.WriteList(paths, MemoryWriterExtensions.WriteString); request.Send(); } /// /// Parses a message as a /// public static DeleteFilesMessage ParseDeleteFilesMessage(this AgentMessage message) { List files = message.ReadList(MemoryReaderExtensions.ReadString); return new DeleteFilesMessage(files); } /// /// Executes a remote process (using ExecuteV1) /// /// Current channel /// Executable to run, relative to the sandbox root /// Arguments for the child process /// Working directory for the process /// Environment variables for the child process /// Cancellation token for the operation public static async Task ExecuteAsync(this AgentMessageChannel channel, string executable, IReadOnlyList arguments, string? workingDir, IReadOnlyDictionary? envVars, CancellationToken cancellationToken = default) { using (IAgentMessageBuilder request = await channel.CreateMessageAsync(AgentMessageType.ExecuteV1, cancellationToken)) { request.WriteString(executable); request.WriteList(arguments, MemoryWriterExtensions.WriteString); request.WriteOptionalString(workingDir); request.WriteDictionary(envVars ?? new Dictionary(), MemoryWriterExtensions.WriteString, MemoryWriterExtensions.WriteOptionalString); request.Send(); } return new AgentManagedProcess(channel); } /// /// Executes a remote process (using ExecuteV2) /// /// Current channel /// Executable to run, relative to the sandbox root /// Arguments for the child process /// Working directory for the process /// Environment variables for the child process /// Additional execution flags /// Cancellation token for the operation public static async Task ExecuteAsync(this AgentMessageChannel channel, string executable, IReadOnlyList arguments, string? workingDir, IReadOnlyDictionary? envVars, ExecuteProcessFlags flags = ExecuteProcessFlags.None, CancellationToken cancellationToken = default) { using (IAgentMessageBuilder request = await channel.CreateMessageAsync(AgentMessageType.ExecuteV2, cancellationToken)) { request.WriteString(executable); request.WriteList(arguments, MemoryWriterExtensions.WriteString); request.WriteOptionalString(workingDir); request.WriteDictionary(envVars ?? new Dictionary(), MemoryWriterExtensions.WriteString, MemoryWriterExtensions.WriteOptionalString); request.WriteInt32((int)flags); request.Send(); } return new AgentManagedProcess(channel); } /// /// Executes a remote process (using ExecuteV3) /// /// Current channel /// Executable to run, relative to the sandbox root /// Arguments for the child process /// Working directory for the process /// Environment variables for the child process /// Additional execution flags /// Optional container image URL. If set, execution will happen inside this container /// Cancellation token for the operation public static async Task ExecuteAsync(this AgentMessageChannel channel, string executable, IReadOnlyList arguments, string? workingDir, IReadOnlyDictionary? envVars, ExecuteProcessFlags flags, string? containerImageUrl, CancellationToken cancellationToken = default) { using (IAgentMessageBuilder request = await channel.CreateMessageAsync(AgentMessageType.ExecuteV3, cancellationToken)) { request.WriteString(executable); request.WriteList(arguments, MemoryWriterExtensions.WriteString); request.WriteOptionalString(workingDir); request.WriteDictionary(envVars ?? new Dictionary(), MemoryWriterExtensions.WriteString, MemoryWriterExtensions.WriteOptionalString); request.WriteInt32((int)flags); request.WriteString(containerImageUrl ?? ""); request.Send(); } return new AgentManagedProcess(channel); } /// /// Parses a message as a /// public static ExecuteProcessMessage ParseExecuteProcessV1Message(this AgentMessage message) { string executable = message.ReadString(); List arguments = message.ReadList(MemoryReaderExtensions.ReadString); string? workingDir = message.ReadOptionalString(); Dictionary envVars = message.ReadDictionary(MemoryReaderExtensions.ReadString, MemoryReaderExtensions.ReadOptionalString); return new ExecuteProcessMessage(executable, arguments, workingDir, envVars, ExecuteProcessFlags.None, null); } /// /// Parses a message as a /// public static ExecuteProcessMessage ParseExecuteProcessV2Message(this AgentMessage message) { string executable = message.ReadString(); List arguments = message.ReadList(MemoryReaderExtensions.ReadString); string? workingDir = message.ReadOptionalString(); Dictionary envVars = message.ReadDictionary(MemoryReaderExtensions.ReadString, MemoryReaderExtensions.ReadOptionalString); ExecuteProcessFlags flags = (ExecuteProcessFlags)message.ReadInt32(); return new ExecuteProcessMessage(executable, arguments, workingDir, envVars, flags, null); } /// /// Parses a message as a /// public static ExecuteProcessMessage ParseExecuteProcessV3Message(this AgentMessage message) { string executable = message.ReadString(); List arguments = message.ReadList(MemoryReaderExtensions.ReadString); string? workingDir = message.ReadOptionalString(); Dictionary envVars = message.ReadDictionary(MemoryReaderExtensions.ReadString, MemoryReaderExtensions.ReadOptionalString); ExecuteProcessFlags flags = (ExecuteProcessFlags)message.ReadInt32(); string containerImageUrl = message.ReadString(); return new ExecuteProcessMessage(executable, arguments, workingDir, envVars, flags, String.IsNullOrEmpty(containerImageUrl) ? null : containerImageUrl); } /// /// Sends output from a child process /// public static async ValueTask SendExecuteOutputAsync(this AgentMessageChannel channel, ReadOnlyMemory data, CancellationToken cancellationToken = default) { using IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.ExecuteOutput, data.Length + 20, cancellationToken); message.WriteFixedLengthBytes(data.Span); message.Send(); } /// /// Sends a response from executing a child process /// /// /// Exit code from the process /// Cancellation token for the operation public static async ValueTask SendExecuteResultAsync(this AgentMessageChannel channel, int exitCode, CancellationToken cancellationToken = default) { using IAgentMessageBuilder builder = await channel.CreateMessageAsync(AgentMessageType.ExecuteResult, cancellationToken); builder.WriteInt32(exitCode); builder.Send(); } /// /// Parses a message as a /// public static ExecuteProcessResponseMessage ParseExecuteProcessResponse(this AgentMessage message) { int exitCode = message.ReadInt32(); return new ExecuteProcessResponseMessage(exitCode); } #endregion #region Storage /// /// /// /// /// public static ReadBlobMessage ParseReadBlobRequest(this AgentMessage message) { BlobLocator locator = new BlobLocator(message.ReadUtf8String()); int offset = (int)message.ReadUnsignedVarInt(); int length = (int)message.ReadUnsignedVarInt(); return new ReadBlobMessage(locator, offset, length); } /// /// Wraps a compute message containing blob data /// sealed class BlobDataStream : ReadOnlyMemoryStream { readonly AgentMessage _message; public BlobDataStream(AgentMessage message) : base(message.Data.Slice(8)) { _message = message; } protected override void Dispose(bool disposing) { base.Dispose(disposing); if (disposing) { _message.Dispose(); } } } /// /// Reads a blob from the remote /// /// Channel to write to /// Path for the blob /// Offset within the blob /// Length of data to return /// Cancellation token for the operation /// Stream containing the blob data public static async Task> ReadBlobAsync(this AgentMessageChannel channel, string path, int offset, int length, CancellationToken cancellationToken = default) { using (IAgentMessageBuilder request = await channel.CreateMessageAsync(AgentMessageType.ReadBlob, cancellationToken)) { request.WriteString(path); request.WriteUnsignedVarInt(offset); request.WriteUnsignedVarInt(length); request.Send(); } byte[]? buffer = null; for (; ; ) { AgentMessage? response = null; try { response = await channel.ReceiveAsync(cancellationToken); response.ThrowIfUnexpectedType(AgentMessageType.ReadBlobResponse); int chunkOffset = BinaryPrimitives.ReadInt32LittleEndian(response.Data.Span.Slice(0, 4)); int chunkLength = response.Data.Length - 8; int totalLength = BinaryPrimitives.ReadInt32LittleEndian(response.Data.Span.Slice(4, 4)); buffer ??= new byte[totalLength]; response.Data.Slice(8).CopyTo(buffer.AsMemory(chunkOffset)); if (chunkOffset + chunkLength == totalLength) { break; } } catch { response?.Dispose(); throw; } } return buffer; } /// /// Writes blob data to a compute channel /// /// Channel to write to /// The read request /// Storage backend to retrieve the blob from /// Cancellation token for the operation public static Task SendBlobDataAsync(this AgentMessageChannel channel, ReadBlobMessage message, IStorageBackend storage, CancellationToken cancellationToken = default) { return SendBlobDataAsync(channel, message.Locator, message.Offset, message.Length, storage, cancellationToken); } /// /// Writes blob data to a compute channel /// /// Channel to write to /// Locator for the blob to send /// Starting offset of the data /// Length of the data /// Storage backend to retrieve the blob from /// Cancellation token for the operation public static async Task SendBlobDataAsync(this AgentMessageChannel channel, BlobLocator locator, int offset, int length, IStorageBackend storage, CancellationToken cancellationToken = default) { using Stream stream = await storage.OpenBlobAsync(locator, offset, (length == 0) ? null : length, cancellationToken); const int MaxChunkSize = 512 * 1024; for (int chunkOffset = 0; ;) { int chunkLength = (int)Math.Min(stream.Length - chunkOffset, MaxChunkSize); using (IAgentMessageBuilder response = await channel.CreateMessageAsync(AgentMessageType.ReadBlobResponse, chunkLength + 128, cancellationToken)) { response.WriteInt32(chunkOffset); response.WriteInt32((int)stream.Length); Memory memory = response.GetMemoryAndAdvance(chunkLength); await stream.ReadFixedLengthBytesAsync(memory, cancellationToken); response.Send(); } chunkOffset += chunkLength; if (chunkOffset == stream.Length) { break; } } } #endregion #region Test Messages /// /// Send a message to request that a byte string be xor'ed with a particular value /// public static async ValueTask SendXorRequestAsync(this AgentMessageChannel channel, ReadOnlyMemory data, byte value, CancellationToken cancellationToken = default) { using IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.XorRequest, cancellationToken); message.WriteFixedLengthBytes(data.Span); message.WriteUInt8(value); message.Send(); } /// /// Parse a message as an XOR request /// public static XorRequestMessage AsXorRequest(this AgentMessage message) { ReadOnlyMemory data = message.Data; return new XorRequestMessage(data[0..^1], data.Span[^1]); } #endregion } }