Files
UnrealEngine/Engine/Source/Programs/Shared/EpicGames.Horde.Tests/AesTransportTests.cs
2025-05-18 13:04:45 +08:00

236 lines
7.3 KiB
C#

// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using EpicGames.Horde.Compute;
using EpicGames.Horde.Compute.Transports;
using EpicGames.Horde.Tests.Compute;
using Microsoft.VisualStudio.TestTools.UnitTesting;
namespace EpicGames.Horde.Tests;
[TestClass]
public sealed class AesTransportTests : IAsyncDisposable
{
private const int SmallBufferThreshold = 1000;
private const int MinChunkSizeForLargeBuffers = 10;
private readonly byte[] _key;
private readonly MemoryComputeTransport _memoryTransport;
private readonly AesTransport _aesTransport;
private readonly Random _random = new(42);
private readonly int[] _bufferSizes = [1, 2, 3, 10, 200000];
public AesTransportTests()
{
_key = AesTransport.CreateKey();
_memoryTransport = new MemoryComputeTransport();
_aesTransport = new AesTransport(_memoryTransport, _key);
}
public async ValueTask DisposeAsync()
{
await _memoryTransport.DisposeAsync();
await _aesTransport.DisposeAsync();
}
[TestMethod]
public async Task SendAndReceive_VariousBufferSizesOverTcp_SuccessfullyTransferredAsync()
{
using CancellationTokenSource cts = new(ComputeSocketTests.TestTimoutMs);
(Socket clientSocket, Socket serverSocket) = await ComputeSocketTests.CreateSocketsAsync(cts.Token);
await using TcpTransport clientTransport = new(clientSocket);
await using TcpTransport serverTransport = new(serverSocket);
await using AesTransport encryptedClient = new(clientTransport, _key);
await using AesTransport encryptedServer = new(serverTransport, _key);
foreach (int bufferSize in _bufferSizes)
{
await SendAndReceiveAsync(encryptedClient, encryptedServer, bufferSize, cts.Token);
}
}
[TestMethod]
public async Task SendAndReceive_VariousBufferSizesInMemory_SuccessfullyTransferredAsync()
{
using CancellationTokenSource cts = new(ComputeSocketTests.TestTimoutMs);
await using MemoryComputeTransport memTransport = new();
await using AesTransport encryptedClient = new(memTransport, _key);
await using AesTransport encryptedServer = new(memTransport, _key);
foreach (int bufferSize in _bufferSizes)
{
await SendAndReceiveAsync(encryptedClient, encryptedServer, bufferSize, cts.Token);
}
}
[TestMethod]
public async Task SendAndReceive_Torture_SuccessfullyTransferredAsync()
{
using CancellationTokenSource cts = new(ComputeSocketTests.TestTimoutMs);
await using MemoryComputeTransport memTransport = new();
await using AesTransport clientAes = new(memTransport, _key);
await using AesTransport serverAes = new(memTransport, _key);
foreach (int bufferSize in _bufferSizes)
{
List<int> evilChunkSizes = [1, 2, 3, bufferSize - 1, bufferSize, bufferSize + 1, bufferSize * 5];
evilChunkSizes = evilChunkSizes
.Where(size => size > 0)
.Where(size => bufferSize < SmallBufferThreshold || size >= MinChunkSizeForLargeBuffers)
.ToList();
foreach (int sendChunk in evilChunkSizes)
{
foreach (int recvChunk in evilChunkSizes)
{
await SendAndMultipleReceiveAsync(clientAes, serverAes, bufferSize, () => sendChunk, () => recvChunk, cts.Token);
}
}
}
}
private async Task SendAndReceiveAsync(AesTransport client, AesTransport server, int bufferSize, CancellationToken cancellationToken)
{
// Arrange
byte[] sentData = new byte[bufferSize];
_random.NextBytes(sentData);
byte[] serverReceivedData = new byte[sentData.Length];
byte[] clientReceivedData = new byte[sentData.Length];
// Act
await client.SendAsync(sentData, cancellationToken);
await server.SendAsync(sentData, cancellationToken);
int serverBytesRead = await server.RecvAsync(serverReceivedData, cancellationToken);
int clientBytesRead = await client.RecvAsync(clientReceivedData, cancellationToken);
// Assert
Assert.AreEqual(sentData.Length, serverBytesRead);
Assert.AreEqual(sentData.Length, clientBytesRead);
CollectionAssert.AreEqual(sentData, serverReceivedData);
CollectionAssert.AreEqual(sentData, clientReceivedData);
}
private async Task SendAndMultipleReceiveAsync(ComputeTransport client, ComputeTransport server, int bufferSize,
Func<int> getSendChunkSize, Func<int> getRecvChunkSize, CancellationToken cancellationToken)
{
// Arrange
byte[] testData = new byte[bufferSize];
_random.NextBytes(testData);
using MemoryStream received = new(testData.Length);
int bytesSent = 0;
// Act
Task sendTask = Task.Run(async () =>
{
while (bytesSent < testData.Length)
{
int bytesLeft = testData.Length - bytesSent;
int bytesToSend = Math.Min(getSendChunkSize(), bytesLeft);
await client.SendAsync(testData.AsMemory(bytesSent, bytesToSend), cancellationToken);
bytesSent += bytesToSend;
}
}, cancellationToken);
Task recvTask = Task.Run(async () =>
{
while (received.Length < testData.Length)
{
byte[] chunk = new byte[getRecvChunkSize()];
int bytesRead = await server.RecvAsync(chunk, cancellationToken);
await received.WriteAsync(chunk, 0, bytesRead, cancellationToken);
}
}, cancellationToken);
await Task.WhenAll(sendTask, recvTask);
// Assert
Assert.AreEqual(testData.Length, received.Length);
CollectionAssert.AreEqual(testData, received.ToArray());
}
[TestMethod]
public async Task Constructor_InvalidKeyLength_ThrowsArgumentExceptionAsync()
{
await Assert.ThrowsExceptionAsync<ArgumentException>(async () =>
{
byte[] invalidKey = new byte[16];
await using AesTransport _ = new (_memoryTransport, invalidKey);
});
}
[TestMethod]
public async Task MarkComplete_PropagatedToInnerTransportAsync()
{
// Act
await _aesTransport.MarkCompleteAsync(CancellationToken.None);
// Assert
Assert.IsTrue(_memoryTransport.IsCompleted);
}
}
// Helper class for testing
internal class MemoryComputeTransport : ComputeTransport
{
private readonly Pipe _pipe;
public bool IsCompleted { get; private set; }
public MemoryComputeTransport()
{
PipeOptions options = new(pauseWriterThreshold: 10 * 1024 * 1024); // Must fit buffer sizes being tested
_pipe = new Pipe(options);
}
/// <inheritdoc/>
public override async ValueTask<int> RecvAsync(Memory<byte> buffer, CancellationToken cancellationToken)
{
ReadResult result = await _pipe.Reader.ReadAsync(cancellationToken);
if (result.Buffer.IsEmpty && result.IsCompleted)
{
return 0;
}
int bytesToCopy = (int)Math.Min(buffer.Length, result.Buffer.Length);
result.Buffer.Slice(0, bytesToCopy).CopyTo(buffer.Span);
_pipe.Reader.AdvanceTo(result.Buffer.GetPosition(bytesToCopy));
return bytesToCopy;
}
/// <inheritdoc/>
public override async ValueTask SendAsync(ReadOnlySequence<byte> buffer, CancellationToken cancellationToken)
{
foreach (ReadOnlyMemory<byte> segment in buffer)
{
FlushResult result = await _pipe.Writer.WriteAsync(segment, cancellationToken);
if (result.IsCompleted)
{
break;
}
}
}
/// <inheritdoc/>
public override async ValueTask MarkCompleteAsync(CancellationToken cancellationToken)
{
IsCompleted = true;
await _pipe.Writer.CompleteAsync();
}
/// <inheritdoc/>
public override async ValueTask DisposeAsync()
{
await _pipe.Writer.CompleteAsync();
await _pipe.Reader.CompleteAsync();
GC.SuppressFinalize(this);
}
}