// Copyright Epic Games, Inc. All Rights Reserved. #pragma once #include "LearningAgentsManagerListener.h" #include "LearningAgentsCompletions.h" #include "LearningAgentsTrainingEnvironment.generated.h" namespace UE::Learning { struct FResetInstanceBuffer; } UCLASS(Abstract, HideDropdown, BlueprintType, Blueprintable) class LEARNINGAGENTSTRAINING_API ULearningAgentsTrainingEnvironment : public ULearningAgentsManagerListener { GENERATED_BODY() // ----- Setup ----- public: // These constructors/destructors are needed to make forward declarations happy ULearningAgentsTrainingEnvironment(); ULearningAgentsTrainingEnvironment(FVTableHelper& Helper); virtual ~ULearningAgentsTrainingEnvironment(); /** * Constructs the training environment and runs the setup functions for rewards and completions. */ UFUNCTION(BlueprintCallable, Category = "LearningAgents", meta = (DeterminesOutputType = "Class", AutoCreateRefTerm = "TrainerSettings")) static ULearningAgentsTrainingEnvironment* MakeTrainingEnvironment( UPARAM(ref) ULearningAgentsManager*& InManager, TSubclassOf Class, const FName Name = TEXT("TrainingEnvironment")); /** * Initializes the training environment and runs the setup functions for rewards and completions. */ UFUNCTION(BlueprintCallable, Category = "LearningAgents", meta = (AutoCreateRefTerm = "TrainerSettings")) void SetupTrainingEnvironment(UPARAM(ref) ULearningAgentsManager*& InManager); public: //~ Begin ULearningAgentsManagerListener Interface virtual void OnAgentsAdded_Implementation(const TArray& AgentIds) override; virtual void OnAgentsRemoved_Implementation(const TArray& AgentIds) override; virtual void OnAgentsReset_Implementation(const TArray& AgentIds) override; virtual void OnAgentsManagerTick_Implementation(const TArray& AgentIds, const float DeltaTime) override; //~ End ULearningAgentsManagerListener Interface // ----- Rewards ----- public: /** * This callback should be overridden by the Trainer and gathers the reward value for the given agent. * * @param OutReward Output reward for the given agent. * @param AgentId Agent id to gather reward for. */ UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction)) void GatherAgentReward(float& OutReward, const int32 AgentId); /** * This callback can be overridden by the Trainer and gathers all the reward values for the given set of agents. By default this will call * GatherAgentReward on each agent. * * @param OutRewards Output rewards for each agent in AgentIds * @param AgentIds Agents to gather rewards for. */ UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction)) void GatherAgentRewards(TArray& OutRewards, const TArray& AgentIds); // ----- Completions ----- public: /** * This callback should be overridden by the Trainer and gathers the completion for a given agent. * * @param OutCompletion Output completion for the given agent. * @param AgentId Agent id to gather completion for. */ UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction)) void GatherAgentCompletion(ELearningAgentsCompletion& OutCompletion, const int32 AgentId); /** * This callback can be overridden by the Trainer and gathers all the completions for the given set of agents. By default this will call * GatherAgentCompletion on each agent. * * @param OutCompletions Output completions for each agent in AgentIds * @param AgentIds Agents to gather completions for. */ UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction)) void GatherAgentCompletions(TArray& OutCompletions, const TArray& AgentIds); // ----- Resets ----- public: /** * This callback should be overridden by the Trainer and resets the episode for the given agent. * * @param AgentId The id of the agent that need resetting. */ UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction)) void ResetAgentEpisode(const int32 AgentId); /** * This callback can be overridden by the Trainer and resets all episodes for each agent in the given set. By default this will call * ResetAgentEpisode on each agent. * * @param AgentIds The ids of the agents that need resetting. */ UFUNCTION(BlueprintNativeEvent, Category = "LearningAgents", Meta = (ForceAsFunction)) void ResetAgentEpisodes(const TArray& AgentIds); // ----- Training Process ----- public: /** * Call this function when it is time to evaluate the rewards for your agents. This should be done at the beginning * of each iteration of your training loop after the initial step, i.e. after taking an action, you want to get into * the next state before evaluating the rewards. */ UFUNCTION(BlueprintCallable, Category = "LearningAgents") void GatherRewards(); /** * Call this function when it is time to evaluate the completions for your agents. This should be done at the beginning * of each iteration of your training loop after the initial step, i.e. after taking an action, you want to get into * the next state before evaluating the completions. */ UFUNCTION(BlueprintCallable, Category = "LearningAgents") void GatherCompletions(); /** * Returns true if GatherRewards has been called and the reward already set for the given agent. */ UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1")) bool HasReward(const int32 AgentId) const; /** * Returns true if GatherCompletions has been called and the completion already set for the given agent. */ UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1")) bool HasCompletion(const int32 AgentId) const; /** * Gets the current reward for an agent. Should be called only after GatherRewards. * * @param AgentId The AgentId to look-up the reward for * @returns The reward */ UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1")) float GetReward(const int32 AgentId) const; /** * Gets the current completion for an agent. Should be called only after GatherCompletions. * * @param AgentId The AgentId to look-up the completion for * @returns The completion type */ UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1")) ELearningAgentsCompletion GetCompletion(const int32 AgentId) const; /** * Gets the current elapsed episode time for the given agent. * * @param AgentId The AgentId to look-up the episode time for * @returns The elapsed episode time */ UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AgentId = "-1")) float GetEpisodeTime(const int32 AgentId) const; // ----- Non-blueprint public interface ----- public: /** Gets the rewards as a const array view. */ const TLearningArrayView<1, const float> GetRewardArrayView() const; /** Gets the reward iteration value for the given agent id. */ uint64 GetRewardIteration(const int32 AgentId) const; /** Gets the agent completion mode for the given agent id. */ UE::Learning::ECompletionMode GetAgentCompletion(const int32 AgentId) const; /** Gets the agent completions as a const array view. */ const TLearningArrayView<1, const UE::Learning::ECompletionMode> GetAgentCompletions() const; /** Gets the all completions as a const array view. */ const TLearningArrayView<1, const UE::Learning::ECompletionMode> GetAllCompletions() const; /** Computes a combined completion buffer for agents that have been completed manually and those which have reached the maximum episode length. */ void SetAllCompletions(UE::Learning::FIndexSet AgentSet); /** Gets the episode completions as a mutable array view. */ TLearningArrayView<1, UE::Learning::ECompletionMode> GetEpisodeCompletions(); /** Gets the completion iteration value for the given agent id. */ uint64 GetCompletionIteration(const int32 AgentId) const; UE::Learning::FResetInstanceBuffer& GetResetBuffer() const; // ----- Private Data ----- private: /** Callback Reward Output */ TArray RewardBuffer; /** Callback Completion Output */ TArray CompletionBuffer; /** Reward Buffer */ TLearningArray<1, float> Rewards; /** Agent Completions Buffer */ TLearningArray<1, UE::Learning::ECompletionMode> AgentCompletions; /** Episode Completions Buffer */ TLearningArray<1, UE::Learning::ECompletionMode> EpisodeCompletions; /** All Completions Buffer */ TLearningArray<1, UE::Learning::ECompletionMode> AllCompletions; /** Agent episode times */ TLearningArray<1, float> EpisodeTimes; TUniquePtr ResetBuffer; /** Number of times rewards have been evaluated for all agents */ TLearningArray<1, uint64, TInlineAllocator<32>> RewardIteration; /** Number of times completions have been evaluated for all agents */ TLearningArray<1, uint64, TInlineAllocator<32>> CompletionIteration; };