Files
UnrealEngine/Engine/Plugins/Experimental/LearningAgents/Source/LearningTraining/Public/LearningPPOTrainer.h
2025-05-18 13:04:45 +08:00

234 lines
11 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "LearningArray.h"
#include "LearningLog.h"
#include "LearningTrainer.h"
#include "LearningSharedMemory.h"
#include "Commandlets/Commandlet.h"
#include "Templates/SharedPointer.h"
#include "LearningPPOTrainer.generated.h"
class ULearningNeuralNetworkData;
UCLASS()
class LEARNINGTRAINING_API ULearningSocketPPOTrainerServerCommandlet : public UCommandlet
{
GENERATED_BODY()
ULearningSocketPPOTrainerServerCommandlet(const FObjectInitializer& ObjectInitializer);
/** Runs the commandlet */
virtual int32 Main(const FString& Params) override;
};
namespace UE::Learning
{
struct IExternalTrainer;
struct FReplayBuffer;
struct FResetInstanceBuffer;
struct FEpisodeBuffer;
enum class ECompletionMode : uint8;
/**
* Settings used for training with PPO
*/
struct FPPOTrainerTrainingSettings
{
// Number of iterations to train the network for. Controls the overall training time.
// Training for about 100000 iterations should give you well trained network, but
// closer to 1000000 iterations or more is required for an exhaustively trained network.
uint32 IterationNum = 1000000;
// Learning rate of the policy network. Typical values are between 0.001f and 0.0001f
float LearningRatePolicy = 0.0001f;
// Learning rate of the critic network. To avoid instability generally the critic
// should have a larger learning rate than the policy.
float LearningRateCritic = 0.001f;
// Amount by which to multiply the learning rate every 1000 iterations.
float LearningRateDecay = 1.0f;
// Amount of weight decay to apply to the network. Larger values encourage network
// weights to be smaller but too large a value can cause the network weights to collapse to all zeros.
float WeightDecay = 0.0001f;
// Batch size to use for training the policy. Large batch sizes are much more computationally efficient when training on the GPU.
uint32 PolicyBatchSize = 1024;
// Batch size to use for training the critic. Large batch sizes are much more computationally efficient when training on the GPU.
uint32 CriticBatchSize = 4096;
// The number of consecutive steps of observations and actions over which to train the policy. Increasing this value
// will encourage the policy to use its memory effectively. Too large and training can become slow and unstable.
uint32 PolicyWindow = 16;
// Number of training iterations to perform per buffer of experience gathered. This should be large enough for
// the critic and policy to be effectively updated, but too large and it will simply slow down training.
uint32 IterationsPerGather = 32;
// Number of iterations of training to perform to warm-up the Critic. This helps speed up and stabilize training
// at the beginning when the Critic may be producing predictions at the wrong order of magnitude.
uint32 CriticWarmupIterations = 8;
// Clipping ratio to apply to policy updates. Keeps the training "on-policy".
// Larger values may speed up training at the cost of stability. Conversely, too small
// values will keep the policy from being able to learn an optimal policy.
float EpsilonClip = 0.2f;
// Weight used to regularize predicted returns. Encourages the critic not to over or under estimate returns.
float ReturnRegularizationWeight = 0.0001f;
// Weight for the loss used to train the policy via the PPO surrogate objective.
float ActionSurrogateWeight = 1.0f;
// Weight used to regularize actions.Larger values will encourage exploration and smaller actions, but too large will cause
// noisy actions centered around zero.
float ActionRegularizationWeight = 0.001f;
// Weighting used for the entropy bonus. Larger values encourage larger action
// noise and therefore greater exploration but can make actions very noisy.
float ActionEntropyWeight = 0.0f;
// This is used in the Generalized Advantage Estimation, where larger values will tend to assign more credit to recent actions. Typical
// values should be between 0.9 and 1.0.
float GaeLambda = 0.95f;
// When true, advantages are normalized. This tends to makes training more robust to
// adjustments of the scale of rewards.
bool bAdvantageNormalization = true;
// The minimum advantage to allow. Setting this below zero will encourage the policy to
// move away from bad actions, but can introduce instability.
float AdvantageMin = 0.0f;
// The maximum advantage to allow. Making this smaller may increase training stability
// at the cost of some training speed.
float AdvantageMax = 10.0f;
// If true, uses gradient norm max clipping. Set this as True if training is unstable or leave as False if unused.
bool bUseGradNormMaxClipping = false;
// The maximum gradient norm to clip updates to.
float GradNormMax = 0.5f;
// Number of steps to trim from the start of each episode during training. This can
// be useful if some reset process is taking several steps or you know your starting
// states are not entirely valid for example.
int32 TrimEpisodeStartStepNum = 0;
// Number of steps to trim from the end of each episode during training. This can be
// useful if you know the last few steps of an episode are not valid or contain incorrect
// information.
int32 TrimEpisodeEndStepNum = 0;
// Random Seed to use for training
uint32 Seed = 1234;
// The discount factor causes future rewards to be scaled down so that the policy will
// favor near-term rewards over potentially uncertain long-term rewards. Larger values
// encourage the system to "look-ahead" but make training more difficult.
float DiscountFactor = 0.99f;
// Which device to use for training
ETrainerDevice Device = ETrainerDevice::GPU;
// If to use TensorBoard for logging and tracking the training progress.
//
// TensorBoard will only work if it is installed in Unreal Engine's python environment. This can be done by
// enabling the "Tensorboard" plugin in your project.
bool bUseTensorboard = false;
// If to save snapshots of the trained networks every 1000 iterations
bool bSaveSnapshots = false;
};
namespace PPOTrainer
{
/**
* Train a policy while gathering experience
*
* @param ExternalTrainer External Trainer
* @param ReplayBuffer Replay Buffer
* @param EpisodeBuffer Episode Buffer
* @param ResetBuffer Reset Buffer
* @param PolicyNetwork Policy Network to use
* @param CriticNetwork Critic Network to use
* @param EncoderNetwork Encoder Network to use
* @param DecoderNetwork Decoder Network to use
* @param ObservationVectorBuffer Buffer to read/write observation vectors into
* @param ActionVectorBuffer Buffer to read/write action vectors into
* @param PreEvaluationMemoryStateVectorBuffer Buffer to read/write pre-evaluation memory state vectors into
* @param MemoryStateVectorBuffer Buffer to read/write (post-evaluation) memory state vectors into
* @param RewardBuffer Buffer to read/write rewards into
* @param CompletionBuffer Buffer to read/write completions into
* @param EpisodeCompletionBuffer Additional buffer to record completions from full episode buffers
* @param AllCompletionBuffer Additional buffer to record all completions from full episodes and normal completions
* @param ResetFunction Function to run for resetting the environment
* @param ObservationFunction Function to run for evaluating observations
* @param PolicyFunction Function to run for evaluating the policy
* @param ActionFunction Function to run for evaluating actions
* @param UpdateFunction Function to run for updating the environment
* @param RewardFunction Function to run for evaluating rewards
* @param CompletionFunction Function to run for evaluating completions
* @param Instances Set of instances to run training for
* @param bRequestTrainingStopSignal Optional signal that can be set to indicate training should be stopped
* @param PolicyNetworkLock Optional Lock to use when updating the policy network
* @param CriticNetworkLock Optional Lock to use when updating the critic network
* @param EncoderNetworkLock Optional Lock to use when updating the encoder network
* @param DecoderNetworkLock Optional Lock to use when updating the decoder network
* @param bPolicyNetworkUpdatedSignal Optional signal that will be set when the policy network is updated
* @param bCriticNetworkUpdatedSignal Optional signal that will be set when the critic network is updated
* @param bEncoderNetworkUpdatedSignal Optional signal that will be set when the encoder network is updated
* @param bDecoderNetworkUpdatedSignal Optional signal that will be set when the decoder network is updated
* @param LogSettings Logging settings
* @returns Trainer response in case of errors during communication otherwise Success
*/
LEARNINGTRAINING_API ETrainerResponse Train(
IExternalTrainer* ExternalTrainer,
FReplayBuffer& ReplayBuffer,
FEpisodeBuffer& EpisodeBuffer,
FResetInstanceBuffer& ResetBuffer,
ULearningNeuralNetworkData& PolicyNetwork,
ULearningNeuralNetworkData& CriticNetwork,
ULearningNeuralNetworkData& EncoderNetwork,
ULearningNeuralNetworkData& DecoderNetwork,
TLearningArrayView<2, float> ObservationVectorBuffer,
TLearningArrayView<2, float> ActionVectorBuffer,
TLearningArrayView<2, float> PreEvaluationMemoryStateVectorBuffer,
TLearningArrayView<2, float> MemoryStateVectorBuffer,
TLearningArrayView<1, float> RewardBuffer,
TLearningArrayView<1, ECompletionMode> CompletionBuffer,
TLearningArrayView<1, ECompletionMode> EpisodeCompletionBuffer,
TLearningArrayView<1, ECompletionMode> AllCompletionBuffer,
const TFunctionRef<void(const FIndexSet Instances)> ResetFunction,
const TFunctionRef<void(const FIndexSet Instances)> ObservationFunction,
const TFunctionRef<void(const FIndexSet Instances)> PolicyFunction,
const TFunctionRef<void(const FIndexSet Instances)> ActionFunction,
const TFunctionRef<void(const FIndexSet Instances)> UpdateFunction,
const TFunctionRef<void(const FIndexSet Instances)> RewardFunction,
const TFunctionRef<void(const FIndexSet Instances)> CompletionFunction,
const FIndexSet Instances,
const Observation::FSchema& ObservationSchema,
const Observation::FSchemaElement& ObservationSchemaElement,
const Action::FSchema& ActionSchema,
const Action::FSchemaElement& ActionSchemaElement,
const FPPOTrainerTrainingSettings& TrainerSettings = FPPOTrainerTrainingSettings(),
TAtomic<bool>* bRequestTrainingStopSignal = nullptr,
FRWLock* PolicyNetworkLock = nullptr,
FRWLock* CriticNetworkLock = nullptr,
FRWLock* EncoderNetworkLock = nullptr,
FRWLock* DecoderNetworkLock = nullptr,
TAtomic<bool>* bPolicyNetworkUpdatedSignal = nullptr,
TAtomic<bool>* bCriticNetworkUpdatedSignal = nullptr,
TAtomic<bool>* bEncoderNetworkUpdatedSignal = nullptr,
TAtomic<bool>* bDecoderNetworkUpdatedSignal = nullptr,
const ELogSetting LogSettings = ELogSetting::Normal);
}
}