Files
UnrealEngine/Engine/Plugins/Experimental/LearningAgents/Source/LearningTraining/Public/LearningTrainer.h
2025-05-18 13:04:45 +08:00

215 lines
5.2 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "LearningArray.h"
#include "LearningLog.h"
#include "Templates/SharedPointer.h"
#include "HAL/PlatformProcess.h"
class FJsonObject;
namespace UE::Learning
{
namespace Observation
{
struct FSchema;
struct FSchemaElement;
}
namespace Action
{
struct FSchema;
struct FSchemaElement;
}
/**
* Training Device
*/
enum class ETrainerDevice : uint8
{
CPU = 0,
GPU = 1,
};
/**
* Type of response from a Trainer
*/
enum class ETrainerResponse : uint8
{
// The communication was successful
Success = 0,
// The communication send or received was unexpected
Unexpected = 1,
// Training is complete
Completed = 2,
// Training is stopped
Stopped = 3,
// The communication timed-out
Timeout = 4,
// The communication timed-out for a network signal
NetworkSignalTimeout = 5,
};
/**
* Subprocess flags
*/
enum class ESubprocessFlags : uint8
{
None = 0,
// If to show the sub-process console window
ShowWindow = 1 << 0,
// If to avoid redirecting the sub-process output to the output log
NoRedirectOutput = 1 << 1,
};
ENUM_CLASS_FLAGS(ESubprocessFlags)
/**
* Simple managed subprocess similar to FMonitoredProcess
*/
struct LEARNINGTRAINING_API FSubprocess
{
/** Will Terminate the subprocess if it is running. */
~FSubprocess();
/**
* Launches a new subprocess.
*
* @param Path The path of the executable to launch.
* @param Params The command line parameters.
* @param Flags Subprocess flags.
* @returns true if the launch was successful, otherwise false
*/
bool Launch(const FString& Path, const FString& Params, const ESubprocessFlags Flags);
/**
* Return true if the Subprocess is launched and running, otherwise false.
*/
bool IsRunning() const;
/**
* Terminates the subprocess
*/
void Terminate();
/**
* Outputs anything the subprocess has written to stdout to the log line-by-line and returns true if the subprocess is still running.
*/
bool Update();
private:
// Buffer for the subprocess' stdout
FString OutputBuffer;
// If a subprocess has been launched
bool bIsLaunched = false;
// Subprocess handle
FProcHandle ProcessHandle;
// Read pipe for subprocess stdout
void* ReadPipe = nullptr;
// Write pipe for subprocess stdin
void* WritePipe = nullptr;
};
namespace Trainer
{
/**
* Default Timeout to use during communication.
*/
static constexpr float DefaultTimeout = 10.0f;
/**
* Default Log Settings to use during communication.
*/
static constexpr ELogSetting DefaultLogSettings = ELogSetting::Normal;
/**
* Default IP to use for networked training
*/
static constexpr const TCHAR* DefaultIp = TEXT("127.0.0.1");
/**
* Default Port to use for networked training
*/
static constexpr uint32 DefaultPort = 48491;
/**
* Converts a ETrainerDevice into a string.
*/
LEARNINGTRAINING_API const TCHAR* GetDeviceString(const ETrainerDevice Device);
/**
* Converts a ETrainerResponse into a string for use in logging and error messages.
*/
LEARNINGTRAINING_API const TCHAR* GetResponseString(const ETrainerResponse Response);
/**
* Compute the discount factor that corresponds to a particular HalfLife and DeltaTime.
*
* @param HalfLife Time by which the reward should be discounted by half
* @param DeltaTime DeltaTime taken upon each step of the environment
* @returns Corresponding discount factor
*/
LEARNINGTRAINING_API float DiscountFactorFromHalfLife(const float HalfLife, const float DeltaTime);
/**
* Compute the discount factor that corresponds to a particular HalfLife provided in terms of number of steps
*
* @param HalfLifeSteps Number of steps taken at which the reward should be discounted by half
* @returns Corresponding discount factor
*/
LEARNINGTRAINING_API float DiscountFactorFromHalfLifeSteps(const int32 HalfLifeSteps);
/**
* Gets the python executable path from the engine directory.
*/
LEARNINGTRAINING_API FString GetPythonExecutablePath(const FString& EngineDir);
/**
* Gets the PythonFoundationPackages site-packages path from the engine directory.
*/
LEARNINGTRAINING_API FString GetSitePackagesPath(const FString& EngineDir);
/**
* Gets the LearningAgents Content path from the engine directory.
*/
LEARNINGTRAINING_API FString GetPythonContentPath(const FString& EngineDir);
/**
* Gets the project's Python Content path.
*/
LEARNINGTRAINING_API FString GetProjectPythonContentPath();
/**
* Gets the LearningAgents Intermediate path from the intermediate directory.
*/
LEARNINGTRAINING_API FString GetIntermediatePath(const FString& IntermediateDir);
/**
* Converts an observation schema element into a JSON representation.
*/
LEARNINGTRAINING_API TSharedPtr<FJsonObject> ConvertObservationSchemaToJSON(
const Observation::FSchema& ObservationSchema,
const Observation::FSchemaElement& ObservationSchemaElement);
/**
* Converts an action schema element into a JSON representation.
*/
LEARNINGTRAINING_API TSharedPtr<FJsonObject> ConvertActionSchemaToJSON(
const Action::FSchema& ActionSchema,
const Action::FSchemaElement& ActionSchemaElement);
}
}