Files
UnrealEngine/Engine/Source/Programs/UnrealBuildAccelerator/Common/Private/UbaNetworkClient.cpp
2025-05-18 13:04:45 +08:00

851 lines
24 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "UbaNetworkClient.h"
#include "UbaConfig.h"
#include "UbaCrypto.h"
#include "UbaBinaryReaderWriter.h"
#include "UbaNetworkBackendTcp.h"
#include "UbaNetworkMessage.h"
#include <stdlib.h>
#include <stdio.h>
namespace uba
{
void NetworkClientCreateInfo::Apply(Config& config, const tchar* tableName)
{
const ConfigTable* tablePtr = config.GetTable(tableName);
if (!tablePtr)
return;
const ConfigTable& table = *tablePtr;
table.GetValueAsU32(desiredConnectionCount, TC("DesiredConnectionCount"));
}
NetworkClient::NetworkClient(bool& outCtorSuccess, const NetworkClientCreateInfo& info, const tchar* name)
: WorkManagerImpl(info.workerCount == ~0u ? GetLogicalProcessorCount() : info.workerCount, TC("UbaWrk/NetClnt"))
, m_logWriter(info.logWriter)
, m_logger(info.logWriter, SetGetPrefix(name))
, m_isConnected(true)
, m_isOrWasConnected(true)
{
outCtorSuccess = true;
u32 fixedSendSize = Max(info.sendSize, (u32)(4*1024));
fixedSendSize = Min(fixedSendSize, (u32)(SendMaxSize));
if (info.sendSize != fixedSendSize)
m_logger.Detail(TC("Adjusted msg size to %u to stay inside limits"), fixedSendSize);
m_sendSize = fixedSendSize;
m_receiveTimeoutSeconds = info.receiveTimeoutSeconds;
m_desiredConnectionCount = info.desiredConnectionCount;
m_connectionsIt = m_connections.end();
if (info.cryptoKey128)
{
m_cryptoKey = Crypto::CreateKey(m_logger, info.cryptoKey128);
if (m_cryptoKey == InvalidCryptoKey)
outCtorSuccess = false;
}
#if UBA_TRACK_NETWORK_TIMES
m_startTime = GetTime();
#endif
}
NetworkClient::~NetworkClient()
{
UBA_ASSERTF(m_connections.empty(), TC("Client still has connections (%llu). %s"), m_connections.size(), m_isDisconnecting ? TC("") : TC("Disconnect has not been called"));
if (m_cryptoKey)
Crypto::DestroyKey(m_cryptoKey);
}
bool NetworkClient::Connect(NetworkBackend& backend, const tchar* ip, u16 port, bool* timedOut)
{
return backend.Connect(m_logger, ip, [&](void* connection, const sockaddr& remoteSocketAddr, bool* timedOut)
{
return AddConnection(backend, connection, timedOut);
}, port, timedOut);
}
bool NetworkClient::AddConnection(NetworkBackend& backend, void* backendConnection, bool* timedOut)
{
struct RecvContext
{
RecvContext(NetworkClient& c, NetworkBackend& b, void* bc) : client(c), backend(b), backendConnection(bc), recvEvent(true), exitScopeEvent(true)
{
error = 255;
}
~RecvContext()
{
if (error)
backend.Shutdown(backendConnection);
exitScopeEvent.IsSet(~0u);
}
NetworkClient& client;
NetworkBackend& backend;
void* backendConnection;
Event recvEvent;
Event exitScopeEvent;
Atomic<u8> error;
};
backend.SetRecvTimeout(backendConnection, m_receiveTimeoutSeconds*1000, nullptr, nullptr);
RecvContext rc(*this, backend, backendConnection);
// The only way out of this function is to get a call to one of the below callbacks since exitScopeEvent must be set.
backend.SetDisconnectCallback(backendConnection, &rc, [](void* context, const Guid& connectionUid, void* connection)
{
auto& rc = *(RecvContext*)context;
if (rc.error == 0)
rc.error = 4;
rc.recvEvent.Set();
rc.exitScopeEvent.Set();
});
backend.SetRecvCallbacks(backendConnection, &rc, 1 + sizeof(Guid), [](void* context, const Guid& connectionUid, u8* headerData, void*& outBodyContext, u8*& outBodyData, u32& outBodySize)
{
auto& rc = *(RecvContext*)context;
rc.error = *headerData;
Guid serverUid = *(Guid*)(headerData+1);
if (serverUid == Guid())
rc.error = 5;
if (!rc.error)
{
SCOPED_FUTEX(rc.client.m_serverUidLock, lock);
if (rc.client.m_serverUid == Guid())
rc.client.m_serverUid = serverUid;
else if (rc.client.m_serverUid != serverUid) // Seems like two different servers tried to connect to this client.. keep the first one and ignore the others
rc.error = 6;
}
if (!rc.error)
if (!rc.client.ConnectedCallback(rc.backend, rc.backendConnection))
rc.error = 4;
if (rc.error != 0)
return false;
rc.recvEvent.Set();
rc.exitScopeEvent.Set();
return true;
}, nullptr, TC("Connecting"));
StackBinaryWriter<1024> handshakeData;
if (m_cryptoKey)
{
handshakeData.WriteByte(1);
// If we have a crypto key we start by sending a predefined 128 bytes blob that is encrypted.
// If server decrypt it to the same blob, we're good on that part
u8* encryptedBuffer = handshakeData.AllocWrite(sizeof(EncryptionHandshakeString));//[1024];
memcpy(encryptedBuffer, EncryptionHandshakeString, sizeof(EncryptionHandshakeString));
if (!Crypto::Encrypt(m_logger, m_cryptoKey, encryptedBuffer, sizeof(EncryptionHandshakeString)))
return false;
}
handshakeData.WriteU32(SystemNetworkVersion);
handshakeData.WriteBytes(&m_uid, sizeof(m_uid));
NetworkBackend::SendContext sendContext;
if (!backend.Send(m_logger, backendConnection, handshakeData.GetData(), u32(handshakeData.GetPosition()), sendContext, TC("Handshake")))
return false;
if (!rc.recvEvent.IsSet(20*1000)) // This can not happen. Since both callbacks are using rc we can't leave this function until we know we are not in the callbacks
return m_logger.Info(TC("Timed out after 20 seconds waiting for connection response from server.")).ToFalse();
m_isOrWasConnected.Set();
if (rc.error == 1) // Bad version
return m_logger.Error(TC("Version mismatch with server"));
if (rc.error == 2)
return m_logger.Error(TC("Server failed to receive client uid"));
if (rc.error == 3)
{
if (!timedOut)
return m_logger.Error(TC("Server does not allow new clients"));
*timedOut = true;
Sleep(1000); // Kind of ugly, but we want the retry-clients to keep retrying so we pretend it is a timeout
return false;
}
if (rc.error == 4)
{
if (!timedOut)
return m_logger.Error(TC("Server disconnected"));
*timedOut = true;
Sleep(1000); // Kind of ugly, but we want the retry-clients to keep retrying so we pretend it is a timeout
return false;
}
if (rc.error == 5)
return m_logger.Error(TC("A connection to a server with uid zero was requested."));
if (rc.error == 6)
return m_logger.Warning(TC("A connection to a server with different uid was requested. Ignore"));
if (m_connectionCount.fetch_add(1) != 0)
return true;
SCOPED_FUTEX(m_onConnectedFunctionsLock, lock);
for (auto& f : m_onConnectedFunctions)
f();
m_isConnected.Set();
lock.Leave();
return true;
}
constexpr u32 SendHeaderSize = 6;
constexpr u32 ReceiveHeaderSize = 5;
void NetworkClient::DisconnectCallback(void* context, const Guid& connectionUid, void* connection)
{
auto& c = *(Connection*)context;
c.owner.OnDisconnected(c, 1);
c.disconnectedEvent.Set();
}
bool NetworkClient::ConnectedCallback(NetworkBackend& backend, void* backendConnection)
{
SCOPED_WRITE_LOCK(m_connectionsLock, lock);
if (m_isDisconnecting)
return false;
m_connections.emplace_back(*this);
Connection* connection = &m_connections.back();
connection->backendConnection = backendConnection;
connection->connected = 1;
connection->backend = &backend;
SCOPED_FUTEX(m_connectionsItLock, l); // Take this lock to make sure callbacks are set before connection is used
m_connectionsIt = --m_connections.end();
m_logger.Detail(TC("Connected to server... (0x%p)"), backendConnection);
lock.Leave();
backend.SetRecvTimeout(backendConnection, m_receiveTimeoutSeconds*1000, nullptr, nullptr);
backend.SetDisconnectCallback(backendConnection, connection, DisconnectCallback);
backend.SetRecvCallbacks(backendConnection, connection, ReceiveHeaderSize, ReceiveResponseHeader, ReceiveResponseBody, TC("ReceiveMessageResponse"));
return true;
}
bool NetworkClient::ReceiveResponseHeader(void* context, const Guid& connectionUid, u8* headerData, void*& outBodyContext, u8*& outBodyData, u32& outBodySize)
{
auto& connection = *(Connection*)context;
auto& client = connection.owner;
#if UBA_TRACK_NETWORK_TIMES
connection.lastHeaderRecvTime = GetTime();
#endif
u16 messageId = u16(headerData[0] << 8) | u16((*(u32*)(headerData + 1) & 0xff000000) >> 24);
u32 messageSize = *(u32*)(headerData + 1) & 0x00FFFFFF;
if (messageSize == MessageKeepAliveSize) // Keep alive message
{
u8 data[6] = { SystemMessageType_KeepAlive, 0, 1, 0, 0, 0 };
NetworkBackend::SendContext sendContext;
return connection.backend->Send(client.m_logger, connection.backendConnection, data, sizeof(data), sendContext, TC("KeepAliveNoResponse"));
}
NetworkMessage* msg;
{
LOG_STALL_SCOPE(client.m_logger, 5, TC("Took more than %s to get message from id"));
SCOPED_READ_LOCK(client.m_activeMessagesLock, lock);
if (!connection.connected)
return false;
if (messageId >= client.m_activeMessages.size())
return client.m_logger.Error(TC("Message id %u is higher than max %u"), messageId, u32(client.m_activeMessages.size()));
msg = client.m_activeMessages[messageId];
}
if (!msg)
return false;
if (messageSize == MessageErrorSize || messageSize == MessageErrorSize - ReceiveHeaderSize) // ReceiveHeaderSize is removed from size in server send
{
msg->m_error = 1;
msg->Done();
return true;
}
if (!messageSize)
{
++client.m_recvCount;
msg->Done();
return true;
}
if (messageSize > msg->m_responseCapacity)
{
u8 serviceIdAndMessageType = msg->m_sendWriter->GetData()[0];
u8 serviceId = serviceIdAndMessageType >> 6;
u8 messageType = serviceIdAndMessageType & 0b111111;
client.m_logger.Error(TC("Message size is %u but reader capacity is only %u (serviceId %u, messageType %u)"), messageSize, msg->m_responseCapacity, u32(serviceId), u32(messageType));
msg->m_error = 1;
msg->Done();
return false;
}
msg->m_responseSize = messageSize;
outBodyContext = msg;
outBodyData = (u8*)msg->m_response;
outBodySize = messageSize;
++client.m_recvCount;
client.m_recvBytes += ReceiveHeaderSize + messageSize;
return true;
}
bool NetworkClient::ReceiveResponseBody(void* context, bool recvError, u8* headerData, void* bodyContext, u8* bodyData, u32 bodySize)
{
auto& msg = *(NetworkMessage*)bodyContext;
if (recvError)
msg.m_error = 2;
#if UBA_TRACK_NETWORK_TIMES
if (msg.m_connection)
msg.m_connection->lastBodyRecvTime = GetTime();
#endif
msg.Done();
return true;
}
void NetworkClient::Disconnect(bool flushWork)
{
auto fg = MakeGuard([&]()
{
if (!flushWork)
return;
if (FlushWork(30*1000))
return;
m_logger.Error(TC("NetworkClient has waited 30 seconds for all work to finish... something is stuck"));
PrintAllCallstacks(m_logger);
});
{
SCOPED_READ_LOCK(m_connectionsLock, lock);
if (m_isDisconnecting)
return;
m_isDisconnecting = true;
for (auto& c : m_connections)
{
OnDisconnected(c, 0);
c.disconnectedEvent.IsSet(~0u);
}
}
{
SCOPED_WRITE_LOCK(m_connectionsLock, lock2);
for (auto& c : m_connections)
c.backend->DeleteConnection(c.backendConnection);
m_connections.clear();
m_connectionsIt = m_connections.end();
}
}
bool NetworkClient::StartListen(NetworkBackend& backend, u16 port, const tchar* ip)
{
return backend.StartListen(m_logger, port, ip, [&](void* connection, const sockaddr& remoteSockAddr)
{
return AddConnection(backend, connection, nullptr);
});
}
bool NetworkClient::SetConnectionCount(u32 count)
{
StackBinaryWriter<64> writer;
NetworkMessage msg(*this, SystemServiceId, SystemMessageType_SetConnectionCount, writer); // Connection count
writer.WriteU32(count);
return msg.Send();
}
bool NetworkClient::SendKeepAlive()
{
StackBinaryWriter<64> writer;
NetworkMessage msg(*this, SystemServiceId, SystemMessageType_KeepAlive, writer);
return msg.Send();
}
bool NetworkClient::FetchConfig(Config& config)
{
StackBinaryWriter<64> writer;
NetworkMessage msg(*this, SystemServiceId, SystemMessageType_FetchConfig, writer);
writer.WriteByte(0); // Need to have a body
StackBinaryReader<SendMaxSize> reader;
if (!msg.Send(reader))
return false;
u64 textLen = reader.Read7BitEncoded();
return config.LoadFromText(m_logger, (const char*)reader.GetPositionData(), textLen);
}
bool NetworkClient::IsConnected(u32 waitTimeoutMs)
{
return m_isConnected.IsSet(waitTimeoutMs);
}
bool NetworkClient::IsOrWasConnected(u32 waitTimeoutMs)
{
return m_isOrWasConnected.IsSet(waitTimeoutMs);
}
void NetworkClient::PrintSummary(Logger& logger)
{
SCOPED_READ_LOCK(m_connectionsLock, lock);
u32 connectionsCount = u32(m_connections.size());
lock.Leave();
logger.Info(TC(" ----- Uba client stats summary ------"));
logger.Info(TC(" SendTotal %8u %9s"), m_sendTimer.count.load(), TimeToText(m_sendTimer.time).str);
logger.Info(TC(" Bytes %9s"), BytesToText(m_sendBytes).str);
logger.Info(TC(" RecvTotal %8u %9s"), m_recvCount.load(), BytesToText(m_recvBytes).str);
if (m_cryptoKey)
{
logger.Info(TC(" EncryptTotal %8u %9s"), m_encryptTimer.count.load(), TimeToText(m_encryptTimer.time).str);
logger.Info(TC(" DecryptTotal %8u %9s"), m_decryptTimer.count.load(), TimeToText(m_decryptTimer.time).str);
}
logger.Info(TC(" MaxActiveMessages %8u"), m_activeMessageIdMax);
logger.Info(TC(" Connections %8u"), connectionsCount);
logger.Info(TC(" SendSize Set/Max %9s %9s"), BytesToText(m_sendSize).str, BytesToText(SendMaxSize).str);
logger.Info(TC(""));
}
void NetworkClient::ValidateNetwork(Logger& logger, bool full)
{
UnorderedMap<NetworkBackend*, Vector<void*>> backends;
u32 connectionsCount;
{
LogStallScope lss(logger, LogEntryType_Info, 1, TC(" Connections lock took %s"));
SCOPED_WRITE_LOCK(m_connectionsLock, lock);
lss.Leave();
connectionsCount = u32(m_connections.size());
u32 connectionIndex = 0;
for (auto& c : m_connections)
{
if (full)
{
#if UBA_TRACK_NETWORK_TIMES
logger.Info(TC(" Connection %u - LastSend %s, LastHeaderRecv: %s, LastBodyRecv: %s"), connectionIndex++, TimeToText(c.lastSendTime - m_startTime).str, TimeToText(c.lastHeaderRecvTime - m_startTime).str, TimeToText(c.lastBodyRecvTime - m_startTime).str);
#else
logger.Info(TC(" Connection %u"), connectionIndex++);
#endif
}
backends[c.backend].push_back(c.backendConnection);
}
}
{
LogStallScope lss(logger, LogEntryType_Info, 1, TC(" ConnectionsIterator lock took %s"));
SCOPED_FUTEX(m_connectionsItLock, lock);
}
if (full)
{
LogStallScope lss(logger, LogEntryType_Info, 1, TC(" ActiveMessages lock took %s"));
SCOPED_WRITE_LOCK(m_activeMessagesLock, lock);
lss.Leave();
logger.Info(TC(" Active messages"));
u64 now = GetTime();
for (auto m : m_activeMessages)
if (m)
{
u64 sendTime = 0;
#if UBA_TRACK_NETWORK_TIMES
sendTime = m->m_sendTime;
#endif
logger.Info(TC(" %s (%u): %s"), MessageToString(m->GetServiceId(), m->GetMessageType()).data, u32(m->m_id), TimeToText(now - sendTime).str);
}
}
for (auto& kv : backends)
kv.first->Validate(logger, kv.second, full);
}
void NetworkClient::RegisterOnConnected(const OnConnectedFunction& function)
{
SCOPED_FUTEX(m_onConnectedFunctionsLock, lock);
m_onConnectedFunctions.push_back(function);
if (!m_isConnected.IsSet(0))
return;
lock.Leave();
function();
}
void NetworkClient::RegisterOnDisconnected(const OnDisconnectedFunction& function)
{
SCOPED_WRITE_LOCK(m_onDisconnectedFunctionsLock, lock);
m_onDisconnectedFunctions.push_back(function);
}
void NetworkClient::RegisterOnVersionMismatch(OnVersionMismatchFunction&& function)
{
m_versionMismatchFunction = std::move(function);
}
void NetworkClient::InvokeVersionMismatch(const CasKey& exeKey, const CasKey& dllKey)
{
if (m_versionMismatchFunction)
m_versionMismatchFunction(exeKey, dllKey);
}
u64 NetworkClient::GetMessageHeaderSize()
{
return SendHeaderSize;
}
u64 NetworkClient::GetMessageReceiveHeaderSize()
{
return ReceiveHeaderSize;
}
u64 NetworkClient::GetMessageMaxSize()
{
return m_sendSize;
}
NetworkBackend* NetworkClient::GetFirstConnectionBackend()
{
SCOPED_READ_LOCK(m_connectionsLock, connectionLock);
if (m_connections.empty())
return nullptr;
return m_connections.front().backend;
}
void NetworkClient::OnDisconnected(Connection& connection, u32 reason)
{
if (connection.connected.exchange(0) == 1)
{
m_logger.Detail(TC("Disconnected from server... (0x%p) (%u)"), connection.backendConnection, reason);
connection.backend->Shutdown(connection.backendConnection);
if (m_connectionCount.fetch_sub(1) == 1)
{
m_isConnected.Reset();
SCOPED_READ_LOCK(m_onDisconnectedFunctionsLock, lock);
for (auto& f : m_onDisconnectedFunctions)
f();
}
}
SCOPED_WRITE_LOCK(m_activeMessagesLock, lock);
for (auto m : m_activeMessages)
{
if (m && m->m_connection == &connection)
{
m->m_error = 3;
m->Done(false);
}
}
}
bool NetworkClient::Send(NetworkMessage& message, void* response, u32 responseCapacity, bool async)
{
SCOPED_READ_LOCK(m_connectionsLock, connectionLock);
SCOPED_FUTEX(m_connectionsItLock, connectionItLock);
if (m_connectionsIt == m_connections.end())
{
if (m_isDisconnecting)
message.m_error = 11;
else if (!m_connections.empty())
message.m_error = 12; // should never happen
else
message.m_error = 6;
if (async)
message.Done(false);
return false;
}
// Skip connections that has disconnected
Connection* connectionPtr = &*m_connectionsIt;
Connection* connectionPtrStart = connectionPtr;
while (true)
{
++m_connectionsIt;
if (m_connectionsIt == m_connections.end())
m_connectionsIt = m_connections.begin();
if (connectionPtr->connected)
break;
connectionPtr = &*m_connectionsIt;
if (connectionPtr == connectionPtrStart)
break;
}
connectionItLock.Leave();
connectionLock.Leave();
Connection& connection = *connectionPtr;
message.m_response = response;
message.m_responseCapacity = responseCapacity;
message.m_connection = &connection;
BinaryWriter& writer = *message.m_sendWriter;
u16 messageId = 0;
Event gotResponse;
if (response)
{
if (!async)
{
if (!gotResponse.Create(true))
{
m_logger.Error(TC("Failed to create event, this should not happen?!?"));
message.m_error = 13;
OnDisconnected(connection, 13);
return false;
}
}
LOG_STALL_SCOPE(m_logger, 5, TC("Took more than %s to get message id"));
while (true)
{
SCOPED_WRITE_LOCK(m_activeMessagesLock, lock);
if (m_availableMessageIds.empty())
{
if (!connection.connected)
{
message.m_error = 7;
if (async)
message.Done(false);
return false;
}
if (m_activeMessageIdMax == 65534)
{
lock.Leave();
m_logger.Info(TC("Reached max limit of active message ids (65534). Waiting 1 second"));
Sleep(100u + u32(rand()) % 900u);
continue;
}
messageId = m_activeMessageIdMax++;
if (m_activeMessages.size() < m_activeMessageIdMax)
m_activeMessages.resize(size_t(m_activeMessageIdMax) + 1024);
}
else
{
messageId = m_availableMessageIds.back();
m_availableMessageIds.pop_back();
}
UBA_ASSERT(!m_activeMessages[messageId]);
m_activeMessages[messageId] = &message;
message.m_id = messageId;
message.m_sendContext.flags = NetworkBackend::SendFlags_ExternalWait;
if (!async)
{
UBA_ASSERT(!message.m_doneFunc);
message.m_doneUserData = &gotResponse;
message.m_doneFunc = [](bool error, void* userData) { ((Event*)userData)->Set(); };
}
break;
}
}
UBA_ASSERT(messageId < 65535);
u32 sendSize = u32(writer.GetPosition());
u8* data = writer.GetData();
data[1] = messageId >> 8;
u32 dataSize = sendSize - 6;
UBA_ASSERTF(dataSize || data[0] == 1, TC("NetworkMessage must have data size of at least 1."));
*(u32*)(data + 2) = dataSize | u32(messageId) << 24;
//m_logger.Debug(TC("Send: %u, %u, %u, %u"), data[0], data[1], data[2], sendSize - 7);
u32 bodySize = sendSize - SendHeaderSize;
if (m_cryptoKey && bodySize)
{
TimerScope ts(m_encryptTimer);
if (!Crypto::Encrypt(m_logger, m_cryptoKey, data + SendHeaderSize, bodySize))
{
message.m_error = 8;
OnDisconnected(connection, 8);
return false;
}
}
m_sendBytes += sendSize;
{
TimerScope ts(m_sendTimer);
#if UBA_TRACK_NETWORK_TIMES
connection.lastSendTime = ts.start;
message.m_sendTime = ts.start;
#endif
if (!connection.backend->Send(m_logger, connection.backendConnection, data, sendSize, message.m_sendContext, TC("Message")))
{
message.m_error = 9;
OnDisconnected(connection, 9);
return false;
}
}
if (async)
return true;
if (response)
{
u64 waitStart = GetTime();
u32 timeoutMs = 10 * 60 * 1000;
if (!gotResponse.IsSet(timeoutMs))
{
m_logger.Error(TC("Timed out after %s waiting for message response from server. (%s)"), TimeToText(GetTime() - waitStart, true).str, MessageToString(message.GetServiceId(), message.GetMessageType()).data);
message.m_error = 4;
OnDisconnected(connection, 4);
}
else if (m_cryptoKey && !message.m_error && message.m_responseSize)
{
TimerScope ts(m_decryptTimer);
if (!Crypto::Decrypt(m_logger, m_cryptoKey, (u8*)message.m_response, message.m_responseSize))
{
message.m_error = 5;
OnDisconnected(connection, 5);
}
}
}
return !message.m_error;
}
const tchar* NetworkClient::SetGetPrefix(const tchar* originalPrefix)
{
CreateGuid(m_uid);
StringBuffer<512> b;
b.Appendf(TC("%s (%s)"), originalPrefix, GuidToString(m_uid).str);
m_prefix = b.data;
return m_prefix.c_str();
}
NetworkMessage::NetworkMessage(NetworkClient& client, u8 serviceId, u8 messageType, BinaryWriter& sendWriter)
{
Init(client, serviceId, messageType, sendWriter);
}
NetworkMessage::~NetworkMessage()
{
UBA_ASSERT(!m_id);
}
void NetworkMessage::Init(NetworkClient& client, u8 serviceId, u8 messageType, BinaryWriter& sendWriter)
{
m_client = &client;
m_sendWriter = &sendWriter;
// Header (SendHeaderSize):
// 1 byte - 2 bits for serviceid, 6 bits for messagetype
// 2 byte - message id
// 3 byte - message size
UBA_ASSERT(sendWriter.GetPosition() == 0);
UBA_ASSERT((serviceId & 0b11) == serviceId);
UBA_ASSERT((messageType & 0b111111) == messageType);
u8* data = sendWriter.AllocWrite(SendHeaderSize);
data[0] = u8(serviceId << 6) | messageType;
}
bool NetworkMessage::Send()
{
return m_client->Send(*this, nullptr, 0, false);
}
bool NetworkMessage::Send(BinaryReader& response)
{
if (!m_client->Send(*this, (u8*)response.GetPositionData(), u32(response.GetLeft()), false))
return false;
response.SetSize(response.GetPosition() + m_responseSize);
return true;
}
bool NetworkMessage::Send(BinaryReader& response, Timer& outTimer)
{
TimerScope ts(outTimer);
bool res = Send(response);
return res;
}
bool NetworkMessage::SendAsync(BinaryReader& response, DoneFunc* func, void* userData)
{
UBA_ASSERT(!m_doneFunc);
m_doneFunc = func;
m_doneUserData = userData;
return m_client->Send(*this, (u8*)response.GetPositionData(), u32(response.GetLeft()), true);
}
bool NetworkMessage::ProcessAsyncResults(BinaryReader& response)
{
if (m_error)
return false;
if (m_client->m_cryptoKey)
{
UBA_ASSERT(!response.GetPosition());
TimerScope ts(m_client->m_decryptTimer);
if (!Crypto::Decrypt(m_client->m_logger, m_client->m_cryptoKey, (u8*)m_response, m_responseSize))
{
m_error = 10;
return false;
}
}
response.SetSize(response.GetPosition() + m_responseSize);
return true;
}
u8 NetworkMessage::GetServiceId()
{
return m_sendWriter ? m_sendWriter->GetData()[0] >> 6 : 0;
}
u8 NetworkMessage::GetMessageType()
{
return m_sendWriter ? m_sendWriter->GetData()[0] & 63 : 0;
}
void NetworkMessage::Done(bool shouldLock)
{
bool hasId = false;
auto returnId = [&]()
{
if (m_id)
{
m_client->m_availableMessageIds.push_back(m_id);
m_client->m_activeMessages[m_id] = nullptr;
m_id = 0;
hasId = true;
}
};
if (shouldLock)
{
SCOPED_WRITE_LOCK(m_client->m_activeMessagesLock, lock);
returnId();
}
else
{
returnId();
}
if (hasId)
m_doneFunc(m_error != 0, m_doneUserData);
}
}