// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Http; using System.Net.Http.Json; using System.Net.Sockets; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using EpicGames.Core; using EpicGames.Horde.Agents; using EpicGames.Horde.Compute.Transports; using Microsoft.Extensions.Logging; namespace EpicGames.Horde.Compute.Clients { /// /// Handshake request message for tunneling server /// /// Target host to relay traffic to/from /// Target port public record TunnelHandshakeRequest(string Host, int Port) { const int Version = 1; const string Name = "HANDSHAKE-REQ"; /// /// Serialize the message /// /// A string based representation public string Serialize() { return $"{Name}\t{Version}\t{Host}\t{Port}"; } /// /// Deserialize the message /// /// A raw string to deserialize /// A request message /// public static TunnelHandshakeRequest Deserialize(string? text) { string[] parts = (text ?? "").Split('\t'); if (parts.Length != 4 || parts[0] != Name || !Int32.TryParse(parts[1], out int _) || !Int32.TryParse(parts[3], out int port)) { throw new Exception("Failed deserializing handshake request. Content: " + text); } return new TunnelHandshakeRequest(parts[2], port); } } /// /// Handshake response message for tunneling server /// /// Whether successful or not /// Message with additional information describing the outcome public record TunnelHandshakeResponse(bool IsSuccess, string Message) { const int Version = 1; const string Name = "HANDSHAKE-RES"; /// /// Serialize the message /// /// A string based representation public string Serialize() { return $"{Name}\t{Version}\t{IsSuccess}\t{Message}"; } /// /// Deserialize the message /// /// A raw string to deserialize /// A request message /// public static TunnelHandshakeResponse Deserialize(string? text) { string[] parts = (text ?? "").Split('\t'); if (parts.Length != 4 || !Int32.TryParse(parts[1], out int _) || !Boolean.TryParse(parts[2], out bool isSuccess)) { throw new Exception("Failed deserializing handshake response. Content: " + text); } return new TunnelHandshakeResponse(isSuccess, parts[3]); } } /// /// Exception for ServerComputeClient /// public class ServerComputeClientException : ComputeException { /// public ServerComputeClientException(string message) : base(message) { } /// public ServerComputeClientException(string? message, Exception? innerException) : base(message, innerException) { } } /// /// Helper class to enlist remote resources to perform compute-intensive tasks. /// public sealed class ServerComputeClient : IComputeClient, IDisposable { /// /// Length of the nonce sent as part of handshaking between initiator and remote /// public const int NonceLength = 64; record LeaseInfo( ClusterId Cluster, IReadOnlyList Properties, IReadOnlyDictionary AssignedResources, RemoteComputeSocket Socket, string Ip, ConnectionMode ConnectionMode, IReadOnlyDictionary Ports); class LeaseImpl : IComputeLease { public ClusterId Cluster => _source.Current.Cluster; public IReadOnlyList Properties => _source.Current.Properties; public IReadOnlyDictionary AssignedResources => _source.Current.AssignedResources; public RemoteComputeSocket Socket => _source.Current.Socket; public string Ip => _source.Current.Ip; public ConnectionMode ConnectionMode => _source.Current.ConnectionMode; public IReadOnlyDictionary Ports => _source.Current.Ports; private readonly IAsyncEnumerator _source; private BackgroundTask? _pingTask; public LeaseImpl(IAsyncEnumerator source) { _source = source; _pingTask = BackgroundTask.StartNew(PingAsync); } /// public async ValueTask DisposeAsync() { if (_pingTask != null) { await _pingTask.DisposeAsync(); _pingTask = null; } await _source.MoveNextAsync(); await _source.DisposeAsync(); } /// public async ValueTask CloseAsync(CancellationToken cancellationToken) { if (_pingTask != null) { await _pingTask.DisposeAsync(); _pingTask = null; } await Socket.CloseAsync(cancellationToken); } async Task PingAsync(CancellationToken cancellationToken) { while (!cancellationToken.IsCancellationRequested) { await Socket.SendKeepAliveMessageAsync(cancellationToken); await Task.Delay(TimeSpan.FromSeconds(5.0), cancellationToken); } } } readonly HttpClient _httpClient; readonly CancellationTokenSource _cancellationSource = new CancellationTokenSource(); readonly string _sessionId; readonly ILogger _logger; readonly ExternalIpResolver _externalIpResolver; /// /// Constructor /// /// Factory for constructing http client instances /// Logger for diagnostic messages public ServerComputeClient(HttpClient httpClient, ILogger logger) : this(httpClient, null, logger) { } /// /// Constructor /// /// Factory for constructing http client instances /// Arbitrary ID used for identifying this compute client. If not provided, a random one will be generated /// Logger for diagnostic messages public ServerComputeClient(HttpClient httpClient, string? sessionId, ILogger logger) { _httpClient = httpClient; _sessionId = sessionId ?? Guid.NewGuid().ToString(); _logger = logger; _externalIpResolver = new ExternalIpResolver(_httpClient); } /// public void Dispose() { _cancellationSource.Dispose(); } /// public async Task GetClusterAsync(Requirements? requirements, string? requestId, ConnectionMetadataRequest? connection, ILogger logger, CancellationToken cancellationToken = default) { AssignComputeRequest request = new() { Requirements = requirements, RequestId = requestId, Connection = connection, Protocol = (int)ComputeProtocol.Latest }; using HttpResponseMessage httpResponse = await HordeHttpRequest.PostAsync(_httpClient, "api/v2/compute/_cluster", request, _cancellationSource.Token); if (!httpResponse.IsSuccessStatusCode) { string body = await httpResponse.Content.ReadAsStringAsync(cancellationToken); throw new ComputeClientException($"Unable to find suitable cluster. HTTP status code {httpResponse.StatusCode}: {body}"); } GetClusterResponse? response = await httpResponse.Content.ReadFromJsonAsync(HordeHttpClient.JsonSerializerOptions, cancellationToken); if (response == null) { throw new InvalidOperationException(); } return response.ClusterId; } /// public async Task TryAssignWorkerAsync(ClusterId? clusterId, Requirements? requirements, string? requestId, ConnectionMetadataRequest? connection, ILogger logger, CancellationToken cancellationToken) { try { IAsyncEnumerator source = ConnectAsync(clusterId, requirements, requestId, connection, logger, cancellationToken).GetAsyncEnumerator(cancellationToken); if (!await source.MoveNextAsync()) { await source.DisposeAsync(); return null; } return new LeaseImpl(source); } catch (Polly.Timeout.TimeoutRejectedException ex) { _logger.LogInformation(ex, "Unable to assign worker from pool {ClusterId} (timeout)", clusterId); return null; } } /// public async Task DeclareResourceNeedsAsync(ClusterId clusterId, string pool, Dictionary resourceNeeds, CancellationToken cancellationToken = default) { ResourceNeedsMessage request = new() { SessionId = _sessionId, Pool = pool, ResourceNeeds = resourceNeeds }; using HttpResponseMessage response = await HordeHttpRequest.PostAsync(_httpClient, $"api/v2/compute/{clusterId}/resource-needs", request, _cancellationSource.Token); response.EnsureSuccessStatusCode(); } async IAsyncEnumerable ConnectAsync(ClusterId? clusterId, Requirements? requirements, string? requestId, ConnectionMetadataRequest? connection, ILogger workerLogger, [EnumeratorCancellation] CancellationToken cancellationToken) { _logger.LogDebug("Requesting compute resource"); // Assign a compute worker AssignComputeRequest request = new AssignComputeRequest(); request.Requirements = requirements; request.RequestId = requestId; request.Connection = connection; request.Protocol = (int)ComputeProtocol.Latest; if (connection is { ModePreference: ConnectionMode.Relay }) { connection.ClientPublicIp = (await _externalIpResolver.GetExternalIpAddressAsync(cancellationToken)).ToString(); } AssignComputeResponse? response; string path = clusterId == null ? "api/v2/compute" : $"api/v2/compute/{clusterId}"; using (HttpResponseMessage httpResponse = await HordeHttpRequest.PostAsync(_httpClient, path, request, _cancellationSource.Token)) { if (httpResponse.StatusCode == HttpStatusCode.NotFound) { throw new NoComputeAgentsFoundException(clusterId ?? new ClusterId("null"), requirements); } if (httpResponse.StatusCode is HttpStatusCode.ServiceUnavailable or HttpStatusCode.TooManyRequests) { _logger.LogDebug("No compute resource is available. Reason: {Reason}", await httpResponse.Content.ReadAsStringAsync(cancellationToken)); yield break; } if (httpResponse.StatusCode == HttpStatusCode.Unauthorized) { string? content; try { content = await httpResponse.Content.ReadAsStringAsync(cancellationToken); } catch { content = "None"; } throw new ComputeClientException($"Bad authentication credentials. Check or refresh token. (HTTP status {httpResponse.StatusCode}, response: {content})"); } if (httpResponse.StatusCode == HttpStatusCode.Forbidden) { LogEvent? logEvent = await httpResponse.Content.ReadFromJsonAsync(HordeHttpClient.JsonSerializerOptions, cancellationToken); if (logEvent != null) { throw new ComputeClientException($"{logEvent.Message} (HTTP status {httpResponse.StatusCode})"); } } if (httpResponse.StatusCode == HttpStatusCode.InternalServerError) { string? content; try { content = await httpResponse.Content.ReadAsStringAsync(cancellationToken); } catch { content = "None"; } throw new ComputeClientException($"InternalServerError requesting compute resources: \"{content}\""); } httpResponse.EnsureSuccessStatusCode(); response = await httpResponse.Content.ReadFromJsonAsync(HordeHttpClient.JsonSerializerOptions, cancellationToken); if (response == null) { throw new InvalidOperationException(); } } string nonce = response.Nonce; nonce = nonce.Length <= 8 ? nonce : $"{nonce[..4]}...{nonce[^4..]}"; // Trim large nonce strings (string host, int port) agentAddress = (response.Ip, response.Port); // Canonical address of agent not accounting for relays (string host, int port) connectionAddress = agentAddress; // De facto address of agent, accounting for relays if (response.ConnectionMode == ConnectionMode.Relay && !String.IsNullOrEmpty(response.ConnectionAddress)) { agentAddress.port = response.Ports[ConnectionMetadataPort.ComputeId].AgentPort; connectionAddress = (response.ConnectionAddress, response.Ports[ConnectionMetadataPort.ComputeId].Port); } // Connect to the remote machine using Socket socket = new (SocketType.Stream, ProtocolType.Tcp); workerLogger.LogDebug( "Connecting to {AgentId}. Agent={AgentHost}:{AgentPort} Connection={ConnectionMode}/{ConnectionHost}:{ConnectionPort} Encryption={Encryption} LeaseId={LeaseId} Requirements={Requirements} PublicIp={PublicIp} Nonce={Nonce}", response.AgentId, agentAddress.host, agentAddress.port, response.ConnectionMode, connectionAddress.host, connectionAddress.port, response.Encryption, response.LeaseId, request.Requirements, request.Connection?.ClientPublicIp, nonce); try { switch (response.ConnectionMode) { case ConnectionMode.Direct: await socket.ConnectAsync(connectionAddress.host, connectionAddress.port, cancellationToken); break; case ConnectionMode.Tunnel when !String.IsNullOrEmpty(response.ConnectionAddress): (connectionAddress.host, connectionAddress.port) = ParseHostPort(response.ConnectionAddress); await socket.ConnectAsync(connectionAddress.host, connectionAddress.port, cancellationToken); await TunnelHandshakeAsync(socket, response, cancellationToken); break; case ConnectionMode.Relay when !String.IsNullOrEmpty(response.ConnectionAddress): response.Ip = response.ConnectionAddress; await ConnectWithRetryAsync(socket, connectionAddress.host, connectionAddress.port, TimeSpan.FromSeconds(5), 3, cancellationToken); break; default: throw new Exception($"Unable to resolve connection mode ({response.ConnectionMode} via {response.ConnectionAddress ?? "none"})"); } } catch (SocketException se) { throw new ServerComputeClientException($"Unable to connect to {connectionAddress.host}:{connectionAddress.port} with mode {response.ConnectionMode}", se); } // Send the nonce byte[] nonceData = StringUtils.ParseHexString(response.Nonce); await socket.SendMessageAsync(nonceData, SocketFlags.None, cancellationToken); workerLogger.LogInformation("Connected to {AgentId} ({Ip}) under lease {LeaseId} (agent version: {AgentVersion})", response.AgentId, response.Ip, response.LeaseId, response.AgentVersion ?? "unknown"); response.Properties = [..response.Properties, $"{KnownPropertyNames.LeaseId}={response.LeaseId}"]; await using ComputeTransport transport = await CreateTransportAsync(socket, response, cancellationToken); await using RemoteComputeSocket computeSocket = new(transport, (ComputeProtocol)response.Protocol, workerLogger); yield return new LeaseInfo(response.ClusterId, response.Properties, response.AssignedResources, computeSocket, response.Ip, response.ConnectionMode, response.Ports); } private async Task ConnectWithRetryAsync(Socket socket, string host, int port, TimeSpan timeout, int maxRetries, CancellationToken cancellationToken) { TimeSpan retryDelay = TimeSpan.FromSeconds(1); for (int attempt = 1; attempt <= maxRetries; attempt++) { try { using CancellationTokenSource timeoutCts = new (timeout); using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken); await socket.ConnectAsync(host, port, linkedCts.Token); return; } catch (OperationCanceledException) { if (attempt == maxRetries) { throw new TimeoutException($"Failed to connect {host}:{port} after {maxRetries} attempts"); } } catch (SocketException se) { if (attempt == maxRetries) { throw; } _logger.LogInformation("Unable to connect to {Host}:{Port} within {Timeout} ms. Error code: {Error}. Waiting {RetryDelay} ms before retrying...", host, port, (int)timeout.TotalMilliseconds, se.SocketErrorCode, (int)retryDelay.TotalMilliseconds); } await Task.Delay(retryDelay, cancellationToken); } } private static async Task CreateTransportAsync(Socket socket, AssignComputeResponse response, CancellationToken cancellationToken) { switch (response.Encryption) { case Encryption.Ssl: case Encryption.SslEcdsaP256: TcpSslTransport sslTransport = new(socket, StringUtils.ParseHexString(response.Certificate), false); await sslTransport.AuthenticateAsync(cancellationToken); return sslTransport; case Encryption.Aes: #pragma warning disable CA2000 // Dispose objects before losing scope TcpTransport tcpTransport = new(socket); return new AesTransport(tcpTransport, StringUtils.ParseHexString(response.Key)); #pragma warning restore CA2000 // Restore CA2000 case Encryption.None: default: return new TcpTransport(socket); } } private static (string host, int port) ParseHostPort(string address) { try { string[] parts = address.Split(":"); string host = parts[0]; int port = Int32.Parse(parts[1]); return (host, port); } catch (Exception e) { throw new Exception($"Unable to parse host and port for address: {address}", e); } } private static async Task TunnelHandshakeAsync(Socket socket, AssignComputeResponse response, CancellationToken cancellationToken) { await using NetworkStream ns = new(socket, false); using StreamReader reader = new(ns); await using StreamWriter writer = new(ns) { AutoFlush = true }; string request = new TunnelHandshakeRequest(response.Ip, response.Port).Serialize(); await writer.WriteLineAsync(request.ToCharArray(), cancellationToken); string exceptionMetadata = $"Connection: {response.ConnectionAddress} Target: {response.Ip}:{response.Port}"; #pragma warning disable CA2016 // Forward the 'CancellationToken' parameter to methods Task readTask = reader.ReadLineAsync(); #pragma warning restore CA2016 // Forward the 'CancellationToken' parameter to methods Task timeoutTask = Task.Delay(15000, cancellationToken); if (await Task.WhenAny(readTask, timeoutTask) == timeoutTask) { throw new TimeoutException($"Timed out reading tunnel handshake response. {exceptionMetadata}"); } TunnelHandshakeResponse handshakeResponse = TunnelHandshakeResponse.Deserialize(await readTask); if (!handshakeResponse.IsSuccess) { throw new Exception($"Tunnel handshake failed! Reason: {handshakeResponse.Message} {exceptionMetadata}"); } } } /// /// Exception indicating that no matching compute agents were found /// public sealed class NoComputeAgentsFoundException : ComputeClientException { /// /// The compute cluster requested /// public ClusterId ClusterId { get; } /// /// Requested agent requirements /// public Requirements? Requirements { get; } /// /// Constructor /// public NoComputeAgentsFoundException(ClusterId clusterId, Requirements? requirements) : base($"No compute agents found matching '{requirements}' in cluster '{clusterId}'") { ClusterId = clusterId; Requirements = requirements; } } }