Files
UnrealEngine/Engine/Plugins/Experimental/LearningAgents/Source/LearningTraining/Private/LearningTrainer.cpp
2025-05-18 13:04:45 +08:00

498 lines
17 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "LearningTrainer.h"
#include "LearningObservation.h"
#include "LearningAction.h"
#include "HAL/Platform.h"
#include "Dom/JsonObject.h"
#include "Misc/Paths.h"
namespace UE::Learning
{
FSubprocess::~FSubprocess()
{
Terminate();
}
bool FSubprocess::Launch(const FString& Path, const FString& Params, const ESubprocessFlags Flags)
{
ensureMsgf(!bIsLaunched, TEXT("Subprocess already launched."));
Terminate();
const bool bCreatePipes = !(Flags & ESubprocessFlags::NoRedirectOutput);
const bool bHideWindow = !(Flags & ESubprocessFlags::ShowWindow);
if (bCreatePipes && !FPlatformProcess::CreatePipe(ReadPipe, WritePipe))
{
return false;
}
ProcessHandle = FPlatformProcess::CreateProc(*Path, *Params, false, bHideWindow, bHideWindow, nullptr, 0, *FPaths::RootDir(), WritePipe, ReadPipe);
bIsLaunched = true;
return true;
}
bool FSubprocess::IsRunning() const
{
return bIsLaunched && FPlatformProcess::IsProcRunning(const_cast<FProcHandle&>(ProcessHandle));
}
void FSubprocess::Terminate()
{
if (IsRunning())
{
UE_LOG(LogLearning, Display, TEXT("Terminating Subprocess..."));
FPlatformProcess::TerminateProc(ProcessHandle, true);
}
Update();
}
bool FSubprocess::Update()
{
// Do nothing if the process is not launched
if (!bIsLaunched)
{
return false;
}
// Append the process stdout to the buffer
OutputBuffer += FPlatformProcess::ReadPipe(ReadPipe);
// Output all the complete lines
int32 LineStartIdx = 0;
for (int32 Idx = 0; Idx < OutputBuffer.Len(); Idx++)
{
if (OutputBuffer[Idx] == '\r' || OutputBuffer[Idx] == '\n')
{
UE_LOG(LogLearning, Display, TEXT("Subprocess: %s"), *OutputBuffer.Mid(LineStartIdx, Idx - LineStartIdx));
if (OutputBuffer[Idx] == '\r' && Idx + 1 < OutputBuffer.Len() && OutputBuffer[Idx + 1] == '\n')
{
Idx++;
}
LineStartIdx = Idx + 1;
}
}
// Remove all the complete lines from the buffer
OutputBuffer.MidInline(LineStartIdx, MAX_int32, EAllowShrinking::Yes);
// If the process is no longer running then close the pipes
if (!IsRunning())
{
FPlatformProcess::ClosePipe(ReadPipe, WritePipe);
ReadPipe = nullptr;
WritePipe = nullptr;
bIsLaunched = false;
return false;
}
return true;
}
}
namespace UE::Learning::Trainer
{
TSharedPtr<FJsonObject> ConvertObservationSchemaToJSON(
const Observation::FSchema& ObservationSchema,
const Observation::FSchemaElement& ObservationSchemaElement)
{
TSharedPtr<FJsonObject> Object = MakeShared<FJsonObject>();
Object->SetNumberField(TEXT("VectorSize"), ObservationSchema.GetObservationVectorSize(ObservationSchemaElement));
Object->SetNumberField(TEXT("EncodedSize"), ObservationSchema.GetEncodedVectorSize(ObservationSchemaElement));
switch (ObservationSchema.GetType(ObservationSchemaElement))
{
case Observation::EType::Null:
{
Object->SetStringField(TEXT("Type"), TEXT("Null"));
break;
}
case Observation::EType::Continuous:
{
const Observation::FSchemaContinuousParameters Parameters = ObservationSchema.GetContinuous(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("Continuous"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
break;
}
case Observation::EType::DiscreteExclusive:
{
const Observation::FSchemaDiscreteExclusiveParameters Parameters = ObservationSchema.GetDiscreteExclusive(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("DiscreteExclusive"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
break;
}
case Observation::EType::DiscreteInclusive:
{
const Observation::FSchemaDiscreteInclusiveParameters Parameters = ObservationSchema.GetDiscreteInclusive(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("DiscreteInclusive"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
break;
}
case Observation::EType::NamedDiscreteExclusive:
{
const Observation::FSchemaNamedDiscreteExclusiveParameters Parameters = ObservationSchema.GetNamedDiscreteExclusive(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteExclusive"));
TArray<TSharedPtr<FJsonValue>> ElementNames;
for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++)
{
ElementNames.Add(MakeShared<FJsonValueString>(Parameters.ElementNames[ElementIdx].ToString()));
}
Object->SetArrayField(TEXT("ElementNames"), ElementNames);
break;
}
case Observation::EType::NamedDiscreteInclusive:
{
const Observation::FSchemaNamedDiscreteInclusiveParameters Parameters = ObservationSchema.GetNamedDiscreteInclusive(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteInclusive"));
TArray<TSharedPtr<FJsonValue>> ElementNames;
for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++)
{
ElementNames.Add(MakeShared<FJsonValueString>(Parameters.ElementNames[ElementIdx].ToString()));
}
Object->SetArrayField(TEXT("ElementNames"), ElementNames);
break;
}
case Observation::EType::And:
{
const Observation::FSchemaAndParameters Parameters = ObservationSchema.GetAnd(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("And"));
TSharedPtr<FJsonObject> SubObject = MakeShared<FJsonObject>();
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
{
TSharedPtr<FJsonObject> SubElement = ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Elements[SubElementIdx]);
SubElement->SetNumberField(TEXT("Index"), SubElementIdx);
SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement);
}
Object->SetObjectField(TEXT("Elements"), SubObject);
break;
}
case Observation::EType::OrExclusive:
{
const Observation::FSchemaOrExclusiveParameters Parameters = ObservationSchema.GetOrExclusive(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("OrExclusive"));
Object->SetNumberField(TEXT("EncodingSize"), Parameters.EncodingSize);
TSharedPtr<FJsonObject> SubObject = MakeShared<FJsonObject>();
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
{
TSharedPtr<FJsonObject> SubElement = ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Elements[SubElementIdx]);
SubElement->SetNumberField(TEXT("Index"), SubElementIdx);
SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement);
}
Object->SetObjectField(TEXT("Elements"), SubObject);
break;
}
case Observation::EType::OrInclusive:
{
const Observation::FSchemaOrInclusiveParameters Parameters = ObservationSchema.GetOrInclusive(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("OrInclusive"));
Object->SetNumberField(TEXT("AttentionEncodingSize"), Parameters.AttentionEncodingSize);
Object->SetNumberField(TEXT("AttentionHeadNum"), Parameters.AttentionHeadNum);
Object->SetNumberField(TEXT("ValueEncodingSize"), Parameters.ValueEncodingSize);
TSharedPtr<FJsonObject> SubObject = MakeShared<FJsonObject>();
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
{
TSharedPtr<FJsonObject> SubElement = ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Elements[SubElementIdx]);
SubElement->SetNumberField(TEXT("Index"), SubElementIdx);
SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement);
}
Object->SetObjectField(TEXT("Elements"), SubObject);
break;
}
case Observation::EType::Array:
{
const Observation::FSchemaArrayParameters Parameters = ObservationSchema.GetArray(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("Array"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
Object->SetObjectField(TEXT("Element"), ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Element));
break;
}
case Observation::EType::Set:
{
const Observation::FSchemaSetParameters Parameters = ObservationSchema.GetSet(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("Set"));
Object->SetNumberField(TEXT("MaxNum"), Parameters.MaxNum);
Object->SetNumberField(TEXT("AttentionEncodingSize"), Parameters.AttentionEncodingSize);
Object->SetNumberField(TEXT("AttentionHeadNum"), Parameters.AttentionHeadNum);
Object->SetNumberField(TEXT("ValueEncodingSize"), Parameters.ValueEncodingSize);
Object->SetObjectField(TEXT("Element"), ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Element));
break;
}
case Observation::EType::Encoding:
{
const Observation::FSchemaEncodingParameters Parameters = ObservationSchema.GetEncoding(ObservationSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("Encoding"));
Object->SetNumberField(TEXT("EncodingSize"), Parameters.EncodingSize);
Object->SetObjectField(TEXT("Element"), ConvertObservationSchemaToJSON(ObservationSchema, Parameters.Element));
break;
}
default:
checkNoEntry();
}
return Object;
}
TSharedPtr<FJsonObject> ConvertActionSchemaToJSON(
const Action::FSchema& ActionSchema,
const Action::FSchemaElement& ActionSchemaElement)
{
TSharedPtr<FJsonObject> Object = MakeShared<FJsonObject>();
Object->SetNumberField(TEXT("VectorSize"), ActionSchema.GetActionVectorSize(ActionSchemaElement));
Object->SetNumberField(TEXT("DistributionSize"), ActionSchema.GetActionDistributionVectorSize(ActionSchemaElement));
Object->SetNumberField(TEXT("EncodedSize"), ActionSchema.GetEncodedVectorSize(ActionSchemaElement));
Object->SetNumberField(TEXT("ModifierSize"), ActionSchema.GetActionModifierVectorSize(ActionSchemaElement));
switch (ActionSchema.GetType(ActionSchemaElement))
{
case Action::EType::Null:
{
Object->SetStringField(TEXT("Type"), TEXT("Null"));
break;
}
case Action::EType::Continuous:
{
const Action::FSchemaContinuousParameters Parameters = ActionSchema.GetContinuous(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("Continuous"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
break;
}
case Action::EType::DiscreteExclusive:
{
const Action::FSchemaDiscreteExclusiveParameters Parameters = ActionSchema.GetDiscreteExclusive(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("DiscreteExclusive"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
break;
}
case Action::EType::DiscreteInclusive:
{
const Action::FSchemaDiscreteInclusiveParameters Parameters = ActionSchema.GetDiscreteInclusive(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("DiscreteInclusive"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
break;
}
case Action::EType::NamedDiscreteExclusive:
{
const Action::FSchemaNamedDiscreteExclusiveParameters Parameters = ActionSchema.GetNamedDiscreteExclusive(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteExclusive"));
TArray<TSharedPtr<FJsonValue>> ElementNames;
for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++)
{
ElementNames.Add(MakeShared<FJsonValueString>(Parameters.ElementNames[ElementIdx].ToString()));
}
Object->SetArrayField(TEXT("ElementNames"), ElementNames);
break;
}
case Action::EType::NamedDiscreteInclusive:
{
const Action::FSchemaNamedDiscreteInclusiveParameters Parameters = ActionSchema.GetNamedDiscreteInclusive(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("NamedDiscreteInclusive"));
TArray<TSharedPtr<FJsonValue>> ElementNames;
for (int32 ElementIdx = 0; ElementIdx < Parameters.ElementNames.Num(); ElementIdx++)
{
ElementNames.Add(MakeShared<FJsonValueString>(Parameters.ElementNames[ElementIdx].ToString()));
}
Object->SetArrayField(TEXT("ElementNames"), ElementNames);
break;
}
case Action::EType::And:
{
const Action::FSchemaAndParameters Parameters = ActionSchema.GetAnd(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("And"));
TSharedPtr<FJsonObject> SubObject = MakeShared<FJsonObject>();
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
{
TSharedPtr<FJsonObject> SubElement = ConvertActionSchemaToJSON(ActionSchema, Parameters.Elements[SubElementIdx]);
SubElement->SetNumberField(TEXT("Index"), SubElementIdx);
SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement);
}
Object->SetObjectField(TEXT("Elements"), SubObject);
break;
}
case Action::EType::OrExclusive:
{
const Action::FSchemaOrExclusiveParameters Parameters = ActionSchema.GetOrExclusive(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("OrExclusive"));
TSharedPtr<FJsonObject> SubObject = MakeShared<FJsonObject>();
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
{
TSharedPtr<FJsonObject> SubElement = ConvertActionSchemaToJSON(ActionSchema, Parameters.Elements[SubElementIdx]);
SubElement->SetNumberField(TEXT("Index"), SubElementIdx);
SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement);
}
Object->SetObjectField(TEXT("Elements"), SubObject);
break;
}
case Action::EType::OrInclusive:
{
const Action::FSchemaOrInclusiveParameters Parameters = ActionSchema.GetOrInclusive(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("OrInclusive"));
TSharedPtr<FJsonObject> SubObject = MakeShared<FJsonObject>();
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
{
TSharedPtr<FJsonObject> SubElement = ConvertActionSchemaToJSON(ActionSchema, Parameters.Elements[SubElementIdx]);
SubElement->SetNumberField(TEXT("Index"), SubElementIdx);
SubObject->SetObjectField(Parameters.ElementNames[SubElementIdx].ToString(), SubElement);
}
Object->SetObjectField(TEXT("Elements"), SubObject);
break;
}
case Action::EType::Array:
{
const Action::FSchemaArrayParameters Parameters = ActionSchema.GetArray(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("Array"));
Object->SetNumberField(TEXT("Num"), Parameters.Num);
Object->SetObjectField(TEXT("Element"), ConvertActionSchemaToJSON(ActionSchema, Parameters.Element));
break;
}
case Action::EType::Encoding:
{
const Action::FSchemaEncodingParameters Parameters = ActionSchema.GetEncoding(ActionSchemaElement);
Object->SetStringField(TEXT("Type"), TEXT("Encoding"));
Object->SetNumberField(TEXT("EncodingSize"), Parameters.EncodingSize);
Object->SetObjectField(TEXT("Element"), ConvertActionSchemaToJSON(ActionSchema, Parameters.Element));
break;
}
default:
checkNoEntry();
}
return Object;
}
const TCHAR* GetDeviceString(const ETrainerDevice Device)
{
switch (Device)
{
case ETrainerDevice::GPU: return TEXT("GPU");
case ETrainerDevice::CPU: return TEXT("CPU");
default: checkNoEntry(); return TEXT("Unknown");
}
}
const TCHAR* GetResponseString(const ETrainerResponse Response)
{
switch (Response)
{
case ETrainerResponse::Success: return TEXT("Success");
case ETrainerResponse::Unexpected: return TEXT("Unexpected communication received");
case ETrainerResponse::Completed: return TEXT("Training completed");
case ETrainerResponse::Stopped: return TEXT("Training stopped");
case ETrainerResponse::Timeout: return TEXT("Communication timeout");
default: checkNoEntry(); return TEXT("Unknown");
}
}
float DiscountFactorFromHalfLife(const float HalfLife, const float DeltaTime)
{
return FMath::Pow(0.5f, DeltaTime / FMath::Max(HalfLife, UE_SMALL_NUMBER));
}
float DiscountFactorFromHalfLifeSteps(const int32 HalfLifeSteps)
{
checkf(HalfLifeSteps >= 1, TEXT("Number of HalfLifeSteps should be at least 1 but got %i"), HalfLifeSteps);
return FMath::Pow(0.5f, 1.0f / FMath::Max(HalfLifeSteps, 1));
}
FString GetPythonExecutablePath(const FString& IntermediateDir)
{
checkf(PLATFORM_WINDOWS || PLATFORM_MAC || PLATFORM_LINUX, TEXT("Python only supported on Windows, Mac, and Linux."));
return IntermediateDir / TEXT("PipInstall") / (PLATFORM_WINDOWS ? TEXT("Scripts/python.exe") : TEXT("bin/python3"));
}
FString GetSitePackagesPath(const FString& EngineDir)
{
checkf(PLATFORM_WINDOWS || PLATFORM_MAC || PLATFORM_LINUX, TEXT("Python only supported on Windows, Mac, and Linux."));
return EngineDir / TEXT("Plugins/Experimental/PythonFoundationPackages/Content/Python/Lib") / FPlatformMisc::GetUBTPlatform() / TEXT("site-packages");
}
FString GetPythonContentPath(const FString& EngineDir)
{
return EngineDir / TEXT("Plugins/Experimental/LearningAgents/Content/Python/");
}
FString GetProjectPythonContentPath()
{
return FPaths::ProjectContentDir() / TEXT("Python/");
}
FString GetIntermediatePath(const FString& IntermediateDir)
{
return IntermediateDir / TEXT("LearningAgents");
}
}