Files
UnrealEngine/Engine/Plugins/Experimental/LearningAgents/Source/Learning/Private/LearningCritic.cpp
2025-05-18 13:04:45 +08:00

64 lines
2.1 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "LearningCritic.h"
namespace UE::Learning
{
FNeuralNetworkCritic::FNeuralNetworkCritic(
const int32 InMaxInstanceNum,
const int32 InObservationEncodedNum,
const int32 InMemoryStateNum,
const TSharedPtr<FNeuralNetwork>& InNeuralNetwork,
const FNeuralNetworkInferenceSettings& InInferenceSettings)
: ObservationEncodedNum(InObservationEncodedNum)
, MemoryStateNum(InMemoryStateNum)
{
check(InNeuralNetwork->GetInputSize() == ObservationEncodedNum + MemoryStateNum);
check(InNeuralNetwork->GetOutputSize() == 1);
Input.SetNumUninitialized({ InMaxInstanceNum, ObservationEncodedNum + MemoryStateNum });
Array::Zero(Input);
NeuralNetworkFunction = MakeShared<FNeuralNetworkFunction>(InMaxInstanceNum, InNeuralNetwork, InInferenceSettings);
}
void FNeuralNetworkCritic::Evaluate(
TLearningArrayView<1, float> OutputReturns,
const TLearningArrayView<2, const float> InputObservationVectorsEncoded,
const TLearningArrayView<2, const float> InputMemoryState,
const FIndexSet Instances)
{
TRACE_CPUPROFILER_EVENT_SCOPE(Learning::FNeuralNetworkCritic::Evaluate);
Array::Check(InputObservationVectorsEncoded, Instances);
Array::Check(InputMemoryState, Instances);
// Copy in Observation and Memory State into network input
for (const int32 InstanceIdx : Instances)
{
Array::Copy(Input[InstanceIdx].Slice(0, ObservationEncodedNum), InputObservationVectorsEncoded[InstanceIdx]);
Array::Copy(Input[InstanceIdx].Slice(ObservationEncodedNum, MemoryStateNum), InputMemoryState[InstanceIdx]);
}
// Evaluate Network
NeuralNetworkFunction->Evaluate(
TLearningArrayView<2, float>(OutputReturns.GetData(), { OutputReturns.Num(), 1 }),
Input,
Instances);
Array::Check(OutputReturns, Instances);
}
void FNeuralNetworkCritic::UpdateNeuralNetwork(const TSharedPtr<FNeuralNetwork>& NewNeuralNetwork)
{
NeuralNetworkFunction->UpdateNeuralNetwork(NewNeuralNetwork);
}
const TSharedPtr<FNeuralNetwork>& FNeuralNetworkCritic::GetNeuralNetwork() const
{
return NeuralNetworkFunction->GetNeuralNetwork();
}
}