// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Text; using System.Threading; using System.Threading.Tasks; using EpicGames.Core; using EpicGames.Horde.Compute.Buffers; using EpicGames.Horde.Storage; using EpicGames.Horde.Storage.Bundles; using EpicGames.Horde.Storage.Nodes; using Microsoft.Extensions.Logging; namespace EpicGames.Horde.Compute { /// /// Implements the remote end of a compute worker. /// public class AgentMessageHandler { readonly DirectoryReference _sandboxDir; readonly Dictionary _envVars; readonly bool _executeInProcess; readonly string? _wineExecutablePath; readonly string? _containerEngineExecutable; readonly ILogger _logger; /// /// Constructor /// /// Directory to use for reading/writing files /// Environment variables to set for any child processes /// Whether to execute any external assemblies in the current process /// Path to Wine executable. If null, execution under Wine is disabled /// Path to container engine executable, e.g /usr/bin/podman. If null, execution inside a container is disabled /// Logger for diagnostics public AgentMessageHandler(DirectoryReference sandboxDir, Dictionary? envVars, bool executeInProcess, string? wineExecutablePath, string? containerEngineExecutable, ILogger logger) { _sandboxDir = sandboxDir; _envVars = envVars ?? new Dictionary(); _executeInProcess = executeInProcess; _wineExecutablePath = wineExecutablePath; _containerEngineExecutable = containerEngineExecutable; _logger = logger; } /// /// Runs the worker using commands sent along the given socket /// /// Socket to read from /// Cancellation token for the operation public async Task RunAsync(ComputeSocket socket, CancellationToken cancellationToken) { // Since we allow forking message channels, we want to ensure that errors on one channel are propagated back here, and terminate the whole connection. // To do that, we take first exception thrown and rethrow it with the original callstack here, while also forcing all other tasks to terminate via a // shared cancellation token. using CancellationTokenSource cancellationSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); ExceptionDispatchInfo? exceptionInfo = null; async Task PostExceptionAsync(Exception ex) { // Capture stack from call site Interlocked.CompareExchange(ref exceptionInfo, ExceptionDispatchInfo.Capture(ex), null); await cancellationSource.CancelAsync(); } await RunAsync(socket, 0, 4 * 1024 * 1024, PostExceptionAsync, cancellationSource.Token); // Throw the regular cancellation exception if requested cancellationToken.ThrowIfCancellationRequested(); // Otherwise throw any exception posted by a child task #pragma warning disable CA1508 // Static analyzer doesn't understand how this can be non-null exceptionInfo?.Throw(); #pragma warning restore CA1508 } async Task RunAsync(ComputeSocket socket, int channelId, int bufferSize, Func postException, CancellationToken cancellationToken) { List childTasks = new List(); using AgentMessageChannel channel = socket.CreateAgentMessageChannel(channelId, bufferSize); try { await channel.AttachAsync(cancellationToken); for (; ; ) { using AgentMessage message = await channel.ReceiveAsync(cancellationToken); _logger.LogDebug("Compute Channel {ChannelId}: {MessageType}", channelId, message.Type); switch (message.Type) { case AgentMessageType.None: return; case AgentMessageType.Ping: await channel.PingAsync(cancellationToken); break; case AgentMessageType.Fork: { ForkMessage fork = message.ParseForkMessage(); childTasks.Add(Task.Run(() => RunAsync(socket, fork.ChannelId, fork.BufferSize, postException, cancellationToken), cancellationToken)); } break; case AgentMessageType.WriteFiles: { UploadFilesMessage writeFiles = message.ParseUploadFilesMessage(); await WriteFilesAsync(channel, writeFiles.Name, writeFiles.Locator, cancellationToken: cancellationToken); } break; case AgentMessageType.DeleteFiles: { DeleteFilesMessage deleteFiles = message.ParseDeleteFilesMessage(); DeleteFiles(deleteFiles.Filter); } break; case AgentMessageType.ExecuteV1: { ExecuteProcessMessage ep = message.ParseExecuteProcessV1Message(); await ExecuteProcessAsync(socket, channel, ep.Executable, ep.Arguments, ep.WorkingDir, ep.ContainerImageUrl, ep.EnvVars, ep.Flags, cancellationToken); } break; case AgentMessageType.ExecuteV2: { ExecuteProcessMessage ep = message.ParseExecuteProcessV2Message(); await ExecuteProcessAsync(socket, channel, ep.Executable, ep.Arguments, ep.WorkingDir, ep.ContainerImageUrl, ep.EnvVars, ep.Flags, cancellationToken); } break; case AgentMessageType.ExecuteV3: { ExecuteProcessMessage ep = message.ParseExecuteProcessV3Message(); await ExecuteProcessAsync(socket, channel, ep.Executable, ep.Arguments, ep.WorkingDir, ep.ContainerImageUrl, ep.EnvVars, ep.Flags, cancellationToken); } break; case AgentMessageType.XorRequest: { XorRequestMessage xorRequest = message.AsXorRequest(); await RunXorAsync(channel, xorRequest.Data, xorRequest.Value, cancellationToken); } break; default: message.ThrowIfUnexpectedType(); return; } } } catch (OperationCanceledException ex) { // Ignore cancellations; we will re-throw from the root RunAsync() method. _logger.LogDebug(ex, "Compute Channel {ChannelId}: Cancelled.", channelId); } catch (Exception ex) { _logger.LogInformation(ex, "Compute Channel {ChannelId}: Exception: {Message}", channelId, ex.Message); await channel.SendExceptionAsync(ex, cancellationToken); await postException(ex); } finally { await Task.WhenAll(childTasks); } } static async ValueTask RunXorAsync(AgentMessageChannel channel, ReadOnlyMemory source, byte value, CancellationToken cancellationToken) { using IAgentMessageBuilder response = await channel.CreateMessageAsync(AgentMessageType.XorResponse, source.Length, cancellationToken); XorData(source.Span, response.GetSpanAndAdvance(source.Length), value); response.Send(); } static void XorData(ReadOnlySpan source, Span target, byte value) { for (int idx = 0; idx < source.Length; idx++) { target[idx] = (byte)(source[idx] ^ value); } } async Task WriteFilesAsync(AgentMessageChannel channel, string path, BlobLocator locator, BlobSerializerOptions? options = null, CancellationToken cancellationToken = default) { using AgentStorageBackend innerStore = new AgentStorageBackend(channel); await using BundleCache cache = new BundleCache(new BundleCacheOptions { HeaderCacheSize = 10 * 1024 * 1024, PacketCacheSize = 128 * 1024 * 1024 }); BundleOptions bundleOptions = ComputeProtocolUtilities.GetBundleOptions(channel.Protocol); BundleStorageNamespace store = new BundleStorageNamespace(innerStore, cache, bundleOptions, _logger); IBlobRef directoryRef = store.CreateBlobRef(locator, options); DirectoryReference outputDir = DirectoryReference.Combine(_sandboxDir, path); if (!outputDir.IsUnderDirectory(_sandboxDir)) { throw new InvalidOperationException("Cannot write files outside sandbox"); } await directoryRef.ExtractAsync(outputDir.ToDirectoryInfo(), new ExtractOptions(), _logger, cancellationToken); await VerifyFilesAsync(outputDir, directoryRef, cancellationToken); using (IAgentMessageBuilder message = await channel.CreateMessageAsync(AgentMessageType.WriteFilesResponse, cancellationToken)) { message.Send(); } } async Task VerifyFilesAsync(DirectoryReference outputDir, IBlobRef directoryRef, CancellationToken cancellationToken = default) { bool result = true; DirectoryNode directoryNode = await directoryRef.ReadBlobAsync(cancellationToken); foreach (FileEntry fileEntry in directoryNode.Files) { FileReference file = FileReference.Combine(outputDir, fileEntry.Name); if (!FileReference.Exists(file)) { _logger.LogError("Extracted file {File} does not exist", file); result = false; } else { await using FileStream stream = FileReference.Open(file, FileMode.Open, FileAccess.Read); IoHash hash = await IoHash.ComputeAsync(stream, cancellationToken); if (hash == fileEntry.Hash) { _logger.LogInformation("Hash of {File} is correct ({Hash})", file, hash); } if (hash != fileEntry.Hash) { _logger.LogError("Hash mismatch for {File}; expected {ExpectedHash}, got {ActualHash}", file, fileEntry.Hash, hash); result = false; } } } foreach (DirectoryEntry directoryEntry in directoryNode.Directories) { result &= await VerifyFilesAsync(DirectoryReference.Combine(outputDir, directoryEntry.Name), directoryEntry.Handle, cancellationToken); } return result; } void DeleteFiles(IReadOnlyList deleteFiles) { FileFilter filter = new FileFilter(deleteFiles); List files = filter.ApplyToDirectory(_sandboxDir, false); foreach (FileReference file in files) { FileUtils.ForceDeleteFile(file); } } async Task ExecuteProcessAsync( ComputeSocket socket, AgentMessageChannel channel, string executable, IReadOnlyList arguments, string? workingDir, string? containerImageUrl, IReadOnlyDictionary? envVars, ExecuteProcessFlags flags, CancellationToken cancellationToken) { try { if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { await ExecuteProcessWindowsAsync(socket, channel, executable, arguments, workingDir, envVars, flags, cancellationToken); } else if (containerImageUrl != null) { await ExecuteProcessInContainerAsync(channel, executable, arguments, workingDir, containerImageUrl, envVars, flags, cancellationToken); } else { await ExecuteProcessInternalAsync(channel, executable, arguments, workingDir, envVars, flags, cancellationToken); } } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { _logger.LogInformation("Compute process execution cancelled"); await channel.SendExceptionAsync(new ComputeExecutionCancelledException(), cancellationToken); } catch (Exception ex) { await channel.SendExceptionAsync(ex, cancellationToken); } } async Task ExecuteProcessWindowsAsync(ComputeSocket socket, AgentMessageChannel channel, string executable, IReadOnlyList arguments, string? workingDir, IReadOnlyDictionary? envVars, ExecuteProcessFlags flags, CancellationToken cancellationToken) { Dictionary newEnvVars = new Dictionary(_envVars); if (envVars != null) { foreach ((string name, string? value) in envVars) { newEnvVars.Add(name, value); } } await using (WorkerComputeSocketBridge server = await WorkerComputeSocketBridge.CreateAsync(socket, _logger)) { newEnvVars[WorkerComputeSocket.IpcEnvVar] = server.BufferName; _logger.LogInformation("Launching {Executable} {Arguments}", CommandLineArguments.Quote(executable), CommandLineArguments.Join(arguments)); await ExecuteProcessInternalAsync(channel, executable, arguments, workingDir, newEnvVars, flags, cancellationToken); _logger.LogInformation("Finished executing process"); } _logger.LogInformation("Child process has shut down"); } internal static async Task ProcessIpcMessagesAsync(ComputeSocket socket, ComputeBufferReader ipcReader, CancellationToken[] cancellationTokens, ILogger logger) { using CancellationTokenSource cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationTokens); CancellationToken cancellationToken = cancellationTokenSource.Token; List buffers = new(); try { List<(int, ComputeBufferWriter)> writers = new List<(int, ComputeBufferWriter)>(); while (await ipcReader.WaitToReadAsync(1, cancellationToken)) { ReadOnlyMemory memory = ipcReader.GetReadBuffer(); MemoryReader reader = new MemoryReader(memory); IpcMessage message = (IpcMessage)reader.ReadUnsignedVarInt(); try { switch (message) { case IpcMessage.AttachSendBuffer: { int channelId = (int)reader.ReadUnsignedVarInt(); string name = reader.ReadString(); logger.LogDebug("Attaching send buffer for channel {ChannelId} to {Name}", channelId, name); SharedMemoryBuffer buffer = SharedMemoryBuffer.OpenExisting(name); buffers.Add(buffer); socket.AttachSendBuffer(channelId, buffer); } break; case IpcMessage.AttachRecvBuffer: { int channelId = (int)reader.ReadUnsignedVarInt(); string name = reader.ReadString(); logger.LogDebug("Attaching recv buffer for channel {ChannelId} to {Name}", channelId, name); SharedMemoryBuffer buffer = SharedMemoryBuffer.OpenExisting(name); buffers.Add(buffer); socket.AttachRecvBuffer(channelId, buffer); } break; default: throw new InvalidOperationException($"Invalid IPC message: {message}"); } } catch (Exception ex) { logger.LogError(ex, "Exception while processing messages from child process: {Message}", ex.Message); } ipcReader.AdvanceReadPosition(memory.Length - reader.RemainingMemory.Length); } } catch (OperationCanceledException) { logger.LogDebug("Ipc message loop cancelled"); } finally { foreach (SharedMemoryBuffer buffer in buffers) { buffer.Dispose(); } } } // Helper class to take raw UTF8 output and merge it into log lines class ProcessOutputWriter { readonly string _prefix; readonly ByteArrayBuilder _lineBuffer = new ByteArrayBuilder(); readonly ILogger _logger; public ProcessOutputWriter(string prefix, ILogger logger) { _prefix = prefix; _logger = logger; } public void WriteBytes(ReadOnlySpan span) { for (; ; ) { int newlineIdx = span.IndexOf((byte)'\n'); if (newlineIdx == -1) { _lineBuffer.WriteFixedLengthBytes(span); break; } ReadOnlySpan line = span.Slice(0, newlineIdx); if (line.Length > 0 && line[^1] == (byte)'\r') { line = line.Slice(0, line.Length - 1); } if (_lineBuffer.Length > 0) { _lineBuffer.WriteFixedLengthBytes(line); _logger.LogInformation("{Prefix}: {Line}", _prefix, Encoding.UTF8.GetString(_lineBuffer.AsMemory().Span)); _lineBuffer.Clear(); } else { _logger.LogInformation("{Prefix}: {Line}", _prefix, Encoding.UTF8.GetString(line)); } span = span.Slice(newlineIdx + 1); } } } async Task ExecuteProcessAssemblyAsync(AgentMessageChannel channel, IReadOnlyList arguments, string? workingDir, IReadOnlyDictionary? envVars, CancellationToken cancellationToken) { List<(string, string?)> prevEnvVars = new List<(string, string?)>(); if (envVars != null) { foreach ((string key, string? value) in envVars) { prevEnvVars.Add((key, Environment.GetEnvironmentVariable(key))); Environment.SetEnvironmentVariable(key, value); } } string prevWorkingDir = Directory.GetCurrentDirectory(); Directory.SetCurrentDirectory(GetWorkingDirAbsPath(workingDir)); try { string assemblyPath = FileReference.Combine(_sandboxDir, arguments[0]).FullName; string[] mainArgs = arguments.Skip(1).ToArray(); _logger.LogWarning("Note: Loading and running {Assembly} in process", assemblyPath); TaskCompletionSource resultTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); Thread thread = new Thread(() => resultTcs.SetResult(AppDomain.CurrentDomain.ExecuteAssembly(assemblyPath, mainArgs))); thread.Start(); int result = await resultTcs.Task; await channel.SendExecuteResultAsync(result, cancellationToken); } finally { Directory.SetCurrentDirectory(prevWorkingDir); foreach ((string key, string? value) in prevEnvVars) { Environment.SetEnvironmentVariable(key, value); } } } async Task ExecuteProcessInContainerAsync(AgentMessageChannel channel, string executable, IReadOnlyList arguments, string? workingDir, string containerImageUrl, IReadOnlyDictionary? envVars, ExecuteProcessFlags flags, CancellationToken cancellationToken) { if (!RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { throw new Exception("Only Linux is supported for executing a process inside a container"); } if (_containerEngineExecutable == null) { throw new Exception("Container execution requested but agent has no container engine configured"); } string resolvedExecutable = FileReference.Combine(_sandboxDir, executable).FullName; uint linuxUid = getuid(); uint linuxGid = getgid(); // Resolve env vars here even if they are resolved later in ExecuteProcessInternalAsync // The environment file must be written at this step Dictionary resolvedEnvVars = ResolveEnvVars(envVars); string envFilePath = Path.GetTempFileName(); StringBuilder sb = new(); foreach ((string key, string value) in resolvedEnvVars) { sb.AppendLine($"{key}={value}"); } await File.WriteAllTextAsync(envFilePath, sb.ToString(), cancellationToken); List resolvedArguments = new() { "run", "--tty", // Allocate a pseudo-TTY "--rm", // Ensure container is removed after run $"--user={linuxUid}:{linuxGid}", // Run container as current user (important for mounted dirs) $"--volume={_sandboxDir}:{_sandboxDir}:rw", "--env-file=" + envFilePath, }; if (flags.HasFlag(ExecuteProcessFlags.ReplaceContainerEntrypoint)) { resolvedArguments.Add("--entrypoint=" + resolvedExecutable); resolvedArguments.Add(containerImageUrl); } else { resolvedArguments.Add(containerImageUrl); resolvedArguments.Add(resolvedExecutable); // Add executable as first argument and assume the entrypoint inside the container image will handle this } resolvedArguments.AddRange(arguments); _logger.LogInformation("Executing {File} {Arguments} in container", _containerEngineExecutable, arguments); // Skip forwarding of env vars as they are explicitly set above as arguments to container run await ExecuteProcessInternalAsync(channel, _containerEngineExecutable, resolvedArguments, workingDir, new Dictionary(), flags, cancellationToken); } async Task ExecuteProcessInternalAsync(AgentMessageChannel channel, string executable, IReadOnlyList arguments, string? workingDir, IReadOnlyDictionary? envVars, ExecuteProcessFlags flags, CancellationToken cancellationToken) { string resolvedExecutable = GetExecutableAbsPath(executable); string resolvedWorkingDir = GetWorkingDirAbsPath(workingDir); if (_executeInProcess && Path.GetFileNameWithoutExtension(resolvedExecutable).Equals("dotnet", StringComparison.OrdinalIgnoreCase)) { await ExecuteProcessAssemblyAsync(channel, arguments, workingDir, envVars, cancellationToken); } else { string resolvedCommandLine = CommandLineArguments.Join(arguments); if (flags.HasFlag(ExecuteProcessFlags.UseWine) && _wineExecutablePath != null) { // Path to the original Windows executable is prepended to the argument list so Wine can run it resolvedCommandLine = CommandLineArguments.Join(new[] { resolvedExecutable }.Concat(arguments).ToList()); resolvedExecutable = _wineExecutablePath; } Dictionary resolvedEnvVars = ResolveEnvVars(envVars); if (!File.Exists(resolvedExecutable)) { _logger.LogWarning("Executable {Path} does not exist", resolvedExecutable); } if (!Directory.Exists(resolvedWorkingDir)) { _logger.LogWarning("Working dir {Path} does not exist", resolvedWorkingDir); } using ManagedProcessGroup group = new ManagedProcessGroup(); using ManagedProcess process = new ManagedProcess(group, resolvedExecutable, resolvedCommandLine, resolvedWorkingDir, resolvedEnvVars, null, ProcessPriorityClass.Normal); byte[] buffer = new byte[1024]; ProcessOutputWriter outputWriter = new ProcessOutputWriter($"{Path.GetFileNameWithoutExtension(resolvedExecutable)}> ", _logger); for (; ; ) { // Use WaitAsync() as ReadAsync() does not respect the cancellation token when reading int length = await process.ReadAsync(buffer, 0, buffer.Length, cancellationToken).AsTask().WaitAsync(cancellationToken); if (length == 0) { await process.WaitForExitAsync(cancellationToken); await channel.SendExecuteResultAsync(process.ExitCode, cancellationToken); return; } ReadOnlyMemory output = buffer.AsMemory(0, length); await channel.SendExecuteOutputAsync(output, cancellationToken); outputWriter.WriteBytes(output.Span); } } } private string GetExecutableAbsPath(string relPath) { return FileReference.Combine(_sandboxDir, relPath).FullName; } private string GetWorkingDirAbsPath(string? relPath) { return DirectoryReference.Combine(_sandboxDir, relPath ?? String.Empty).FullName; } /// /// Flattens and merges available env vars to be used for compute process execution /// /// Optional extra env vars /// Merged environment variables private Dictionary ResolveEnvVars(IReadOnlyDictionary? envVars) { Dictionary resolvedEnvVars = ManagedProcess.GetCurrentEnvVars(); foreach ((string key, string? value) in _envVars) { if (value != null) { resolvedEnvVars[key] = value; } } if (envVars != null) { foreach ((string key, string? value) in envVars) { if (value == null) { resolvedEnvVars.Remove(key); } else { resolvedEnvVars[key] = value; } } } return resolvedEnvVars; } /// /// Get user identity (Linux only) /// /// Real user ID of the calling process [DllImport("libc", SetLastError = true)] internal static extern uint getuid(); /// /// Get group identity (Linux only) /// /// Real group ID of the calling process [DllImport("libc", SetLastError = true)] internal static extern uint getgid(); } }