Files
UnrealEngine/Engine/Source/Developer/Zen/Private/ZenClient.cpp
2025-05-18 13:04:45 +08:00

914 lines
22 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "ZenClient.h"
#include "Async/Async.h"
#include "Containers/StringView.h"
#include "HAL/Event.h"
#include "Math/UnrealMathUtility.h"
#include "Memory/MemoryView.h"
#include "Misc/StringBuilder.h"
#include "ProfilingDebugging/CountersTrace.h"
#include "ProfilingDebugging/CpuProfilerTrace.h"
#include "Serialization/LargeMemoryWriter.h"
#include "Serialization/MemoryReader.h"
#include "IPAddress.h"
#include "Sockets.h"
#include "SocketSubsystem.h"
DEFINE_LOG_CATEGORY_STATIC(LogZenClient, Log, All);
TRACE_DECLARE_INT_COUNTER(ZenClient_StreamedBytes, TEXT("ZenClient/StreamedBytes"));
TRACE_DECLARE_FLOAT_COUNTER(ZenClient_StreamedSeconds, TEXT("ZenClient/StreamedSeconds"));
namespace UE::Zen {
////////////////////////////////////////////////////////////////////////////////
bool SendData(FSocket& Socket, FMemoryView Data)
{
while (Data.IsEmpty() == false)
{
const int32 BytesToSend = (int32)FMath::Clamp<uint64>(Data.GetSize(), 0, MAX_int32);
int32 BytesSent = 0;
if (Socket.Send((const uint8*)Data.GetData(), BytesToSend, BytesSent) == false)
{
return false;
}
Data += BytesSent;
}
return true;
}
void CreateHttpUpgradeRequest(const TArray<FString>& Protocols, FAnsiStringBuilderBase& OutRequest)
{
const FAnsiStringView Endpoint = ANSITEXTVIEW("/zen");
OutRequest << ANSITEXTVIEW("GET ") << Endpoint << ANSITEXTVIEW(" HTTP/1.1\r\n");
OutRequest << ANSITEXTVIEW("Upgrade: websocket\r\n");
OutRequest << ANSITEXTVIEW("Connection: upgrade\r\n");
OutRequest << ANSITEXTVIEW("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
if (Protocols.Num())
{
OutRequest << ANSITEXTVIEW("Sec-WebSocket-Protocol: ");
for (int32 Idx = 0, Count = Protocols.Num(); Idx < Count; ++Idx)
{
if (Idx > 0)
{
OutRequest << ANSITEXTVIEW(" ");
}
FTCHARToUTF8 Utf8(*Protocols[Idx]);
OutRequest << FAnsiStringView(Utf8.Get(), Utf8.Length());
}
OutRequest << ANSITEXTVIEW("\r\n");
}
OutRequest << ANSITEXTVIEW("\r\n");
}
bool ReceieveHttpUpgradeResponse(FSocket& Socket, FAnsiStringBuilderBase& Response)
{
static constexpr int32 MaxLineSize = 1024;
uint8 Buffer[MaxLineSize];
for(;;)
{
int32 BytesRead = 0;
Socket.Recv(Buffer, MaxLineSize, BytesRead, ESocketReceiveFlags::Peek);
FAnsiStringView BufferView(reinterpret_cast<const ANSICHAR*>(Buffer), BytesRead);
int32 LineEndIndex = 0;
if (BufferView.FindChar('\n', LineEndIndex) && BytesRead > LineEndIndex)
{
check(BufferView[LineEndIndex - 1] == '\r');
if (Socket.Recv(Buffer, LineEndIndex + 1, BytesRead, ESocketReceiveFlags::None) == false)
{
return false;
}
check(BytesRead == LineEndIndex + 1);
FAnsiStringView Line = BufferView.Left(LineEndIndex - 1);
if (Line.Len() == 0)
{
return true;
}
Response << Line << ANSITEXTVIEW("\n");
}
}
}
bool ValidateHandshake(FAnsiStringView UpgradeResponse, FString& OutReason)
{
int32 NewLineIndex = 0;
if (UpgradeResponse.FindChar('\n', NewLineIndex))
{
FAnsiStringView StatusLine = UpgradeResponse.Left(NewLineIndex);
UpgradeResponse.RightChopInline(NewLineIndex + 1);
if (StatusLine != ANSITEXTVIEW("HTTP/1.1 101 Switching Protocols"))
{
FUTF8ToTCHAR Str(StatusLine.GetData(), StatusLine.Len());
OutReason = Str.Get();
return false;
}
}
//TODO: Validate accept hash and protocols
return true;
}
////////////////////////////////////////////////////////////////////////////////
enum class EZenWebSocketMessageType : uint8
{
Invalid,
Notification,
Request,
StreamRequest,
Response,
StreamResponse,
StreamCompleteResponse,
Count
};
class FZenWebSocketMessage
{
struct FHeader
{
static constexpr uint32_t ExpectedMagic = 0x7a776d68;
uint64 MessageSize = 0;
uint32 Magic = ExpectedMagic;
uint32 CorrelationId = 0;
uint32 StatusCode = 200;
EZenWebSocketMessageType MessageType = EZenWebSocketMessageType::Invalid;
uint8 Reserved[3] = {0};
};
static_assert(sizeof(FHeader) == 24);
public:
static constexpr uint64 HeaderSize = sizeof(FHeader);
FZenWebSocketMessage() = default;
EZenWebSocketMessageType MessageType() const { return MessageHeader.MessageType; }
void SetMessageType(EZenWebSocketMessageType InType) { MessageHeader.MessageType = InType; }
uint64 MessageSize() const { return MessageHeader.MessageSize; }
uint64 TotalSize() const { return HeaderSize + MessageHeader.MessageSize; }
uint32 CorrelationId() const { return MessageHeader.CorrelationId; }
void SetCorrelationId(uint32 Id) { MessageHeader.CorrelationId = Id; }
FCbPackage& Body() { return MessageBody; }
FCbPackage&& ConsumeBody() { return MoveTemp(MessageBody); }
void SetBody(FCbPackage&& Body) { MessageBody = MoveTemp(Body); }
bool TryLoadHeader(FMemoryView HeaderView);
void Save(FArchive& Ar);
static uint32 GenerateCorrelationId()
{
return NextCorrelationId.IncrementExchange();
}
private:
static TAtomic<uint32> NextCorrelationId;
FHeader MessageHeader;
FCbPackage MessageBody;
};
bool FZenWebSocketMessage::TryLoadHeader(FMemoryView HeaderView)
{
if (HeaderView.GetSize() < HeaderSize)
{
return false;
}
FMemory::Memcpy(&MessageHeader, HeaderView.GetData(), HeaderSize);
const bool bIsValidHeader =
MessageHeader.Magic == FHeader::ExpectedMagic &&
MessageHeader.StatusCode > 0 &&
static_cast<uint8>(MessageHeader.MessageType) > static_cast<uint8>(EZenWebSocketMessageType::Invalid) &&
static_cast<uint8>(MessageHeader.MessageType) < static_cast<uint8>(EZenWebSocketMessageType::Count);
return bIsValidHeader;
}
void FZenWebSocketMessage::Save(FArchive& Ar)
{
Ar.Serialize(&MessageHeader, FZenWebSocketMessage::HeaderSize);
MessageBody.Save(Ar);
const int64 End = Ar.Tell();
MessageHeader.MessageSize = End - int64(FZenWebSocketMessage::HeaderSize);
MessageHeader.StatusCode = 200;
Ar.Seek(0);
Ar.Serialize(&MessageHeader, FZenWebSocketMessage::HeaderSize);
Ar.Seek(End);
}
TAtomic<uint32> FZenWebSocketMessage::NextCorrelationId{1};
////////////////////////////////////////////////////////////////////////////////
enum class EParseMessageStatus : uint32
{
Error,
Continue,
Done
};
class FZenMessageParser
{
public:
FZenMessageParser() = default;
EParseMessageStatus ParseMessage(FMemoryView MessageData, uint64& OutBytesParsed, FString& OutReason);
void Reset();
FZenWebSocketMessage&& ConsumeMessage() { return MoveTemp(Message); }
private:
FLargeMemoryWriter Ar;
FZenWebSocketMessage Message;
};
void FZenMessageParser::Reset()
{
Ar.Seek(0);
Message = FZenWebSocketMessage();
}
EParseMessageStatus FZenMessageParser::ParseMessage(FMemoryView MessageData, uint64& OutBytesParsed, FString& OutReason)
{
TRACE_CPUPROFILER_EVENT_SCOPE(FZenMessageParser::ParseMessage);
const uint64 PrevPos = Ar.Tell();
if (Ar.Tell() < FZenWebSocketMessage::HeaderSize)
{
FMemoryView HeaderView = MessageData.Left(FZenWebSocketMessage::HeaderSize - Ar.Tell());
Ar.Serialize(const_cast<void*>(HeaderView.GetData()), HeaderView.GetSize());
OutBytesParsed = Ar.Tell() - PrevPos;
MessageData.RightChopInline(HeaderView.GetSize());
if (Ar.Tell() < FZenWebSocketMessage::HeaderSize)
{
return EParseMessageStatus::Continue;
}
if (Message.TryLoadHeader(Ar.GetView()) == false)
{
OutReason = TEXT("Invalid websocket header");
return EParseMessageStatus::Error;
}
if (Message.MessageSize() == 0)
{
return EParseMessageStatus::Done;
}
//Ar.Reserve(Message.MessageSize());
}
if (MessageData.IsEmpty() == false)
{
const uint64 RemainingMessageSize = Message.MessageSize() - (Ar.Tell() - FZenWebSocketMessage::HeaderSize);
FMemoryView BodyView = MessageData.Left(RemainingMessageSize);
Ar.Serialize(const_cast<void*>(BodyView.GetData()), BodyView.GetSize());
OutBytesParsed = Ar.Tell() - PrevPos;
}
EParseMessageStatus Status = EParseMessageStatus::Continue;
if (Ar.Tell() == FZenWebSocketMessage::HeaderSize + Message.MessageSize())
{
Status = EParseMessageStatus::Done;
FMemoryView BodyView = Ar.GetView();
BodyView.RightChopInline(FZenWebSocketMessage::HeaderSize);
FMemoryReaderView BodyAr(BodyView);
FCbPackage Body;
{
TRACE_CPUPROFILER_EVENT_SCOPE(CbPackage::TryLoad);
if (Body.TryLoad(BodyAr))
{
Message.SetBody(MoveTemp(Body));
}
else
{
OutReason = TEXT("Invalid CbPackage");
Status = EParseMessageStatus::Error;
}
}
}
return Status;
}
////////////////////////////////////////////////////////////////////////////////
class FZenClient final : public IZenClient
{
public:
FZenClient();
virtual ~FZenClient();
virtual bool Connect(FStringView Host, int32 Port) override;
virtual void Disconnect() override;
virtual bool IsConnected() const override { return bConnected; }
virtual bool SendRequest(FCbPackage&& Request, FOnStreamResponse&& OnResponse) override;
virtual bool SendRequest(FCbObject&& Request, FOnStreamResponse&& OnResponse) override;
virtual bool SendStreamRequest(FCbObject&& Request, FOnStreamResponse&& OnStreamResponse) override;
private:
bool SendRequestMessage(FZenWebSocketMessage&& Request, FOnStreamResponse&& OnResponse);
void ReadMessageThreadEntry();
bool HandleMessage(FMemoryView MessageData, FString& OutReason);
void DispatchMessage(FZenWebSocketMessage&& Message);
void CloseConnection(FStringView Reason);
struct FPendingRequest
{
FPendingRequest(FOnStreamResponse&& Callback)
: OnStreamResponse(MoveTemp(Callback))
, StartTime(FPlatformTime::Seconds()) {}
FPendingRequest(const FPendingRequest&) = delete;
FPendingRequest& operator=(const FPendingRequest&) = delete;
FOnStreamResponse OnStreamResponse;
FGraphEventRef LastCompletionEvent;
const double StartTime;
};
struct FDispatchStreamResponseTask
{
FDispatchStreamResponseTask(TSharedPtr<FPendingRequest> InRequest, FCbPackage&& InResponse, bool bInComplete)
: Request(InRequest), Response(MoveTemp(InResponse)), bComplete(bInComplete)
{
check(InRequest);
}
FORCEINLINE TStatId GetStatId() const { return TStatId(); }
ENamedThreads::Type GetDesiredThread() { return ENamedThreads::AnyBackgroundHiPriTask; }
static ESubsequentsMode::Type GetSubsequentsMode() { return ESubsequentsMode::TrackSubsequents; }
FORCEINLINE void DoTask(ENamedThreads::Type CurrentThread, const FGraphEventRef& CompletionGraphEvent)
{
Request->OnStreamResponse({ MoveTemp(Response), bComplete ? EZenStreamStatus::Completed : EZenStreamStatus::Ok });
}
TSharedPtr<FPendingRequest> Request;
FCbPackage Response;
bool bComplete;
};
using FPendingRequests = TMap<uint32_t, TSharedPtr<FPendingRequest>>;
TUniquePtr<FSocket> Socket;
TFuture<void> ReadMessageThread;
TSharedPtr<FInternetAddr> HostAddr;
TUniquePtr<FZenMessageParser> Parser;
FPendingRequests PendingRequests;
FCriticalSection RequestsCriticalSection;
TArray<uint8> ReadBuffer;
TAtomic<bool> bConnected{false};
TAtomic<bool> bStopRequested{false};
};
FZenClient::FZenClient()
{
const int32 ReadBufferSize = 256 << 10;
Parser = MakeUnique<FZenMessageParser>();
ReadBuffer.SetNum(ReadBufferSize);
}
FZenClient::~FZenClient()
{
Disconnect();
}
bool FZenClient::Connect(FStringView Host, int32 Port)
{
ISocketSubsystem& SocketSubsystem = *ISocketSubsystem::Get();
TSharedPtr<FInternetAddr> Addr = SocketSubsystem.GetAddressFromString(FString(Host));
Addr->SetPort(Port);
Socket.Reset(SocketSubsystem.CreateSocket(NAME_Stream, TEXT("ZenClient"), Addr->GetProtocolType()));
if (Socket.IsValid() == false)
{
UE_LOG(LogZenClient, Error, TEXT("connect to server FAILED, reason 'failed to create socket'"));
return false;
}
int32 NewReadBufferSize = 0;
if (Socket->SetReceiveBufferSize(ReadBuffer.Num(), NewReadBufferSize))
{
UE_LOG(LogZenClient, Log, TEXT("read buffer size set to %dB"), NewReadBufferSize);
ReadBuffer.SetNum(NewReadBufferSize);
}
//Socket->SetLoopBackFastPath(true);
Socket->SetNoDelay(true);
if (Socket->Connect(*Addr) == false)
{
UE_LOG(LogZenClient, Error, TEXT("connect to server FAILED"));
return false;
}
HostAddr = Addr;
TArray<FString> Protocols;
TAnsiStringBuilder<1024> UpgradeRequest;
CreateHttpUpgradeRequest(Protocols, UpgradeRequest);
if (SendData(*Socket, FMemoryView(UpgradeRequest.GetData(), UpgradeRequest.Len())) == false)
{
UE_LOG(LogZenClient, Error, TEXT("connect to server FAILED, reason 'failed to send handshake request'"));
return false;
}
TAnsiStringBuilder<1024> UpgradeResponse;
if (ReceieveHttpUpgradeResponse(*Socket, UpgradeResponse) == false)
{
UE_LOG(LogZenClient, Error, TEXT("connect to server FAILED, reason 'failed to receive handshake request'"));
return false;
}
FString Reason;
if (ValidateHandshake(UpgradeResponse.ToView(), Reason) == false)
{
UE_LOG(LogZenClient, Error, TEXT("connect to server FAILED, reason '%s'"), *Reason);
return false;
}
const uint32 StackSize = 0;
ReadMessageThread = AsyncThread([this] { return ReadMessageThreadEntry(); }, StackSize, TPri_AboveNormal);
bConnected = true;
return true;
}
void FZenClient::Disconnect()
{
CloseConnection(TEXTVIEW(""));
}
bool FZenClient::SendRequestMessage(FZenWebSocketMessage&& RequestMessage, FOnStreamResponse&& OnResponse)
{
FLargeMemoryWriter Ar;
RequestMessage.Save(Ar);
TSharedPtr<FPendingRequest> PendingRequest = MakeShared<FPendingRequest>(MoveTemp(OnResponse));
{
FScopeLock _(&RequestsCriticalSection);
PendingRequests.Emplace(RequestMessage.CorrelationId(), MoveTemp(PendingRequest));
}
return SendData(*Socket, Ar.GetView());
}
bool FZenClient::SendRequest(FCbPackage&& Request, FOnStreamResponse&& OnResponse)
{
FZenWebSocketMessage RequestMessage;
RequestMessage.SetMessageType(EZenWebSocketMessageType::Request);
RequestMessage.SetCorrelationId(FZenWebSocketMessage::GenerateCorrelationId());
RequestMessage.SetBody(MoveTemp(Request));
return SendRequestMessage(MoveTemp(RequestMessage), MoveTemp(OnResponse));
}
bool FZenClient::SendRequest(FCbObject&& Request, FOnStreamResponse&& OnResponse)
{
FCbPackage Pkg;
Pkg.SetObject(MoveTemp(Request));
return SendRequest(MoveTemp(Pkg), MoveTemp(OnResponse));
}
bool FZenClient::SendStreamRequest(FCbObject&& Request, FOnStreamResponse&& OnStreamResponse)
{
FZenWebSocketMessage StreamRequestMessage;
StreamRequestMessage.SetMessageType(EZenWebSocketMessageType::StreamRequest);
StreamRequestMessage.SetCorrelationId(FZenWebSocketMessage::GenerateCorrelationId());
StreamRequestMessage.SetBody(FCbPackage(MoveTemp(Request)));
return SendRequestMessage(MoveTemp(StreamRequestMessage), MoveTemp(OnStreamResponse));
}
void FZenClient::ReadMessageThreadEntry()
{
FString Reason;
while(bStopRequested == false)
{
int32 BytesRead = 0;
{
if (Socket->Recv(ReadBuffer.GetData(), ReadBuffer.Num(), BytesRead) == false)
{
Reason = TEXT("receive message FAILED");
break;
}
}
if (HandleMessage(FMemoryView(ReadBuffer.GetData(), BytesRead), Reason) == false)
{
break;
}
}
if (bStopRequested == false)
{
CloseConnection(Reason);
}
}
bool FZenClient::HandleMessage(FMemoryView MessageData, FString& OutReason)
{
TRACE_CPUPROFILER_EVENT_SCOPE(ZenClient::HandleMessage);
while (MessageData.IsEmpty() == false)
{
uint64 BytesParsed = 0;
const EParseMessageStatus Status = Parser->ParseMessage(MessageData, BytesParsed, OutReason);
MessageData.RightChopInline(BytesParsed);
if (Status == EParseMessageStatus::Continue)
{
check(MessageData.IsEmpty());
return true;
}
if (Status == EParseMessageStatus::Error)
{
return false;
}
DispatchMessage(Parser->ConsumeMessage());
Parser->Reset();
}
return true;
}
void FZenClient::DispatchMessage(FZenWebSocketMessage&& Message)
{
TRACE_CPUPROFILER_EVENT_SCOPE(ZenClient::DispatchMessage);
const uint64 TotalMessageSize = Message.TotalSize();
switch(Message.MessageType())
{
case EZenWebSocketMessageType::Notification:
{
check(false);
break;
}
case EZenWebSocketMessageType::Response:
case EZenWebSocketMessageType::StreamResponse:
case EZenWebSocketMessageType::StreamCompleteResponse:
{
const bool bIsStreamComplete =
Message.MessageType() == EZenWebSocketMessageType::StreamCompleteResponse ||
Message.MessageType() == EZenWebSocketMessageType::Response;
TGraphTask<FDispatchStreamResponseTask>* Task = nullptr;
double StartTime = 0.0;
{
FScopeLock _(&RequestsCriticalSection);
TSharedPtr<FPendingRequest> PendingRequest = PendingRequests.FindChecked(Message.CorrelationId());
if (bIsStreamComplete)
{
PendingRequests.Remove(Message.CorrelationId());
StartTime = PendingRequest->StartTime;
}
FGraphEventArray CompletionEvents;
if (PendingRequest->LastCompletionEvent.IsValid())
{
CompletionEvents.Add(PendingRequest->LastCompletionEvent);
}
Task = TGraphTask<FDispatchStreamResponseTask>::CreateTask(&CompletionEvents).ConstructAndHold(PendingRequest, Message.ConsumeBody(), bIsStreamComplete);
PendingRequest->LastCompletionEvent = Task->GetCompletionEvent();
}
Task->Unlock();
if (bIsStreamComplete)
{
TRACE_COUNTER_ADD(ZenClient_StreamedBytes, TotalMessageSize);
TRACE_COUNTER_ADD(ZenClient_StreamedSeconds, FPlatformTime::Seconds() - StartTime);
}
break;
}
default:
UE_LOG(LogZenClient, Warning, TEXT("dispatch message FAILED, reason 'invalid message type (%d)'"), int(Message.MessageType()));
break;
}
}
void FZenClient::CloseConnection(FStringView Reason)
{
if (bConnected.Exchange(false) == true)
{
UE_CLOG(Reason.IsEmpty() == false, LogZenClient, Warning, TEXT("closing connection, reason '%s'"), *FString(Reason));
check(PendingRequests.Num() == 0);
bStopRequested = true;
if (Socket)
{
Socket->Close();
}
ReadMessageThread.Wait();
Socket.Reset();
}
}
TUniquePtr<IZenClient> IZenClient::Create()
{
return MakeUnique<FZenClient>();
}
////////////////////////////////////////////////////////////////////////////////
class FZenClientPool final : public IZenClientPool
{
public:
FZenClientPool() = default;
virtual ~FZenClientPool();
virtual bool Connect(FStringView Host, int32 Port, int32 PoolSize = 8) override;
virtual void Disconnect() override;
virtual inline bool IsConnected() const override { return bConnected; }
virtual bool SendRequest(FCbPackage&& Request, FOnStreamResponse&& OnResponse) override;
virtual bool SendRequest(FCbObject&& Request, FOnStreamResponse&& OnResponse) override;
virtual bool SendStreamRequest(FCbObject&& Request, FOnStreamResponse&& OnStreamResponse) override;
private:
#define UE_WITH_ZEN_NON_BLOCKING_POOL (1)
#if UE_WITH_ZEN_NON_BLOCKING_POOL
struct FPoolEntry
{
std::atomic<uint8> bAllocated{0u};
FZenClient Client;
};
#else
struct FPoolEntry
{
FZenClient Client;
FPoolEntry* Next = nullptr;
};
#endif
using FPool = TArray<FPoolEntry>;
FPoolEntry* GetClient();
void ReleaseClient(FPoolEntry* Entry);
FPool Pool;
#if UE_WITH_ZEN_NON_BLOCKING_POOL
TAtomic<int32> PoolEntryStartIndex{0};
#else
FPoolEntry* PoolHead = nullptr;
FPoolEntry* PoolTail = nullptr;
#endif
FCriticalSection PoolCS;
FEventRef PoolAvailable;
TAtomic<bool> bConnected{false};
};
FZenClientPool::~FZenClientPool()
{
Disconnect();
}
bool FZenClientPool::Connect(FStringView Host, int32 Port, int32 PoolSize)
{
check(bConnected == false);
Pool.SetNum(PoolSize);
for (FPoolEntry& Entry : Pool)
{
if (Entry.Client.Connect(Host, Port) == false)
{
return false;
}
}
bConnected = true;
#if UE_WITH_ZEN_NON_BLOCKING_POOL
for (FPoolEntry& Entry : Pool)
{
Entry.bAllocated = 0u;
}
#else
PoolHead = PoolTail = &Pool[0];
for (int32 Index = 1; Index < PoolSize; ++Index)
{
Pool[Index].Next = PoolTail;
PoolTail = &Pool[Index];
}
#endif
return bConnected;
}
void FZenClientPool::Disconnect()
{
TRACE_CPUPROFILER_EVENT_SCOPE(ZenClientPool::Disconnect);
if (bConnected.Exchange(false) == true)
{
for (FPoolEntry& Entry : Pool)
{
Entry.Client.Disconnect();
}
Pool.Empty();
}
}
bool FZenClientPool::SendRequest(FCbPackage&& Request, FOnStreamResponse&& OnResponse)
{
TRACE_CPUPROFILER_EVENT_SCOPE(ZenClientPool::SendRequest);
bool bResult = false;
if (FPoolEntry* Entry = GetClient())
{
bResult = Entry->Client.SendRequest(MoveTemp(Request), MoveTemp(OnResponse));
ReleaseClient(Entry);
}
return bResult;
}
bool FZenClientPool::SendRequest(FCbObject&& Request, FOnStreamResponse&& OnResponse)
{
FCbPackage Pkg;
Pkg.SetObject(MoveTemp(Request));
return SendRequest(MoveTemp(Pkg), MoveTemp(OnResponse));
}
bool FZenClientPool::SendStreamRequest(FCbObject&& Request, FOnStreamResponse&& OnStreamResponse)
{
bool bResult = false;
if (FPoolEntry* Entry = GetClient())
{
bResult = Entry->Client.SendStreamRequest(MoveTemp(Request), MoveTemp(OnStreamResponse));
ReleaseClient(Entry);
}
return bResult;
}
FZenClientPool::FPoolEntry* FZenClientPool::GetClient()
{
TRACE_CPUPROFILER_EVENT_SCOPE(ZenClientPool::WaitForConnection);
#if UE_WITH_ZEN_NON_BLOCKING_POOL
const int32 PoolSize = Pool.Num();
int32 EntryIndex = PoolEntryStartIndex.IncrementExchange() % PoolSize;
int32 SpinCount = 0;
for (;;)
{
FPoolEntry& Entry = Pool[EntryIndex];
if (Entry.bAllocated.load(std::memory_order_relaxed) == false)
{
uint8 Expected = 0u;
if (Entry.bAllocated.compare_exchange_strong(Expected, 1u))
{
return &Entry;
}
}
EntryIndex = (EntryIndex + 1) % PoolSize;
if (++SpinCount >= PoolSize * 2)
{
SpinCount = 0;
FPlatformProcess::Sleep(0.01f);
}
}
#else
for(;;)
{
{
FScopeLock _(&PoolCS);
if (FPoolEntry* Entry = PoolHead)
{
if (Entry->Next)
{
PoolHead = Entry->Next;
}
else
{
PoolHead = PoolTail = nullptr;
}
Entry->Next = nullptr;
return Entry;
}
}
PoolAvailable->Wait();
}
#endif
}
void FZenClientPool::ReleaseClient(FZenClientPool::FPoolEntry* Entry)
{
#if UE_WITH_ZEN_NON_BLOCKING_POOL
uint8 Expected = 1u;
Entry->bAllocated.compare_exchange_strong(Expected, 0u);
#else
check(Entry->Next == nullptr);
{
FScopeLock _(&PoolCS);
if (PoolTail)
{
check(PoolTail->Next == nullptr);
PoolTail->Next = Entry;
}
else
{
PoolHead = Entry;
}
PoolTail = Entry;
}
PoolAvailable->Trigger();
#endif
}
TUniquePtr<IZenClientPool> IZenClientPool::Create()
{
return MakeUnique<FZenClientPool>();
}
} // namespace UE::Zen