// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Http.Json;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using EpicGames.Core;
using EpicGames.Horde.Agents;
using EpicGames.Horde.Compute.Transports;
using Microsoft.Extensions.Logging;
namespace EpicGames.Horde.Compute.Clients
{
///
/// Handshake request message for tunneling server
///
/// Target host to relay traffic to/from
/// Target port
public record TunnelHandshakeRequest(string Host, int Port)
{
const int Version = 1;
const string Name = "HANDSHAKE-REQ";
///
/// Serialize the message
///
/// A string based representation
public string Serialize()
{
return $"{Name}\t{Version}\t{Host}\t{Port}";
}
///
/// Deserialize the message
///
/// A raw string to deserialize
/// A request message
///
public static TunnelHandshakeRequest Deserialize(string? text)
{
string[] parts = (text ?? "").Split('\t');
if (parts.Length != 4 || parts[0] != Name || !Int32.TryParse(parts[1], out int _) || !Int32.TryParse(parts[3], out int port))
{
throw new Exception("Failed deserializing handshake request. Content: " + text);
}
return new TunnelHandshakeRequest(parts[2], port);
}
}
///
/// Handshake response message for tunneling server
///
/// Whether successful or not
/// Message with additional information describing the outcome
public record TunnelHandshakeResponse(bool IsSuccess, string Message)
{
const int Version = 1;
const string Name = "HANDSHAKE-RES";
///
/// Serialize the message
///
/// A string based representation
public string Serialize()
{
return $"{Name}\t{Version}\t{IsSuccess}\t{Message}";
}
///
/// Deserialize the message
///
/// A raw string to deserialize
/// A request message
///
public static TunnelHandshakeResponse Deserialize(string? text)
{
string[] parts = (text ?? "").Split('\t');
if (parts.Length != 4 || !Int32.TryParse(parts[1], out int _) || !Boolean.TryParse(parts[2], out bool isSuccess))
{
throw new Exception("Failed deserializing handshake response. Content: " + text);
}
return new TunnelHandshakeResponse(isSuccess, parts[3]);
}
}
///
/// Exception for ServerComputeClient
///
public class ServerComputeClientException : ComputeException
{
///
public ServerComputeClientException(string message) : base(message)
{
}
///
public ServerComputeClientException(string? message, Exception? innerException) : base(message, innerException)
{
}
}
///
/// Helper class to enlist remote resources to perform compute-intensive tasks.
///
public sealed class ServerComputeClient : IComputeClient, IDisposable
{
///
/// Length of the nonce sent as part of handshaking between initiator and remote
///
public const int NonceLength = 64;
record LeaseInfo(
ClusterId Cluster,
IReadOnlyList Properties,
IReadOnlyDictionary AssignedResources,
RemoteComputeSocket Socket,
string Ip,
ConnectionMode ConnectionMode,
IReadOnlyDictionary Ports);
class LeaseImpl : IComputeLease
{
public ClusterId Cluster => _source.Current.Cluster;
public IReadOnlyList Properties => _source.Current.Properties;
public IReadOnlyDictionary AssignedResources => _source.Current.AssignedResources;
public RemoteComputeSocket Socket => _source.Current.Socket;
public string Ip => _source.Current.Ip;
public ConnectionMode ConnectionMode => _source.Current.ConnectionMode;
public IReadOnlyDictionary Ports => _source.Current.Ports;
private readonly IAsyncEnumerator _source;
private BackgroundTask? _pingTask;
public LeaseImpl(IAsyncEnumerator source)
{
_source = source;
_pingTask = BackgroundTask.StartNew(PingAsync);
}
///
public async ValueTask DisposeAsync()
{
if (_pingTask != null)
{
await _pingTask.DisposeAsync();
_pingTask = null;
}
await _source.MoveNextAsync();
await _source.DisposeAsync();
}
///
public async ValueTask CloseAsync(CancellationToken cancellationToken)
{
if (_pingTask != null)
{
await _pingTask.DisposeAsync();
_pingTask = null;
}
await Socket.CloseAsync(cancellationToken);
}
async Task PingAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
await Socket.SendKeepAliveMessageAsync(cancellationToken);
await Task.Delay(TimeSpan.FromSeconds(5.0), cancellationToken);
}
}
}
readonly HttpClient _httpClient;
readonly CancellationTokenSource _cancellationSource = new CancellationTokenSource();
readonly string _sessionId;
readonly ILogger _logger;
readonly ExternalIpResolver _externalIpResolver;
///
/// Constructor
///
/// Factory for constructing http client instances
/// Logger for diagnostic messages
public ServerComputeClient(HttpClient httpClient, ILogger logger) : this(httpClient, null, logger)
{
}
///
/// Constructor
///
/// Factory for constructing http client instances
/// Arbitrary ID used for identifying this compute client. If not provided, a random one will be generated
/// Logger for diagnostic messages
public ServerComputeClient(HttpClient httpClient, string? sessionId, ILogger logger)
{
_httpClient = httpClient;
_sessionId = sessionId ?? Guid.NewGuid().ToString();
_logger = logger;
_externalIpResolver = new ExternalIpResolver(_httpClient);
}
///
public void Dispose()
{
_cancellationSource.Dispose();
}
///
public async Task GetClusterAsync(Requirements? requirements, string? requestId, ConnectionMetadataRequest? connection, ILogger logger, CancellationToken cancellationToken = default)
{
AssignComputeRequest request = new()
{
Requirements = requirements,
RequestId = requestId,
Connection = connection,
Protocol = (int)ComputeProtocol.Latest
};
using HttpResponseMessage httpResponse = await HordeHttpRequest.PostAsync(_httpClient, "api/v2/compute/_cluster", request, _cancellationSource.Token);
if (!httpResponse.IsSuccessStatusCode)
{
string body = await httpResponse.Content.ReadAsStringAsync(cancellationToken);
throw new ComputeClientException($"Unable to find suitable cluster. HTTP status code {httpResponse.StatusCode}: {body}");
}
GetClusterResponse? response = await httpResponse.Content.ReadFromJsonAsync(HordeHttpClient.JsonSerializerOptions, cancellationToken);
if (response == null)
{
throw new InvalidOperationException();
}
return response.ClusterId;
}
///
public async Task TryAssignWorkerAsync(ClusterId? clusterId, Requirements? requirements, string? requestId, ConnectionMetadataRequest? connection, ILogger logger, CancellationToken cancellationToken)
{
try
{
IAsyncEnumerator source = ConnectAsync(clusterId, requirements, requestId, connection, logger, cancellationToken).GetAsyncEnumerator(cancellationToken);
if (!await source.MoveNextAsync())
{
await source.DisposeAsync();
return null;
}
return new LeaseImpl(source);
}
catch (Polly.Timeout.TimeoutRejectedException ex)
{
_logger.LogInformation(ex, "Unable to assign worker from pool {ClusterId} (timeout)", clusterId);
return null;
}
}
///
public async Task DeclareResourceNeedsAsync(ClusterId clusterId, string pool, Dictionary resourceNeeds, CancellationToken cancellationToken = default)
{
ResourceNeedsMessage request = new() { SessionId = _sessionId, Pool = pool, ResourceNeeds = resourceNeeds };
using HttpResponseMessage response = await HordeHttpRequest.PostAsync(_httpClient, $"api/v2/compute/{clusterId}/resource-needs", request, _cancellationSource.Token);
response.EnsureSuccessStatusCode();
}
async IAsyncEnumerable ConnectAsync(ClusterId? clusterId, Requirements? requirements, string? requestId, ConnectionMetadataRequest? connection, ILogger workerLogger, [EnumeratorCancellation] CancellationToken cancellationToken)
{
_logger.LogDebug("Requesting compute resource");
// Assign a compute worker
AssignComputeRequest request = new AssignComputeRequest();
request.Requirements = requirements;
request.RequestId = requestId;
request.Connection = connection;
request.Protocol = (int)ComputeProtocol.Latest;
if (connection is { ModePreference: ConnectionMode.Relay })
{
connection.ClientPublicIp = (await _externalIpResolver.GetExternalIpAddressAsync(cancellationToken)).ToString();
}
AssignComputeResponse? response;
string path = clusterId == null ? "api/v2/compute" : $"api/v2/compute/{clusterId}";
using (HttpResponseMessage httpResponse = await HordeHttpRequest.PostAsync(_httpClient, path, request, _cancellationSource.Token))
{
if (httpResponse.StatusCode == HttpStatusCode.NotFound)
{
throw new NoComputeAgentsFoundException(clusterId ?? new ClusterId("null"), requirements);
}
if (httpResponse.StatusCode is HttpStatusCode.ServiceUnavailable or HttpStatusCode.TooManyRequests)
{
_logger.LogDebug("No compute resource is available. Reason: {Reason}", await httpResponse.Content.ReadAsStringAsync(cancellationToken));
yield break;
}
if (httpResponse.StatusCode == HttpStatusCode.Unauthorized)
{
string? content;
try
{
content = await httpResponse.Content.ReadAsStringAsync(cancellationToken);
}
catch
{
content = "None";
}
throw new ComputeClientException($"Bad authentication credentials. Check or refresh token. (HTTP status {httpResponse.StatusCode}, response: {content})");
}
if (httpResponse.StatusCode == HttpStatusCode.Forbidden)
{
LogEvent? logEvent = await httpResponse.Content.ReadFromJsonAsync(HordeHttpClient.JsonSerializerOptions, cancellationToken);
if (logEvent != null)
{
throw new ComputeClientException($"{logEvent.Message} (HTTP status {httpResponse.StatusCode})");
}
}
if (httpResponse.StatusCode == HttpStatusCode.InternalServerError)
{
string? content;
try
{
content = await httpResponse.Content.ReadAsStringAsync(cancellationToken);
}
catch
{
content = "None";
}
throw new ComputeClientException($"InternalServerError requesting compute resources: \"{content}\"");
}
httpResponse.EnsureSuccessStatusCode();
response = await httpResponse.Content.ReadFromJsonAsync(HordeHttpClient.JsonSerializerOptions, cancellationToken);
if (response == null)
{
throw new InvalidOperationException();
}
}
string nonce = response.Nonce;
nonce = nonce.Length <= 8 ? nonce : $"{nonce[..4]}...{nonce[^4..]}"; // Trim large nonce strings
(string host, int port) agentAddress = (response.Ip, response.Port); // Canonical address of agent not accounting for relays
(string host, int port) connectionAddress = agentAddress; // De facto address of agent, accounting for relays
if (response.ConnectionMode == ConnectionMode.Relay && !String.IsNullOrEmpty(response.ConnectionAddress))
{
agentAddress.port = response.Ports[ConnectionMetadataPort.ComputeId].AgentPort;
connectionAddress = (response.ConnectionAddress, response.Ports[ConnectionMetadataPort.ComputeId].Port);
}
// Connect to the remote machine
using Socket socket = new (SocketType.Stream, ProtocolType.Tcp);
workerLogger.LogDebug(
"Connecting to {AgentId}. Agent={AgentHost}:{AgentPort} Connection={ConnectionMode}/{ConnectionHost}:{ConnectionPort} Encryption={Encryption} LeaseId={LeaseId} Requirements={Requirements} PublicIp={PublicIp} Nonce={Nonce}",
response.AgentId,
agentAddress.host,
agentAddress.port,
response.ConnectionMode,
connectionAddress.host,
connectionAddress.port,
response.Encryption,
response.LeaseId,
request.Requirements,
request.Connection?.ClientPublicIp,
nonce);
try
{
switch (response.ConnectionMode)
{
case ConnectionMode.Direct:
await socket.ConnectAsync(connectionAddress.host, connectionAddress.port, cancellationToken);
break;
case ConnectionMode.Tunnel when !String.IsNullOrEmpty(response.ConnectionAddress):
(connectionAddress.host, connectionAddress.port) = ParseHostPort(response.ConnectionAddress);
await socket.ConnectAsync(connectionAddress.host, connectionAddress.port, cancellationToken);
await TunnelHandshakeAsync(socket, response, cancellationToken);
break;
case ConnectionMode.Relay when !String.IsNullOrEmpty(response.ConnectionAddress):
response.Ip = response.ConnectionAddress;
await ConnectWithRetryAsync(socket, connectionAddress.host, connectionAddress.port, TimeSpan.FromSeconds(5), 3, cancellationToken);
break;
default:
throw new Exception($"Unable to resolve connection mode ({response.ConnectionMode} via {response.ConnectionAddress ?? "none"})");
}
}
catch (SocketException se)
{
throw new ServerComputeClientException($"Unable to connect to {connectionAddress.host}:{connectionAddress.port} with mode {response.ConnectionMode}", se);
}
// Send the nonce
byte[] nonceData = StringUtils.ParseHexString(response.Nonce);
await socket.SendMessageAsync(nonceData, SocketFlags.None, cancellationToken);
workerLogger.LogInformation("Connected to {AgentId} ({Ip}) under lease {LeaseId} (agent version: {AgentVersion})", response.AgentId, response.Ip, response.LeaseId, response.AgentVersion ?? "unknown");
response.Properties = [..response.Properties, $"{KnownPropertyNames.LeaseId}={response.LeaseId}"];
await using ComputeTransport transport = await CreateTransportAsync(socket, response, cancellationToken);
await using RemoteComputeSocket computeSocket = new(transport, (ComputeProtocol)response.Protocol, workerLogger);
yield return new LeaseInfo(response.ClusterId, response.Properties, response.AssignedResources, computeSocket, response.Ip, response.ConnectionMode, response.Ports);
}
private async Task ConnectWithRetryAsync(Socket socket, string host, int port, TimeSpan timeout, int maxRetries, CancellationToken cancellationToken)
{
TimeSpan retryDelay = TimeSpan.FromSeconds(1);
for (int attempt = 1; attempt <= maxRetries; attempt++)
{
try
{
using CancellationTokenSource timeoutCts = new (timeout);
using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken);
await socket.ConnectAsync(host, port, linkedCts.Token);
return;
}
catch (OperationCanceledException)
{
if (attempt == maxRetries)
{
throw new TimeoutException($"Failed to connect {host}:{port} after {maxRetries} attempts");
}
}
catch (SocketException se)
{
if (attempt == maxRetries)
{
throw;
}
_logger.LogInformation("Unable to connect to {Host}:{Port} within {Timeout} ms. Error code: {Error}. Waiting {RetryDelay} ms before retrying...",
host, port, (int)timeout.TotalMilliseconds, se.SocketErrorCode, (int)retryDelay.TotalMilliseconds);
}
await Task.Delay(retryDelay, cancellationToken);
}
}
private static async Task CreateTransportAsync(Socket socket, AssignComputeResponse response, CancellationToken cancellationToken)
{
switch (response.Encryption)
{
case Encryption.Ssl:
case Encryption.SslEcdsaP256:
TcpSslTransport sslTransport = new(socket, StringUtils.ParseHexString(response.Certificate), false);
await sslTransport.AuthenticateAsync(cancellationToken);
return sslTransport;
case Encryption.Aes:
#pragma warning disable CA2000 // Dispose objects before losing scope
TcpTransport tcpTransport = new(socket);
return new AesTransport(tcpTransport, StringUtils.ParseHexString(response.Key));
#pragma warning restore CA2000 // Restore CA2000
case Encryption.None:
default:
return new TcpTransport(socket);
}
}
private static (string host, int port) ParseHostPort(string address)
{
try
{
string[] parts = address.Split(":");
string host = parts[0];
int port = Int32.Parse(parts[1]);
return (host, port);
}
catch (Exception e)
{
throw new Exception($"Unable to parse host and port for address: {address}", e);
}
}
private static async Task TunnelHandshakeAsync(Socket socket, AssignComputeResponse response, CancellationToken cancellationToken)
{
await using NetworkStream ns = new(socket, false);
using StreamReader reader = new(ns);
await using StreamWriter writer = new(ns) { AutoFlush = true };
string request = new TunnelHandshakeRequest(response.Ip, response.Port).Serialize();
await writer.WriteLineAsync(request.ToCharArray(), cancellationToken);
string exceptionMetadata = $"Connection: {response.ConnectionAddress} Target: {response.Ip}:{response.Port}";
#pragma warning disable CA2016 // Forward the 'CancellationToken' parameter to methods
Task readTask = reader.ReadLineAsync();
#pragma warning restore CA2016 // Forward the 'CancellationToken' parameter to methods
Task timeoutTask = Task.Delay(15000, cancellationToken);
if (await Task.WhenAny(readTask, timeoutTask) == timeoutTask)
{
throw new TimeoutException($"Timed out reading tunnel handshake response. {exceptionMetadata}");
}
TunnelHandshakeResponse handshakeResponse = TunnelHandshakeResponse.Deserialize(await readTask);
if (!handshakeResponse.IsSuccess)
{
throw new Exception($"Tunnel handshake failed! Reason: {handshakeResponse.Message} {exceptionMetadata}");
}
}
}
///
/// Exception indicating that no matching compute agents were found
///
public sealed class NoComputeAgentsFoundException : ComputeClientException
{
///
/// The compute cluster requested
///
public ClusterId ClusterId { get; }
///
/// Requested agent requirements
///
public Requirements? Requirements { get; }
///
/// Constructor
///
public NoComputeAgentsFoundException(ClusterId clusterId, Requirements? requirements)
: base($"No compute agents found matching '{requirements}' in cluster '{clusterId}'")
{
ClusterId = clusterId;
Requirements = requirements;
}
}
}