// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Threading; using System.Threading.Tasks; using EpicGames.Horde; using EpicGames.Horde.Jobs; using EpicGames.Horde.Logs; using EpicGames.Horde.Storage; using EpicGames.Horde.Storage.Backends; using EpicGames.Horde.Storage.Bundles; using Grpc.Core; using Horde.Common.Rpc; using HordeCommon.Rpc.Messages; using HordeCommon.Rpc.Tasks; using JobDriver.Execution; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace JobDriver.Tests { [TestClass] public class WorkerServiceTest { private readonly ServiceCollection _serviceCollection; private readonly JobId _jobId = JobId.Parse("65bd0655591b5d5d7d047b58"); private readonly JobStepBatchId _batchId = new JobStepBatchId(0x1234); private readonly JobStepId _stepId1 = new JobStepId(1); private readonly JobStepId _stepId2 = new JobStepId(2); private readonly JobStepId _stepId3 = new JobStepId(3); private readonly LogId _logId = LogId.Parse("65bd0655591b5d5d7d047b00"); class FakeServerLogger : IServerLogger { public IDisposable? BeginScope(TState state) where TState : notnull => NullLogger.Instance.BeginScope(state); public ValueTask DisposeAsync() => new ValueTask(); public bool IsEnabled(LogLevel logLevel) => true; public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { } public Task StopAsync() => Task.CompletedTask; } internal static JobExecutor NullExecutor = new SimpleTestExecutor(async (step, logger, cancellationToken) => { await Task.Delay(1, cancellationToken); return JobStepOutcome.Success; }); public WorkerServiceTest() { _serviceCollection = new ServiceCollection(); _serviceCollection.AddLogging(); _serviceCollection.AddHorde(); _serviceCollection.AddSingleton(); _serviceCollection.AddSingleton(); _serviceCollection.AddSingleton(); _serviceCollection.AddSingleton(); _serviceCollection.Configure(settings => { settings.Executor = TestExecutor.Name; // Not really used since the executor is overridden in the tests }); } [TestMethod] public async Task AbortExecuteStepTestAsync() { { using CancellationTokenSource cancelSource = new CancellationTokenSource(); using CancellationTokenSource stepCancelSource = new CancellationTokenSource(); using JobExecutor executor = new SimpleTestExecutor(async (stepResponse, logger, cancelToken) => { cancelSource.CancelAfter(10); await Task.Delay(5000, cancelToken); return JobStepOutcome.Success; }); await Assert.ThrowsExceptionAsync(() => executor.ExecuteStepAsync( null!, NullLogger.Instance, cancelSource.Token, stepCancelSource.Token)); } { using CancellationTokenSource cancelSource = new CancellationTokenSource(); using CancellationTokenSource stepCancelSource = new CancellationTokenSource(); using JobExecutor executor = new SimpleTestExecutor(async (stepResponse, logger, cancelToken) => { stepCancelSource.CancelAfter(10); await Task.Delay(5000, cancelToken); return JobStepOutcome.Success; }); (JobStepOutcome stepOutcome, JobStepState stepState) = await executor.ExecuteStepAsync(null!, NullLogger.Instance, cancelSource.Token, stepCancelSource.Token); Assert.AreEqual(JobStepOutcome.Failure, stepOutcome); Assert.AreEqual(JobStepState.Aborted, stepState); } } [Ignore] [TestMethod] public async Task AbortExecuteJobTestAsync() { using CancellationTokenSource source = new CancellationTokenSource(); CancellationToken token = source.Token; ExecuteJobTask executeJobTask = new ExecuteJobTask(); executeJobTask.JobId = _jobId.ToString(); executeJobTask.BatchId = _batchId.ToString(); executeJobTask.LogId = _logId.ToString(); executeJobTask.JobName = "jobName1"; executeJobTask.JobOptions = new RpcJobOptions { Executor = SimpleTestExecutor.Name }; executeJobTask.AutoSdkWorkspace = new RpcAgentWorkspace(); executeJobTask.Workspace = new RpcAgentWorkspace(); JobRpcClientStub client = new JobRpcClientStub(NullLogger.Instance); // await using FakeHordeRpcServer fakeServer = new(); // await using ISession session = FakeServerSessionFactory.CreateSession(null!); client.BeginStepResponses.Enqueue(new RpcBeginStepResponse { Name = "stepName1", StepId = _stepId1.ToString() }); client.BeginStepResponses.Enqueue(new RpcBeginStepResponse { Name = "stepName2", StepId = _stepId2.ToString() }); client.BeginStepResponses.Enqueue(new RpcBeginStepResponse { Name = "stepName3", StepId = _stepId3.ToString() }); RpcGetStepRequest step2Req = new RpcGetStepRequest(_jobId, _batchId, _stepId2); RpcGetStepResponse step2Res = new RpcGetStepResponse(JobStepOutcome.Unspecified, JobStepState.Unspecified, true); client.GetStepResponses[step2Req] = step2Res; using SimpleTestExecutor executor = new SimpleTestExecutor(async (step, logger, cancelToken) => { await Task.Delay(50, cancelToken); return JobStepOutcome.Success; }); _serviceCollection.AddSingleton(x => new SimpleTestExecutorFactory(executor)); await using ServiceProvider serviceProvider = _serviceCollection.BuildServiceProvider(); executor._stepAbortPollInterval = TimeSpan.FromMilliseconds(1); await executor.ExecuteAsync(NullLogger.Instance, token); Assert.AreEqual(3, client.UpdateStepRequests.Count); Assert.AreEqual(JobStepOutcome.Success, (JobStepOutcome)client.UpdateStepRequests[0].Outcome); Assert.AreEqual(JobStepState.Completed, (JobStepState)client.UpdateStepRequests[0].State); Assert.AreEqual(JobStepOutcome.Failure, (JobStepOutcome)client.UpdateStepRequests[1].Outcome); Assert.AreEqual(JobStepState.Aborted, (JobStepState)client.UpdateStepRequests[1].State); Assert.AreEqual(JobStepOutcome.Success, (JobStepOutcome)client.UpdateStepRequests[2].Outcome); Assert.AreEqual(JobStepState.Completed, (JobStepState)client.UpdateStepRequests[2].State); } [Ignore] [TestMethod] public async Task PollForStepAbortFailureTestAsync() { using JobExecutor executor = new SimpleTestExecutor(async (step, logger, cancelToken) => { await Task.Delay(50, cancelToken); return JobStepOutcome.Success; }); _serviceCollection.AddSingleton(x => new SimpleTestExecutorFactory(executor)); await using ServiceProvider serviceProvider = _serviceCollection.BuildServiceProvider(); // JobHandler jobHandler = serviceProvider.GetRequiredService(); executor._stepAbortPollInterval = TimeSpan.FromMilliseconds(5); JobRpcClientStub client = new JobRpcClientStub(NullLogger.Instance); int c = 0; client._getStepFunc = (request) => { return ++c switch { 1 => new RpcGetStepResponse { AbortRequested = false }, 2 => throw new RpcException(new Status(StatusCode.Cancelled, "Fake cancel from test")), 3 => new RpcGetStepResponse { AbortRequested = true }, _ => throw new Exception("Should never reach here") }; }; using CancellationTokenSource stepPollCancelSource = new CancellationTokenSource(); using CancellationTokenSource stepCancelSource = new CancellationTokenSource(); TaskCompletionSource stepFinishedSource = new TaskCompletionSource(); await executor.PollForStepAbortAsync(null!, _jobId, _batchId, _stepId2, stepCancelSource, stepFinishedSource.Task, NullLogger.Instance, stepPollCancelSource.Token); Assert.IsTrue(stepCancelSource.IsCancellationRequested); } } /* internal class FakeServerSessionFactory : ISessionFactory { readonly FakeHordeRpcServer _fakeServer; public FakeServerSessionFactory(FakeHordeRpcServer fakeServer) => _fakeServer = fakeServer; public Task CreateAsync(CancellationToken cancellationToken) { return Task.FromResult(CreateSession(_fakeServer.GetHordeClient())); } public static ISession CreateSession(IHordeClient hordeClient) { Mock fakeSession = new Mock(MockBehavior.Strict); fakeSession.Setup(x => x.HordeClient).Returns(hordeClient); fakeSession.Setup(x => x.AgentId).Returns(new EpicGames.Horde.Agents.AgentId("LocalAgent")); fakeSession.Setup(x => x.SessionId).Returns(new EpicGames.Horde.Agents.Sessions.SessionId(default)); fakeSession.Setup(x => x.DisposeAsync()).Returns(new ValueTask()); fakeSession.Setup(x => x.WorkingDir).Returns(DirectoryReference.Combine(DirectoryReference.GetCurrentDirectory(), Guid.NewGuid().ToString())); return fakeSession.Object; } } */ #if false /// /// Fake implementation of a HordeRpc gRPC server. /// Provides a corresponding gRPC client class that can be used with the WorkerService /// to test client-server interactions. /// internal class FakeHordeRpcServer : IAsyncDisposable { private readonly string _serverName; private readonly bool _isStopping = false; private readonly Dictionary _leases = new(); private readonly Dictionary _streamIdToStreamResponse = new(); private readonly Dictionary _jobIdToJobResponse = new(); private readonly ILogger _logger; public readonly TaskCompletionSource CreateSessionReceived = new(); public readonly TaskCompletionSource UpdateSessionReceived = new(); // private readonly FakeHordeClient _hordeClient; /* private class FakeHordeClient : IHordeClient { readonly FakeHordeRpcServer _server; public Uri ServerUrl => new Uri("http://horde-server"); public FakeHordeClient(FakeHordeRpcServer server) => _server = server; public Task LoginAsync(bool allowLogin, CancellationToken cancellationToken) => throw new NotImplementedException(); public HordeHttpClient CreateHttpClient() => throw new NotImplementedException(); public IComputeClient CreateComputeClient() => throw new NotImplementedException(); public IStorageClient CreateStorageClient(string relativePath, string? accessToken = null) => throw new NotImplementedException(); public ValueTask DisposeAsync() => default; public Task GetAccessTokenAsync(bool interactive, CancellationToken cancellationToken = default) => Task.FromResult(null); public Task CreateGrpcChannelAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); public Task CreateGrpcClientAsync(CancellationToken cancellationToken = default) where TClient : ClientBase { if (typeof(TClient) == typeof(HordeRpc.HordeRpcClient)) { return Task.FromResult((TClient)(object)new FakeHordeRpcClient(_server)); } if (typeof(TClient) == typeof(JobRpc.JobRpcClient)) { return Task.FromResult((TClient)(object)new FakeJobRpcClient(_server)); } throw new NotImplementedException(); } public bool HasValidAccessToken() => throw new NotImplementedException(); public IServerLogger CreateServerLogger(LogId logId, LogLevel minimumLevel = LogLevel.Information) => throw new NotImplementedException(); } private class FakeHordeRpcClient : HordeRpc.HordeRpcClient { private readonly FakeHordeRpcServer _outer; public FakeHordeRpcClient(FakeHordeRpcServer outer) { _outer = outer; } public override AsyncDuplexStreamingCall UpdateSession(Metadata headers = null!, DateTime? deadline = null, CancellationToken cancellationToken = default) { return _outer.GetUpdateSessionCall(CancellationToken.None); } } */ private class FakeJobRpcClient : JobRpc.JobRpcClient { private readonly FakeHordeRpcServer _outer; public FakeJobRpcClient(FakeHordeRpcServer outer) { _outer = outer; } public override AsyncUnaryCall GetJobAsync(RpcGetJobRequest request, CallOptions options) { if (_outer._jobIdToJobResponse.TryGetValue(JobId.Parse(request.JobId), out RpcGetJobResponse? jobResponse)) { return JobRpcClientStub.Wrap(jobResponse); } throw new RpcException(new Status(StatusCode.NotFound, $"Job ID {request.JobId} not found")); } } public FakeHordeRpcServer() { _serverName = "FakeServer"; _logger = NullLogger.Instance; _hordeClient = new FakeHordeClient(this); } public void AddTestLease(string leaseId) { if (_leases.ContainsKey(leaseId)) { throw new ArgumentException($"Lease ID {leaseId} already exists"); } TestTask testTask = new(); _leases[leaseId] = new RpcLease { Id = leaseId, State = RpcLeaseState.Pending, Payload = Any.Pack(testTask) }; } public RpcLease GetLease(string leaseId) { return _leases[leaseId]; } public void AddStream(StreamId streamId, string streamName) { if (_streamIdToStreamResponse.ContainsKey(streamId)) { throw new Exception($"Stream ID {streamId} already added"); } _streamIdToStreamResponse[streamId] = new RpcGetStreamResponse { Name = streamName, }; } public void AddAgentType(StreamId streamId, string agentType) { if (!_streamIdToStreamResponse.TryGetValue(streamId, out RpcGetStreamResponse? streamResponse)) { throw new Exception($"Stream ID {streamId} not found"); } string tempDir = Path.Join(Path.GetTempPath(), $"horde-agent-type-{agentType}-" + Guid.NewGuid().ToString()[..8]); Directory.CreateDirectory(tempDir); streamResponse.AgentTypes[agentType] = new RpcGetAgentTypeResponse { TempStorageDir = tempDir }; } public void AddJob(JobId jobId, StreamId streamId, int change, int preflightChange) { if (!_streamIdToStreamResponse.ContainsKey(streamId)) { throw new Exception($"Stream ID {streamId} not found"); } _jobIdToJobResponse[jobId] = new RpcGetJobResponse { StreamId = streamId.ToString(), Change = change, PreflightChange = preflightChange }; } public IHordeClient GetHordeClient() { return _hordeClient; } public RpcCreateSessionResponse OnCreateSessionRequest(RpcCreateSessionRequest request) { CreateSessionReceived.TrySetResult(true); _logger.LogInformation("OnCreateSessionRequest: {AgentId} {Status}", request.Id, request.Status); RpcCreateSessionResponse response = new() { AgentId = "bogusAgentId", Token = "bogusToken", SessionId = "bogusSessionId", ExpiryTime = Timestamp.FromDateTime(DateTime.UtcNow.AddHours(3)), }; return response; } public AsyncDuplexStreamingCall GetQueryServerStateCall(CancellationToken cancellationToken) { FakeAsyncStreamReader responseStream = new(cancellationToken); FakeClientStreamWriter requestStream = new(onComplete: () => { responseStream.Complete(); return Task.CompletedTask; }); responseStream.Write(new RpcQueryServerStateResponse { Name = _serverName, Stopping = _isStopping }); return new( requestStream, responseStream, Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { /*isDisposed = true;*/ }); } public AsyncDuplexStreamingCall GetUpdateSessionCall(CancellationToken cancellationToken) { FakeAsyncStreamReader responseStream = new(cancellationToken); async Task OnRequest(RpcUpdateSessionRequest request) { UpdateSessionReceived.TrySetResult(true); foreach (RpcLease agentLease in request.Leases) { RpcLease serverLease = _leases[agentLease.Id]; serverLease.State = agentLease.State; serverLease.Outcome = agentLease.Outcome; serverLease.Output = agentLease.Output; } _logger.LogInformation("OnUpdateSessionRequest: {AgentId} {SessionId} {Status}", request.AgentId, request.SessionId, request.Status); await Task.Delay(100, cancellationToken); RpcUpdateSessionResponse response = new() { ExpiryTime = Timestamp.FromDateTime(DateTime.UtcNow + TimeSpan.FromMinutes(120)) }; response.Leases.AddRange(_leases.Values.Where(x => x.State != RpcLeaseState.Completed)); await responseStream.Write(response); } FakeClientStreamWriter requestStream = new(OnRequest, () => { responseStream.Complete(); return Task.CompletedTask; }); return new( requestStream, responseStream, Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { }); } public static AsyncUnaryCall CreateAsyncUnaryCall(TResponse response) { return new AsyncUnaryCall( Task.FromResult(response), Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { }); } public async ValueTask DisposeAsync() { await _hordeClient.DisposeAsync(); foreach (RpcGetStreamResponse stream in _streamIdToStreamResponse.Values) { foreach (RpcGetAgentTypeResponse agentType in stream.AgentTypes.Values) { if (Directory.Exists(agentType.TempStorageDir)) { Directory.Delete(agentType.TempStorageDir, true); } } } } } /// /// Fake stream reader used for testing gRPC clients /// /// Message type reader will handle internal class FakeAsyncStreamReader : IAsyncStreamReader where T : class { private readonly Channel _channel = System.Threading.Channels.Channel.CreateUnbounded(); private T? _current; private readonly CancellationToken? _cancellationTokenOverride; public FakeAsyncStreamReader(CancellationToken? cancellationTokenOverride = null) { _cancellationTokenOverride = cancellationTokenOverride; } public Task Write(T message) { if (!_channel.Writer.TryWrite(message)) { throw new InvalidOperationException("Unable to write message."); } return Task.CompletedTask; } public void Complete() { _channel.Writer.Complete(); } /// public async Task MoveNext(CancellationToken cancellationToken) { if (_cancellationTokenOverride != null) { cancellationToken = _cancellationTokenOverride.Value; } if (await _channel.Reader.WaitToReadAsync(cancellationToken)) { if (_channel.Reader.TryRead(out T? message)) { _current = message; return true; } } _current = null!; return false; } /// public T Current { get { if (_current == null) { throw new InvalidOperationException("No current element is available."); } return _current; } } } /// /// Fake stream writer used for testing gRPC clients /// /// Message type writer will handle internal class FakeClientStreamWriter : IClientStreamWriter where T : class { private readonly Func? _onWrite; private readonly Func? _onComplete; private bool _isCompleted; public FakeClientStreamWriter(Func? onWrite = null, Func? onComplete = null) { _onWrite = onWrite; _onComplete = onComplete; } /// public async Task WriteAsync(T message) { if (_isCompleted) { throw new InvalidOperationException("Stream is marked as complete"); } if (_onWrite != null) { await _onWrite(message); } } /// public WriteOptions? WriteOptions { get; set; } /// public async Task CompleteAsync() { _isCompleted = true; if (_onComplete != null) { await _onComplete(); } } } #endif }