Files
UnrealEngine/Engine/Plugins/Experimental/LearningAgents/Source/Learning/Public/LearningOptimizationPlanner.h
2025-05-18 13:04:45 +08:00

135 lines
6.3 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "LearningArray.h"
#include "LearningLog.h"
namespace UE::Learning
{
struct IOptimizer;
struct FProgress;
/**
* Basic buffer containing some additional data used during optimization planning
*/
struct LEARNING_API FOptimizationPlannerBuffer
{
void Resize(
const int32 SampleNum,
const int32 StepNum,
const int32 ActionVectorDimensionNum);
TLearningArray<3, float> Samples;
TLearningArray<1, float> Losses;
};
namespace OptimizationPlanner
{
/**
* Runs a sequence of action vectors for a set instances
*
* @param ActionVectorBuffer Buffer to write action vectors into
* @param ResetFunction Function to run for resetting the environment
* @param ActionFunction Function to run for evaluating actions
* @param UpdateFunction Function to run for updating the environment
* @param ActionVectors Action Vectors of shape (StepNum, InstanceNum, ActionVectorDimNum)
* @param Instances Set of instances to run action vectors for
*/
LEARNING_API void RunPlan(
TLearningArrayView<2, float> ActionVectorBuffer,
const TFunctionRef<void(const FIndexSet Instances)> ResetFunction,
const TFunctionRef<void(const FIndexSet Instances)> ActionFunction,
const TFunctionRef<void(const FIndexSet Instances)> UpdateFunction,
const TLearningArrayView<3, const float> ActionVectors,
const FIndexSet Instances);
/**
* Run the optimization based planner on the provided action vectors to maximize the reward
*
* @param InOutActionVectors Action vectors to adjust of shape (StepNum, ActionVectorDimNum)
* @param OptimizationPlannerBuffer Buffer for the optimization planner
* @param Optimizer Optimizer to use
* @param ActionVectorBuffer Buffer to read/write action vectors into
* @param RewardBuffer Buffer to read/write rewards into
* @param IterationNum Number of iterations to run the planner for
* @param ResetFunction Function to run for resetting the environment
* @param ActionFunction Function to run for evaluating actions
* @param UpdateFunction Function to run for updating the environment
* @param RewardFunction Function to run for evaluating rewards
* @param SampleInstances Instances to use for sampling during the optimization
* @param LogSettings Log settings to use
* @param Progress Optional progress to record progress in
* @param ActionVectorsLock Optional lock to use when updating InOutActionVectors
* @param bActionVectorsUpdatedFlag Optional signal to set when updating InOutActionVectors
*/
LEARNING_API void Plan(
TLearningArrayView<2, float> InOutActionVectors,
FOptimizationPlannerBuffer& OptimizationPlannerBuffer,
IOptimizer& Optimizer,
TLearningArrayView<2, float> ActionVectorBuffer,
TLearningArrayView<1, float> RewardBuffer,
const int32 IterationNum,
const TFunctionRef<void(const FIndexSet Instances)> ResetFunction,
const TFunctionRef<void(const FIndexSet Instances)> ActionFunction,
const TFunctionRef<void(const FIndexSet Instances)> UpdateFunction,
const TFunctionRef<void(const FIndexSet Instances)> RewardFunction,
const FIndexSet SampleInstances,
const ELogSetting LogSettings = ELogSetting::Normal,
FProgress* Progress = nullptr,
FRWLock* ActionVectorsLock = nullptr,
TAtomic<bool>* bActionVectorsUpdatedFlag = nullptr);
/**
* Run the optimization based planner windowed on the provided action vectors to maximize the reward.
*
* In windowed mode the optimization planner will only optimize a fixed window of steps over the action
* vectors at a time. This gives the optimizer an easier job, but limits how for the system can plan ahead.
*
* @param InOutActionVectors Action vectors to adjust of shape (StepNum, ActionVectorDimNum)
* @param OptimizationPlannerBuffer Buffer for the optimization planner
* @param Optimizer Optimizer to use
* @param ActionVectorBuffer Buffer to read/write action vectors into
* @param RewardBuffer Buffer to read/write rewards into
* @param InitialIterationsNum Number of iterations to run the planner for on the initial window
* @param FinalIterationsNum Number of iterations to run the planner for on the final window
* @param WindowIterationsNum Number of iterations to run the planner for on the intermediate windows
* @param WindowStepNum Number of steps to consider in the window
* @param ResetFunction Function to run for resetting the environment
* @param ResetFromInstanceFunction Function to run to reset the planning instances from the state of single instance
* @param ActionFunction Function to run for evaluating actions
* @param UpdateFunction Function to run for updating the environment
* @param RewardFunction Function to run for evaluating rewards
* @param PlanInstance The instance used to execute the plan. Should not be included in the SampleInstances.
* @param SampleInstances Instances to use for sampling during the optimization
* @param LogSettings Log settings to use
* @param Progress Optional progress to record progress in
* @param ActionVectorsLock Optional lock to use when updating InOutActionVectors
* @param bActionVectorsUpdatedFlag Optional signal to set when updating InOutActionVectors
*/
LEARNING_API void PlanWindowed(
TLearningArrayView<2, float> InOutActionVectors,
FOptimizationPlannerBuffer& OptimizationPlannerBuffer,
IOptimizer& Optimizer,
TLearningArrayView<2, float> ActionVectorBuffer,
TLearningArrayView<1, float> RewardBuffer,
const int32 InitialIterationsNum,
const int32 FinalIterationsNum,
const int32 WindowIterationsNum,
const int32 WindowStepNum,
const TFunctionRef<void(const FIndexSet Instances)> ResetFunction,
const TFunctionRef<void(const FIndexSet Instances, const int32 Instance)> ResetFromInstanceFunction,
const TFunctionRef<void(const FIndexSet Instances)> ActionFunction,
const TFunctionRef<void(const FIndexSet Instances)> UpdateFunction,
const TFunctionRef<void(const FIndexSet Instances)> RewardFunction,
const int32 PlanInstance,
const FIndexSet SampleInstances,
const ELogSetting LogSettings = ELogSetting::Normal,
FProgress* Progress = nullptr,
FRWLock* ActionVectorsLock = nullptr,
TAtomic<bool>* bActionVectorsUpdatedFlag = nullptr);
}
}