Files
UnrealEngine/Engine/Plugins/Experimental/LearningAgents/Source/LearningAgentsTraining/Public/LearningAgentsCompletions.h
2025-05-18 13:04:45 +08:00

259 lines
14 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "Kismet/BlueprintFunctionLibrary.h"
#include "LearningAgentsCompletions.generated.h"
class ULearningAgentsManagerListener;
namespace UE::Learning
{
enum class ECompletionMode : uint8;
}
/** Completion modes for episodes. */
UENUM(BlueprintType, Category = "LearningAgents", meta = (ScriptName = "LearningAgentsCompletionEnum"))
enum class ELearningAgentsCompletion : uint8
{
/** Episode is still running. */
Running UMETA(DisplayName = "Running"),
/** Episode ended while in progress. Critic will be used to estimate final return. */
Truncation UMETA(DisplayName = "Truncation"),
/** Episode ended and zero reward was expected for all future steps. */
Termination UMETA(DisplayName = "Termination"),
};
namespace UE::Learning::Agents
{
/** Get the learning agents completion from the UE::Learning completion. */
LEARNINGAGENTSTRAINING_API ELearningAgentsCompletion GetLearningAgentsCompletion(const ECompletionMode CompletionMode);
/** Get the UE::Learning completion from the learning agents completion. */
LEARNINGAGENTSTRAINING_API ECompletionMode GetCompletionMode(const ELearningAgentsCompletion Completion);
}
UCLASS(BlueprintType)
class LEARNINGAGENTSTRAINING_API ULearningAgentsCompletions : public UBlueprintFunctionLibrary
{
GENERATED_BODY()
public:
/** Returns true if a completion is running, otherwise false. */
UFUNCTION(BlueprintPure, Category = "LearningAgents")
static bool IsCompletionRunning(const ELearningAgentsCompletion Completion);
/** Returns true if a completion is either truncated or terminated, otherwise false. */
UFUNCTION(BlueprintPure, Category = "LearningAgents")
static bool IsCompletionCompleted(const ELearningAgentsCompletion Completion);
/** Returns true if a completion is truncated, otherwise false. */
UFUNCTION(BlueprintPure, Category = "LearningAgents")
static bool IsCompletionTruncation(const ELearningAgentsCompletion Completion);
/** Returns true if a completion is terminated, otherwise false. */
UFUNCTION(BlueprintPure, Category = "LearningAgents")
static bool IsCompletionTermination(const ELearningAgentsCompletion Completion);
/** Returns a termination if either input is a termination, otherwise a truncation if either input is a truncation, otherwise returns running. */
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (CommutativeAssociativeBinaryOperator, DisplayName="Completion OR", CompactNodeTitle = "OR"))
static ELearningAgentsCompletion CompletionOr(ELearningAgentsCompletion A, ELearningAgentsCompletion B);
/** Returns a termination if both inputs are a termination, otherwise a truncation if both inputs are either a truncation or termination, otherwise returns running. */
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (CommutativeAssociativeBinaryOperator, DisplayName = "Completion AND", CompactNodeTitle="AND"))
static ELearningAgentsCompletion CompletionAnd(ELearningAgentsCompletion A, ELearningAgentsCompletion B);
/** Returns running if the input A is either a termination or truncation, otherwise returns the completion specified by NotRunningType */
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (DisplayName = "Completion NOT", CompactNodeTitle = "NOT"))
static ELearningAgentsCompletion CompletionNot(ELearningAgentsCompletion A, ELearningAgentsCompletion NotRunningType = ELearningAgentsCompletion::Termination);
/**
* Make a completion.
*
* @param CompletionType The type of completion to make.
* @param Tag The tag for the completion. Used for debugging.
* @param bVisualLoggerEnabled When true, debug data will be sent to the visual logger.
* @param VisualLoggerListener The listener object which is making this completion. This must be set to use logging.
* @param VisualLoggerAgentId The agent id associated with this completion.
* @param VisualLoggerLocation A location for the visual logger information in the world.
* @param VisualLoggerColor The color for the visual logger display.
* @return The resulting completion.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AdvancedDisplay = 1, DefaultToSelf = "VisualLoggerListener"))
static ELearningAgentsCompletion MakeCompletion(
const ELearningAgentsCompletion CompletionType = ELearningAgentsCompletion::Termination,
const FName Tag = TEXT("Completion"),
const bool bVisualLoggerEnabled = false,
ULearningAgentsManagerListener* VisualLoggerListener = nullptr,
const int32 VisualLoggerAgentId = -1,
const FVector VisualLoggerLocation = FVector::ZeroVector,
const FLinearColor VisualLoggerColor = FLinearColor::Yellow);
/**
* Make a completion based on some condition.
*
* @param bCondition When true, returns the given CompletionType, otherwise returns Running.
* @param CompletionType The type of completion to make.
* @param Tag The tag for the completion. Used for debugging.
* @param bVisualLoggerEnabled When true, debug data will be sent to the visual logger.
* @param VisualLoggerListener The listener object which is making this completion. This must be set to use logging.
* @param VisualLoggerAgentId The agent id associated with this completion.
* @param VisualLoggerLocation A location for the visual logger information in the world.
* @param VisualLoggerColor The color for the visual logger display.
* @return The resulting completion.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AdvancedDisplay = 2, DefaultToSelf = "VisualLoggerListener"))
static ELearningAgentsCompletion MakeCompletionOnCondition(
const bool bCondition,
const ELearningAgentsCompletion CompletionType = ELearningAgentsCompletion::Termination,
const FName Tag = TEXT("ConditionCompletion"),
const bool bVisualLoggerEnabled = false,
ULearningAgentsManagerListener* VisualLoggerListener = nullptr,
const int32 VisualLoggerAgentId = -1,
const FVector VisualLoggerLocation = FVector::ZeroVector,
const FLinearColor VisualLoggerColor = FLinearColor::Yellow);
/**
* Make a completion when a time goes above a threshold.
*
* @param Time The current time.
* @param TimeThreshold The time threshold above which to complete with the given CompletionType.
* @param CompletionType The type of completion to make
* @param Tag The tag for the completion. Used for debugging.
* @param bVisualLoggerEnabled When true, debug data will be sent to the visual logger.
* @param VisualLoggerListener The listener object which is making this completion. This must be set to use logging.
* @param VisualLoggerAgentId The agent id associated with this completion.
* @param VisualLoggerLocation A location for the visual logger information in the world.
* @param VisualLoggerColor The color for the visual logger display.
* @return The resulting completion.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AdvancedDisplay = 3, DefaultToSelf = "VisualLoggerListener"))
static ELearningAgentsCompletion MakeCompletionOnTimeElapsed(
const float Time,
const float TimeThreshold = 10.0f,
const ELearningAgentsCompletion CompletionType = ELearningAgentsCompletion::Truncation,
const FName Tag = TEXT("TimeElapsedCompletion"),
const bool bVisualLoggerEnabled = false,
ULearningAgentsManagerListener* VisualLoggerListener = nullptr,
const int32 VisualLoggerAgentId = -1,
const FVector VisualLoggerLocation = FVector::ZeroVector,
const FLinearColor VisualLoggerColor = FLinearColor::Yellow);
/**
* Make a completion when the number of episode steps recorded exceeds some threshold.
*
* @param EpisodeSteps The number of steps recorded.
* @param MaxEpisodeSteps The step threshold above which to complete with the given CompletionType.
* @param CompletionType The type of completion to make
* @param Tag The tag for the completion. Used for debugging.
* @param bVisualLoggerEnabled When true, debug data will be sent to the visual logger.
* @param VisualLoggerListener The listener object which is making this completion. This must be set to use logging.
* @param VisualLoggerAgentId The agent id associated with this completion.
* @param VisualLoggerLocation A location for the visual logger information in the world.
* @param VisualLoggerColor The color for the visual logger display.
* @return The resulting completion.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AdvancedDisplay = 3, DefaultToSelf = "VisualLoggerListener"))
static ELearningAgentsCompletion MakeCompletionOnEpisodeStepsRecorded(
const int32 EpisodeSteps,
const int32 MaxEpisodeSteps = 64,
const ELearningAgentsCompletion CompletionType = ELearningAgentsCompletion::Truncation,
const FName Tag = TEXT("EpisodeStepsRecordedCompletion"),
const bool bVisualLoggerEnabled = false,
ULearningAgentsManagerListener* VisualLoggerListener = nullptr,
const int32 VisualLoggerAgentId = -1,
const FVector VisualLoggerLocation = FVector::ZeroVector,
const FLinearColor VisualLoggerColor = FLinearColor::Yellow);
/**
* Make a completion when the distance between two locations is below some threshold.
*
* @param LocationA The first location.
* @param LocationB The second location.
* @param DistanceThreshold The distance threshold.
* @param CompletionType The type of completion to make
* @param Tag The tag for the completion. Used for debugging.
* @param bVisualLoggerEnabled When true, debug data will be sent to the visual logger.
* @param VisualLoggerListener The listener object which is making this completion. This must be set to use logging.
* @param VisualLoggerAgentId The agent id associated with this completion.
* @param VisualLoggerLocation A location for the visual logger information in the world.
* @param VisualLoggerColor The color for the visual logger display.
* @return The resulting completion.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AdvancedDisplay = 4, DefaultToSelf = "VisualLoggerListener"))
static ELearningAgentsCompletion MakeCompletionOnLocationDifferenceBelowThreshold(
const FVector LocationA,
const FVector LocationB,
const float DistanceThreshold = 100.0f,
const ELearningAgentsCompletion CompletionType = ELearningAgentsCompletion::Termination,
const FName Tag = TEXT("LocationDifferenceBelowThresholdCompletion"),
const bool bVisualLoggerEnabled = false,
ULearningAgentsManagerListener* VisualLoggerListener = nullptr,
const int32 VisualLoggerAgentId = -1,
const FVector VisualLoggerLocation = FVector::ZeroVector,
const FLinearColor VisualLoggerColor = FLinearColor::Yellow);
/**
* Make a completion when the distance between two locations is above some threshold.
*
* @param LocationA The first location.
* @param LocationB The second location.
* @param DistanceThreshold The distance threshold.
* @param CompletionType The type of completion to make
* @param Tag The tag for the completion. Used for debugging.
* @param bVisualLoggerEnabled When true, debug data will be sent to the visual logger.
* @param VisualLoggerListener The listener object which is making this completion. This must be set to use logging.
* @param VisualLoggerAgentId The agent id associated with this completion.
* @param VisualLoggerLocation A location for the visual logger information in the world.
* @param VisualLoggerColor The color for the visual logger display.
* @return The resulting completion.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AdvancedDisplay = 4, DefaultToSelf = "VisualLoggerListener"))
static ELearningAgentsCompletion MakeCompletionOnLocationDifferenceAboveThreshold(
const FVector LocationA,
const FVector LocationB,
const float DistanceThreshold = 100.0f,
const ELearningAgentsCompletion CompletionType = ELearningAgentsCompletion::Termination,
const FName Tag = TEXT("LocationDifferenceAboveThresholdCompletion"),
const bool bVisualLoggerEnabled = false,
ULearningAgentsManagerListener* VisualLoggerListener = nullptr,
const int32 VisualLoggerAgentId = -1,
const FVector VisualLoggerLocation = FVector::ZeroVector,
const FLinearColor VisualLoggerColor = FLinearColor::Yellow);
/**
* Make a completion when a location moves outside of sound bounds.
*
* @param Location The location.
* @param BoundsTransform The transform of the bounds object.
* @param BoundsMins The minimums of the bounds object.
* @param BoundsMaxs The maximums of the bounds object.
* @param CompletionType The type of completion to make
* @param Tag The tag for the completion. Used for debugging.
* @param bVisualLoggerEnabled When true, debug data will be sent to the visual logger.
* @param VisualLoggerListener The listener object which is making this completion. This must be set to use logging.
* @param VisualLoggerAgentId The agent id associated with this completion.
* @param VisualLoggerLocation A location for the visual logger information in the world.
* @param VisualLoggerColor The color for the visual logger display.
* @return The resulting completion.
*/
UFUNCTION(BlueprintPure, Category = "LearningAgents", meta = (AdvancedDisplay = 5, DefaultToSelf = "VisualLoggerListener"))
static ELearningAgentsCompletion MakeCompletionOnLocationOutsideBounds(
const FVector Location,
const FTransform BoundsTransform = FTransform(),
const FVector BoundsMins = FVector(-100.0f, -100.0f, -100.0f),
const FVector BoundsMaxs = FVector(+100.0f, +100.0f, +100.0f),
const ELearningAgentsCompletion CompletionType = ELearningAgentsCompletion::Termination,
const FName Tag = TEXT("LocationOutsideBoundsCompletion"),
const bool bVisualLoggerEnabled = false,
ULearningAgentsManagerListener* VisualLoggerListener = nullptr,
const int32 VisualLoggerAgentId = -1,
const FVector VisualLoggerLocation = FVector::ZeroVector,
const FLinearColor VisualLoggerColor = FLinearColor::Yellow);
};