// Copyright Epic Games, Inc. All Rights Reserved. #include "LearningPPOTrainer.h" #include "LearningArray.h" #include "LearningExternalTrainer.h" #include "LearningLog.h" #include "LearningNeuralNetwork.h" #include "LearningExperience.h" #include "LearningProgress.h" #include "LearningSharedMemory.h" #include "LearningSharedMemoryTraining.h" #include "LearningSocketTraining.h" #include "LearningObservation.h" #include "LearningAction.h" #include "Misc/Guid.h" #include "Misc/FileHelper.h" #include "Misc/CommandLine.h" #include "Misc/Parse.h" #include "Misc/Paths.h" #include "Dom/JsonObject.h" #include "Sockets.h" #include "Common/TcpSocketBuilder.h" #include "SocketSubsystem.h" ULearningSocketPPOTrainerServerCommandlet::ULearningSocketPPOTrainerServerCommandlet(const FObjectInitializer& ObjectInitializer) : Super(ObjectInitializer) {} int32 ULearningSocketPPOTrainerServerCommandlet::Main(const FString& Commandline) { UE_LOG(LogLearning, Display, TEXT("Running PPO Training Server Commandlet...")); #if WITH_EDITOR TArray Tokens; TArray Switches; TMap Params; UCommandlet::ParseCommandLine(*Commandline, Tokens, Switches, Params); const FString* PythonExecutiblePathParam = Params.Find(TEXT("PythonExecutiblePath")); const FString* PythonContentPathParam = Params.Find(TEXT("PythonContentPath")); const FString* IntermediatePathParam = Params.Find(TEXT("IntermediatePath")); const FString* IpAddressParam = Params.Find(TEXT("IpAddress")); const FString* PortParam = Params.Find(TEXT("Port")); const FString* LogSettingsParam = Params.Find(TEXT("LogSettings")); const FString PythonExecutiblePath = PythonExecutiblePathParam ? *PythonExecutiblePathParam : UE::Learning::Trainer::GetPythonExecutablePath(FPaths::ProjectIntermediateDir()); const FString PythonContentPath = PythonContentPathParam ? *PythonContentPathParam : UE::Learning::Trainer::GetPythonContentPath(FPaths::EngineDir()); const FString IntermediatePath = IntermediatePathParam ? *IntermediatePathParam : UE::Learning::Trainer::GetIntermediatePath(FPaths::ProjectIntermediateDir()); const TCHAR* IpAddress = IpAddressParam ? *(*IpAddressParam) : UE::Learning::Trainer::DefaultIp; const uint32 Port = PortParam ? FCString::Atoi(*(*PortParam)) : UE::Learning::Trainer::DefaultPort; UE::Learning::ELogSetting LogSettings = UE::Learning::ELogSetting::Normal; if (LogSettingsParam) { if (*LogSettingsParam == TEXT("Normal")) { LogSettings = UE::Learning::ELogSetting::Normal; } else if (*LogSettingsParam == TEXT("Silent")) { LogSettings = UE::Learning::ELogSetting::Silent; } else { checkNoEntry(); return 1; } } UE_LOG(LogLearning, Display, TEXT("--- PPO Training Server Arguments ---")); UE_LOG(LogLearning, Display, TEXT("PythonExecutiblePath: %s"), *PythonExecutiblePath); UE_LOG(LogLearning, Display, TEXT("PythonContentPath: %s"), *PythonContentPath); UE_LOG(LogLearning, Display, TEXT("IntermediatePath: %s"), *IntermediatePath); UE_LOG(LogLearning, Display, TEXT("IpAddress: %s"), IpAddress); UE_LOG(LogLearning, Display, TEXT("Port: %i"), Port); UE_LOG(LogLearning, Display, TEXT("LogSettings: %s"), LogSettings == UE::Learning::ELogSetting::Normal ? TEXT("Normal") : TEXT("Silent")); UE::Learning::FSocketTrainerServerProcess ServerProcess( TEXT("Training"), UE::Learning::Trainer::GetProjectPythonContentPath(), TEXT("train_ppo"), PythonExecutiblePath, PythonContentPath, IntermediatePath, IpAddress, Port, UE::Learning::Trainer::DefaultTimeout, UE::Learning::ESubprocessFlags::None, LogSettings); while (ServerProcess.IsRunning()) { FPlatformProcess::Sleep(0.01f); } #else checkNoEntry(); #endif return 0; } namespace UE::Learning::PPOTrainer { ETrainerResponse Train( IExternalTrainer* ExternalTrainer, FReplayBuffer& ReplayBuffer, FEpisodeBuffer& EpisodeBuffer, FResetInstanceBuffer& ResetBuffer, ULearningNeuralNetworkData& PolicyNetwork, ULearningNeuralNetworkData& CriticNetwork, ULearningNeuralNetworkData& EncoderNetwork, ULearningNeuralNetworkData& DecoderNetwork, TLearningArrayView<2, float> ObservationVectorBuffer, TLearningArrayView<2, float> ActionVectorBuffer, TLearningArrayView<2, float> PreEvaluationMemoryStateVectorBuffer, TLearningArrayView<2, float> MemoryStateVectorBuffer, TLearningArrayView<1, float> RewardBuffer, TLearningArrayView<1, ECompletionMode> CompletionBuffer, TLearningArrayView<1, ECompletionMode> EpisodeCompletionBuffer, TLearningArrayView<1, ECompletionMode> AllCompletionBuffer, const TFunctionRef ResetFunction, const TFunctionRef ObservationFunction, const TFunctionRef PolicyFunction, const TFunctionRef ActionFunction, const TFunctionRef UpdateFunction, const TFunctionRef RewardFunction, const TFunctionRef CompletionFunction, const FIndexSet Instances, const Observation::FSchema& ObservationSchema, const Observation::FSchemaElement& ObservationSchemaElement, const Action::FSchema& ActionSchema, const Action::FSchemaElement& ActionSchemaElement, const FPPOTrainerTrainingSettings& TrainerSettings, TAtomic* bRequestTrainingStopSignal, FRWLock* PolicyNetworkLock, FRWLock* CriticNetworkLock, FRWLock* EncoderNetworkLock, FRWLock* DecoderNetworkLock, TAtomic* bPolicyNetworkUpdatedSignal, TAtomic* bCriticNetworkUpdatedSignal, TAtomic* bEncoderNetworkUpdatedSignal, TAtomic* bDecoderNetworkUpdatedSignal, const ELogSetting LogSettings) { TRACE_CPUPROFILER_EVENT_SCOPE(Learning::PPOTrainer::Train); ETrainerResponse Response = ETrainerResponse::Success; if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Display, TEXT("Sending initial Policy...")); } const int32 PolicyNetworkId = ExternalTrainer->AddNetwork(PolicyNetwork); const int32 CriticNetworkId = ExternalTrainer->AddNetwork(CriticNetwork); const int32 EncoderNetworkId = ExternalTrainer->AddNetwork(EncoderNetwork); const int32 DecoderNetworkId = ExternalTrainer->AddNetwork(DecoderNetwork); const int32 ReplayBufferId = ExternalTrainer->AddReplayBuffer(ReplayBuffer); const int32 ObservationVectorDimensionNum = ReplayBuffer.GetObservations(0).Num<1>(); const int32 ActionVectorDimensionNum = ReplayBuffer.GetActions(0).Num<1>(); const int32 MemoryStateVectorDimensionNum = ReplayBuffer.GetMemoryStates(0).Num<1>(); // Write Data Config TSharedRef DataConfigObject = MakeShared(); const int32 ObservationSchemaId = 0; const int32 ActionSchemaId = 0; // Add Neural Network Config Entries TArray> NetworkObjects; // Policy { TSharedPtr NetworkObject = MakeShared(); NetworkObject->SetNumberField(TEXT("Id"), PolicyNetworkId); NetworkObject->SetStringField(TEXT("Name"), "Policy"); NetworkObject->SetNumberField(TEXT("MaxByteNum"), PolicyNetwork.GetSnapshotByteNum()); TSharedRef JsonValue = MakeShared(NetworkObject); NetworkObjects.Add(JsonValue); } // Critic { TSharedPtr NetworkObject = MakeShared(); NetworkObject->SetNumberField(TEXT("Id"), CriticNetworkId); NetworkObject->SetStringField(TEXT("Name"), "Critic"); NetworkObject->SetNumberField(TEXT("MaxByteNum"), CriticNetwork.GetSnapshotByteNum()); NetworkObject->SetNumberField(TEXT("InputSchemaId"), ObservationSchemaId); TSharedRef JsonValue = MakeShared(NetworkObject); NetworkObjects.Add(JsonValue); } // Encoder { TSharedPtr NetworkObject = MakeShared(); NetworkObject->SetNumberField(TEXT("Id"), EncoderNetworkId); NetworkObject->SetStringField(TEXT("Name"), "Encoder"); NetworkObject->SetNumberField(TEXT("MaxByteNum"), EncoderNetwork.GetSnapshotByteNum()); NetworkObject->SetNumberField(TEXT("InputSchemaId"), ObservationSchemaId); TSharedRef JsonValue = MakeShared(NetworkObject); NetworkObjects.Add(JsonValue); } // Decoder { TSharedPtr NetworkObject = MakeShared(); NetworkObject->SetNumberField(TEXT("Id"), DecoderNetworkId); NetworkObject->SetStringField(TEXT("Name"), "Decoder"); NetworkObject->SetNumberField(TEXT("MaxByteNum"), DecoderNetwork.GetSnapshotByteNum()); NetworkObject->SetNumberField(TEXT("OutputSchemaId"), ActionSchemaId); TSharedRef JsonValue = MakeShared(NetworkObject); NetworkObjects.Add(JsonValue); } DataConfigObject->SetArrayField(TEXT("Networks"), NetworkObjects); // Add Replay Buffers Config Entries TArray> ReplayBufferObjects; TSharedRef ReplayBufferJsonValue = MakeShared(ReplayBuffer.AsJsonConfig(ReplayBufferId)); ReplayBufferObjects.Add(ReplayBufferJsonValue); DataConfigObject->SetArrayField(TEXT("ReplayBuffers"), ReplayBufferObjects); // Schemas TSharedPtr SchemasObject = MakeShared(); // Add the observation schemas TArray> ObservationSchemaObjects; { // For this PPO trainer, add the one observation schema we have TSharedPtr ObservationSchemaObject = MakeShared(); ObservationSchemaObject->SetNumberField(TEXT("Id"), ObservationSchemaId); ObservationSchemaObject->SetStringField(TEXT("Name"), "Default"); ObservationSchemaObject->SetObjectField(TEXT("Schema"), UE::Learning::Trainer::ConvertObservationSchemaToJSON(ObservationSchema, ObservationSchemaElement)); TSharedRef JsonValue = MakeShared(ObservationSchemaObject); ObservationSchemaObjects.Add(JsonValue); } SchemasObject->SetArrayField(TEXT("Observations"), ObservationSchemaObjects); // Add the action schemas TArray> ActionSchemaObjects; { // For this PPO trainer, add the one action schema we have TSharedPtr ActionSchemaObject = MakeShared(); ActionSchemaObject->SetNumberField(TEXT("Id"), ActionSchemaId); ActionSchemaObject->SetStringField(TEXT("Name"), "Default"); ActionSchemaObject->SetObjectField(TEXT("Schema"), UE::Learning::Trainer::ConvertActionSchemaToJSON(ActionSchema, ActionSchemaElement)); TSharedRef JsonValue = MakeShared(ActionSchemaObject); ActionSchemaObjects.Add(JsonValue); } SchemasObject->SetArrayField(TEXT("Actions"), ActionSchemaObjects); DataConfigObject->SetObjectField(TEXT("Schemas"), SchemasObject); // Add PPO Specific Config Entries TSharedRef TrainingConfigObject = MakeShared(); TrainingConfigObject->SetStringField(TEXT("TaskName"), TEXT("Training")); TrainingConfigObject->SetStringField(TEXT("TrainerMethod"), TEXT("PPO")); TrainingConfigObject->SetStringField(TEXT("TimeStamp"), *FDateTime::Now().ToFormattedString(TEXT("%Y-%m-%d_%H-%M-%S"))); TrainingConfigObject->SetNumberField(TEXT("IterationNum"), TrainerSettings.IterationNum); TrainingConfigObject->SetNumberField(TEXT("LearningRatePolicy"), TrainerSettings.LearningRatePolicy); TrainingConfigObject->SetNumberField(TEXT("LearningRateCritic"), TrainerSettings.LearningRateCritic); TrainingConfigObject->SetNumberField(TEXT("LearningRateDecay"), TrainerSettings.LearningRateDecay); TrainingConfigObject->SetNumberField(TEXT("WeightDecay"), TrainerSettings.WeightDecay); TrainingConfigObject->SetNumberField(TEXT("PolicyBatchSize"), TrainerSettings.PolicyBatchSize); TrainingConfigObject->SetNumberField(TEXT("CriticBatchSize"), TrainerSettings.CriticBatchSize); TrainingConfigObject->SetNumberField(TEXT("PolicyWindow"), TrainerSettings.PolicyWindow); TrainingConfigObject->SetNumberField(TEXT("IterationsPerGather"), TrainerSettings.IterationsPerGather); TrainingConfigObject->SetNumberField(TEXT("CriticWarmupIterations"), TrainerSettings.CriticWarmupIterations); TrainingConfigObject->SetNumberField(TEXT("EpsilonClip"), TrainerSettings.EpsilonClip); TrainingConfigObject->SetNumberField(TEXT("ActionSurrogateWeight"), TrainerSettings.ActionSurrogateWeight); TrainingConfigObject->SetNumberField(TEXT("ActionRegularizationWeight"), TrainerSettings.ActionRegularizationWeight); TrainingConfigObject->SetNumberField(TEXT("ActionEntropyWeight"), TrainerSettings.ActionEntropyWeight); TrainingConfigObject->SetNumberField(TEXT("ReturnRegularizationWeight"), TrainerSettings.ReturnRegularizationWeight); TrainingConfigObject->SetNumberField(TEXT("GaeLambda"), TrainerSettings.GaeLambda); TrainingConfigObject->SetBoolField(TEXT("AdvantageNormalization"), TrainerSettings.bAdvantageNormalization); TrainingConfigObject->SetNumberField(TEXT("AdvantageMin"), TrainerSettings.AdvantageMin); TrainingConfigObject->SetNumberField(TEXT("AdvantageMax"), TrainerSettings.AdvantageMax); TrainingConfigObject->SetBoolField(TEXT("UseGradNormMaxClipping"), TrainerSettings.bUseGradNormMaxClipping); TrainingConfigObject->SetNumberField(TEXT("GradNormMax"), TrainerSettings.GradNormMax); TrainingConfigObject->SetNumberField(TEXT("TrimEpisodeStartStepNum"), TrainerSettings.TrimEpisodeStartStepNum); TrainingConfigObject->SetNumberField(TEXT("TrimEpisodeEndStepNum"), TrainerSettings.TrimEpisodeEndStepNum); TrainingConfigObject->SetNumberField(TEXT("Seed"), TrainerSettings.Seed); TrainingConfigObject->SetNumberField(TEXT("DiscountFactor"), TrainerSettings.DiscountFactor); TrainingConfigObject->SetStringField(TEXT("Device"), UE::Learning::Trainer::GetDeviceString(TrainerSettings.Device)); TrainingConfigObject->SetBoolField(TEXT("UseTensorBoard"), TrainerSettings.bUseTensorboard); TrainingConfigObject->SetBoolField(TEXT("SaveSnapshots"), TrainerSettings.bSaveSnapshots); ExternalTrainer->SendConfigs(DataConfigObject, TrainingConfigObject, LogSettings); Response = ExternalTrainer->SendNetwork(PolicyNetworkId, PolicyNetwork, PolicyNetworkLock); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error sending initial policy to trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } ExternalTrainer->Terminate(); return Response; } // Send initial Critic if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Display, TEXT("Sending initial Critic...")); } Response = ExternalTrainer->SendNetwork(CriticNetworkId, CriticNetwork, CriticNetworkLock); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error sending initial critic to trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } ExternalTrainer->Terminate(); return Response; } // Send initial Encoder if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Display, TEXT("Sending initial Encoder...")); } Response = ExternalTrainer->SendNetwork(EncoderNetworkId, EncoderNetwork, EncoderNetworkLock); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error sending initial encoder to trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } ExternalTrainer->Terminate(); return Response; } // Send initial Decoder if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Display, TEXT("Sending initial Decoder...")); } Response = ExternalTrainer->SendNetwork(DecoderNetworkId, DecoderNetwork, DecoderNetworkLock); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error sending initial decoder to trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } ExternalTrainer->Terminate(); return Response; } // Start Training Loop while (true) { if (bRequestTrainingStopSignal && (*bRequestTrainingStopSignal)) { *bRequestTrainingStopSignal = false; if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Display, TEXT("Stopping Training...")); } Response = ExternalTrainer->SendStop(); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error sending stop signal to trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } ExternalTrainer->Terminate(); return Response; } break; } else { Experience::GatherExperienceUntilReplayBufferFull( ReplayBuffer, EpisodeBuffer, ResetBuffer, { ObservationVectorBuffer }, { ActionVectorBuffer}, { PreEvaluationMemoryStateVectorBuffer }, { MemoryStateVectorBuffer }, { RewardBuffer }, CompletionBuffer, EpisodeCompletionBuffer, AllCompletionBuffer, ResetFunction, { ObservationFunction }, { PolicyFunction }, { ActionFunction }, { UpdateFunction }, { RewardFunction }, CompletionFunction, Instances); Response = ExternalTrainer->SendReplayBuffer(ReplayBufferId, ReplayBuffer); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error sending replay buffer to trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } ExternalTrainer->Terminate(); return Response; } } // Update Policy Response = ExternalTrainer->ReceiveNetwork(PolicyNetworkId, PolicyNetwork, PolicyNetworkLock); if (Response == ETrainerResponse::Completed) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Display, TEXT("Trainer completed training.")); } break; } else if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error receiving policy from trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } break; } if (bPolicyNetworkUpdatedSignal) { *bPolicyNetworkUpdatedSignal = true; } // Update Critic Response = ExternalTrainer->ReceiveNetwork(CriticNetworkId, CriticNetwork, CriticNetworkLock); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error receiving critic from trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } break; } if (bCriticNetworkUpdatedSignal) { *bCriticNetworkUpdatedSignal = true; } // Update Encoder Response = ExternalTrainer->ReceiveNetwork(EncoderNetworkId, EncoderNetwork, EncoderNetworkLock); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error receiving encoder from trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } break; } if (bEncoderNetworkUpdatedSignal) { *bEncoderNetworkUpdatedSignal = true; } // Update Decoder Response = ExternalTrainer->ReceiveNetwork(DecoderNetworkId, DecoderNetwork, DecoderNetworkLock); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error receiving decoder from trainer: %s. Check log for errors."), Trainer::GetResponseString(Response)); } break; } if (bDecoderNetworkUpdatedSignal) { *bDecoderNetworkUpdatedSignal = true; } } // Allow some time for trainer to shut down gracefully before we kill it... Response = ExternalTrainer->Wait(); if (Response != ETrainerResponse::Success) { if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Error, TEXT("Error waiting for trainer to exit: %s. Check log for errors."), Trainer::GetResponseString(Response)); } } ExternalTrainer->Terminate(); if (LogSettings != ELogSetting::Silent) { UE_LOG(LogLearning, Display, TEXT("Training Task Done!")); } return ETrainerResponse::Success; } }