// Copyright Epic Games, Inc. All Rights Reserved. #include "LearningTrainer.h" #include "LearningObservation.h" #include "LearningAction.h" #include "HAL/Platform.h" #include "Dom/JsonObject.h" #include "Misc/Paths.h" namespace UE::Learning { FSubprocess::~FSubprocess() { Terminate(); } bool FSubprocess::Launch(const FString& Path, const FString& Params, const ESubprocessFlags Flags) { ensureMsgf(!bIsLaunched, TEXT("Subprocess already launched.")); Terminate(); const bool bCreatePipes = !(Flags & ESubprocessFlags::NoRedirectOutput); const bool bHideWindow = !(Flags & ESubprocessFlags::ShowWindow); if (bCreatePipes && !FPlatformProcess::CreatePipe(ReadPipe, WritePipe)) { return false; } ProcessHandle = FPlatformProcess::CreateProc(*Path, *Params, false, bHideWindow, bHideWindow, nullptr, 0, *FPaths::RootDir(), WritePipe, ReadPipe); bIsLaunched = true; return true; } bool FSubprocess::IsRunning() const { return bIsLaunched && FPlatformProcess::IsProcRunning(const_cast(ProcessHandle)); } void FSubprocess::Terminate() { if (IsRunning()) { UE_LOG(LogLearning, Display, TEXT("Terminating Subprocess...")); FPlatformProcess::TerminateProc(ProcessHandle, true); } Update(); } bool FSubprocess::Update() { // Do nothing if the process is not launched if (!bIsLaunched) { return false; } // Append the process stdout to the buffer OutputBuffer += FPlatformProcess::ReadPipe(ReadPipe); // Output all the complete lines int32 LineStartIdx = 0; for (int32 Idx = 0; Idx < OutputBuffer.Len(); Idx++) { if (OutputBuffer[Idx] == '\r' || OutputBuffer[Idx] == '\n') { UE_LOG(LogLearning, Display, TEXT("Subprocess: %s"), *OutputBuffer.Mid(LineStartIdx, Idx - LineStartIdx)); if (OutputBuffer[Idx] == '\r' && Idx + 1 < OutputBuffer.Len() && OutputBuffer[Idx + 1] == '\n') { Idx++; } LineStartIdx = Idx + 1; } } // Remove all the complete lines from the buffer OutputBuffer.MidInline(LineStartIdx, MAX_int32, EAllowShrinking::Yes); // If the process is no longer running then close the pipes if (!IsRunning()) { FPlatformProcess::ClosePipe(ReadPipe, WritePipe); ReadPipe = nullptr; WritePipe = nullptr; bIsLaunched = false; return false; } return true; } } namespace UE::Learning::Trainer { TSharedPtr ConvertObservationSchemaToJSON( const Observation::FSchema& ObservationSchema, const Observation::FSchemaElement& ObservationSchemaElement) { TSharedPtr Object = MakeShared(); Object->SetNumberField(TEXT("VectorSize"), ObservationSchema.GetObservationVectorSize(ObservationSchemaElement)); Object->SetNumberField(TEXT("EncodedSize"), ObservationSchema.GetEncodedVectorSize(ObservationSchemaElement)); switch (ObservationSchema.GetType(ObservationSchemaElement)) { case Observation::EType::Null: { Object->SetStringField(TEXT("Type"), TEXT("Null")); break; } case Observation::EType::Continuous: { const Observation::FSchemaContinuousParameters Parameters = ObservationSchema.GetContinuous(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("Continuous")); Object->SetNumberField(TEXT("Num"), Parameters.Num); break; } case Observation::EType::DiscreteExclusive: { const Observation::FSchemaDiscreteExclusiveParameters Parameters = ObservationSchema.GetDiscreteExclusive(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("DiscreteExclusive")); Object->SetNumberField(TEXT("Num"), Parameters.Num); break; } case Observation::EType::DiscreteInclusive: { const Observation::FSchemaDiscreteInclusiveParameters Parameters = ObservationSchema.GetDiscreteInclusive(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("DiscreteInclusive")); Object->SetNumberField(TEXT("Num"), Parameters.Num); break; } case Observation::EType::NamedDiscreteExclusive: { const Observation::FSchemaNamedDiscreteExclusiveParameters Parameters = ObservationSchema.GetNamedDiscreteExclusive(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteExclusive")); TArray> ElementNames; for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++) { ElementNames.Add(MakeShared(Parameters.ElementNames[ElementIdx].ToString())); } Object->SetArrayField(TEXT("ElementNames"), ElementNames); break; } case Observation::EType::NamedDiscreteInclusive: { const Observation::FSchemaNamedDiscreteInclusiveParameters Parameters = ObservationSchema.GetNamedDiscreteInclusive(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteInclusive")); TArray> ElementNames; for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++) { ElementNames.Add(MakeShared(Parameters.ElementNames[ElementIdx].ToString())); } Object->SetArrayField(TEXT("ElementNames"), ElementNames); break; } case Observation::EType::And: { const Observation::FSchemaAndParameters Parameters = ObservationSchema.GetAnd(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("And")); TSharedPtr SubObject = MakeShared(); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { TSharedPtr SubElement = ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Elements[SubElementIdx]); SubElement->SetNumberField(TEXT("Index"), SubElementIdx); SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement); } Object->SetObjectField(TEXT("Elements"), SubObject); break; } case Observation::EType::OrExclusive: { const Observation::FSchemaOrExclusiveParameters Parameters = ObservationSchema.GetOrExclusive(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("OrExclusive")); Object->SetNumberField(TEXT("EncodingSize"), Parameters.EncodingSize); TSharedPtr SubObject = MakeShared(); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { TSharedPtr SubElement = ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Elements[SubElementIdx]); SubElement->SetNumberField(TEXT("Index"), SubElementIdx); SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement); } Object->SetObjectField(TEXT("Elements"), SubObject); break; } case Observation::EType::OrInclusive: { const Observation::FSchemaOrInclusiveParameters Parameters = ObservationSchema.GetOrInclusive(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("OrInclusive")); Object->SetNumberField(TEXT("AttentionEncodingSize"), Parameters.AttentionEncodingSize); Object->SetNumberField(TEXT("AttentionHeadNum"), Parameters.AttentionHeadNum); Object->SetNumberField(TEXT("ValueEncodingSize"), Parameters.ValueEncodingSize); TSharedPtr SubObject = MakeShared(); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { TSharedPtr SubElement = ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Elements[SubElementIdx]); SubElement->SetNumberField(TEXT("Index"), SubElementIdx); SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement); } Object->SetObjectField(TEXT("Elements"), SubObject); break; } case Observation::EType::Array: { const Observation::FSchemaArrayParameters Parameters = ObservationSchema.GetArray(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("Array")); Object->SetNumberField(TEXT("Num"), Parameters.Num); Object->SetObjectField(TEXT("Element"), ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Element)); break; } case Observation::EType::Set: { const Observation::FSchemaSetParameters Parameters = ObservationSchema.GetSet(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("Set")); Object->SetNumberField(TEXT("MaxNum"), Parameters.MaxNum); Object->SetNumberField(TEXT("AttentionEncodingSize"), Parameters.AttentionEncodingSize); Object->SetNumberField(TEXT("AttentionHeadNum"), Parameters.AttentionHeadNum); Object->SetNumberField(TEXT("ValueEncodingSize"), Parameters.ValueEncodingSize); Object->SetObjectField(TEXT("Element"), ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Element)); break; } case Observation::EType::Encoding: { const Observation::FSchemaEncodingParameters Parameters = ObservationSchema.GetEncoding(ObservationSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("Encoding")); Object->SetNumberField(TEXT("EncodingSize"), Parameters.EncodingSize); Object->SetObjectField(TEXT("Element"), ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Element)); break; } default: checkNoEntry(); } return Object; } TSharedPtr ConvertActionSchemaToJSON( const Action::FSchema& ActionSchema, const Action::FSchemaElement& ActionSchemaElement) { TSharedPtr Object = MakeShared(); Object->SetNumberField(TEXT("VectorSize"), ActionSchema.GetActionVectorSize(ActionSchemaElement)); Object->SetNumberField(TEXT("DistributionSize"), ActionSchema.GetActionDistributionVectorSize(ActionSchemaElement)); Object->SetNumberField(TEXT("EncodedSize"), ActionSchema.GetEncodedVectorSize(ActionSchemaElement)); Object->SetNumberField(TEXT("ModifierSize"), ActionSchema.GetActionModifierVectorSize(ActionSchemaElement)); switch (ActionSchema.GetType(ActionSchemaElement)) { case Action::EType::Null: { Object->SetStringField(TEXT("Type"), TEXT("Null")); break; } case Action::EType::Continuous: { const Action::FSchemaContinuousParameters Parameters = ActionSchema.GetContinuous(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("Continuous")); Object->SetNumberField(TEXT("Num"), Parameters.Num); break; } case Action::EType::DiscreteExclusive: { const Action::FSchemaDiscreteExclusiveParameters Parameters = ActionSchema.GetDiscreteExclusive(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("DiscreteExclusive")); Object->SetNumberField(TEXT("Num"), Parameters.Num); break; } case Action::EType::DiscreteInclusive: { const Action::FSchemaDiscreteInclusiveParameters Parameters = ActionSchema.GetDiscreteInclusive(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("DiscreteInclusive")); Object->SetNumberField(TEXT("Num"), Parameters.Num); break; } case Action::EType::NamedDiscreteExclusive: { const Action::FSchemaNamedDiscreteExclusiveParameters Parameters = ActionSchema.GetNamedDiscreteExclusive(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteExclusive")); TArray> ElementNames; for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++) { ElementNames.Add(MakeShared(Parameters.ElementNames[ElementIdx].ToString())); } Object->SetArrayField(TEXT("ElementNames"), ElementNames); break; } case Action::EType::NamedDiscreteInclusive: { const Action::FSchemaNamedDiscreteInclusiveParameters Parameters = ActionSchema.GetNamedDiscreteInclusive(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteInclusive")); TArray> ElementNames; for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++) { ElementNames.Add(MakeShared(Parameters.ElementNames[ElementIdx].ToString())); } Object->SetArrayField(TEXT("ElementNames"), ElementNames); break; } case Action::EType::And: { const Action::FSchemaAndParameters Parameters = ActionSchema.GetAnd(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("And")); TSharedPtr SubObject = MakeShared(); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { TSharedPtr SubElement = ConvertActionSchemaToJSON(ActionSchema, Parameters.Elements[SubElementIdx]); SubElement->SetNumberField(TEXT("Index"), SubElementIdx); SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement); } Object->SetObjectField(TEXT("Elements"), SubObject); break; } case Action::EType::OrExclusive: { const Action::FSchemaOrExclusiveParameters Parameters = ActionSchema.GetOrExclusive(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("OrExclusive")); TSharedPtr SubObject = MakeShared(); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { TSharedPtr SubElement = ConvertActionSchemaToJSON(ActionSchema, Parameters.Elements[SubElementIdx]); SubElement->SetNumberField(TEXT("Index"), SubElementIdx); SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement); } Object->SetObjectField(TEXT("Elements"), SubObject); break; } case Action::EType::OrInclusive: { const Action::FSchemaOrInclusiveParameters Parameters = ActionSchema.GetOrInclusive(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("OrInclusive")); TSharedPtr SubObject = MakeShared(); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { TSharedPtr SubElement = ConvertActionSchemaToJSON(ActionSchema, Parameters.Elements[SubElementIdx]); SubElement->SetNumberField(TEXT("Index"), SubElementIdx); SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement); } Object->SetObjectField(TEXT("Elements"), SubObject); break; } case Action::EType::Array: { const Action::FSchemaArrayParameters Parameters = ActionSchema.GetArray(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("Array")); Object->SetNumberField(TEXT("Num"), Parameters.Num); Object->SetObjectField(TEXT("Element"), ConvertActionSchemaToJSON(ActionSchema, Parameters.Element)); break; } case Action::EType::Encoding: { const Action::FSchemaEncodingParameters Parameters = ActionSchema.GetEncoding(ActionSchemaElement); Object->SetStringField(TEXT("Type"), TEXT("Encoding")); Object->SetNumberField(TEXT("EncodingSize"), Parameters.EncodingSize); Object->SetObjectField(TEXT("Element"), ConvertActionSchemaToJSON(ActionSchema, Parameters.Element)); break; } default: checkNoEntry(); } return Object; } const TCHAR* GetDeviceString(const ETrainerDevice Device) { switch (Device) { case ETrainerDevice::GPU: return TEXT("GPU"); case ETrainerDevice::CPU: return TEXT("CPU"); default: checkNoEntry(); return TEXT("Unknown"); } } const TCHAR* GetResponseString(const ETrainerResponse Response) { switch (Response) { case ETrainerResponse::Success: return TEXT("Success"); case ETrainerResponse::Unexpected: return TEXT("Unexpected communication received"); case ETrainerResponse::Completed: return TEXT("Training completed"); case ETrainerResponse::Stopped: return TEXT("Training stopped"); case ETrainerResponse::Timeout: return TEXT("Communication timeout"); default: checkNoEntry(); return TEXT("Unknown"); } } float DiscountFactorFromHalfLife(const float HalfLife, const float DeltaTime) { return FMath::Pow(0.5f, DeltaTime / FMath::Max(HalfLife, UE_SMALL_NUMBER)); } float DiscountFactorFromHalfLifeSteps(const int32 HalfLifeSteps) { checkf(HalfLifeSteps >= 1, TEXT("Number of HalfLifeSteps should be at least 1 but got %i"), HalfLifeSteps); return FMath::Pow(0.5f, 1.0f / FMath::Max(HalfLifeSteps, 1)); } FString GetPythonExecutablePath(const FString& IntermediateDir) { checkf(PLATFORM_WINDOWS || PLATFORM_MAC || PLATFORM_LINUX, TEXT("Python only supported on Windows, Mac, and Linux.")); return IntermediateDir / TEXT("PipInstall") / (PLATFORM_WINDOWS ? TEXT("Scripts/python.exe") : TEXT("bin/python3")); } FString GetSitePackagesPath(const FString& EngineDir) { checkf(PLATFORM_WINDOWS || PLATFORM_MAC || PLATFORM_LINUX, TEXT("Python only supported on Windows, Mac, and Linux.")); return EngineDir / TEXT("Plugins/Experimental/PythonFoundationPackages/Content/Python/Lib") / FPlatformMisc::GetUBTPlatform() / TEXT("site-packages"); } FString GetPythonContentPath(const FString& EngineDir) { return EngineDir / TEXT("Plugins/Experimental/LearningAgents/Content/Python/"); } FString GetProjectPythonContentPath() { return FPaths::ProjectContentDir() / TEXT("Python/"); } FString GetIntermediatePath(const FString& IntermediateDir) { return IntermediateDir / TEXT("LearningAgents"); } }