Files
UnrealEngine/Engine/Source/Programs/UnrealBuildAccelerator/Common/Private/UbaNetworkServer.cpp
2025-05-18 13:04:45 +08:00

1435 lines
40 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "UbaNetworkServer.h"
#include "UbaConfig.h"
#include "UbaCrypto.h"
#include "UbaBinaryReaderWriter.h"
#include "UbaPlatform.h"
namespace uba
{
void NetworkServerCreateInfo::Apply(Config& config, const tchar* tableName)
{
}
struct NetworkServer::WorkerContext
{
WorkerContext(NetworkServer& s) : server(s), workAvailable(false)
{
writeMemSize = server.m_sendSize;
writeMem = new u8[writeMemSize];
}
~WorkerContext()
{
delete[] writeMem;
}
NetworkServer& server;
Event workAvailable;
u8* writeMem = nullptr;
u32 writeMemSize = 0;
Vector<u8> buffer;
Connection* connection = nullptr;
u32 dataSize = 0;
u8 serviceId = 0;
u8 messageType = 0;
u16 id = 0;
};
class NetworkServer::Worker
{
public:
Worker() {}
~Worker()
{
UBA_ASSERT(!m_inUse);
m_context->connection = nullptr;
m_loop = false;
m_context->workAvailable.Set();
m_thread.Wait();
delete m_context;
m_context = nullptr;
}
void Start(NetworkServer& server)
{
m_context = new WorkerContext(server);
m_loop = true;
m_thread.Start([&]() { ThreadWorker(server); return 0; }, TC("UbaWrkNetwSrv"));
}
void Stop(NetworkServer& server)
{
m_loop = false;
SCOPED_FUTEX(server.m_availableWorkersLock, lock);
while (m_inUse)
{
m_context->workAvailable.Set();
lock.Leave();
if (m_thread.Wait(5))
break;
lock.Enter();
}
}
void ThreadWorker(NetworkServer& server);
void Update(WorkerContext& context);
void DoAdditionalWorkAndSignalAvailable(NetworkServer& server);
Worker* m_nextWorker = nullptr;
Worker* m_prevWorker = nullptr;
WorkerContext* m_context = nullptr;
Atomic<bool> m_loop;
Atomic<bool> m_inUse;
Thread m_thread;
Worker(const Worker&) = delete;
};
thread_local NetworkServer::Worker* t_worker;
class NetworkServer::Connection
{
public:
Connection(NetworkServer& server, NetworkBackend& backend, void* backendConnection, const sockaddr& remoteSockAddr, bool requiresCrypto, CryptoKey cryptoKey, u32 id)
: m_server(server)
, m_backend(backend)
, m_remoteSockAddr(remoteSockAddr)
, m_cryptoKey(cryptoKey)
, m_disconnectCallbackCalled(true)
, m_id(id)
, m_backendConnection(backendConnection)
{
m_activeWorkerCount = 1;
m_backend.SetDisconnectCallback(m_backendConnection, this, [](void* context, const Guid& connectionUid, void* connection)
{
auto& conn = *(Connection*)context;
conn.Disconnect(TC("Backend"));
conn.m_disconnectCallbackCalled.Set();
});
m_backend.SetDataSentCallback(m_backendConnection, this, [](void* context, u32 bytes)
{
auto& conn = *(Connection*)context;
if (auto c = conn.m_client)
c->recvBytes += bytes;
conn.m_server.m_sendBytes += bytes;
});
m_backend.SetRecvTimeout(m_backendConnection, m_server.m_receiveTimeoutMs, this, [](void* context, u32 timeoutMs, const tchar* recvHint, const tchar* hint)
{
auto& conn = *(Connection*)context;
u32 clientId = ~0u;
if (auto c = conn.m_client)
clientId = c->id;
conn.m_server.m_logger.Warning(TC("Connection %u (Client %u) timed out after %u seconds (%s%s)"), conn.m_id, clientId, timeoutMs/1000, recvHint, hint);
return false;
});
if (requiresCrypto)
m_backend.SetRecvCallbacks(m_backendConnection, this, 1, ReceiveHandshakeHeader, ReceiveHandshakeBody, TC("ReceiveHandshake"));
else
m_backend.SetRecvCallbacks(m_backendConnection, this, 4, ReceiveVersion, nullptr, TC("ReceiveVersion"));
}
~Connection()
{
Stop();
if (m_cryptoKey)
Crypto::DestroyKey(m_cryptoKey);
}
void Disconnect(const tchar* reason)
{
if (m_disconnectCalled.fetch_add(1) != 0)
return;
SetShouldDisconnect();
int activeWorkerCount = --m_activeWorkerCount;
if (!activeWorkerCount) // Will disconnect in send if there are active workers
TestDisconnect();
//else
// m_server.m_logger.Detail(TC("Connection %u disconnected while it has %u active workers (%s)"), m_id, activeWorkerCount, reason);
}
bool Stop()
{
Disconnect(TC("Stop"));
u64 startTimer = GetTime();
while (m_activeWorkerCount)
{
if (TimeToMs(GetTime() - startTimer) > 3000)
{
m_server.m_logger.Error(TC("Connection %u has waited 3 seconds to stop... something is stuck (Active worker count: %u)"), m_id, m_activeWorkerCount.load());
PrintAllCallstacks(m_server.m_logger);
return false;
}
Sleep(1);
}
if (!m_disconnectCallbackCalled.IsSet(30*1000)) // This should never time out!
{
m_server.m_logger.Warning(TC("Disconnect callback event timed out. This should never happen!!"));
PrintAllCallstacks(m_server.m_logger);
}
return true;
}
bool SendInitialResponse(u8 value)
{
u8 data[32];
*data = value;
*(Guid*)(data+1) = m_server.m_uid;
NetworkBackend::SendContext context(NetworkBackend::SendFlags_Async);
return m_backend.Send(m_server.m_logger, m_backendConnection, data, 1 + sizeof(Guid), context, TC("UidResponse"));
}
static bool ReceiveHandshakeHeader(void* context, const Guid& connectionUid, u8* headerData, void*& outBodyContext, u8*& outBodyData, u32& outBodySize)
{
u8* handshakeData = new u8[sizeof(EncryptionHandshakeString)];
outBodyData = handshakeData;
outBodySize = sizeof(EncryptionHandshakeString);
return true;
}
static bool ReceiveHandshakeBody(void* context, bool recvError, u8* headerData, void* bodyContext, u8* bodyData, u32 bodySize)
{
auto& conn = *(Connection*)context;
u8* handshakeData = bodyData;
auto g = MakeGuard([handshakeData]() { delete[] handshakeData; });
auto& logger = conn.m_server.m_logger;
if (bodySize != sizeof(EncryptionHandshakeString))
return logger.Warning(TC("Connection %u Crypto mismatch... (body size was %u, expected %u)"), conn.m_id, bodySize, sizeof(EncryptionHandshakeString));
auto TestHandshake = [&](CryptoKey key)
{
u8 temp[sizeof(EncryptionHandshakeString)];
memcpy(temp, handshakeData, sizeof(temp));
if (!Crypto::Decrypt(logger, key, temp, sizeof(EncryptionHandshakeString)))
return false;
return memcmp(temp, EncryptionHandshakeString, sizeof(EncryptionHandshakeString)) == 0;
};
if (conn.m_cryptoKey != InvalidCryptoKey)
{
if (!TestHandshake(conn.m_cryptoKey))
return logger.Warning(TC("Connection %u Crypto mismatch... (Handshake string is encrypted with different key)"), conn.m_id);
}
else
{
SCOPED_FUTEX(conn.m_server.m_cryptoKeysLock, lock);
auto& keys = conn.m_server.m_cryptoKeys;
u64 time = GetTime();
for (auto it=keys.begin(); it!=keys.end();)
{
auto& entry = *it;
if (entry.expirationTime < time)
{
it = keys.erase(it);
continue;
}
++it;
CryptoKey key = Crypto::DuplicateKey(logger, entry.key);
auto keyGuard = MakeGuard([&]() { Crypto::DestroyKey(key); });
if (!TestHandshake(key))
continue;
keyGuard.Cancel();
conn.m_cryptoKey = key;
break;
}
if (conn.m_cryptoKey == InvalidCryptoKey)
return logger.Warning(TC("Connection %u Crypto mismatch... (Handshake string is encrypted with different key than any registered keys)"), conn.m_id);
}
conn.m_backend.SetRecvCallbacks(conn.m_backendConnection, &conn, 4, ReceiveVersion, nullptr, TC("ReceiveVersion"));
return true;
}
static bool ReceiveVersion(void* context, const Guid& connectionUid, u8* headerData, void*& outBodyContext, u8*& outBodyData, u32& outBodySize)
{
auto& conn = *(Connection*)context;
u32 clientVersion = *(u32*)headerData;
if (clientVersion != SystemNetworkVersion)
{
conn.SendInitialResponse(1);
return false;
}
conn.m_backend.SetRecvCallbacks(conn.m_backendConnection, &conn, sizeof(Guid), ReceiveClientUid, nullptr, TC("ReceiveClientUid"));
return true;
}
static bool RecvTimeout(void* context, u32 timeoutMs, const tchar* recvHint, const tchar* hint)
{
auto& conn = *(Connection*)context;
++conn.m_recvTimeoutCount;
conn.SendKeepAlive();
conn.m_backend.SetRecvTimeout(conn.m_backendConnection, KeepAliveIntervalSeconds*1000, context, RecvTimeout);
if (conn.m_recvTimeoutCount < KeepAliveProbeCount)
return true;
constexpr u32 totalTimeoutSeconds = KeepAliveIdleSeconds + KeepAliveIntervalSeconds*KeepAliveProbeCount;
u32 clientId = ~0u;
if (auto c = conn.m_client)
clientId = c->id;
conn.m_server.m_logger.Warning(TC("Connection %u (Client %u) timed out after %u seconds (%s%s)"), conn.m_id, clientId, totalTimeoutSeconds, recvHint, hint);
return false;
}
static bool ReceiveClientUid(void* context, const Guid& connectionUid, u8* headerData, void*& outBodyContext, u8*& outBodyData, u32& outBodySize)
{
auto& conn = *(Connection*)context;
auto& server = conn.m_server;
Guid clientUid = *(Guid*)headerData;
if (!server.m_allowNewClients)
{
SCOPED_READ_LOCK(server.m_clientsLock, clientsLock);
bool found = false;
for (auto& kv : server.m_clients)
found |= kv.second.uid == clientUid;
if (!found)
{
conn.SendInitialResponse(3);
return false;
}
}
constexpr u32 HeaderSize = 6;
conn.m_backend.SetRecvCallbacks(conn.m_backendConnection, &conn, HeaderSize, ReceiveMessageHeader, ReceiveMessageBody, TC("ReceiveMessage"));
// If keep alive we change timeout to 60 seconds
if (server.m_useKeepAlive)
conn.m_backend.SetRecvTimeout(conn.m_backendConnection, KeepAliveIdleSeconds*1000, context, RecvTimeout);
if (!conn.SendInitialResponse(0))
return false;
SCOPED_FUTEX(conn.m_shutdownLock, shutdownLock);
SCOPED_WRITE_LOCK(server.m_clientsLock, clientsLock);
u32 clientId = 0;
for (auto& kv : server.m_clients)
if (kv.second.uid == clientUid)
clientId = kv.second.id;
if (!clientId)
clientId = ++server.m_clientCounter;
Client& client = server.m_clients.try_emplace(clientId, clientUid, clientId).first->second;
++client.refCount;
clientsLock.Leave();
conn.m_client = &client;
if (client.connectionCount.fetch_add(1) == 0)
{
if (server.m_onConnectionFunction)
server.m_onConnectionFunction(clientUid, clientId);
if (server.m_logConnections)
server.m_logger.Detail(TC("Client %u (%s) connected on connection %s"), clientId, GuidToString(clientUid).str, GuidToString(connectionUid).str);
}
else
{
if (server.m_logConnections)
server.m_logger.Detail(TC("Client %u (%s) additional connection %s connected"), clientId, GuidToString(clientUid).str, GuidToString(connectionUid).str);
}
return true;
}
static bool ReceiveMessageHeader(void* context, const Guid& connectionUid, u8* headerData, void*& outBodyContext, u8*& outBodyData, u32& outBodySize)
{
auto& conn = *(Connection*)context;
u8 serviceIdAndMessageType = headerData[0];
u8 serviceId = serviceIdAndMessageType >> 6;
u8 messageType = serviceIdAndMessageType & 0b111111;
u16 messageId = u16(headerData[1] << 8) | u16((*(u32*)(headerData + 2) & 0xff000000) >> 24);
u32 messageSize = *(u32*)(headerData + 2) & 0x00ffffff;
if (messageSize > SendMaxSize)
return conn.m_server.m_logger.Error(TC("Client %u Got message size %u which is larger than max %u. Protocol error? (serviceId %u, messageType %u, messageId %u)"), conn.m_client->id, messageSize, SendMaxSize, u32(serviceId), u32(messageType), u32(messageId));
if (serviceId >= sizeof(NetworkServer::m_workerFunctions))
return conn.m_server.m_logger.Error(TC("Client %u Got message with service id %u which is out of range. Protocol error?"), conn.m_client->id, serviceId);
if (conn.m_recvTimeoutCount)
{
conn.m_recvTimeoutCount = 0;
conn.m_backend.SetRecvTimeout(conn.m_backendConnection, KeepAliveIdleSeconds*1000, context, RecvTimeout);
}
if (serviceId == SystemServiceId && messageType == SystemMessageType_KeepAlive) // Keep alive
return true;
LOG_STALL_SCOPE(conn.m_server.m_logger, 5, TC("PopWorker took more than %s"));
//m_logger.Debug(TC("Recv: %u, %u, %u, %u"), serviceId, messageType, id, size);
Worker* worker = conn.m_server.PopWorker();
if (!worker)
return false;
if (!worker->m_context)
return conn.m_server.m_logger.Error(TC("Client %u - Popped worker which has no context"), conn.m_client->id);
auto& wc = *worker->m_context;
wc.id = messageId;
wc.serviceId = serviceId;
wc.messageType = messageType;
wc.dataSize = messageSize;
wc.connection = &conn;
if (wc.buffer.size() < messageSize)
wc.buffer.resize(size_t(Min(messageSize + 1024u, SendMaxSize)));
outBodyContext = worker;
outBodyData = wc.buffer.data();
outBodySize = messageSize;
return true;
}
static bool ReceiveMessageBody(void* context, bool recvError, u8* headerData, void* bodyContext, u8* bodyData, u32 bodySize)
{
auto& conn = *(Connection*)context;
auto worker = (Worker*)bodyContext;
if (recvError)
{
conn.m_server.PushWorker(worker);
return false;
}
auto& wc = *worker->m_context;
conn.m_client->sendBytes += wc.dataSize;
conn.m_server.m_recvBytes += wc.dataSize;
++conn.m_server.m_recvCount;
++conn.m_activeWorkerCount;
wc.workAvailable.Set();
return true;
}
void Send(const void* data, u32 bytes, const tchar* sendHint)
{
TimerScope ts(m_sendTimer);
NetworkBackend::SendContext context;
if (!m_backend.Send(m_server.m_logger, m_backendConnection, data, bytes, context, sendHint))
SetShouldDisconnect();
}
bool SetShouldDisconnect()
{
SCOPED_FUTEX(m_shutdownLock, lock);
bool isConnected = !m_shouldDisconnect;
m_shouldDisconnect = true;
return isConnected;
}
void Release()
{
if (--m_activeWorkerCount == 0)
TestDisconnect();
}
void TestDisconnect()
{
SCOPED_FUTEX(m_shutdownLock, lock);
if (!m_shouldDisconnect)
return;
if (m_disconnected)
return;
lock.Leave();
m_backend.Shutdown(m_backendConnection);
if (m_client && m_client->connectionCount.fetch_sub(1) == 1)
{
SCOPED_READ_LOCK(m_server.m_onDisconnectFunctionsLock, l);
for (auto& entry : m_server.m_onDisconnectFunctions)
entry.function(m_client->uid, m_client->id);
if (m_server.m_logConnections)
m_server.m_logger.Detail(TC("Client %u (%s) disconnected"), m_client->id, GuidToString(m_client->uid).str);
}
m_disconnected = true;
}
bool SendKeepAlive()
{
NetworkBackend::SendContext sendContext;
constexpr u32 HeaderSize = 5;
u16 messageId = 0;
u32 bodySize = MessageKeepAliveSize;
u8 data[5];
data[0] = messageId >> 8;
*(u32*)(data + 1) = bodySize | u32(messageId << 24);
return m_backend.Send(m_server.m_logger, m_backendConnection, data, HeaderSize, sendContext, TC("KeepAlive"));
}
NetworkServer& m_server;
NetworkBackend& m_backend;
Futex m_shutdownLock;
Client* m_client = nullptr;
sockaddr m_remoteSockAddr;
CryptoKey m_cryptoKey;
Event m_disconnectCallbackCalled;
Atomic<int> m_activeWorkerCount;
Atomic<int> m_disconnectCalled;
Atomic<bool> m_disconnected;
u32 m_id = 0;
u32 m_recvTimeoutCount = 0;
bool m_shouldDisconnect = false;
void* m_backendConnection = nullptr;
Timer m_sendTimer;
Timer m_encryptTimer;
Timer m_decryptTimer;
Connection(const Connection& o) = delete;
void operator=(const Connection& o) = delete;
};
const Guid& ConnectionInfo::GetUid() const
{
return ((NetworkServer::Connection*)internalData)->m_client->uid;
}
u32 ConnectionInfo::GetId() const
{
return ((NetworkServer::Connection*)internalData)->m_client->id;
}
bool ConnectionInfo::GetName(StringBufferBase& out) const
{
#if PLATFORM_WINDOWS
auto& remoteSockAddr = ((NetworkServer::Connection*)internalData)->m_remoteSockAddr;
if (!InetNtopW(AF_INET, &remoteSockAddr, out.data, out.capacity))
return false;
out.count = u32(wcslen(out.data));
return true;
#else
UBA_ASSERT(false);
return false;
#endif
}
bool ConnectionInfo::ShouldDisconnect() const
{
auto& conn = *(NetworkServer::Connection*)internalData;
SCOPED_FUTEX(conn.m_shutdownLock, lock);
return conn.m_shouldDisconnect;
}
void NetworkServer::Worker::Update(WorkerContext& context)
{
auto& server = context.server;
// This is only additional work
if (!context.connection)
return;
auto& connection = *context.connection;
context.connection = nullptr;
WorkerRec& rec = server.m_workerFunctions[context.serviceId];
TrackWorkScope tws(server, rec.toString(context.messageType), ColorWork);
CryptoKey cryptoKey = connection.m_cryptoKey;
if (cryptoKey)
{
//TrackHintScope ths(tws, TCV("Decrypt"));
TimerScope ts(connection.m_decryptTimer);
if (!Crypto::Decrypt(server.m_logger, cryptoKey, context.buffer.data(), context.dataSize))
{
connection.SetShouldDisconnect();
connection.Release();
return;
}
}
BinaryReader reader(context.buffer.data(), 0, context.dataSize);
constexpr u32 HeaderSize = 5; // 2 byte id, 3 bytes size
BinaryWriter writer(context.writeMem, 0, context.writeMemSize);
u8* idAndSizePtr = writer.AllocWrite(HeaderSize);
u32 size;
MessageInfo mi;
mi.type = context.messageType;
mi.connectionId = connection.m_id;
mi.messageId = context.id;
{
//TrackHintScope ths(tws, TCV("HandleMessage"));
if (!rec.func)
{
server.m_logger.Error(TC("WORKER FUNCTION NOT FOUND. id: %u, serviceid: %u type: %s, client: %u"), context.id, context.serviceId, rec.toString(context.messageType).data, connection.m_client->id);
connection.SetShouldDisconnect();
size = MessageErrorSize;
}
else if (!rec.func({&connection}, {tws}, mi, reader, writer))
{
if (connection.SetShouldDisconnect())
{
#if UBA_DEBUG
server.m_logger.Error(TC("WORKER FUNCTION FAILED. id: %u, serviceid: %u type: %s, client: %u"), context.id, context.serviceId, rec.toString(context.messageType).data, connection.m_client->id);
#endif
}
size = MessageErrorSize;
}
else
{
size = u32(writer.GetPosition());
}
}
if (mi.messageId)
{
UBA_ASSERT(size < (1 << 24));
u32 bodySize = u32(size - HeaderSize);
if (cryptoKey && size != MessageErrorSize && bodySize)
{
//TrackHintScope ths(tws, TCV("Encrypt"));
TimerScope ts(connection.m_encryptTimer);
u8* bodyData = writer.GetData() + HeaderSize;
if (!Crypto::Encrypt(server.m_logger, cryptoKey, bodyData, bodySize))
{
connection.SetShouldDisconnect();
size = MessageErrorSize;
bodySize = u32(size - HeaderSize);
}
}
idAndSizePtr[0] = context.id >> 8;
*(u32*)(idAndSizePtr + 1) = bodySize | u32(context.id << 24);
// This can happen for proxy servers in a valid situation
//if (size == MessageErrorSize)
// UBA_ASSERT(false);
//TrackHintScope ths(tws, TCV("Send"));
connection.Send(writer.GetData(), size == MessageErrorSize ? HeaderSize : size, TC("MessageResponse"));
}
connection.Release();
}
void NetworkServer::Worker::ThreadWorker(NetworkServer& server)
{
ElevateCurrentThreadPriority();
t_worker = this;
while (m_context->workAvailable.IsSet(~0u) && m_loop)
{
Update(*m_context);
DoAdditionalWorkAndSignalAvailable(m_context->server);
}
t_worker = nullptr;
if (m_inUse) // I have no idea how this can happen.. should not be possible. There is a path somewhere where it can leave while still being in use
server.PushWorker(this);
}
void NetworkServer::Worker::DoAdditionalWorkAndSignalAvailable(NetworkServer& server)
{
while (true)
{
while (true)
{
AdditionalWork work;
SCOPED_FUTEX(server.m_additionalWorkLock, lock);
if (server.m_additionalWork.empty())
break;
work = std::move(server.m_additionalWork.front());
server.m_additionalWork.pop_front();
lock.Leave();
#if UBA_TRACK_WORK
TrackWorkScope tws(server, work.desc, ColorWork);
#else
TrackWorkScope tws;
#endif
work.func({tws});
}
// Both locks needs to be taken to verify if additional work
// is present before making ourself available to avoid
// a race where AddWork would not see this thread in the
// available list after adding some work.
SCOPED_FUTEX(server.m_availableWorkersLock, lock1);
SCOPED_FUTEX_READ(server.m_additionalWorkLock, lock2);
// Verify there is not additional work while we hold both lock
// and only add ourself as available if no additional work is present.
if (!server.m_additionalWork.empty())
continue;
server.PushWorkerNoLock(this);
break;
}
}
const tchar* g_typeStr[] = { TC("0"), TC("1"), TC("2"), TC("3"), TC("4"), TC("5"), TC("6"), TC("7"), TC("8"), TC("9"), TC("10"), TC("11"), TC("12") };
static StringView GetMessageTypeToName(u8 type)
{
if (type <= 12)
return ToView(g_typeStr[type]);
return ToView(TC("NUMBER HIGHER THAN 12"));
}
NetworkServer::NetworkServer(bool& outCtorSuccess, const NetworkServerCreateInfo& info, const tchar* name)
: m_logger(info.logWriter, name)
{
outCtorSuccess = true;
u32 workerCount;
if (info.workerCount == 0)
workerCount = GetLogicalProcessorCount();
else
workerCount = Min(Max(info.workerCount, (u32)(1u)), (u32)(1024u));
m_maxWorkerCount = workerCount;
#if UBA_DEBUG
m_logger.Info(TC("Created in DEBUG"));
#endif
u32 fixedSendSize = Max(info.sendSize, (u32)(4*1024));
fixedSendSize = Min(fixedSendSize, (u32)(SendMaxSize));
if (info.sendSize != fixedSendSize)
m_logger.Detail(TC("Adjusted msg size to %u to stay inside limits"), fixedSendSize);
m_sendSize = fixedSendSize;
m_receiveTimeoutMs = info.receiveTimeoutSeconds * 1000;
m_logConnections = info.logConnections;
m_useKeepAlive = info.useKeepAlive;
#if PLATFORM_MAC
m_useKeepAlive = true; // Always run keep alive on mac since the built-in one has a probe interval of 1 minute or something like that... so timeout is always 10 minutes
#endif
m_workerFunctions[SystemServiceId].toString = GetMessageTypeToName;
m_workerFunctions[SystemServiceId].func = [this](const ConnectionInfo& connectionInfo, const WorkContext& workContext, MessageInfo& messageInfo, BinaryReader& reader, BinaryWriter& writer)
{
return HandleSystemMessage(connectionInfo, messageInfo.type, reader, writer);
};
if (!CreateGuid(m_uid))
outCtorSuccess = false;
}
NetworkServer::~NetworkServer()
{
UBA_ASSERT(m_connections.empty());
FlushWorkers();
for (auto& entry : m_cryptoKeys)
Crypto::DestroyKey(entry.key);
}
bool NetworkServer::StartListen(NetworkBackend& backend, u16 port, const tchar* ip, bool requiresCrypto)
{
return backend.StartListen(m_logger, port, ip, [this, &backend, requiresCrypto](void* connection, const sockaddr& remoteSockAddr)
{
return AddConnection(backend, connection, remoteSockAddr, requiresCrypto, InvalidCryptoKey);
});
}
void NetworkServer::DisallowNewClients()
{
m_allowNewClients = false;
}
void NetworkServer::DisconnectClients()
{
{
SCOPED_FUTEX(m_availableWorkersLock, lock);
m_workersEnabled = false;
while (PopWorkerRequest* req = m_firstRequest)
{
m_firstRequest = req->next;
req->next = nullptr;
req->ev.Set();
}
m_lastRequest = nullptr;
}
{
SCOPED_FUTEX(m_addConnectionsLock, lock);
m_addConnections.clear();
}
{
SCOPED_WRITE_LOCK(m_connectionsLock, lock);
bool success = true;
for (auto& c : m_connections)
{
success = c.Stop() && success;
m_sendTimer += c.m_sendTimer;
m_encryptTimer += c.m_encryptTimer;
m_decryptTimer += c.m_decryptTimer;
}
lock.Leave();
// If stopping connections fail we need to abort because we will most likely run into a deadlock when deleting the workers.
if (!success)
{
m_logger.Info(TC("Failed to stop connection(s) in a graceful way. Will abort process"));
abort(); // TODO: Does this produce core dump on windows?
}
}
FlushWorkers();
SCOPED_WRITE_LOCK(m_connectionsLock, lock);
m_connections.clear();
m_allClientsDisconnected = true;
m_workersEnabled = true;
}
bool NetworkServer::RegisterCryptoKey(const u8* cryptoKey128, u64 expirationTime)
{
CryptoKey key = Crypto::CreateKey(m_logger, cryptoKey128);
if (key == InvalidCryptoKey)
return false;
SCOPED_FUTEX(m_cryptoKeysLock, lock);
m_cryptoKeys.push_back(CryptoEntry{key, expirationTime});
return true;
}
void NetworkServer::SetClientsConfig(const Config& config)
{
config.SaveToText(m_logger, m_clientsConfig);
}
bool NetworkServer::AddClient(NetworkBackend& backend, const tchar* ip, u16 port, const u8* cryptoKey128)
{
SCOPED_FUTEX(m_addConnectionsLock, lock);
if (!m_workersEnabled)
return false;
for (auto it = m_addConnections.begin(); it != m_addConnections.end();)
{
if (it->Wait(0))
it = m_addConnections.erase(it);
else
++it;
}
CryptoKey cryptoKey = InvalidCryptoKey;
if (cryptoKey128)
{
cryptoKey = Crypto::CreateKey(m_logger, cryptoKey128);
if (cryptoKey == InvalidCryptoKey)
return false;
}
Event done(true);
bool success = false;
m_addConnections.emplace_back([this, &success , &done, &backend, ip2 = TString(ip), port, cryptoKey]()
{
// TODO: Should this retry?
success = backend.Connect(m_logger, ip2.c_str(), [this, &backend, cryptoKey](void* connection, const sockaddr& remoteSocketAddr, bool* timedOut)
{
return AddConnection(backend, connection, remoteSocketAddr, cryptoKey != InvalidCryptoKey, cryptoKey);
}, port, nullptr);
if (!success)
Crypto::DestroyKey(cryptoKey);
done.Set();
return 0;
});
done.IsSet();
return success;
}
bool NetworkServer::HasConnectInProgress()
{
SCOPED_FUTEX(m_addConnectionsLock, lock);
for (auto it = m_addConnections.begin(); it != m_addConnections.end();)
{
if (it->Wait(0))
it = m_addConnections.erase(it);
else
++it;
}
return !m_addConnections.empty();
}
void NetworkServer::PrintSummary(Logger& logger)
{
if (!m_maxActiveConnections)
return;
m_maxCreatedWorkerCount = Max(m_createdWorkerCount, m_maxCreatedWorkerCount);
StringBuffer<> workers;
workers.Appendf(TC("%u/%u"), m_maxCreatedWorkerCount, m_maxWorkerCount);
logger.Info(TC(" ----- Uba server stats summary ------"));
logger.Info(TC(" MaxActiveConnections %6u"), m_maxActiveConnections);
logger.Info(TC(" SendTotal %8u %9s"), m_sendTimer.count.load(), TimeToText(m_sendTimer.time).str);
logger.Info(TC(" Bytes %9s"), BytesToText(m_sendBytes).str);
logger.Info(TC(" RecvTotal %8u %9s"), m_recvCount.load(), BytesToText(m_recvBytes.load()).str);
if (m_encryptTimer.count || m_decryptTimer.count)
{
logger.Info(TC(" EncryptTotal %8u %9s"), m_encryptTimer.count.load(), TimeToText(m_encryptTimer.time).str);
logger.Info(TC(" DecryptTotal %8u %9s"), m_decryptTimer.count.load(), TimeToText(m_decryptTimer.time).str);
}
logger.Info(TC(" WorkerCount %9s"), workers.data);
logger.Info(TC(" SendSize Set/Max %9s %9s"), BytesToText(m_sendSize).str, BytesToText(SendMaxSize).str);
logger.Info(TC(""));
}
void NetworkServer::RegisterService(u8 serviceId, const WorkerFunction& function, TypeToNameFunction* typeToNameFunc)
{
UBA_ASSERTF(serviceId != 0, TC("ServiceId 0 is reserved by system"));
WorkerRec& rec = m_workerFunctions[serviceId];
UBA_ASSERT(!rec.func);
rec.func = function;
rec.toString = typeToNameFunc;
if (!typeToNameFunc)
rec.toString = GetMessageTypeToName;
}
void NetworkServer::UnregisterService(u8 serviceId)
{
SCOPED_WRITE_LOCK(m_connectionsLock, lock);
UBA_ASSERTF(m_connections.empty(), TC("Unregistering service while still having live connections"));
WorkerRec& rec = m_workerFunctions[serviceId];
rec.func = {};
//rec.toString = nullptr; // Keep this for now, we want to be able to output stats
}
void NetworkServer::RegisterOnClientConnected(u8 id, const OnConnectionFunction& func)
{
UBA_ASSERT(!m_onConnectionFunction);
m_onConnectionFunction = func;
}
void NetworkServer::UnregisterOnClientConnected(u8 id)
{
SCOPED_WRITE_LOCK(m_connectionsLock, lock);
UBA_ASSERT(m_connections.empty());
m_onConnectionFunction = {};
}
void NetworkServer::RegisterOnClientDisconnected(u8 id, const OnDisconnectFunction& func)
{
SCOPED_WRITE_LOCK(m_onDisconnectFunctionsLock, l);
m_onDisconnectFunctions.emplace_back(OnDisconnectEntry{id, func});
}
void NetworkServer::UnregisterOnClientDisconnected(u8 id)
{
SCOPED_WRITE_LOCK(m_onDisconnectFunctionsLock, l);
for (auto it = m_onDisconnectFunctions.begin(); it != m_onDisconnectFunctions.end(); ++it)
{
if (it->id != id)
continue;
m_onDisconnectFunctions.erase(it);
return;
}
}
void NetworkServer::AddWork(const WorkFunction& work, u32 count, const tchar* desc, const Color& color, bool highPriority)
{
UBA_ASSERT(*desc);
SCOPED_FUTEX(m_additionalWorkLock, lock);
for (u32 i = 0; i != count; ++i)
{
if (highPriority)
{
m_additionalWork.push_front({ work });
if (m_workTracker)
m_additionalWork.front().desc = desc;
}
else
{
m_additionalWork.push_back({ work });
if (m_workTracker)
m_additionalWork.back().desc = desc;
}
}
lock.Leave();
SCOPED_FUTEX(m_availableWorkersLock, lock2);
if (!m_workersEnabled)
return;
while (count--)
{
Worker* worker = PopWorkerNoLock();
if (!worker)
break;
UBA_ASSERT(worker->m_inUse);
worker->m_context->connection = nullptr;
worker->m_context->workAvailable.Set();
}
}
void NetworkServer::DoWork(u32 count)
{
while (count--)
if (!DoAdditionalWork())
return;
}
u32 NetworkServer::GetWorkerCount()
{
return m_maxWorkerCount;
}
MutableLogger& NetworkServer::GetLogger()
{
return m_logger;
}
u64 NetworkServer::GetTotalSentBytes()
{
return m_sendBytes;
}
u64 NetworkServer::GetTotalRecvBytes()
{
return m_recvBytes;
}
Timer& NetworkServer::GetTotalSentTimer()
{
return m_sendTimer;
}
u32 NetworkServer::GetClientCount()
{
SCOPED_READ_LOCK(m_clientsLock, lock)
return u32(m_clients.size());
}
u32 NetworkServer::GetConnectionCount()
{
SCOPED_READ_LOCK(m_connectionsLock, lock);
u32 count = 0;
for (auto& con : m_connections)
if (!con.m_disconnected)
++count;
return count;
}
void NetworkServer::GetClientStats(ClientStats& out, u32 clientId)
{
SCOPED_READ_LOCK(m_clientsLock, lock);
auto findIt = m_clients.find(clientId);
if (findIt == m_clients.end())
return;
Client& c = findIt->second;
out.send += c.sendBytes;
out.recv += c.recvBytes;
out.connectionCount += c.connectionCount;
}
bool NetworkServer::IsConnected(u32 clientId)
{
SCOPED_READ_LOCK(m_clientsLock, lock);
auto findIt = m_clients.find(clientId);
if (findIt == m_clients.end())
return false;
Client& c = findIt->second;
return c.connectionCount > 0;
}
void NetworkServer::ResetTotalStats()
{
m_sendTimer = {};
m_sendBytes = 0;
m_recvBytes = 0;
}
bool NetworkServer::DoAdditionalWork()
{
SCOPED_FUTEX(m_additionalWorkLock, lock);
if (m_additionalWork.empty())
{
lock.Leave();
SCOPED_FUTEX(m_availableWorkersLock, lock2);
if (m_createdWorkerCount != m_maxWorkerCount)
return false;
lock2.Leave();
auto worker = t_worker;
if (!worker)
return false;
auto oldContext = worker->m_context;
WorkerContext context(*this);
worker->m_context = &context;
PushWorker(worker);
bool workAvail = context.workAvailable.IsSet(10);
lock2.Enter();
if (worker->m_inUse)
{
lock2.Leave();
if (!workAvail)
context.workAvailable.IsSet(~0u);
worker->Update(context);
UBA_ASSERT(worker->m_inUse);
}
else
{
// Take worker back from free list
if (m_firstAvailableWorker == worker)
m_firstAvailableWorker = worker->m_nextWorker;
else
worker->m_prevWorker->m_nextWorker = worker->m_nextWorker;
if (worker->m_nextWorker)
worker->m_nextWorker->m_prevWorker = worker->m_prevWorker;
worker->m_prevWorker = nullptr;
worker->m_nextWorker = m_firstActiveWorker;
if (m_firstActiveWorker)
m_firstActiveWorker->m_prevWorker = worker;
m_firstActiveWorker = worker;
worker->m_inUse = true;
}
worker->m_context = oldContext;
return true;
}
AdditionalWork work = std::move(m_additionalWork.front());
m_additionalWork.pop_front();
lock.Leave();
#if UBA_TRACK_WORK
TrackWorkScope tws(*this, work.desc, ColorWork);
#else
TrackWorkScope tws;
#endif
work.func({tws});
return true;
}
bool NetworkServer::SendResponse(const MessageInfo& info, const u8* body, u32 bodySize)
{
UBA_ASSERT(info.connectionId);
UBA_ASSERT(info.messageId);
LOG_STALL_SCOPE(m_logger, 5, TC("NetworkServer::SendResponse took more than %s"));
SCOPED_READ_LOCK(m_connectionsLock, lock);
Connection* found = nullptr;
for (auto& it : m_connections)
{
if (it.m_id != info.connectionId)
continue;
if (!it.m_disconnected)
found = &it;
break;
}
if (!found)
return false;
Connection& connection = *found;
u8 buffer[SendMaxSize];
constexpr u32 HeaderSize = 5; // 2 byte id, 3 bytes size
BinaryWriter writer(buffer, 0, sizeof_array(buffer));
u8* idAndSizePtr = writer.AllocWrite(HeaderSize);
if (body)
{
writer.WriteBytes(body, bodySize);
if (connection.m_cryptoKey && bodySize)
{
TimerScope ts(connection.m_encryptTimer);
u8* bodyData = writer.GetData() + HeaderSize;
if (!Crypto::Encrypt(m_logger, connection.m_cryptoKey, bodyData, bodySize))
{
connection.SetShouldDisconnect();
return false;
}
}
}
else
{
bodySize = MessageErrorSize;
connection.SetShouldDisconnect();
}
idAndSizePtr[0] = info.messageId >> 8;
*(u32*)(idAndSizePtr + 1) = bodySize | u32(info.messageId << 24);
connection.Send(writer.GetData(), u32(writer.GetPosition()), TC("MessageResponse"));
return true;
}
bool NetworkServer::SendKeepAlive()
{
SCOPED_READ_LOCK(m_connectionsLock, lock);
for (auto& it : m_connections)
if (!it.SendKeepAlive())
return false;
return true;
}
NetworkServer::Worker* NetworkServer::PopWorker()
{
while (true)
{
SCOPED_FUTEX(m_availableWorkersLock, lock);
if (!m_workersEnabled)
return nullptr;
if (auto worker = PopWorkerNoLock())
return worker;
PopWorkerRequest req;
req.ev.Create(true);
if (!m_firstRequest)
m_firstRequest = &req;
else
m_lastRequest->next = &req;
m_lastRequest = &req;
lock.Leave();
req.ev.IsSet();
if (req.worker)
return req.worker;
}
}
NetworkServer::Worker* NetworkServer::PopWorkerNoLock()
{
Worker* worker = m_firstAvailableWorker;
if (worker)
{
m_firstAvailableWorker = worker->m_nextWorker;
if (m_firstAvailableWorker)
m_firstAvailableWorker->m_prevWorker = nullptr;
}
else
{
if (m_createdWorkerCount == m_maxWorkerCount)
return nullptr;
worker = new Worker();
worker->Start(*this);
++m_createdWorkerCount;
}
if (m_firstActiveWorker)
m_firstActiveWorker->m_prevWorker = worker;
worker->m_nextWorker = m_firstActiveWorker;
m_firstActiveWorker = worker;
worker->m_inUse = true;
return worker;
}
void NetworkServer::PushWorker(Worker* worker)
{
SCOPED_FUTEX(m_availableWorkersLock, lock);
PushWorkerNoLock(worker);
}
void NetworkServer::PushWorkerNoLock(Worker* worker)
{
UBA_ASSERT(worker->m_inUse);
if (PopWorkerRequest* first = m_firstRequest)
{
m_firstRequest = first->next;
if (!m_firstRequest)
m_lastRequest = nullptr;
first->worker = worker;
first->ev.Set();
return;
}
if (worker->m_prevWorker)
worker->m_prevWorker->m_nextWorker = worker->m_nextWorker;
else
m_firstActiveWorker = worker->m_nextWorker;
if (worker->m_nextWorker)
worker->m_nextWorker->m_prevWorker = worker->m_prevWorker;
if (m_firstAvailableWorker)
m_firstAvailableWorker->m_prevWorker = worker;
worker->m_prevWorker = nullptr;
worker->m_nextWorker = m_firstAvailableWorker;
worker->m_inUse = false;
m_firstAvailableWorker = worker;
}
void NetworkServer::FlushWorkers()
{
SCOPED_FUTEX(m_availableWorkersLock, lock);
while (auto worker = m_firstActiveWorker)
{
lock.Leave();
worker->Stop(*this);
lock.Enter();
}
UBA_ASSERT(m_firstActiveWorker == nullptr);
auto worker = m_firstAvailableWorker;
while (worker)
{
auto temp = worker;
worker = worker->m_nextWorker;
delete temp;
}
m_firstAvailableWorker = nullptr;
m_maxCreatedWorkerCount = Max(m_createdWorkerCount, m_maxCreatedWorkerCount);
m_createdWorkerCount = 0;
}
void NetworkServer::RemoveDisconnectedConnections()
{
bool clientRefCountChanged = false;
for (auto it=m_connections.begin(); it!=m_connections.end();)
{
Connection& con = *it;
if (!con.m_disconnected)
{
++it;
continue;
}
m_sendTimer += con.m_sendTimer;
auto& backend = con.m_backend;
void* backendConnection = con.m_backendConnection;
if (auto client = con.m_client)
{
--client->refCount;
clientRefCountChanged = true;
}
it = m_connections.erase(it);
backend.DeleteConnection(backendConnection);
}
if (!clientRefCountChanged)
return;
SCOPED_WRITE_LOCK(m_clientsLock, lock);
for (auto it=m_clients.begin(); it!=m_clients.end();)
{
if (it->second.refCount)
++it;
else
it = m_clients.erase(it);
}
}
bool NetworkServer::HandleSystemMessage(const ConnectionInfo& connectionInfo, u8 messageType, BinaryReader& reader, BinaryWriter& writer)
{
switch (messageType)
{
case SystemMessageType_SetConnectionCount:
{
LOG_STALL_SCOPE(m_logger, 5, TC("SystemMessageType_SetConnectionCount took more than %s"));
u32 connectionCount = reader.ReadU32();
u32 clientId = connectionInfo.GetId();
SCOPED_READ_LOCK(m_clientsLock, lock);
auto findIt = m_clients.find(clientId);
if (findIt == m_clients.end())
return true;
Client& c = findIt->second;
u32 currentCount = c.connectionCount + c.queuedConnectionCount;
if (currentCount >= connectionCount)
return true;
u32 toAdd = connectionCount - currentCount;
c.queuedConnectionCount += toAdd;
m_logger.Detail(TC("Client %u requested %u connections. Has %u, queue %u"), c.id, connectionCount, c.connectionCount.load(), c.queuedConnectionCount);
lock.Leave();
u32 connectionId = ((NetworkServer::Connection*)connectionInfo.internalData)->m_id;
SCOPED_FUTEX(m_addConnectionsLock, lock2);
for (u32 i = 0; i != toAdd; ++i)
{
m_addConnections.emplace_back([this, connectionId]()
{
auto cg = MakeGuard([&]()
{
SCOPED_READ_LOCK(m_clientsLock, lock);
auto findIt = m_clients.find(connectionId);
if (findIt != m_clients.end())
--findIt->second.queuedConnectionCount;
});
Connection* conn = nullptr;
SCOPED_READ_LOCK(m_connectionsLock, lock);
for (auto& c : m_connections)
if (c.m_id == connectionId)
conn = &c;
if (!conn || conn->m_disconnected)
return 0;
auto& backend = conn->m_backend;
auto remoteSockAddr = conn->m_remoteSockAddr;
CryptoKey cryptoKey = InvalidCryptoKey;
if (conn->m_cryptoKey)
{
cryptoKey = Crypto::DuplicateKey(m_logger, conn->m_cryptoKey);
if (cryptoKey == InvalidCryptoKey)
return 0;
}
lock.Leave();
bool success = backend.Connect(m_logger, remoteSockAddr, [this, &backend, &cryptoKey](void* connection, const sockaddr& remoteSocketAddr, bool* timedOut)
{
return AddConnection(backend, connection, remoteSocketAddr, cryptoKey != InvalidCryptoKey, cryptoKey);
}, nullptr);
if (!success)
Crypto::DestroyKey(cryptoKey);
return 0;
});
}
return true;
}
case SystemMessageType_FetchConfig:
{
writer.Write7BitEncoded(m_clientsConfig.size());
writer.WriteBytes(m_clientsConfig.data(), m_clientsConfig.size());
return true;
}
}
return false;
}
bool NetworkServer::AddConnection(NetworkBackend& backend, void* backendConnection, const sockaddr& remoteSocketAddr, bool requiresCrypto, CryptoKey cryptoKey)
{
LOG_STALL_SCOPE(m_logger, 5, TC("NetworkServer::AddConnection took more than %s"));
SCOPED_WRITE_LOCK(m_connectionsLock, lock);
RemoveDisconnectedConnections();
if (!m_workersEnabled || m_allClientsDisconnected)
{
// Just to prevent errors in log
backend.SetDisconnectCallback(backendConnection, nullptr, [](void*, const Guid&, void*) {});
backend.SetRecvCallbacks(backendConnection, nullptr, 0, [](void*, const Guid&, u8*, void*&, u8*&, u32&) { return false; }, nullptr, TC("Disconnecting"));
return false;
}
m_connections.emplace_back(*this, backend, backendConnection, remoteSocketAddr, requiresCrypto, cryptoKey, m_connectionIdCounter++);
m_maxActiveConnections = Max(m_maxActiveConnections, u32(m_connections.size()));
return true;
}
}