// Copyright Epic Games, Inc. All Rights Reserved. #include "UbaSessionServer.h" #include "UbaApplicationRules.h" #include "UbaConfig.h" #include "UbaNetworkServer.h" #include "UbaProcess.h" #include "UbaProcessStartInfoHolder.h" #include "UbaScheduler.h" #include "UbaStorage.h" namespace uba { class SessionServer::RemoteProcess final : public Process { public: RemoteProcess(SessionServer* server, const ProcessStartInfo& si, u32 processId, float weight_) : m_server(server) , m_startInfo(si) , m_processId(processId) , m_done(true) { m_startInfo.weight = weight_; } ~RemoteProcess() { if (m_knownInputsDone.IsCreated()) m_knownInputsDone.IsSet(50*1000); delete[] m_knownInputs; } virtual const ProcessStartInfo& GetStartInfo() const override { return m_startInfo; } virtual u32 GetId() override { return m_processId; } virtual u32 GetExitCode() override { UBA_ASSERT(m_done.IsSet(0)); return m_exitCode; } virtual bool HasExited() override { return m_done.IsSet(0); } virtual bool WaitForExit(u32 millisecondsTimeout) override { return m_done.IsSet(millisecondsTimeout); } virtual u64 GetTotalProcessorTime() const override { return m_processorTime; } virtual u64 GetTotalWallTime() const override { return m_wallTime; } virtual const Vector& GetLogLines() const override { return m_logLines; } virtual const Vector& GetTrackedInputs() const override { return m_trackedInputs; } virtual const Vector& GetTrackedOutputs() const override { return m_trackedOutputs; }; virtual void Cancel(bool terminate) override { if (m_cancelled) return; m_cancelled = true; m_exitCode = ProcessCancelExitCode; if (auto s = m_server) s->OnCancelled(this); else m_done.Set(); ProcessHandle h; h.m_process = this; CallProcessExit(h); h.m_process = nullptr; } virtual const tchar* GetExecutingHost() const override { return m_executingHost.c_str(); } virtual bool IsRemote() const override { return true; } virtual ProcessExecutionType GetExecutionType() const override { return ProcessExecutionType_Remote; } virtual bool IsChild() override { return false; } void CallProcessExit(ProcessHandle& h) { SCOPED_FUTEX(m_exitedLock, lock); if (!m_startInfo.exitedFunc) return; auto exitedFunc = m_startInfo.exitedFunc; auto userData = m_startInfo.userData; m_startInfo.exitedFunc = nullptr; m_startInfo.userData = nullptr; ProcessExitedResponse response = ProcessExitedResponse_None; exitedFunc(userData, h, response); } SessionServer* m_server; ProcessStartInfoHolder m_startInfo; Futex m_exitedLock; u32 m_processId; u32 m_exitCode = ~0u; u64 m_processorTime = 0; u64 m_wallTime = 0; Event m_done; Vector m_logLines; Vector m_trackedInputs; Vector m_trackedOutputs; bool m_cancelled = false; bool m_allowCrossArchitecture = false; u32 m_clientId = ~0u; u32 m_sessionId = 0; TString m_executingHost; struct KnownInput { CasKey key; u32 mappingAlignment = 0; bool allowProxy = true; }; KnownInput* m_knownInputs = nullptr; u32 m_knownInputsCount = 0; Event m_knownInputsDone; }; void SessionServerCreateInfo::Apply(const Config& config) { SessionCreateInfo::Apply(config); if (const ConfigTable* table = config.GetTable(TC("Session"))) { table->GetValueAsBool(remoteLogEnabled, TC("RemoteLogEnabled")); table->GetValueAsBool(remoteTraceEnabled, TC("RemoteTraceEnabled")); table->GetValueAsBool(nameToHashTableEnabled, TC("NameToHashTableEnabled")); table->GetValueAsBool(traceIOEnabled, TC("TraceIOEnabled")); } } bool GetCrossArchitectureDir(Logger& logger, StringBufferBase& dir, bool reportError) { // UBT has the path win-x64/native or win-arm64/native bool isUbtPath = dir.EndsWith(TCV("native")); if (isUbtPath) dir.Resize(dir.count - 7); // Remove native and slash const tchar* archPath[2] = { TC("x64"), TC("arm64") }; if (!dir.EndsWith(archPath[IsArmBinary])) return reportError ? logger.Error(TC("Module dir is not under supported folder (%s) to be able to run cross architecutres, can't figure out matching x64/arm64 folder"), dir.data) : false; dir.Resize(dir.count - TStrlen(archPath[IsArmBinary])).Append(archPath[!IsArmBinary]); if (isUbtPath) dir.Append(PathSeparator).Append(TCV("native")); return true; } SessionServer::SessionServer(const SessionServerCreateInfo& info, const u8* environment, u32 environmentSize) : Session(info, TC("UbaSessionServer"), false, info.server) , m_server(info.server) , m_maxRemoteProcessCount(~0u) { m_server.RegisterOnClientDisconnected(ServiceId, [this](const Guid& clientUid, u32 clientId) { OnDisconnected(clientUid, clientId); }); m_server.RegisterService(ServiceId, [this](const ConnectionInfo& connectionInfo, const WorkContext& workContext, MessageInfo& messageInfo, BinaryReader& reader, BinaryWriter& writer) { switch (messageInfo.type) { #define UBA_SESSION_MESSAGE(x) case SessionMessageType_##x: return Handle##x(connectionInfo, workContext, reader, writer); UBA_SESSION_MESSAGES #undef UBA_SESSION_MESSAGE } UBA_ASSERT(false); return false; }, [](u8 type) { switch (type) { #define UBA_SESSION_MESSAGE(x) case SessionMessageType_##x: return AsView(TC("")#x); UBA_SESSION_MESSAGES #undef UBA_SESSION_MESSAGE default: return ToView(TC("Unknown")); } } ); if (environmentSize) { m_environmentMemory.resize(environmentSize); memcpy(m_environmentMemory.data(), environment, environmentSize); } m_uiLanguage = GetUserDefaultUILanguage(); m_resetCas = info.resetCas; m_remoteExecutionEnabled = info.remoteExecutionEnabled; m_nameToHashTableEnabled = info.nameToHashTableEnabled; m_memKillLoadPercent = info.memKillLoadPercent; m_remoteLogEnabled = info.remoteLogEnabled; m_remoteTraceEnabled = info.remoteTraceEnabled; m_traceIOEnabled = info.traceIOEnabled; if (m_resetCas) m_storage.Reset(); m_storage.SetTrace(&m_trace, m_detailedTrace); if (m_detailedTrace) m_server.SetWorkTracker(&m_trace); m_memoryThreadEvent.Create(true); if (info.checkMemory) { m_allowWaitOnMem = info.allowWaitOnMem; m_allowKillOnMem = info.allowKillOnMem; u64 memAvail; u64 memTotal; if (GetMemoryInfo(memAvail, memTotal)) { m_memAvail = memAvail; m_memTotal = memTotal; m_memRequiredToSpawn = Min(u64(double(m_memTotal) * double(100 - info.memWaitLoadPercent) / 100.0), 35ull * 1024 * 1024 * 1024); } m_memoryThread.Start([this]() { ThreadMemoryCheckLoop(); return 0; }, TC("UbaMemTrackLoop")); } #if PLATFORM_WINDOWS m_localEnvironmentVariables.insert(TC("TMP")); m_localEnvironmentVariables.insert(TC("TEMP")); #else m_localEnvironmentVariables.insert(TC("TMPDIR")); #endif StringBuffer<> detoursFile; if (!GetDirectoryOfCurrentModule(m_logger, detoursFile)) { UBA_ASSERT(false); return; } u32 dirLength = detoursFile.count;(void)dirLength; detoursFile.Append(PathSeparator).Append(UBA_DETOURS_LIBRARY); #if PLATFORM_WINDOWS char temp[1024]; detoursFile.Parse(temp, sizeof_array(temp)); m_detoursLibrary[IsArmBinary] = temp; if (GetCrossArchitectureDir(m_logger, detoursFile.Resize(dirLength), false)) { detoursFile.Append(PathSeparator).Append(UBA_DETOURS_LIBRARY).Parse(temp, sizeof_array(temp)); m_detoursLibrary[!IsArmBinary] = temp; } #else m_detoursLibrary[IsArmBinary] = detoursFile.data; #endif if (!Create(info)) { UBA_ASSERT(false); return; } } SessionServer::~SessionServer() { m_memoryThreadEvent.Set(); m_memoryThread.Wait(); StopTraceThread(); m_server.SetWorkTracker(nullptr); m_server.UnregisterOnClientDisconnected(ServiceId); m_server.UnregisterService(ServiceId); SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); for (ProcessHandle& p : m_queuedRemoteProcesses) { ((RemoteProcess*)p.m_process)->m_server = nullptr; p.Cancel(true); } m_queuedRemoteProcesses.clear(); for (const ProcessHandle& p : m_activeRemoteProcesses) { ((RemoteProcess*)p.m_process)->m_server = nullptr; p.Cancel(true); } m_activeRemoteProcesses.clear(); if (m_trace.IsWriting()) { StackBinaryWriter writer; WriteSummary(writer, [&](Logger& logger) { PrintSummary(logger); m_storage.PrintSummary(logger); m_server.PrintSummary(logger); KernelStats::GetGlobal().Print(logger, true); PrintContentionSummary(logger); }); m_trace.SessionSummary(0, writer.GetData(), writer.GetPosition()); } for (auto s : m_clientSessions) { s->~ClientSession(); aligned_free(s); } m_clientSessions.clear(); #if 0 for (auto& kv : m_directoryTable.m_lookup) { DirectoryTable::Directory& dir = kv.second; DirectoryTable::EntryLookup files(m_directoryTable.m_memoryBlock); m_directoryTable.PopulateDirectoryRecursive(StringKeyHasher(), dir.tableOffset, 0, files); for (auto& fileKv : files) { BinaryReader reader(m_directoryTable.m_memory, fileKv.second); StringBuffer<> filename; reader.ReadString(filename); m_logger.Info(filename.data); } } #endif } ProcessHandle SessionServer::RunProcessRacing(u32 raceAgainstRemoteProcessId) { // TODO: Implement return {}; } ProcessHandle SessionServer::RunProcessRemote(const ProcessStartInfo& startInfo, float weight, const void* knownInputs, u32 knownInputsCount, bool allowCrossArchitecture) { //TrackWorkScope tws(m_trace, AsView(TC("RunProcessRemote")), ColorWork); UBA_ASSERT(!startInfo.startSuspended); FlushDeadProcesses(); ValidateStartInfo(startInfo); u32 processId = CreateProcessId(); RemoteProcess* remoteProcess = new RemoteProcess(this, startInfo, processId, weight); auto rules = GetRules(remoteProcess->m_startInfo); remoteProcess->m_startInfo.rules = rules; remoteProcess->m_allowCrossArchitecture = allowCrossArchitecture; if (knownInputsCount) { remoteProcess->m_knownInputsDone.Create(true); auto kiBegin = (const tchar*)knownInputs; auto kiEnd = kiBegin; for (u32 i=0;i!=knownInputsCount; ++i) kiEnd += TStrlen(kiEnd) + 1; u32 knownInputsBytes = u32((kiEnd - kiBegin)*sizeof(tchar)); void* knownInputsCopy = malloc(knownInputsBytes); memcpy(knownInputsCopy, knownInputs, knownInputsBytes); #ifdef __clang_analyzer__ free(knownInputsCopy); // doesn't seem to understand it is handed over #endif m_server.AddWork([remoteProcess, this, knownInputsCopy, knownInputsCount, rules](const WorkContext& context) { auto keys = remoteProcess->m_knownInputs = new RemoteProcess::KnownInput[knownInputsCount]; struct Container { struct iterator { iterator() : ptr(nullptr), index(0) {} iterator(const tchar* p, u32 i) : ptr(p), index(i) {} iterator operator++(int) { auto prev = ptr; ptr += TStrlen(ptr) + 1; return iterator(prev, index++); } const tchar* operator*() { return ptr; } bool operator==(const iterator& o) const { return index == o.index; } const tchar* ptr; u32 index; }; Container(const tchar* b, u32 c) : ptr(b), count(c) {} iterator begin() const { return iterator(ptr, 0); } iterator end() const { return iterator(ptr, count); } u32 size() { return count; } const tchar* ptr; u32 count; }; Container container((const tchar*)knownInputsCopy, knownInputsCount); Atomic keysIndex; const TString& workingDir = remoteProcess->m_startInfo.workingDirStr; m_server.ParallelFor(knownInputsCount, container, [&](const WorkContext&, auto& it) { //TrackWorkScope tws2(m_trace, AsView(TC("KnownInputs")), ColorWork); StringBuffer<> fileName; FixPath(*it, workingDir.c_str(), u32(workingDir.size()), fileName); StringKey fileNameKey = CaseInsensitiveFs ? ToStringKeyLower(fileName) : ToStringKey(fileName); //tws2.AddHint(fileName); // Make sure cas entry exists and caskey is calculated (cas content creation is deferred in case client already has it) CasKey casKey; if (!StoreCasFile(casKey, fileNameKey, fileName.data) || casKey == CasKeyZero) return; UBA_ASSERT(keysIndex < knownInputsCount); auto& ki = keys[keysIndex++]; ki.key = casKey; ki.mappingAlignment = GetMemoryMapAlignment(fileName, true); ki.allowProxy = rules->AllowStorageProxy(fileName); // Update name to hash table SCOPED_WRITE_LOCK(m_nameToHashLookupLock, lock); CasKey& lookupCasKey = m_nameToHashLookup[fileNameKey]; if (lookupCasKey != casKey) { lookupCasKey = casKey; BinaryWriter w(m_nameToHashTableMem.memory, m_nameToHashTableMem.writtenSize, NameToHashMemSize); m_nameToHashTableMem.AllocateNoLock(sizeof(StringKey) + sizeof(CasKey), 1, TC("NameToHashTable")); w.WriteStringKey(fileNameKey); w.WriteCasKey(lookupCasKey); } }, AsView(TC("KnownInputsLoop")), true); remoteProcess->m_knownInputsCount = keysIndex; remoteProcess->m_knownInputsDone.Set(); free(knownInputsCopy); }, 1, TC("KnownInputs")); } ProcessHandle h(remoteProcess); // Keep ref count up even if process is removed by callbacks etc. SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); m_queuedRemoteProcesses.push_back(remoteProcess); SCOPED_READ_LOCK(m_remoteProcessReturnedEventLock, lock2); if (m_remoteProcessReturnedEvent) { if (!m_remoteExecutionEnabled) { m_logger.Info(TC("Process queued for remote but remote execution was disabled, returning process to queue")); m_remoteProcessReturnedEvent(*remoteProcess); } else if (!m_connectionCount) { m_logger.Info(TC("Process queued for remote but there are no active connections, returning process to queue")); m_remoteProcessReturnedEvent(*remoteProcess); } } return h; } void SessionServer::DisableRemoteExecution() { SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); if (m_remoteExecutionEnabled) m_logger.Info(TC("Disable remote execution (remote sessions will finish current processes)")); m_remoteExecutionEnabled = false; m_trace.RemoteExecutionDisabled(); } bool SessionServer::IsRemoteExecutionDisabled() { return !m_remoteExecutionEnabled; } void SessionServer::ReenableRemoteExecution() { SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); if (m_remoteExecutionEnabled) return; m_logger.Info(TC("Reenabled remote execution")); m_remoteExecutionEnabled = true; //m_trace.RemoteExecutionDisabled(); } void SessionServer::SetCustomCasKeyFromTrackedInputs(const tchar* fileName_, const tchar* workingDir_, const u8* trackedInputs, u32 trackedInputsBytes) { StringBuffer<> workingDir; FixFileName(workingDir, workingDir_, nullptr); if (workingDir[workingDir.count - 1] != '\\') workingDir.Append(TCV("\\")); StringBuffer<> fileName; FixFileName(fileName, fileName_, workingDir.data); StringKey fileNameKey = ToStringKey(fileName); SCOPED_FUTEX(m_customCasKeysLock, lock); auto insres = m_customCasKeys.try_emplace(fileNameKey); CustomCasKey& customKey = insres.first->second; customKey.casKey = CasKeyZero; customKey.workingDir = workingDir.data; customKey.trackedInputs.resize(trackedInputsBytes); memcpy(customKey.trackedInputs.data(), trackedInputs, trackedInputsBytes); //m_logger.Debug(TC("Registered file using custom cas %s (%s)"), fileName_, GuidToString(fileNameHash).str); } bool SessionServer::GetCasKeyFromTrackedInputs(CasKey& out, const tchar* fileName, const tchar* workingDir, const u8* data, u32 dataLen) { u64 workingDirLen = TStrlen(workingDir); BinaryReader reader(data); CasKeyHasher hasher; while (reader.GetPosition() < dataLen) { tchar str[512] = { 0 }; reader.ReadString(str, sizeof_array(str)); tchar* path = str; tchar temp[512]; if (str[1] != ':' && (TStrstr(str, TC(".dll")) || TStrstr(str, TC(".exe")))) { bool res = SearchPathW(NULL, str, NULL, 512, temp, NULL); UBA_ASSERT(res); if (!res) return false; path = temp; } StringBuffer<> inputFileName; FixPath(path, workingDir, workingDirLen, inputFileName); if (inputFileName.StartsWith(m_tempPath.data)) continue; if (inputFileName.Equals(fileName)) continue; if (inputFileName.StartsWith(m_systemPath.data)) continue; CasKey casKey; bool deferCreation = true; if (!m_storage.StoreCasFile(casKey, path, CasKeyZero, deferCreation)) return false; UBA_ASSERTF(casKey != CasKeyZero, TC("Failed to store cas for %s when calculating key for tracked inputs on %s"), path, fileName); hasher.Update(&casKey, sizeof(CasKey)); } out = ToCasKey(hasher, m_storage.StoreCompressed()); return true; } void SessionServer::SetRemoteProcessSlotAvailableEvent(const Function& remoteProcessSlotAvailableEvent) { SCOPED_WRITE_LOCK(m_remoteProcessSlotAvailableEventLock, lock); m_remoteProcessSlotAvailableEvent = remoteProcessSlotAvailableEvent; } void SessionServer::SetRemoteProcessReturnedEvent(const Function& remoteProcessReturnedEvent) { SCOPED_WRITE_LOCK(m_remoteProcessReturnedEventLock, lock); m_remoteProcessReturnedEvent = remoteProcessReturnedEvent; } void SessionServer::WaitOnAllTasks() { while (true) { SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); if (m_activeRemoteProcesses.empty() && m_queuedRemoteProcesses.empty()) break; lock.Leave(); Sleep(200); } bool isEmpty = false; while (!isEmpty) { Vector processes; { SCOPED_FUTEX(m_processesLock, lock); isEmpty = m_processes.empty(); processes.reserve(m_processes.size()); for (auto& pair : m_processes) processes.push_back(pair.second); } for (auto& process : processes) process.WaitForExit(100000); } FlushDeadProcesses(); } void SessionServer::SetMaxRemoteProcessCount(u32 count) { m_maxRemoteProcessCount.exchange(count); } u32 SessionServer::BeginExternalProcess(const tchar* description, const tchar* breadcrumbs) { u32 processId = CreateProcessId(); m_trace.ProcessAdded(0, processId, ToView(description), ToView(breadcrumbs)); return processId; } void SessionServer::EndExternalProcess(u32 id, u32 exitCode) { StackBinaryWriter<1024> statsWriter; ProcessStats processStats; processStats.Write(statsWriter); m_trace.ProcessExited(id, exitCode, statsWriter.GetData(), statsWriter.GetPosition(), Vector()); } void SessionServer::UpdateProgress(u32 processesTotal, u32 processesDone, u32 errorCount) { m_trace.ProgressUpdate(processesTotal, processesDone, errorCount); } void SessionServer::UpdateStatus(u32 statusRow, u32 statusColumn, const tchar* statusText, LogEntryType statusType, const tchar* statusLink) { m_trace.StatusUpdate(statusRow, statusColumn, statusText, statusType, statusLink); } void SessionServer::AddProcessBreadcrumbs(u32 processId, const tchar* breadcrumbs, bool deleteOld) { m_trace.ProcessAddBreadcrumbs(processId, ToView(breadcrumbs), deleteOld); } NetworkServer& SessionServer::GetServer() { return m_server; } void SessionServer::RegisterNetworkTrafficProvider(const NetworkTrafficProvider& provider) { SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); m_provider = provider; } void SessionServer::RegisterCrossArchitectureMapping(const tchar* from, const tchar* to) { m_crossArchitectureMappings.emplace_back(CrossArchitectureMapping{from, to}); } void SessionServer::SetOuterScheduler(Scheduler* scheduler) { UBA_ASSERT(!m_outerScheduler || !scheduler); m_outerScheduler = scheduler; } Scheduler* SessionServer::GetOuterScheduler() { return m_outerScheduler; } void SessionServer::OnDisconnected(const Guid& clientUid, u32 clientId) { u32 returnCount = 0; SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, queueLock); for (auto it=m_activeRemoteProcesses.begin(); it!=m_activeRemoteProcesses.end();) { RemoteProcess* remoteProcess = (RemoteProcess*)it->m_process; if (remoteProcess->m_clientId != clientId) { ++it; continue; } m_queuedRemoteProcesses.push_front(*it); it = m_activeRemoteProcesses.erase(it); remoteProcess->m_executingHost.clear(); m_trace.ProcessReturned(remoteProcess->m_processId, AsView(TC("Disconnected"))); ProcessHandle h = ProcessRemoved(remoteProcess->m_processId); if (!h.m_process) m_logger.Warning(TC("Trying to remove process on client %u that does not exist in active list.. investigate me"), clientId); ++returnCount; remoteProcess->m_clientId = ~0u; remoteProcess->m_sessionId = 0; if (m_remoteProcessReturnedEvent) m_remoteProcessReturnedEvent(*remoteProcess); } m_returnedRemoteProcessCount += returnCount; u32 sessionId = 0; StringBuffer<> sessionName; for (auto sptr : m_clientSessions) { ++sessionId; auto& s = *sptr; if (s.clientId != clientId) continue; if (!returnCount && !s.hasNotification && !s.enabled) m_trace.SessionNotification(sessionId, TC("Done")); m_trace.SessionDisconnect(sessionId); sessionName.Append(s.name); UBA_ASSERTF(s.usedSlotCount == returnCount || m_logger.isMuted, TC("Used slot count different than return count (%u vs %u)"), s.usedSlotCount, returnCount); s.usedSlotCount -= returnCount; if (s.enabled) m_availableRemoteSlotCount -= s.processSlotCount - returnCount; s.enabled = false; s.connected = false; --m_connectionCount; } if (returnCount) { if (sessionName.IsEmpty()) sessionName.Append(TCV("")); m_logger.Info(TC("Client session %s (%s) disconnected. Returned %u process(s) to queue"), sessionName.data, GuidToString(clientUid).str, returnCount); } if (m_connectionCount) return; if (!m_queuedRemoteProcesses.empty()) { if (m_remoteProcessReturnedEvent) { m_logger.Info(TC("No client sessions connected and there are %llu processes left in the remote queue. Will return all queued remote processes"), m_queuedRemoteProcesses.size()); List temp(m_queuedRemoteProcesses); for (ProcessHandle& remoteProcess : temp) m_remoteProcessReturnedEvent(*remoteProcess.m_process); } else { m_logger.Info(TC("No client sessions connected and there are %llu processes left in the remote queue. processes will be picked up when remote connection is established"), m_queuedRemoteProcesses.size()); } } if (!m_activeRemoteProcesses.empty()) { // This path has been seen in the wild (Once over million of runs).... // And the theory is that m_logger.Error(TC("No client sessions connected but there are %llu active remote processes. This should not happen, there is a bug in the code!!"), m_activeRemoteProcesses.size()); } } bool SessionServer::HandleConnect(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { StringBuffer<128> name; reader.ReadString(name); u32 clientVersion = reader.ReadU32(); bool isClientArm = false; if (clientVersion >= 36) isClientArm = reader.ReadBool(); m_logger.Detail(TC("Client session %s connected (Id: %u, Uid: %s%s)"), name.data, connectionInfo.GetId(), GuidToString(connectionInfo.GetUid()).str, (isClientArm ? TC(", IsArm: true") : TC(""))); CasKey clientAgentKey = reader.ReadCasKey(); CasKey clientDetoursKey = reader.ReadCasKey();(void)clientDetoursKey; CasKey detoursBinaryKey[2]; CasKey& agentBinaryKey = m_agentBinaryKey[isClientArm]; bool binAsVersion = clientAgentKey != CasKeyZero; { SCOPED_FUTEX(m_binKeysLock, lock); detoursBinaryKey[0] = m_detoursBinaryKey[0]; detoursBinaryKey[1] = m_detoursBinaryKey[1]; StringBuffer<> detoursLib; bool deferCreation = true; // Handle x64 for both architectures and arm64 only for arm64 clients for (u32 i=0; i!=(isClientArm ? 2u : 1u); ++i) { if (detoursBinaryKey[i] != CasKeyZero) continue; detoursLib.Clear().Append(m_detoursLibrary[i].c_str()); if (!m_storage.StoreCasFile(detoursBinaryKey[i], detoursLib.data, CasKeyZero, deferCreation) || detoursBinaryKey[i] == CasKeyZero) return m_logger.Error(TC("Failed to create cas for %s"), detoursLib.data); m_detoursBinaryKey[i] = detoursBinaryKey[i]; } if (binAsVersion && agentBinaryKey == CasKeyZero) { StringBuffer<> agentDir; if (!GetDirectoryOfCurrentModule(m_logger, agentDir)) return false; if (IsArmBinary != isClientArm) if (!GetCrossArchitectureDir(m_logger, agentDir, true)) return false; UBA_ASSERT(IsWindows); agentDir.Append(PathSeparator).Append(UBA_AGENT_EXECUTABLE); if (!m_storage.StoreCasFile(agentBinaryKey, agentDir.data, CasKeyZero, deferCreation) || agentBinaryKey == CasKeyZero) { // This is hacky but uba binary is not copied by ubt anymore. StringBuffer<> dir2; if (!GetAlternativeUbaPath(m_logger, dir2, agentDir, IsWindows && isClientArm)) return false; dir2.Append(UBA_AGENT_EXECUTABLE); if (!m_storage.StoreCasFile(agentBinaryKey, dir2.data, CasKeyZero, deferCreation) || agentBinaryKey == CasKeyZero) return m_logger.Error(TC("Failed to create cas for %s"), dir2.data); } } } StringBuffer<> tempBuffer; auto& disconnectResponse = tempBuffer; if (binAsVersion && clientAgentKey != agentBinaryKey) { m_logger.Warning(TC("UbaAgent binaries mismatch. Disconnecting %s"), name.data); disconnectResponse.Appendf(TC("UbaAgent binaries mismatch. Disconnecting...")); } else if (clientVersion != SessionNetworkVersion) { m_logger.Warning(TC("Version mismatch. Server is on version %u while client is on %u. Disconnecting %s"), SessionNetworkVersion, clientVersion, name.data); disconnectResponse.Appendf(TC("Version mismatch. Server is on version %u while client is on %u. Disconnecting..."), SessionNetworkVersion, clientVersion); } writer.WriteBool(disconnectResponse.IsEmpty()); if (!disconnectResponse.IsEmpty()) { writer.WriteString(disconnectResponse); writer.WriteCasKey(agentBinaryKey); writer.WriteCasKey(detoursBinaryKey[0]); if (isClientArm) writer.WriteCasKey(detoursBinaryKey[1]); return true; } u32 processSlotCount = reader.ReadU32(); bool dedicated = reader.ReadBool(); StringBuffer<256> info; reader.ReadString(info); u64 memAvail = reader.ReadU64(); u64 memTotal = reader.ReadU64(); u32 cpuLoadValue = reader.ReadU32(); float cpuLoad = *(float*)&cpuLoadValue; // I have no explanation for this. On linux we get a shutdown crash when running through UBT if session is allocated with normal new // For now we will work around it by using aligned_alloc which seems to be working on all platforms auto& session = *new (aligned_alloc(alignof(ClientSession), sizeof(ClientSession))) ClientSession(); SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); m_clientSessions.push_back(&session); u32 sessionId = u32(m_clientSessions.size()); session.name = name.data; session.clientId = connectionInfo.GetId(); session.processSlotCount = processSlotCount; session.dedicated = dedicated; session.isArm = isClientArm; session.memAvail = memAvail; session.memTotal = memTotal; session.cpuLoad = cpuLoad; m_availableRemoteSlotCount += processSlotCount; ++m_connectionCount; if (!InitializeNameToHashTable()) return false; writer.WriteCasKey(m_detoursBinaryKey[0]); if (isClientArm) writer.WriteCasKey(m_detoursBinaryKey[1]); writer.WriteBool(m_resetCas); writer.WriteU32(sessionId); writer.WriteU32(m_uiLanguage); writer.WriteBool(m_storeIntermediateFilesCompressed); writer.WriteBool(m_detailedTrace); writer.WriteBool(m_remoteLogEnabled); writer.WriteBool(m_remoteTraceEnabled); writer.WriteBool(m_readIntermediateFilesCompressed); auto& computerName = tempBuffer.Clear(); GetComputerNameW(computerName); writer.WriteString(computerName); WriteRemoteEnvironmentVariables(writer); m_trace.SessionAdded(sessionId, connectionInfo.GetId(), name, info); // Must be inside lock for TraceSessionUpdate() to not include m_trace.SessionUpdate(sessionId, 1, 0, 0, 0, memAvail, memTotal, cpuLoad); lock.Leave(); return true; } bool SessionServer::HandleEnsureBinaryFile(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { bool clientIsArm = reader.ReadBool(); StringBuffer<> fileName; reader.ReadString(fileName); StringKey fileNameKey = reader.ReadStringKey(); TString applicationDir = reader.ReadString(); TString workingDir = reader.ReadString(); StringBuffer<> lookupStr; lookupStr.Append(fileName).Append(applicationDir).Append(workingDir).Append('#'); lookupStr.MakeLower(); StringKey lookupKey = ToStringKeyNoCheck(lookupStr.data, lookupStr.count); SCOPED_FUTEX(m_applicationDataLock, lock); auto insres = m_applicationData.try_emplace(lookupKey); ApplicationData& data = insres.first->second; lock.Leave(); SCOPED_FUTEX(data.lock, lock2); if (!data.bytes.empty()) { writer.WriteBytes(data.bytes.data(), data.bytes.size()); return true; } Vector loaderPaths; while (reader.GetLeft()) loaderPaths.push_back(reader.ReadString()); CasKey casKey = CasKeyZero; StringBuffer<> absoluteFile; auto FixCrossArchitecture = [this, clientIsArm](StringBuffer<>& absoluteFile) { if (clientIsArm == IsArmBinary) return; for (auto& mapping : m_crossArchitectureMappings) if (absoluteFile.StartsWith(mapping.from.c_str())) { StringBuffer<> temp; temp.Append(absoluteFile.data + mapping.from.size()); absoluteFile.Clear().Append(mapping.to).Append(temp); break; } }; if (!loaderPaths.empty()) { for (auto& loaderPath : loaderPaths) { StringBuffer<> fullPath; #if PLATFORM_LINUX if (loaderPath[0] != '/') // TODO: Revisit this.. should be done in a less hacky way. #endif fullPath.Append(applicationDir).EnsureEndsWithSlash(); fullPath.Append(loaderPath).EnsureEndsWithSlash().Append(fileName); if (GetFileAttributesW(fullPath.data) == INVALID_FILE_ATTRIBUTES) continue; FixPath(fullPath.data, nullptr, 0, absoluteFile); FixCrossArchitecture(absoluteFile); fileNameKey = ToStringKeyLower(absoluteFile); if (!StoreCasFile(casKey, fileNameKey, absoluteFile.data)) return false; break; } #if 0 if (casKey == CasKeyZero) { m_logger.Warning(TC("HandleEnsureBinaryFile - Failed to find file %s"), fileName.data); for (auto& loaderPath : loaderPaths) { m_logger.Warning(TC(" LoaderPath %s"), loaderPath.c_str()); } } #endif } else if (SearchPathForFile(m_logger, absoluteFile, fileName.data, workingDir, applicationDir)) { if (!absoluteFile.StartsWith(m_systemPath.data) || !IsKnownSystemFile(absoluteFile.data)) { FixCrossArchitecture(absoluteFile); fileNameKey = ToStringKeyLower(absoluteFile); if (!StoreCasFile(casKey, fileNameKey, absoluteFile.data)) return false; } } u64 startPos = writer.GetPosition(); writer.WriteCasKey(casKey); writer.WriteString(absoluteFile); u64 bytesSize = writer.GetPosition() - startPos; data.bytes.resize(bytesSize); memcpy(data.bytes.data(), writer.GetData() + startPos, bytesSize); return true; } bool SessionServer::HandleGetApplication(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = reader.ReadU32(); (void)processId; StringBuffer<> applicationName; reader.ReadString(applicationName); StringKey applicationKey = ToStringKeyLower(applicationName); SCOPED_FUTEX(m_applicationDataLock, lock); auto insres = m_applicationData.try_emplace(applicationKey); ApplicationData& data = insres.first->second; lock.Leave(); SCOPED_FUTEX(data.lock, lock2); if (!data.bytes.empty()) { writer.WriteBytes(data.bytes.data(), data.bytes.size()); return true; } u64 startPos = writer.GetPosition(); Vector modules; if (!GetBinaryModules(modules, applicationName.data)) return false; writer.WriteU32(m_systemPath.count); writer.WriteU32(u32(modules.size())); for (BinaryModule& m : modules) { CasKey casKey; if (!StoreCasFile(casKey, StringKeyZero, m.path.c_str())) return false; writer.WriteString(m.path); writer.WriteU32(m.fileAttributes); writer.WriteBool(m.isSystem); writer.WriteCasKey(casKey); #if PLATFORM_MAC writer.WriteU32(m.minOsVersion); #endif } u64 bytesSize = writer.GetPosition() - startPos; data.bytes.resize(bytesSize); memcpy(data.bytes.data(), writer.GetData() + startPos, bytesSize); return true; } bool SessionServer::HandleGetFileFromServer(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = reader.ReadU32(); (void)processId; StringBuffer<> fileName; reader.ReadString(fileName); StringKey fileNameKey = reader.ReadStringKey(); workContext.tracker.AddHint(StringView(fileName).GetFileName()); CasKey casKey; if (!StoreCasFile(casKey, fileNameKey, fileName.data)) return false; if (casKey == CasKeyZero) { // TODO: Should this instead use DirectoryTable? (it is currently not properly populated for lookups) u32 attr = GetFileAttributesW(fileName.data); if (attr == INVALID_FILE_ATTRIBUTES || !IsDirectory(attr)) { // Not finding a file is a valid path. Some applications try with a path and if fails try another path //m_logger.Error(TC("Failed to create cas for %s (not found)"), fileName.data); writer.WriteCasKey(casKey); return true; } casKey = CasKeyIsDirectory; } u64 serverTime; if (m_nameToHashInitialized && casKey != CasKeyIsDirectory) { SCOPED_WRITE_LOCK(m_nameToHashLookupLock, lock); serverTime = GetTime(); CasKey& lookupCasKey = m_nameToHashLookup[fileNameKey]; if (lookupCasKey != casKey) { lookupCasKey = casKey; BinaryWriter w(m_nameToHashTableMem.memory, m_nameToHashTableMem.writtenSize, NameToHashMemSize); m_nameToHashTableMem.AllocateNoLock(sizeof(StringKey) + sizeof(CasKey), 1, TC("NameToHashTable")); w.WriteStringKey(fileNameKey); w.WriteCasKey(casKey); } } else serverTime = GetTime(); writer.WriteCasKey(casKey); writer.WriteU64(serverTime); return true; } bool SessionServer::HandleGetLongPathName(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { #if PLATFORM_WINDOWS StringBuffer<> shortPath; reader.ReadString(shortPath); StringBuffer<> longPath; longPath.count = ::GetLongPathNameW(shortPath.data, longPath.data, longPath.capacity); writer.WriteU32(GetLastError()); writer.WriteString(longPath); return true; #else return false; #endif } bool SessionServer::HandleSendFileToServer(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 clientId = connectionInfo.GetId(); u32 processId = reader.ReadU32(); StringBuffer<> destination; reader.ReadString(destination); StringKey destinationKey = reader.ReadStringKey(); u32 attributes = reader.ReadU32(); UBA_ASSERT(attributes); CasKey casKey = reader.ReadCasKey(); Storage::RetrieveResult res; bool success = m_storage.RetrieveCasFile(res, casKey, destination.data, nullptr, 1, true, clientId); casKey = res.casKey; if (!success) { auto logType = connectionInfo.ShouldDisconnect() ? LogEntryType_Info : LogEntryType_Warning; m_logger.Logf(logType, TC("Failed to retrieve cas for %s from client %u (Needed to write %s)"), CasKeyString(casKey).str, clientId, destination.data); } auto writeResponse = MakeGuard([&]() { writer.WriteBool(success); }); bool shouldWriteToDisk = ShouldWriteToDisk(destination); if (success) { if (destination.StartsWith(TC(""))) { StringBuffer<> logPath; logPath.Append(m_sessionLogDir).Append(destination.data + 5); if (!m_storage.CopyOrLink(casKey, logPath.data, attributes)) m_logger.Error(TC("Failed to copy cas from %s to %s"), CasKeyString(casKey).str, logPath.data); else if (!m_storage.DropCasFile(casKey, false, logPath.data)) m_logger.Error(TC("Failed to drop cas %s"), CasKeyString(casKey).str); return true; } if (destination.StartsWith(TC(""))) { StringBuffer<> ubaPath; ubaPath.Append(m_sessionLogDir); ClientSession* session = nullptr; for (auto& s : m_clientSessions) if (s->clientId == clientId) session = s; if (session) ubaPath.Append(session->name); else ubaPath.Append(TCV("Connection")).AppendValue(clientId); ubaPath.Append(TCV(".uba")); m_storage.CopyOrLink(casKey, ubaPath.data, attributes); m_storage.DropCasFile(casKey, false, ubaPath.data); return true; } if (shouldWriteToDisk) { bool writeCompressed = false; ProcessHandle h = GetProcess(processId); if (!h.IsValid()) { success = false; m_logger.Info(TC("Failed to find process for id %u when receiving SendFileToServer message"), processId); return false; } RootsHandle rootsHandle = h.GetStartInfo().rootsHandle; auto& rules = *h.GetStartInfo().rules; Storage::FormattingFunc formattingFunc; bool escapeSpaces; if (HasVfs(rootsHandle) && rules.ShouldDevirtualizeFile(destination, escapeSpaces)) { formattingFunc = [&](MemoryBlock& destData, const void* sourceData, u64 sourceSize, const tchar* hint) { return DevirtualizeDepsFile(rootsHandle, destData, sourceData, sourceSize, escapeSpaces, hint); }; } else if (m_storeIntermediateFilesCompressed) { writeCompressed = g_globalRules.FileCanBeCompressed(destination); } success = m_storage.CopyOrLink(casKey, destination.data, attributes, writeCompressed, formattingFunc); if (!success) m_logger.Error(TC("Failed to copy cas from %s to %s (%s)"), CasKeyString(casKey).str, destination.data, GetProcessDescription(processId).c_str()); else TraceWrittenFile(processId, destination); } else { success = m_storage.FakeCopy(casKey, destination.data); if (!success) m_logger.Error(TC("Failed to fake copy cas from %s to %s (%s)"), CasKeyString(casKey).str, destination.data, GetProcessDescription(processId).c_str()); SCOPED_WRITE_LOCK(m_receivedFilesLock, lock); m_receivedFiles.try_emplace(destinationKey, casKey); } } if (success) { bool invalidateStorage = false; // No need, already handled in m_storage.CopyOrLink RegisterCreateFileForWrite(StringKeyZero, destination, shouldWriteToDisk, 0, 0, invalidateStorage); SCOPED_FUTEX_READ(m_processesLock, lock); auto findIt = m_processes.find(processId); if (findIt != m_processes.end()) { ProcessHandle h(findIt->second); lock.Leave(); auto& process = *(RemoteProcess*)h.m_process; if (process.m_startInfo.trackInputs) { u64 bytes = GetStringWriteSize(destination.data, destination.count); u64 prevSize = process.m_trackedOutputs.size(); process.m_trackedOutputs.resize(prevSize + bytes); BinaryWriter w2(process.m_trackedOutputs.data(), prevSize, prevSize + bytes); w2.WriteString(destination); } } } return true; } bool SessionServer::HandleDeleteFile(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { StringKey fileNameKey = reader.ReadStringKey(); StringBuffer<> fileName; reader.ReadString(fileName); bool result = uba::DeleteFileW(fileName.data); u32 errorCode = GetLastError(); if (result) RegisterDeleteFile(fileNameKey, fileName); writer.WriteBool(result); writer.WriteU32(errorCode); return true; } bool SessionServer::HandleCopyFile(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { StringKey fromNameKey = reader.ReadStringKey(); (void)fromNameKey; StringBuffer<> fromName; reader.ReadString(fromName); StringKey toNameKey = reader.ReadStringKey(); StringBuffer<> toName; reader.ReadString(toName); bool result = uba::CopyFileW(fromName.data, toName.data, false); u32 errorCode = GetLastError(); if (result) RegisterCreateFileForWrite(toNameKey, toName, true); writer.WriteU32(errorCode); return true; } bool SessionServer::HandleCreateDirectory(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { CreateDirectoryMessage msg; reader.ReadString(msg.name); CreateDirectoryResponse response; if (!Session::CreateDirectory(response, msg)) return false; writer.WriteBool(response.result); writer.WriteU32(response.errorCode); return true; } bool SessionServer::HandleRemoveDirectory(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { RemoveDirectoryMessage msg; reader.ReadString(msg.name); RemoveDirectoryResponse response; if (!Session::RemoveDirectory(response, msg)) return false; writer.WriteBool(response.result); writer.WriteU32(response.errorCode); return true; } bool SessionServer::HandleListDirectory(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 sessionId = reader.ReadU32(); u32 sessionIndex = sessionId - 1; SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); if (sessionIndex >= m_clientSessions.size()) return m_logger.Error(TC("Got ListDirectory message from connection using bad sessionid (%u/%llu)"), sessionIndex, m_clientSessions.size()); ClientSession& session = *m_clientSessions[sessionIndex]; lock.Leave(); StringBuffer<> dirName; reader.ReadString(dirName); StringKey dirKey = reader.ReadStringKey(); ListDirectoryResponse out; GetListDirectoryInfo(out, dirName, dirKey); writer.WriteU32(out.tableOffset); WriteDirectoryTable(session, reader, writer); return true; } bool SessionServer::HandleGetDirectoriesFromServer(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 sessionId = reader.ReadU32(); u32 sessionIndex = sessionId - 1; SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); if (sessionIndex >= m_clientSessions.size()) return m_logger.Error(TC("Got GetDirectories message from connection using bad sessionid (%u/%llu)"), sessionIndex, m_clientSessions.size()); ClientSession& session = *m_clientSessions[sessionIndex]; lock.Leave(); WriteDirectoryTable(session, reader, writer); return true; } bool SessionServer::HandleGetNameToHashFromServer(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 requestedSize = reader.ReadU32(); SCOPED_READ_LOCK(m_nameToHashLookupLock, lock); if (requestedSize == ~0u) { requestedSize = u32(m_nameToHashTableMem.writtenSize); writer.WriteU32(requestedSize); } writer.WriteU64(GetTime()); lock.Leave(); WriteNameToHashTable(reader, writer, requestedSize); return true; } bool SessionServer::HandleProcessAvailable(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 sessionId = reader.ReadU32(); u32 sessionIndex = sessionId - 1; SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, sessionsLock); if (sessionIndex >= m_clientSessions.size()) return m_logger.Error(TC("Got ProcessAvailable message from connection using bad sessionid (%u/%llu)"), sessionIndex, m_clientSessions.size()); ClientSession& session = *m_clientSessions[sessionIndex]; sessionsLock.Leave(); bool isCrossArchitecture = IsArmBinary != session.isArm; u32 weight32 = reader.ReadU32(); float availableWeight = *(float*)&weight32; Vector processesWithKnownInputsToSend; float weightLeft = availableWeight; u32 addCount = 0; SCOPED_FUTEX(m_fillUpOneAtTheTimeLock, fillLock); // This is a lock to group files better (all clients connect at the same time); while (weightLeft > 0) { RemoteProcess* process = DequeueProcess(session, sessionId, connectionInfo.GetId()); if (process == nullptr) break; auto& startInfo = process->m_startInfo; StringBuffer<> applicationOverride; if (isCrossArchitecture) { auto returnProcess = [&]() { m_queuedRemoteProcesses.push_front({process}); if (m_remoteProcessReturnedEvent) m_remoteProcessReturnedEvent(*process); }; if (!process->m_allowCrossArchitecture) { // There is a risk cross architecture client dequeues another client's process that is not cross architecture // and in that case we just have to return it and try again later.. (break out) returnProcess(); break; } for (auto& mapping : m_crossArchitectureMappings) if (StartsWith(startInfo.application, mapping.from.c_str())) { applicationOverride.Append(mapping.to).Append(startInfo.application + mapping.from.size()); break; } if (applicationOverride.count) { //m_logger.Info(TC("Couldn't find cross architecture mapping for %s"), startInfo.application); //returnProcess(); //break; if (!FileExists(m_logger, applicationOverride.data)) { m_logger.Info(TC("Couldn't find cross architecture executable %s"), applicationOverride.data); returnProcess(); break; } } } ProcessAdded(*process, sessionId); writer.WriteU32(process->m_processId); startInfo.Write(writer, applicationOverride); if (process->m_knownInputsDone.IsCreated()) processesWithKnownInputsToSend.push_back(process); ++addCount; if (writer.GetCapacityLeft() < 5000) // Arbitrary number to cover all parameters above break; weightLeft -= startInfo.weight; } fillLock.Leave(); u32 neededDirectoryTableSize = GetDirectoryTableSize(); u32 neededHashTableSize; { SCOPED_READ_LOCK(m_nameToHashLookupLock, l); neededHashTableSize = u32(m_nameToHashTableMem.writtenSize); } sessionsLock.Enter(); //if (addCount) // m_logger.Debug(TC("Gave %u processes to %s using up %.1f weight out of %.1f available"), addCount, session.name.c_str(), availableWeight - weightLeft, availableWeight); bool remoteExecutionEnabled = m_remoteExecutionEnabled || !m_queuedRemoteProcesses.empty(); if (!remoteExecutionEnabled) { if (session.enabled) m_availableRemoteSlotCount -= session.processSlotCount - session.usedSlotCount; session.enabled = false; m_logger.Detail(TC("Disable remote execution on %s because remote execution has been disabled and queue is empty (will finish %u processes)"), session.name.c_str(), session.usedSlotCount); } // If this client session has 0 active processes and m_maxRemoteProcessCount < total available compute - client session, then we can disconnect this client if (remoteExecutionEnabled && !addCount && m_maxRemoteProcessCount != ~0u) { if (!session.dedicated && !session.usedSlotCount) { if (m_maxRemoteProcessCount < m_availableRemoteSlotCount - session.processSlotCount) { if (session.enabled) m_availableRemoteSlotCount -= session.processSlotCount - session.usedSlotCount; session.enabled = false; remoteExecutionEnabled = false; m_logger.Info(TC("Disable remote execution on %s because host session has enough help (%u left and %u remote slots)"), session.name.c_str(), m_maxRemoteProcessCount.load(), m_availableRemoteSlotCount); } } } sessionsLock.Leave(); writer.WriteU32(remoteExecutionEnabled ? SessionProcessAvailableResponse_None : SessionProcessAvailableResponse_RemoteExecutionDisabled); // Write in the needed dir and hash table offset to be up-to-date (to potentially avoid additional messages from client writer.WriteU32(neededDirectoryTableSize); writer.WriteU32(neededHashTableSize); // Collect known inputs Vector knownInputsToSend; for (auto process : processesWithKnownInputsToSend) if (process->m_knownInputsDone.IsSet(50*1000)) for (auto kiIt = process->m_knownInputs, kiEnd = kiIt + process->m_knownInputsCount; kiIt!=kiEnd; ++kiIt) if (session.sentKeys.insert(kiIt->key).second) knownInputsToSend.push_back(kiIt); // Send caskeys of known inputs so client can start retrieving them straight away u32 kiCapacity = u32(writer.GetCapacityLeft() - sizeof(u32)) / sizeof(RemoteProcess::KnownInput); u32 toSendCount = Min(kiCapacity, u32(knownInputsToSend.size())); writer.WriteU32(toSendCount); for (auto kv : knownInputsToSend) { if (!toSendCount--) break; writer.WriteCasKey(kv->key); writer.WriteU32(kv->mappingAlignment); writer.WriteBool(kv->allowProxy); } return true; } bool SessionServer::HandleProcessInputs(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = u32(reader.Read7BitEncoded()); ProcessHandle h(GetProcess(processId)); if (!h.IsValid()) { m_logger.Info(TC("Failed to find process for id %u when receiving custom message"), processId); return false; } auto& process = *(RemoteProcess*)h.m_process; auto& inputs = process.m_trackedInputs; u64 size = inputs.size(); if (u64 addCapacity = reader.Read7BitEncoded()) inputs.reserve(size + addCapacity); u64 toRead = reader.GetLeft(); inputs.resize(size + toRead); reader.ReadBytes(inputs.data() + size, toRead); return true; } bool SessionServer::HandleProcessFinished(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = reader.ReadU32(); ProcessHandle h = ProcessRemoved(processId); if (!h.m_process) { m_logger.Info(TC("Client finished process with id %u that is not found on server"), processId); return false; } auto& process = *(RemoteProcess*)h.m_process; SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, cs2); if (!m_activeRemoteProcesses.erase(&process)) { cs2.Leave(); m_logger.Info(TC("Got finished process but process was not in active remote processes. Was there a disconnect happening directly after but executed before?")); return false; } u32 sessionIndex = process.m_sessionId - 1; if (sessionIndex >= m_clientSessions.size()) return m_logger.Error(TC("Got ProcessFinished message from connection using bad sessionid (%u/%llu)"), sessionIndex, m_clientSessions.size()); auto& session = *m_clientSessions[sessionIndex]; ++m_finishedRemoteProcessCount; --session.usedSlotCount; if (session.enabled) ++m_availableRemoteSlotCount; process.m_clientId = ~0u; cs2.Leave(); u32 exitCode = reader.ReadU32(); u32 logLineCount = reader.ReadU32(); process.m_exitCode = exitCode; process.m_logLines.reserve(logLineCount); while (logLineCount-- != 0) { TString text = reader.ReadString(); LogEntryType type = LogEntryType(reader.ReadByte()); process.m_logLines.push_back({ std::move(text), type }); } if (auto func = process.m_startInfo.logLineFunc) for (auto& line : process.m_logLines) func(process.m_startInfo.logLineUserData, line.text.c_str(), u32(line.text.size()), line.type); u32 id = process.m_processId; Vector emptyLines; auto& logLines = (exitCode != 0 || m_detailedTrace) ? process.m_logLines : emptyLines; m_trace.ProcessExited(id, exitCode, reader.GetPositionData(), reader.GetLeft(), logLines); ProcessStats processStats; processStats.Read(reader, ~0u); //SessionStats sessionStats; //sessionStats.Read(reader); //StorageStats storageStats; //storageStats.Read(reader); process.m_processorTime = processStats.cpuTime; process.m_wallTime = processStats.wallTime; process.m_server = nullptr; process.m_done.Set(); process.CallProcessExit(h); return true; } bool SessionServer::HandleProcessReturned(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = reader.ReadU32(); StringBuffer<> reason; reader.ReadString(reason); ProcessHandle h = ProcessRemoved(processId); RemoteProcess* process = (RemoteProcess*)h.m_process; if (!process) { m_logger.Warning(TC("Client %s returned process %u that is not found on server (%s)"), GuidToString(connectionInfo.GetUid()).str, processId, reason.data); return true; } SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, cs2); if (!m_activeRemoteProcesses.erase(process)) { cs2.Leave(); m_logger.Warning(TC("Got returned process %u from client %s but process was not in active remote processes. Was there a disconnect happening directly after but executed before?"), processId, GuidToString(connectionInfo.GetUid()).str); return true; } u32 sessionIndex = process->m_sessionId - 1; if (sessionIndex >= m_clientSessions.size()) return m_logger.Error(TC("Got ProcessReturned message from connection using bad sessionid (%u/%llu)"), sessionIndex, m_clientSessions.size()); auto& session = *m_clientSessions[sessionIndex]; --session.usedSlotCount; if (session.enabled) ++m_availableRemoteSlotCount; m_logger.Detail(TC("Client %s returned process %u to queue (%s)"), session.name.c_str(), processId, reason.data); ++m_returnedRemoteProcessCount; process->m_executingHost.clear(); process->m_clientId = ~0u; process->m_sessionId = 0; m_trace.ProcessReturned(process->m_processId, reason); m_queuedRemoteProcesses.push_front(h); if (m_remoteProcessReturnedEvent) m_remoteProcessReturnedEvent(*process); return true; } bool SessionServer::HandleGetRoots(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { RootsHandle rootsHandle = reader.ReadU64(); auto rootsEntry = GetRootsEntry(rootsHandle); if (!rootsEntry) return false; writer.WriteBytes(rootsEntry->memory.data(), rootsEntry->memory.size()); return true; } bool SessionServer::HandleVirtualAllocFailed(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { m_logger.Error(TC("VIRTUAL ALLOC FAILING ON REMOTE MACHINE %s !"), GuidToString(connectionInfo.GetUid()).str); return true; } bool SessionServer::HandleGetTraceInformation(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 remotePos = reader.ReadU32(); u32 localPos; { SCOPED_FUTEX_READ(m_trace.m_memoryLock, l); localPos = u32(m_trace.m_memoryPos); } writer.WriteU32(localPos); u32 toWrite = Min(localPos - remotePos, u32(writer.GetCapacityLeft())); writer.WriteBytes(m_trace.m_memoryBegin + remotePos, toWrite); return true; } bool SessionServer::HandlePing(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { LOG_STALL_SCOPE(m_logger, 5, TC("HandlePing took more than %s")); u32 sessionId = reader.ReadU32(); u64 lastPing = reader.ReadU64(); u64 memAvail = reader.ReadU64(); u64 memTotal = reader.ReadU64(); u32 cpuLoadValue = reader.ReadU32(); u64 pingTime = GetTime(); u32 sessionIndex = sessionId - 1; SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); if (sessionIndex >= m_clientSessions.size()) return m_logger.Error(TC("Got Pingmessage from connection using bad sessionid (%u/%llu)"), sessionIndex, m_clientSessions.size()); auto& session = *m_clientSessions[sessionIndex]; session.pingTime = pingTime; session.lastPing = lastPing; session.memAvail = memAvail; session.memTotal = memTotal; session.cpuLoad = *(float*)&cpuLoadValue; writer.WriteBool(session.abort); writer.WriteBool(session.crashdump); session.crashdump = false; return true; } bool SessionServer::HandleNotification(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 sessionId = reader.ReadU32(); u32 sessionIndex = sessionId - 1; { SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); if (sessionIndex < m_clientSessions.size()) m_clientSessions[sessionIndex]->hasNotification = true; } StringBuffer<1024> str; reader.ReadString(str); m_trace.SessionNotification(sessionId, str.data); return true; } bool SessionServer::HandleGetNextProcess(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = reader.ReadU32(); u32 prevExitCode = reader.ReadU32(); ProcessHandle h(GetProcess(processId)); if (!h.IsValid()) { m_logger.Info(TC("Failed to find process for id %u when receiving GetNextProcess message"), processId); return false; } auto& remoteProcess = *(RemoteProcess*)h.m_process; SCOPED_FUTEX(remoteProcess.m_exitedLock, exitedLock); NextProcessInfo nextProcess; bool newProcess; remoteProcess.m_exitCode = prevExitCode; remoteProcess.m_done.Set(); bool success = GetNextProcess(remoteProcess, newProcess, nextProcess, prevExitCode, reader); remoteProcess.m_exitCode = ~0u; remoteProcess.m_done.Reset(); if (!success) return false; writer.WriteBool(newProcess); if (newProcess) { writer.WriteString(nextProcess.arguments); writer.WriteString(nextProcess.workingDir); writer.WriteString(nextProcess.description); writer.WriteString(nextProcess.logFile); } return true; } bool SessionServer::HandleCustom(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = reader.ReadU32(); ProcessHandle h(GetProcess(processId)); if (!h.IsValid()) { m_logger.Info(TC("Failed to find process for id %u when receiving custom message"), processId); return false; } auto& remoteProcess = *(RemoteProcess*)h.m_process; SCOPED_FUTEX(remoteProcess.m_exitedLock, exitedLock); CustomMessage(remoteProcess, reader, writer); return true; } bool SessionServer::HandleUpdateEnvironment(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 processId = reader.ReadU32(); ProcessHandle h(GetProcess(processId)); if (!h.IsValid()) { m_logger.Info(TC("Failed to find process for id %u when receiving update environment message"), processId); return false; } StringBuffer<> reason; reader.ReadString(reason); m_trace.ProcessEnvironmentUpdated(processId, reason, reader.GetPositionData(), reader.GetLeft(), ToView(h.GetStartInfo().breadcrumbs)); return true; } bool SessionServer::HandleSummary(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { u32 sessionId = reader.ReadU32(); m_trace.SessionSummary(sessionId, reader.GetPositionData(), reader.GetLeft()); return true; } bool SessionServer::HandleCommand(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { StringBuffer<128> command; reader.ReadString(command); auto WriteString = [&](const tchar* str, LogEntryType type = LogEntryType_Info) { writer.WriteByte(type); writer.WriteString(str); }; if (command.Equals(TCV("status"))) { u32 totalUsed = 0; u32 totalSlots = 0; SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, queueLock); u64 time = GetTime(); for (auto& s : m_clientSessions) { if (!s->enabled) continue; WriteString(StringBuffer<>().Appendf(TC("Session %u (%s)"), s->clientId, s->name.c_str()).data); WriteString(StringBuffer<>().Appendf(TC(" Process slots used %u/%u"), s->usedSlotCount, s->processSlotCount).data); if (s->pingTime) WriteString(StringBuffer<>().Appendf(TC(" Last ping %s ago"), TimeToText(time - s->pingTime).str).data); totalUsed += s->usedSlotCount; totalSlots += s->processSlotCount; } WriteString(StringBuffer<>().Appendf(TC("Total remote slots used %u/%u"), totalUsed, totalSlots).data); } if (command.Equals(TCV("crashdump"))) { WriteString(TC("Requesting crashdumps from all remotes on next ping")); SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, queueLock); for (auto& s : m_clientSessions) s->crashdump = true; } else if (command.StartsWith(TC("abort"))) { bool abortWithProxy = command.Equals(TCV("abortproxy")); bool abortUseProxy = command.Equals(TCV("abortnonproxy")); if (!abortWithProxy && !abortUseProxy) { abortWithProxy = true; abortUseProxy = true; } SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, queueLock); u32 abortCount = 0; for (auto& s : m_clientSessions) { if (!s->enabled || s->abort) continue; bool hasProxy = m_storage.HasProxy(s->clientId); if (abortWithProxy && hasProxy) s->abort = true; else if (abortUseProxy && !hasProxy) s->abort = true; if (s->abort) ++abortCount; } WriteString(StringBuffer<>().Appendf(TC("Aborting: %u remote sessions"), abortCount).data); } else if (command.Equals(TCV("disableremote"))) { DisableRemoteExecution(); WriteString(StringBuffer<>().Appendf(TC("Remote execution is disabled")).data); } else { WriteString(StringBuffer<>().Appendf(TC("Unknown command: %s"), command.data).data, LogEntryType_Error); } writer.WriteByte(255); return true; } bool SessionServer::HandleSHGetKnownFolderPath(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { #if PLATFORM_WINDOWS GUID kfid; reader.ReadBytes(&kfid, sizeof(GUID)); u32 flags = reader.ReadU32(); PWSTR str; static HMODULE moduleHandle = LoadLibrary(L"Shell32.dll"); using SHGetKnownFolderPathFunc = HRESULT(const GUID& rfid, DWORD dwFlags, HANDLE hToken, PWSTR* ppszPath); static SHGetKnownFolderPathFunc* SHGetKnownFolderPath = (SHGetKnownFolderPathFunc*)GetProcAddress(moduleHandle, "SHGetKnownFolderPath");; HRESULT res = SHGetKnownFolderPath(kfid, flags, NULL, &str); writer.WriteU32(res); if (res == S_OK) { writer.WriteString(str); CoTaskMemFree(str); } #endif return true; } bool SessionServer::StoreCasFile(CasKey& out, const StringKey& fileNameKey, const tchar* fileName) { CasKey casKeyOverride = CasKeyZero; bool deferCreation = true; { SCOPED_FUTEX(m_customCasKeysLock, lock); auto findIt = m_customCasKeys.find(fileNameKey); if (findIt != m_customCasKeys.end()) { CustomCasKey& customKey = findIt->second; if (customKey.casKey == CasKeyZero) { if (!GetCasKeyFromTrackedInputs(customKey.casKey, fileName, customKey.workingDir.c_str(), customKey.trackedInputs.data(), u32(customKey.trackedInputs.size()))) return false; UBA_ASSERTF(customKey.casKey != CasKeyZero, TC("This should never happen!!")); //m_logger.Debug(TC("Calculated custom key: %s (%s)"), GuidToString(customKey.casKey).str, fileName); } casKeyOverride = customKey.casKey; } } if (!m_storage.StoreCasFile(out, fileName, casKeyOverride, deferCreation)) // We can defer the creation of the cas file since client might already have it return false; return true;//out != CasKeyZero; } bool SessionServer::WriteDirectoryTable(ClientSession& session, BinaryReader& reader, BinaryWriter& writer) { auto& dirTable = m_directoryTable; SCOPED_FUTEX(session.dirTablePosLock, lock2); //m_logger.Info(TC("WritePos: %llu"), session.dirTablePos); writer.WriteU32(session.dirTablePos); // We can figure out on the other side if everything was written based on if the message is full or not. u32 toSend = GetDirectoryTableSize() - session.dirTablePos; if (toSend == 0) return true; u32 capacityLeft = u32(writer.GetCapacityLeft()); if (capacityLeft < toSend) toSend = capacityLeft; writer.WriteBytes(dirTable.m_memory + session.dirTablePos, toSend); session.dirTablePos += toSend; return true; } bool SessionServer::WriteNameToHashTable(BinaryReader& reader, BinaryWriter& writer, u32 requestedSize) { u32 remoteTableSize = reader.ReadU32(); u32 toSend = requestedSize - remoteTableSize; if (toSend == 0) return true; u32 capacityLeft = u32(writer.GetCapacityLeft()); if (capacityLeft < toSend) toSend = capacityLeft; writer.WriteBytes(m_nameToHashTableMem.memory + remoteTableSize, toSend); return true; } void SessionServer::ThreadMemoryCheckLoop() { u64 lastMessageTime = 0; while (true) { if (m_memoryThreadEvent.IsSet(1000)) break; #if 0 SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, queueLock); m_logger.Info(TC("RemoteQueue: %llu Active: %llu ConnectionCount: %u"), m_queuedRemoteProcesses.size(), m_activeRemoteProcesses.size(), m_connectionCount); if (!m_connectionCount && !m_activeRemoteProcesses.empty()) { for (auto& i : m_activeRemoteProcesses) { m_logger.Info(TC("ACTIVE PROCESS: %s"), i.GetStartInfo().GetDescription()); } break; } #endif u64 memAvail; u64 memTotal; if (!GetMemoryInfo(memAvail, memTotal)) m_memRequiredToSpawn = 0; m_memAvail = memAvail; bool allGood = false; while (memAvail >= m_memRequiredToSpawn) { SCOPED_FUTEX(m_waitingProcessesLock, lock); WaitingProcess* wp = m_oldestWaitingProcess; if (!wp) { allGood = true; break; } m_oldestWaitingProcess = wp->next; if (m_newestWaitingProcess == wp) m_newestWaitingProcess = nullptr; wp->event.Set(); memAvail -= m_memRequiredToSpawn; } if (allGood) continue; u64 time = GetTime(); if (TimeToMs(time - lastMessageTime) > 5*1000) { lastMessageTime = time; u32 delayCount = 0; SCOPED_FUTEX(m_waitingProcessesLock, lock); for (auto it = m_oldestWaitingProcess; it; it = it->next) ++delayCount; lock.Leave(); if (delayCount) { m_logger.BeginScope(); m_logger.Info(TC("Delaying %u processes from spawning due to memory pressure (Available: %s Total: %s)"), delayCount, BytesToText(m_memAvail).str, BytesToText(m_memTotal).str); #if PLATFORM_WINDOWS static bool hasBeenRunOnce; if (!hasBeenRunOnce) { hasBeenRunOnce = true; m_logger.Info(TC("NOTE - To mitigate this spawn delay it is recommended to make page file larger until you don't see these messages again (Or reduce number of max parallel processes)")); m_logger.Info(TC(" Set max page file to a large number (like 128gb). It will not use disk space unless you actually start using that amount of committed memory")); m_logger.Info(TC(" Also note, this is \"committed\" memory. Not memory in use. So you necessarily don't need more physical memory")); MEMORYSTATUSEX memStatus = { sizeof(memStatus) }; GlobalMemoryStatusEx(&memStatus); m_logger.Info(TC(" MaxPage: %s"), BytesToText(m_maxPageSize)); m_logger.Info(TC(" TotalPhys: %s"), BytesToText(memStatus.ullTotalPhys)); m_logger.Info(TC(" AvailPhys: %s"), BytesToText(memStatus.ullAvailPhys)); m_logger.Info(TC(" TotalPage: %s"), BytesToText(memStatus.ullTotalPageFile)); m_logger.Info(TC(" AvailPage: %s"), BytesToText(memStatus.ullAvailPageFile)); } #endif m_logger.EndScope(); } } if (!m_allowKillOnMem) continue; // TODO: This code path is not implemented yet... the cancel need to end up in a Requeue call. UBA_ASSERT(false); u64 memRequiredFree = u64(double(memTotal) * double(100 - m_memKillLoadPercent) / 100.0); if (m_memAvail < memRequiredFree) { u64 newestTime = 0; ProcessImpl* newestProcess = nullptr; SCOPED_FUTEX(m_processesLock, lock); for (auto& kv : m_processes) { ProcessHandle& h = kv.second; if (h.IsRemote()) continue; auto& p = *(ProcessImpl*)h.m_process; if (p.m_startTime <= newestTime) continue; newestTime = p.m_startTime; newestProcess = &p; } if (newestProcess) { newestProcess->Cancel(true); newestProcess->WaitForExit(3000); } m_logger.Info(TC("Killed process due to memory pressure (Available: %s Total: %s)"), BytesToText(m_memAvail).str, BytesToText(m_memTotal).str); } } SCOPED_FUTEX(m_waitingProcessesLock, lock); for (auto it = m_oldestWaitingProcess; it; it = it->next) it->event.Set(); m_oldestWaitingProcess = nullptr; m_newestWaitingProcess = nullptr; } SessionServer::RemoteProcess* SessionServer::DequeueProcess(ClientSession& session, u32 sessionId, u32 clientId) { //TrackWorkScope tws(m_trace, AsView(TC("DequeueProcess")), ColorWork); SCOPED_READ_LOCK(m_remoteProcessSlotAvailableEventLock, lock); bool hasCalledCallback = !m_remoteProcessSlotAvailableEvent; while (true) { SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, queueLock); if (!session.connected) // This should not be possible { m_logger.Warning(TC("Dequeing process to session that is not connected. This should never happen. Report to Epic (%u)"), clientId); return nullptr; } while (!m_queuedRemoteProcesses.empty()) { auto processHandle = m_queuedRemoteProcesses.front(); auto process = (RemoteProcess*)processHandle.m_process; m_queuedRemoteProcesses.pop_front(); if (process->m_cancelled) continue; if (session.enabled) --m_availableRemoteSlotCount; ++session.usedSlotCount; process->m_clientId = clientId; process->m_sessionId = sessionId; process->m_executingHost = session.name; UBA_ASSERT(!process->m_cancelled); m_activeRemoteProcesses.insert(process); return process; } queueLock.Leave(); if (hasCalledCallback) return nullptr; m_remoteProcessSlotAvailableEvent(IsArmBinary != session.isArm); hasCalledCallback = true; } return nullptr; } void SessionServer::OnCancelled(RemoteProcess* process) { ProcessHandle h(process); SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, queueLock); process->m_server = nullptr; if (process->m_clientId == ~0u) { for (auto it=m_queuedRemoteProcesses.begin(); it!=m_queuedRemoteProcesses.end(); ++it) { if (it->m_process != process) continue; m_queuedRemoteProcesses.erase(it); break; } } else { u32 sessionIndex = process->m_sessionId - 1; UBA_ASSERT(sessionIndex < m_clientSessions.size()); ClientSession& session = *m_clientSessions[sessionIndex]; --session.usedSlotCount; m_activeRemoteProcesses.erase(process); { SCOPED_FUTEX(m_processesLock, lock); m_processes.erase(process->m_processId); } queueLock.Leave(); StackBinaryWriter<1024> writer; ProcessStats().Write(writer); SessionStats().Write(writer); StorageStats().Write(writer); KernelStats().Write(writer); m_trace.ProcessExited(process->m_processId, process->m_exitCode, writer.GetData(), writer.GetPosition(), Vector()); } process->m_done.Set(); } ProcessHandle SessionServer::ProcessRemoved(u32 processId) { SCOPED_FUTEX(m_processesLock, lock); auto findIt = m_processes.find(processId); if (findIt == m_processes.end()) return {}; ProcessHandle h(findIt->second); m_processes.erase(findIt); return h; } ProcessHandle SessionServer::GetProcess(u32 processId) { SCOPED_FUTEX_READ(m_processesLock, lock); auto findIt = m_processes.find(processId); if (findIt == m_processes.end()) return {}; return ProcessHandle(findIt->second); } TString SessionServer::GetProcessDescription(u32 processId) { StringBuffer<512> str; SCOPED_FUTEX_READ(m_processesLock, lock); auto findIt = m_processes.find(processId); if (findIt == m_processes.end()) return str.Appendf(TC(""), processId).data; return str.Appendf(TC("%s"), findIt->second.GetStartInfo().GetDescription()).data; } bool SessionServer::PrepareProcess(ProcessImpl& process, bool isChild, StringBufferBase& outRealApplication, const tchar*& outRealWorkingDir) { if (!Session::PrepareProcess(process, isChild, outRealApplication, outRealWorkingDir)) return false; if (!m_memTotal || !m_allowWaitOnMem || isChild) return true; if (m_memAvail >= m_memRequiredToSpawn) return true; u64 startWait = GetTime(); WaitingProcess wp; wp.event.Create(true); SCOPED_FUTEX(m_waitingProcessesLock, lock); if (m_memoryThreadEvent.IsSet(0)) return false; if (!m_oldestWaitingProcess) m_oldestWaitingProcess = ℘ else m_newestWaitingProcess->next = ℘ m_newestWaitingProcess = ℘ lock.Leave(); wp.event.IsSet(); u64 waitTime = GetTime() - startWait; m_logger.Info(TC("Waited %s for memory pressure to go down (Available: %s Total: %s)"), TimeToText(waitTime).str, BytesToText(m_memAvail).str, BytesToText(m_memTotal).str); return true; } bool SessionServer::CreateFile(CreateFileResponse& out, const CreateFileMessage& msg) { if (!m_shouldWriteToDisk && ((msg.access & FileAccess_Write) == 0)) { SCOPED_READ_LOCK(m_receivedFilesLock, lock); auto findIt = m_receivedFiles.find(msg.fileNameKey); if (findIt != m_receivedFiles.end()) { u64 memoryMapAlignment = GetMemoryMapAlignment(msg.fileName); if (!memoryMapAlignment) memoryMapAlignment = 4096; MemoryMap map; if (!CreateMemoryMapFromView(map, msg.fileNameKey, msg.fileName.data, findIt->second, memoryMapAlignment)) return false; out.directoryTableSize = GetDirectoryTableSize(); out.mappedFileTableSize = GetFileMappingSize(); out.fileName.Append(map.name); out.size = map.size; return true; } } return Session::CreateFile(out, msg); } void SessionServer::FileEntryAdded(StringKey fileNameKey, u64 lastWritten, u64 size) { SCOPED_WRITE_LOCK(m_nameToHashLookupLock, lock); if (!m_nameToHashInitialized) return; Storage::CachedFileInfo cachedInfo; if (!m_storage.VerifyAndGetCachedFileInfo(cachedInfo, fileNameKey, lastWritten, size)) if (m_nameToHashLookup.find(fileNameKey) == m_nameToHashLookup.end()) return; CasKey& lookupCasKey = m_nameToHashLookup[fileNameKey]; if (lookupCasKey == cachedInfo.casKey) return; #if 0 //m_debugLogger->Info(TC("NAMETOHASHADD %s %s\n"), KeyToString(fileNameKey).data, CasKeyString(lookupCasKey).str); #endif lookupCasKey = cachedInfo.casKey; BinaryWriter w(m_nameToHashTableMem.memory, m_nameToHashTableMem.writtenSize, NameToHashMemSize); m_nameToHashTableMem.AllocateNoLock(sizeof(StringKey) + sizeof(CasKey), 1, TC("NameToHashTable")); w.WriteStringKey(fileNameKey); w.WriteCasKey(lookupCasKey); } bool SessionServer::RunSpecialProgram(ProcessImpl& process, BinaryReader& reader, BinaryWriter& writer) { TString application = reader.ReadString(); TString cmdLine = reader.ReadLongString(); TString workingDir = reader.ReadString(); UBA_ASSERT(StringView(application).Contains(TCV("UbaCli.exe"))); StringBuffer<> jsonFile; ParseArguments(cmdLine.data(), cmdLine.size(), [&](const tchar* arg, u32 argLen) { StringView sv(arg, argLen); if (sv.Contains(TCV(".json"))) jsonFile.Append(workingDir).EnsureEndsWithSlash().Append(sv); }); if (jsonFile.IsEmpty()) return false; ProcessImpl* rootProcess = &process; while (rootProcess->m_parentProcess) rootProcess = rootProcess->m_parentProcess; auto& startInfo = rootProcess->GetStartInfo(); UBA_ASSERTF(m_outerScheduler, TC("No outer scheduler set")); return m_outerScheduler->EnqueueFromSpecialJson(jsonFile.data, workingDir.c_str(), TC("UbaDistributor"), startInfo.rootsHandle, startInfo.userData); } void SessionServer::PrintSessionStats(Logger& logger) { Session::PrintSessionStats(logger); if (m_nameToHashLookup.size()) logger.Info(TC(" NameToHashLookup %7u %9s"), u32(m_nameToHashLookup.size()), BytesToText(m_nameToHashTableMem.writtenSize).str); logger.Info(TC(" Remote processes finished %8u"), m_finishedRemoteProcessCount); logger.Info(TC(" Remote processes returned %8u"), m_returnedRemoteProcessCount); logger.Info(TC("")); } void SessionServer::TraceSessionUpdate() { u32 sessionIndex = 1; u64 serverSend = m_server.GetTotalSentBytes(); u64 serverRecv = m_server.GetTotalRecvBytes(); SCOPED_CRITICAL_SECTION(m_remoteProcessAndSessionLock, lock); for (auto sptr : m_clientSessions) { auto& s = *sptr; NetworkServer::ClientStats stats; m_server.GetClientStats(stats, s.clientId); if (stats.connectionCount && (stats.send || stats.recv)) m_trace.SessionUpdate(sessionIndex, stats.connectionCount, stats.send, stats.recv, s.lastPing, s.memAvail, s.memTotal, s.cpuLoad); ++sessionIndex; } if (m_provider) { u64 send; u64 recv; m_provider(send, recv); serverSend += send; serverRecv += recv; } lock.Leave(); float cpuLoad = UpdateCpuLoad(); u64 memAvail = m_memAvail; u64 memTotal = m_memTotal; if (m_traceIOEnabled) { for (auto& volume : m_volumeCache.volumes) { if (volume.drives.empty()) continue; u8 busyPercent; u32 readCount; u64 readBytes; u32 writeCount; u64 writeBytes; if (!volume.UpdateStats(busyPercent, readCount, readBytes, writeCount, writeBytes)) continue; if (!busyPercent && !readCount && !readBytes && !writeCount && !writeBytes) continue; m_trace.DriveUpdate(volume.drives[0], busyPercent, readCount, readBytes, writeCount, writeBytes); } } m_trace.SessionUpdate(0, 0, serverSend, serverRecv, 0, memAvail, memTotal, cpuLoad); } void SessionServer::WriteRemoteEnvironmentVariables(BinaryWriter& writer) { if (!m_remoteEnvironmentVariables.empty()) { writer.WriteBytes(m_remoteEnvironmentVariables.data(), m_remoteEnvironmentVariables.size()); return; } u64 startPos = writer.GetPosition(); #if PLATFORM_WINDOWS auto strs = GetEnvironmentStringsW(); auto freeStrs = MakeGuard([strs]() { FreeEnvironmentStringsW(strs); }); #else auto strs = (const char*)GetProcessEnvironmentVariables(); #endif for (auto it = strs; *it; it += TStrlen(it) + 1) { StringBuffer<> varName; varName.Append(it, TStrchr(it, '=') - it); if (!varName.IsEmpty() && !varName.Equals(TCV("CL")) && !varName.Equals(TCV("_CL_"))) if (m_localEnvironmentVariables.find(varName.data) == m_localEnvironmentVariables.end()) writer.WriteString(it); } writer.WriteString(TC("")); u64 size = writer.GetPosition() - startPos; m_remoteEnvironmentVariables.resize(size); memcpy(m_remoteEnvironmentVariables.data(), writer.GetData() + startPos, size); } bool SessionServer::InitializeNameToHashTable() { if (!m_nameToHashTableEnabled || m_nameToHashInitialized) return true; SCOPED_WRITE_LOCK(m_nameToHashLookupLock, lock); m_nameToHashTableMem.Init(NameToHashMemSize); m_nameToHashInitialized = true; lock.Leave(); auto& dirTable = m_directoryTable; { Vector dirs; SCOPED_READ_LOCK(dirTable.m_lookupLock, dirsLock); dirs.reserve(dirTable.m_lookup.size()); for (auto& kv : dirTable.m_lookup) dirs.push_back(&kv.second); dirsLock.Leave(); for (auto dirPtr : dirs) { DirectoryTable::Directory& dir = *dirPtr; SCOPED_READ_LOCK(dir.lock, dirLock); for (auto& fileKv : dir.files) { StringKey fileNameKey = fileKv.first; BinaryReader reader(dirTable.m_memory, fileKv.second); u64 lastWritten = reader.ReadU64(); u32 attr = reader.ReadU32(); if (IsDirectory(attr)) continue; reader.Skip(sizeof(u32) + sizeof(u64)); u64 size = reader.ReadU64(); FileEntryAdded(fileNameKey, lastWritten, size); } } } SCOPED_WRITE_LOCK(m_nameToHashLookupLock, lock2); u64 entryCount = m_nameToHashLookup.size(); lock2.Leave(); m_logger.Debug(TC("Prepopulated NameToHash table with %u entries"), entryCount); return true; } bool SessionServer::HandleDebugFileNotFoundError(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { #if PLATFORM_WINDOWS StringBuffer<> errorPath; reader.ReadString(errorPath); StringBuffer<> workDir; reader.ReadString(workDir); StringView searchString = errorPath; if (searchString.data[0] == '.' && searchString.data[1] == '.') { searchString.data += 3; searchString.count -= 3; } auto LogLine = [&](const StringView& text) { m_logger.Log(LogEntryType_Warning, text.data, text.count); }; // Make a copy of dir table since we can't populate it MemoryBlock block(64*1024*1024); DirectoryTable dirTable(block); u8* dirMem; u32 dirMemSize; { SCOPED_READ_LOCK(m_directoryTable.m_memoryLock, lock); dirMem = m_directoryTableMem; dirMemSize = m_directoryTable.m_memorySize; } dirTable.Init(dirMem, 0, dirMemSize); u32 foundCount = 0; dirTable.TraverseAllFilesNoLock([&](const DirectoryTable::EntryInformation& info, const StringBufferBase& path, u32 dirOffset) { if (!path.EndsWith(searchString)) return; if (path[path.count - searchString.count - 1] != PathSeparator) return; auto ToString = [](bool b) { return b ? TC("true") : TC("false"); }; ++foundCount; StringBuffer<> logStr; logStr.Appendf(TC("File %s found in directory table at offset %u of %u while searching for matches for %s (File size %llu attr %u)"), path.data, dirOffset, dirTable.m_memorySize, searchString.data, info.size, info.attributes); LogLine(logStr); StringKey fileNameKey = ToStringKey(path); { SCOPED_FUTEX_READ(m_fileMappingTableLookupLock, mlock); auto findIt = m_fileMappingTableLookup.find(fileNameKey); if (findIt != m_fileMappingTableLookup.end()) { auto& entry = findIt->second; SCOPED_FUTEX_READ(entry.lock, entryCs); logStr.Clear().Appendf(TC("File %s found in mapping table table."), path.data); if (entry.handled) { StringBuffer<128> mappingName; if (entry.mapping.IsValid()) Storage::GetMappingString(mappingName, entry.mapping, entry.mappingOffset); else mappingName.Append(TCV("Not valid")); logStr.Appendf(TC(" Success: %s Size: %u IsDir: %s Mapping name: %s Mapping offset: %u"), ToString(entry.success), entry.size, ToString(entry.isDir), mappingName.data, entry.mappingOffset); } else { logStr.Appendf(TC(" Entry not handled")); } } else logStr.Clear().Appendf(TC("File %s not found in mapping table table."), path.data); LogLine(logStr); } { SCOPED_READ_LOCK(m_nameToHashLookupLock, hlock); auto findIt = m_nameToHashLookup.find(fileNameKey); if (findIt != m_nameToHashLookup.end()) logStr.Clear().Appendf(TC("File %s found in name-to-hash lookup. CasKey is %s"), path.data, CasKeyString(findIt->second).str); else logStr.Clear().Appendf(TC("File %s not found in name-to-hash lookup"), path.data); LogLine(logStr); } }); if (!foundCount) { StringBuffer<> logStr; logStr.Appendf(TC("No matching entry found in directory table while searching for matches for %s. DirTable size: %u"), searchString.data, GetDirectoryTableSize()); LogLine(logStr); /* if (errorPath.StartsWith(TC("..\\Intermediate"))) { StringBuffer<> fullPath; FixPath(errorPath.data, workDir.data, workDir.count, fullPath); GetFileAttributes } */ } #endif return true; } bool SessionServer::HandleHostRun(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { return HostRun(reader, writer); } bool SessionServer::HandleGetSymbols(const ConnectionInfo& connectionInfo, const WorkContext& workContext, BinaryReader& reader, BinaryWriter& writer) { TString application = reader.ReadString(); bool isClientArm = reader.ReadBool(); if (IsArmBinary != isClientArm) { writer.WriteString(TC("Can't resolve callstack on cross architectures")); return true; } GetSymbols(application.c_str(), isClientArm, reader, writer); if constexpr (DownloadDebugSymbols) { CasKey detoursSymbolsKey; StringBuffer<> dir; if (GetDirectoryOfCurrentModule(m_logger, dir)) { bool deferCreation = true; auto ChangeToSymbolExtension = [](StringBufferBase& str) -> StringBufferBase& { IsWindows ? str.Resize(str.count - 3).Append("pdb") : str.Resize(str.count - 2).Append("debug"); return str; }; if (!m_storage.StoreCasFile(detoursSymbolsKey, ChangeToSymbolExtension(dir).data, CasKeyZero, deferCreation) || detoursSymbolsKey == CasKeyZero) { StringBuffer<> dir2; if (GetAlternativeUbaPath(m_logger, dir2, dir, IsWindows && isClientArm)) m_storage.StoreCasFile(detoursSymbolsKey, ChangeToSymbolExtension(dir2.Append(UBA_DETOURS_LIBRARY)).data, CasKeyZero, deferCreation); } } writer.WriteCasKey(detoursSymbolsKey); } return true; } }