Files
UnrealEngine/Engine/Plugins/Experimental/LearningAgents/Source/LearningAgentsTraining/Public/LearningAgentsTrainingEnvironment.h
2025-05-18 13:04:45 +08:00

236 lines
8.9 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "LearningAgentsManagerListener.h"
#include "LearningAgentsCompletions.h"
#include "LearningAgentsTrainingEnvironment.generated.h"
namespace UE::Learning
{
struct FResetInstanceBuffer;
}
UCLASS(Abstract, HideDropdown, BlueprintType, Blueprintable)
class LEARNINGAGENTSTRAINING_API ULearningAgentsTrainingEnvironment : public ULearningAgentsManagerListener
{
GENERATED_BODY()
// ----- Setup -----
public:
// These constructors/destructors are needed to make forward declarations happy
ULearningAgentsTrainingEnvironment();
ULearningAgentsTrainingEnvironment(FVTableHelper& Helper);
virtual ~ULearningAgentsTrainingEnvironment();
/**
* Constructs the training environment and runs the setup functions for rewards and completions.
*/
UFUNCTION(BlueprintCallable, Category = "LearningAgents", meta = (DeterminesOutputType = "Class", AutoCreateRefTerm = "TrainerSettings"))
static ULearningAgentsTrainingEnvironment* MakeTrainingEnvironment(
UPARAM(ref) ULearningAgentsManager*& InManager,
TSubclassOf<ULearningAgentsTrainingEnvironment> Class,
const FName Name = TEXT("TrainingEnvironment"));
/**
* Initializes the training environment and runs the setup functions for rewards and completions.
*/
UFUNCTION(BlueprintCallable, Category = "LearningAgents", meta = (AutoCreateRefTerm = "TrainerSettings"))
void SetupTrainingEnvironment(UPARAM(ref) ULearningAgentsManager*& InManager);
public:
//~ Begin ULearningAgentsManagerListener Interface
virtual void OnAgentsAdded_Implementation(const TArray<int32>& AgentIds) override;
virtual void OnAgentsRemoved_Implementation(const TArray<int32>& AgentIds) override;
virtual void OnAgentsReset_Implementation(const TArray<int32>& AgentIds) override;
virtual void OnAgentsManagerTick_Implementation(const TArray<int32>& AgentIds, const float DeltaTime) override;
//~ End ULearningAgentsManagerListener Interface
// ----- Rewards -----
public:
/**
* This callback should be overridden by the Trainer and gathers the reward value for the given agent.
*
* @param OutReward Output reward for the given agent.
* @param AgentId Agent id to gather reward for.
*/
UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction))
void GatherAgentReward(float& OutReward, const int32 AgentId);
/**
* This callback can be overridden by the Trainer and gathers all the reward values for the given set of agents. By default this will call
* GatherAgentReward on each agent.
*
* @param OutRewards Output rewards for each agent in AgentIds
* @param AgentIds Agents to gather rewards for.
*/
UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction))
void GatherAgentRewards(TArray<float>& OutRewards, const TArray<int32>& AgentIds);
// ----- Completions -----
public:
/**
* This callback should be overridden by the Trainer and gathers the completion for a given agent.
*
* @param OutCompletion Output completion for the given agent.
* @param AgentId Agent id to gather completion for.
*/
UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction))
void GatherAgentCompletion(ELearningAgentsCompletion& OutCompletion, const int32 AgentId);
/**
* This callback can be overridden by the Trainer and gathers all the completions for the given set of agents. By default this will call
* GatherAgentCompletion on each agent.
*
* @param OutCompletions Output completions for each agent in AgentIds
* @param AgentIds Agents to gather completions for.
*/
UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction))
void GatherAgentCompletions(TArray<ELearningAgentsCompletion>& OutCompletions, const TArray<int32>& AgentIds);
// ----- Resets -----
public:
/**
* This callback should be overridden by the Trainer and resets the episode for the given agent.
*
* @param AgentId The id of the agent that need resetting.
*/
UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction))
void ResetAgentEpisode(const int32 AgentId);
/**
* This callback can be overridden by the Trainer and resets all episodes for each agent in the given set. By default this will call
* ResetAgentEpisode on each agent.
*
* @param AgentIds The ids of the agents that need resetting.
*/
UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction))
void ResetAgentEpisodes(const TArray<int32>& AgentIds);
// ----- Training Process -----
public:
/**
* Call this function when it is time to evaluate the rewards for your agents. This should be done at the beginning
* of each iteration of your training loop after the initial step, i.e. after taking an action, you want to get into
* the next state before evaluating the rewards.
*/
UFUNCTION(BlueprintCallable, Category = "LearningAgents")
void GatherRewards();
/**
* Call this function when it is time to evaluate the completions for your agents. This should be done at the beginning
* of each iteration of your training loop after the initial step, i.e. after taking an action, you want to get into
* the next state before evaluating the completions.
*/
UFUNCTION(BlueprintCallable, Category = "LearningAgents")
void GatherCompletions();
/**
* Returns true if GatherRewards has been called and the reward already set for the given agent.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1"))
bool HasReward(const int32 AgentId) const;
/**
* Returns true if GatherCompletions has been called and the completion already set for the given agent.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1"))
bool HasCompletion(const int32 AgentId) const;
/**
* Gets the current reward for an agent. Should be called only after GatherRewards.
*
* @param AgentId The AgentId to look-up the reward for
* @returns The reward
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1"))
float GetReward(const int32 AgentId) const;
/**
* Gets the current completion for an agent. Should be called only after GatherCompletions.
*
* @param AgentId The AgentId to look-up the completion for
* @returns The completion type
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1"))
ELearningAgentsCompletion GetCompletion(const int32 AgentId) const;
/**
* Gets the current elapsed episode time for the given agent.
*
* @param AgentId The AgentId to look-up the episode time for
* @returns The elapsed episode time
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1"))
float GetEpisodeTime(const int32 AgentId) const;
// ----- Non-blueprint public interface -----
public:
/** Gets the rewards as a const array view. */
const TLearningArrayView<1, const float> GetRewardArrayView() const;
/** Gets the reward iteration value for the given agent id. */
uint64 GetRewardIteration(const int32 AgentId) const;
/** Gets the agent completion mode for the given agent id. */
UE::Learning::ECompletionMode GetAgentCompletion(const int32 AgentId) const;
/** Gets the agent completions as a const array view. */
const TLearningArrayView<1, const UE::Learning::ECompletionMode> GetAgentCompletions() const;
/** Gets the all completions as a const array view. */
const TLearningArrayView<1, const UE::Learning::ECompletionMode> GetAllCompletions() const;
/** Computes a combined completion buffer for agents that have been completed manually and those which have reached the maximum episode length. */
void SetAllCompletions(UE::Learning::FIndexSet AgentSet);
/** Gets the episode completions as a mutable array view. */
TLearningArrayView<1, UE::Learning::ECompletionMode> GetEpisodeCompletions();
/** Gets the completion iteration value for the given agent id. */
uint64 GetCompletionIteration(const int32 AgentId) const;
UE::Learning::FResetInstanceBuffer& GetResetBuffer() const;
// ----- Private Data -----
private:
/** Callback Reward Output */
TArray<float> RewardBuffer;
/** Callback Completion Output */
TArray<ELearningAgentsCompletion> CompletionBuffer;
/** Reward Buffer */
TLearningArray<1, float> Rewards;
/** Agent Completions Buffer */
TLearningArray<1, UE::Learning::ECompletionMode> AgentCompletions;
/** Episode Completions Buffer */
TLearningArray<1, UE::Learning::ECompletionMode> EpisodeCompletions;
/** All Completions Buffer */
TLearningArray<1, UE::Learning::ECompletionMode> AllCompletions;
/** Agent episode times */
TLearningArray<1, float> EpisodeTimes;
TUniquePtr<UE::Learning::FResetInstanceBuffer> ResetBuffer;
/** Number of times rewards have been evaluated for all agents */
TLearningArray<1, uint64, TInlineAllocator<32>> RewardIteration;
/** Number of times completions have been evaluated for all agents */
TLearningArray<1, uint64, TInlineAllocator<32>> CompletionIteration;
};