Files
UnrealEngine/Engine/Plugins/Experimental/NeuralRendering/Source/NeuralPostProcessing/Private/NeuralPostProcessModelInstance.h
2025-05-18 13:04:45 +08:00

102 lines
3.1 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "NNE.h"
#include "NNETypes.h"
#include "NNERuntimeRDG.h"
#include "NNEModelData.h"
#include "NeuralPostProcessModelInstance.generated.h"
#if WITH_EDITOR
DECLARE_LOG_CATEGORY_EXTERN(LogNeuralPostProcessing, Log, All);
#endif
//Util functions
TSharedPtr<UE::NNE::IModelInstanceRDG> CreateNNEModelInstance(UNNEModelData* NNEModelData, FString RuntimeName);
enum class ENeuralModelTileType : uint8;
enum class ETileOverlapResolveType : uint8;
UCLASS()
class UNeuralPostProcessModelInstance : public UObject
{
GENERATED_UCLASS_BODY()
public:
// Engine thread
// If the model data and the Runtime name is different, recreate
// the model
void Update(UNNEModelData* NNEModelData, FString RuntimeName);
// Render thread functions
void CreateRDGBuffers(class FRDGBuilder& GraphBuilder);
// Create the input and output buffer on demand
// useful to use multiple post processing material to read/write
void CreateRDGBuffersIfNeeded(class FRDGBuilder& GraphBuilder, bool bForceCreate = false);
void Execute(class FRDGBuilder& GraphBuilder);
UE::NNE::FTensorShape GetResolvedInputTensorShape();
UE::NNE::FTensorShape GetResolvedOutputTensorShape();
//TODO: different type support
FRDGBufferRef GetInputBuffer();
FRDGBufferRef GetOutputBuffer();
FRDGBufferRef GetTiledInputBuffer();
FRDGBufferRef GetTiledOutputBuffer();
void UpdateDispatchSize(int InDispatchSize){ DispatchSize = InDispatchSize;}
int GetDispatchSize() const{ return DispatchSize;}
void UpdateTileDimension(FIntPoint InTileDim) { TileDim = InTileDim; };
FIntPoint GetTileDimension() const { return TileDim; };
void UpdateModelTileType(ENeuralModelTileType InTileType) { ModelTileSize = InTileType; }
ENeuralModelTileType GetModelTileType() const { return ModelTileSize; }
void UpdateTileOverlap(FIntPoint InTileOverlap){ TileOverlap = InTileOverlap;}
FIntPoint GetTileOverlap() const { return TileOverlap;}
void UpdateTileOverlapResolveType(ETileOverlapResolveType InTileOverlapResolveType) { TileOverlapResolveType = InTileOverlapResolveType; }
ETileOverlapResolveType GetTileOverlapResolveType() const { return TileOverlapResolveType; }
bool IsValid() { return ModelInstanceRDG.IsValid();}
bool ModifyInputShape(int Dim, int Size);
private:
void CreateDefaultNNEModel(UNNEModelData* NNEModelData, FString RuntimeName);
// Input Buffer for neural post process network pass
FRDGBufferRef RDGInputBuffer;
// Output Buffer for Neural post process network pass
FRDGBufferRef RDGOutputBuffer;
// Temporary buffers to emulate batchsize
FRDGBufferRef RDGTiledInputBuffer;
FRDGBufferRef RDGTiledOutputBuffer;
ENeuralModelTileType ModelTileSize;
int DispatchSize = 1;
FIntPoint TileDim = FIntPoint(1,1);
// The NNE RDG Model
TSharedPtr<UE::NNE::IModelInstanceRDG> ModelInstanceRDG;
// Resolved Input Tensorshape
UE::NNE::FTensorShape ResolvedInputTensorShape;
// Output Tensorshape
UE::NNE::FTensorShape ResolvedOutputTensorShape;
FIntVector4 DimensionOverride;
FIntPoint TileOverlap;
ETileOverlapResolveType TileOverlapResolveType;
};