// Copyright Epic Games, Inc. All Rights Reserved. #include "NeuralPostProcessing.h" #include "Interfaces/IPluginManager.h" #include "NNE.h" #include "NNEModelData.h" #include "NNERuntimeRDG.h" #include "Engine/NeuralProfile.h" #include "NeuralPostProcessModelInstance.h" #include "NeuralPostProcessingCS.h" #include "PostProcess/NeuralPostProcessInterface.h" #include "PixelShaderUtils.h" #include "RenderGraphEvent.h" #define LOCTEXT_NAMESPACE "FNeuralPostProcessingModule" #if WITH_EDITOR DEFINE_LOG_CATEGORY(LogNeuralPostProcessing); #endif using namespace NeuralPostProcessng; namespace { TAutoConsoleVariable CVarNeuralPostProcessApply( TEXT("r.Neuralpostprocess.Apply"), 1, TEXT(" 0: disabled\n") TEXT(" 1: enabled (default)"), ECVF_RenderThreadSafe | ECVF_Scalability); int32 GNeuralPostProcessTileOverlap = -1; FAutoConsoleVariableRef CVarNeuralPostProcessTileOverlap( TEXT("r.Neuralpostprocess.TileOverlap"), GNeuralPostProcessTileOverlap, TEXT(" <0: Use the overlap from the profile\n"), ECVF_RenderThreadSafe | ECVF_Scalability); TAutoConsoleVariable CVarNeuralPostProcessTileOverlapResolveType( TEXT("r.Neuralpostprocess.TileOverlap.ResolveType"), -1, TEXT(" 0: Ignore overlapped region when concatinating the resolved subframe.\n") TEXT(" 1: Feathering overlapping regions based on bilinear filtering. Works best when Input and Output dimension are the same\n") TEXT(" <0: Use the ResolveType from profile."), ECVF_RenderThreadSafe | ECVF_Scalability); TAutoConsoleVariable CVarNeuralPostProcessTileOverlapVisualize( TEXT("r.Neuralpostprocess.TileOverlap.Visualize"), 0, TEXT(" Visualize the overlap region if not zero\n"), ECVF_RenderThreadSafe | ECVF_Scalability); TAutoConsoleVariable CVarNeuralPostProcessTileOverlapVisualizeIntensity( TEXT("r.Neuralpostprocess.TileOverlap.Visualize.Intensity"), 1.0f, TEXT(" Adjust the intensity of the overlap for visualization\n"), ECVF_RenderThreadSafe | ECVF_Scalability); DECLARE_GPU_STAT(NeuralPostProcessing); } static FIntPoint ClampOverlap(FIntPoint TileOverlap,FIntPoint BufferSizeXY) { // Use TileOverlap from CVar if larger or equal than zero. int32 OverlapFromCVar = GNeuralPostProcessTileOverlap; bool ShouldOverrideOverlapWithCVar = OverlapFromCVar >= 0; if (ShouldOverrideOverlapWithCVar) { TileOverlap = FIntPoint(OverlapFromCVar); } // Max tile border is 1/4 of the dimension. FIntPoint MaxTileOverlap = BufferSizeXY / 4; TileOverlap.X = FMath::Clamp(TileOverlap.X, 0, MaxTileOverlap.X); TileOverlap.Y = FMath::Clamp(TileOverlap.Y, 0, MaxTileOverlap.Y); return TileOverlap; } /** * =============================================================================================================================== FNeuralPostProcessModelInstanceManager: Hold the mapping between neural profile and neural network models. /|\ /|\ | USE | | | |-ENGINE_API FNeuralProfileManager---| | |(INeuralProfileManager) | | | | |----RENDERING_API FNeuralPostProcess----| | | | (INeuralPostProcessInterface) | | Manage the neural profile | | | | | | | |------------------------------------| | Interface accessed by RDGBuilder | | in post process material. | |----------------------------------------| | AllocateBuffer() | | Apply(): apply neural networks | =============================================================================================================================== */ class FNeuralPostProcessModelInstanceManager { public: static FNeuralPostProcessModelInstanceManager* Get() { static FNeuralPostProcessModelInstanceManager Manager; return &Manager; } UNeuralPostProcessModelInstance* GetModelInstance(int32 ProfileId) { if (NeuralPostProcessModelInstances.Contains(ProfileId)) { return NeuralPostProcessModelInstances[ProfileId]; } else { return nullptr; } } void SetModelInstance(int32 ProfileId, UNeuralPostProcessModelInstance* Instance) { NeuralPostProcessModelInstances.Emplace(ProfileId, Instance); } UNeuralPostProcessModelInstance* Remove(int32 ProfileId) { return NeuralPostProcessModelInstances.FindAndRemoveChecked(ProfileId); } ~FNeuralPostProcessModelInstanceManager() { NeuralPostProcessModelInstances.Reset(); } public: TMap NeuralPostProcessModelInstances; }; class FNeuralProfileManager : public NeuralProfile::INeuralProfileManager { public: FNeuralProfileManager() {} virtual void UpdateModel(int32 AllocationId, UObject* NNEModelData, FString RuntimeName) override { check(IsInRenderingThread()) UNNEModelData* RawNNEModelData = Cast(NNEModelData); if (!IsValid(RawNNEModelData)) { #if WITH_EDITOR UE_LOG(LogNeuralPostProcessing, Error, TEXT("NNEModelData is invalid at Slot %d."), AllocationId); #endif FNeuralPostProcessModelInstanceManager::Get()->SetModelInstance(AllocationId, nullptr); return; } UNeuralPostProcessModelInstance* ModelInstance = NewObject(); ModelInstance->Update(RawNNEModelData, RuntimeName); FNeuralPostProcessModelInstanceManager::Get()->SetModelInstance(AllocationId, ModelInstance->IsValid()? ModelInstance : nullptr); } virtual void UpdateTileType(int32 AllocationId, ENeuralModelTileType ModelTileType) override { UNeuralPostProcessModelInstance* Instance = FNeuralPostProcessModelInstanceManager::Get()->GetModelInstance(AllocationId); if (Instance) { Instance->UpdateModelTileType(ModelTileType); } } virtual bool UpdateBatchSize(int32 AllocationId, int32 BatchSize) override { UNeuralPostProcessModelInstance* Instance = FNeuralPostProcessModelInstanceManager::Get()->GetModelInstance(AllocationId); if (Instance) { return Instance->ModifyInputShape(0, BatchSize); } return false; } virtual void UpdateTileOverlap(int32 AllocationId, FIntPoint TileOverlap) override { UNeuralPostProcessModelInstance* Instance = FNeuralPostProcessModelInstanceManager::Get()->GetModelInstance(AllocationId); if (Instance) { UE::NNE::FTensorShape InputShape = Instance->GetResolvedInputTensorShape(); FIntPoint BufferSizeXY = FIntPoint((int32)InputShape.GetData()[3], (int32)InputShape.GetData()[2]); Instance->UpdateTileOverlap(ClampOverlap(TileOverlap, BufferSizeXY)); } } virtual void UpdateTileOverlapResolveType(int32 AllocationId, ETileOverlapResolveType TileOverlapResolveType) override { UNeuralPostProcessModelInstance* Instance = FNeuralPostProcessModelInstanceManager::Get()->GetModelInstance(AllocationId); if (Instance) { Instance->UpdateTileOverlapResolveType(TileOverlapResolveType); } } virtual void RemoveModel(int32 AllocationId) override { FNeuralPostProcessModelInstanceManager::Get()->Remove(AllocationId); } virtual FIntVector4 GetInputDimension(UObject* NNEModelData, FString RuntimeName) override { UE::NNE::FTensorShape TensorShape; FIntVector4 InputDimension = FIntVector4(-1, -1, -1, -1); if (UNNEModelData* ModelData = Cast(NNEModelData)) { // Need to create the ModelInstance in order to get the dimension in case the Model is not created // in the rendering thread. TSharedPtr ModelInstance = CreateNNEModelInstance(ModelData, RuntimeName); if (ModelInstance) { TConstArrayView InputTensorDescs = ModelInstance->GetInputTensorDescs(); UE::NNE::FSymbolicTensorShape InputTensorShape = InputTensorDescs[0].GetShape(); // Support only output dimension of rank 4 checkf(InputTensorShape.Rank() == 4, TEXT("Neural Post Processing requires models with input shape N x channel x height x width!")); InputDimension = FIntVector4( InputTensorShape.GetData()[0], InputTensorShape.GetData()[1], InputTensorShape.GetData()[2], InputTensorShape.GetData()[3]); } ModelInstance.Reset(); } return InputDimension; } virtual FIntVector4 GetOutputDimension(UObject* NNEModelData, FString RuntimeName) override { UE::NNE::FTensorShape TensorShape; FIntVector4 OutputDimension = FIntVector4(-1,-1,-1,-1); if (UNNEModelData* ModelData = Cast(NNEModelData)) { TSharedPtr ModelInstance = CreateNNEModelInstance(ModelData, RuntimeName); if (ModelInstance) { TConstArrayView OutputTensorDescs = ModelInstance->GetOutputTensorDescs(); UE::NNE::FSymbolicTensorShape OutputTensorShape = OutputTensorDescs[0].GetShape(); checkf(OutputTensorShape.Rank() == 4, TEXT("Neural Post Processing requires models with output shape N x channel x height x width!")); OutputDimension = FIntVector4( OutputTensorShape.GetData()[0], OutputTensorShape.GetData()[1], OutputTensorShape.GetData()[2], OutputTensorShape.GetData()[3]); } ModelInstance.Reset(); } return OutputDimension; } ~FNeuralProfileManager() {} }; int32 GetTotalModelTileCount(ENeuralModelTileType ModelTileSize) { int TotalTileCount = 1; //@TODO: uncomment and implement the corresponding function. switch (ModelTileSize) { case ENeuralModelTileType::TwoByTwo: TotalTileCount = 2 * 2; break; case ENeuralModelTileType::FourByFour: TotalTileCount = 4 * 4; break; case ENeuralModelTileType::EightByEight: TotalTileCount = 8 * 8; break; case ENeuralModelTileType::Auto: TotalTileCount = -1; // determined at runtime by the viewport size and W,H dimension of the network break; case ENeuralModelTileType::OneByOne: default: break; } return TotalTileCount; } static void AllocateInputBuffer_RenderingThread( FRDGBuilder& GraphBuilder, const FScreenPassTextureViewport& Viewport, int32 ProfileId, FRDGBufferRef& InputNeuralBuffer, FVector4f& InputBufferDimension) { check(IsInRenderingThread()) check(ProfileId >= 0); UNeuralPostProcessModelInstance* Model = FNeuralPostProcessModelInstanceManager::Get()->GetModelInstance(ProfileId); if (!IsValid(Model)) { InputNeuralBuffer = nullptr; return; } UE::NNE::FTensorShape InputShape = Model->GetResolvedInputTensorShape(); InputBufferDimension = FVector4f(InputShape.GetData()[0], InputShape.GetData()[1], InputShape.GetData()[2], InputShape.GetData()[3]); ENeuralModelTileType ModelTileType = Model->GetModelTileType(); FScreenPassTextureViewportParameters ViewportParameters = GetScreenPassTextureViewportParameters(Viewport); int32 BatchDim = InputBufferDimension.X; // Calculate the number of tiles int NumDispatches = 1; FIntPoint TileDimension = FIntPoint(1, 1); { if (ModelTileType == ENeuralModelTileType::Auto) { // Allocate the number of tiles based on the size of the viewport, the size of the buffer and tile configurations. FIntPoint InputBufferSizeXY = FIntPoint(InputBufferDimension.W, InputBufferDimension.Z); FIntPoint TileOvelapBorder = ClampOverlap(Model->GetTileOverlap(), InputBufferSizeXY); FIntPoint EffectiveBufferWH = InputBufferSizeXY - TileOvelapBorder * 2; FVector2f TileSizeWH = FMath::DivideAndRoundUp(ViewportParameters.ViewportSize.IntPoint(), EffectiveBufferWH); TileDimension = FIntPoint(TileSizeWH.X,TileSizeWH.Y); NumDispatches = FMath::DivideAndRoundUp(TileDimension.X * TileDimension.Y, BatchDim); } else { // Use fixed number of tiles int32 TileCount = GetTotalModelTileCount(ModelTileType); int32 SideLength = (int32)FMath::Sqrt((float)TileCount); TileDimension = FIntPoint(SideLength, SideLength); NumDispatches = FMath::DivideAndRoundUp(TileCount, BatchDim); } Model->UpdateDispatchSize(NumDispatches); Model->UpdateTileDimension(TileDimension); } // Create the buffer Model->CreateRDGBuffersIfNeeded(GraphBuilder, true); // Output the buffer and dimension for use. InputNeuralBuffer = Model->GetTiledInputBuffer(); InputBufferDimension.X *= NumDispatches; // Clear the input buffer at the border if we use ENeuralModelTileType::Auto // and the botom and right tiles partially covers the outside. if (ModelTileType == ENeuralModelTileType::Auto) { //@TODO: only clear part of the regions at borders. AddClearUAVPass(GraphBuilder,GraphBuilder.CreateUAV(InputNeuralBuffer,PF_R32_FLOAT),0); } } FCopyBetweenTextureAndOverlappedTileBufferCS::EOverlapResolveType GetOverlapResolveType(FIntPoint TileOverlap, ETileOverlapResolveType TileOverlapResolveType) { typedef FCopyBetweenTextureAndOverlappedTileBufferCS SHADER; SHADER::EOverlapResolveType OverlapResolveType = SHADER::EOverlapResolveType::Ignore; // Set TileOverlap Resolve type by profile config switch (TileOverlapResolveType) { case ETileOverlapResolveType::Feathering: OverlapResolveType = SHADER::EOverlapResolveType::Feathering; break; case ETileOverlapResolveType::Ignore: default: break; } // If both overlaps are zero, switch to Ignore for performance if (OverlapResolveType == SHADER::EOverlapResolveType::Feathering && (TileOverlap.X == 0 && TileOverlap.Y == 0)) { OverlapResolveType = SHADER::EOverlapResolveType::Ignore; } // Override by console variables int32 OverlapResolveTypeOverride = CVarNeuralPostProcessTileOverlapResolveType.GetValueOnRenderThread(); if (OverlapResolveTypeOverride == 0) { OverlapResolveType = SHADER::EOverlapResolveType::Ignore; } else if (OverlapResolveTypeOverride > 0) { OverlapResolveType = SHADER::EOverlapResolveType::Feathering; } return OverlapResolveType; } FIntPoint ScaleValueFromSourceToTargetSize(FIntPoint Value, FIntPoint SourceSize, FIntPoint TargetSize) { FVector2f ScaledValue = FVector2f(Value.X * TargetSize.X * 1.0f / SourceSize.X, Value.Y * TargetSize.Y * 1.0f / SourceSize.Y); return FIntPoint(FMath::CeilToInt(ScaledValue.X), FMath::CeilToInt(ScaledValue.Y)); } FIntPoint GetScaledTileOverlap(FIntPoint TileOverlap, FIntPoint SourceSize, FIntPoint TargetSize) { // Scaled overlap is designed to return integer. return ScaleValueFromSourceToTargetSize(TileOverlap, SourceSize, TargetSize); } static void ApplyNeuralNetworks_RenderingThread( FRDGBuilder& GraphBuilder, int32 ProfileId, FRDGTextureRef NeuralTexture, FIntRect ViewRect, FRDGBufferRef InputSourceType, FRDGBufferRef& OutputNeuralBuffer, FVector4f& BufferDimension) { check(IsInRenderingThread()) check(ProfileId >= 0); UNeuralPostProcessModelInstance* Model = FNeuralPostProcessModelInstanceManager::Get()->GetModelInstance(ProfileId); if (!IsValid(Model)) { return; } RDG_EVENT_SCOPE_STAT(GraphBuilder, NeuralPostProcessing, "NeuralPostProcessing"); RDG_GPU_STAT_SCOPE(GraphBuilder, NeuralPostProcessing); FIntPoint TextureSize = NeuralTexture->Desc.Extent; FGlobalShaderMap* GlobalShaderMap = GetGlobalShaderMap(GMaxRHIFeatureLevel); float Scale = 1.0; UE::NNE::FTensorShape InputShape = Model->GetResolvedInputTensorShape(); FIntPoint NeuralNetworkInputSize = { (int32)InputShape.GetData()[3], (int32)InputShape.GetData()[2] };// Width, height FIntPoint TileOverlap = ClampOverlap(Model->GetTileOverlap(), NeuralNetworkInputSize); FIntPoint TileDimension = Model->GetTileDimension(); const FScreenPassTextureViewport NeuralPostProcessViewport(NeuralTexture, ViewRect); const FScreenPassTextureViewportParameters InputViewportParameters = GetScreenPassTextureViewportParameters(NeuralPostProcessViewport); const ENeuralModelTileType ModelTileType = Model->GetModelTileType(); const bool bRequireInputBufferScale = ModelTileType != ENeuralModelTileType::Auto; const ETileOverlapResolveType TileOverlapResolveType = Model->GetTileOverlapResolveType(); struct FIntermediateTexture { FRDGTextureRef Texture; FIntPoint Extent; }; auto GetIntermediateTexture = [&](FIntPoint InNetworkSize, FIntPoint InOverLap, bool bRequireBufferScale, const TCHAR* Name)->FIntermediateTexture { FIntermediateTexture InputIntermediateTexture; if (bRequireBufferScale) { FIntPoint IntermediateTextureExtent = (InNetworkSize - InOverLap * 2) * TileDimension; FRDGTextureDesc IntermediateTextureDesc = NeuralTexture->Desc; IntermediateTextureDesc.Extent = IntermediateTextureExtent; // @TODO: conditionally create the intermediate buffer texture InputIntermediateTexture.Texture = GraphBuilder.CreateTexture(IntermediateTextureDesc, Name); InputIntermediateTexture.Extent = IntermediateTextureExtent; } else { InputIntermediateTexture.Texture = NeuralTexture; InputIntermediateTexture.Extent = InputViewportParameters.ViewportSize.IntPoint(); } return InputIntermediateTexture; }; // 1. Preprocess the input data by copying from Texture to buffer if required { // Build the indirect dispatch parameters FRDGBufferRef IndirectDispatchBuffer = GraphBuilder.CreateBuffer( FRDGBufferDesc::CreateIndirectDesc(1), TEXT("NeuralPostProcessing.IndirectDispatchBuffer")); FIntermediateTexture IntermediateTexture = GetIntermediateTexture(NeuralNetworkInputSize, TileOverlap, bRequireInputBufferScale, TEXT("NeuralPostProcessing.IntermediateBufferTexture")); // Downscale if (bRequireInputBufferScale) { // Build the indirect dispatch parameters with InputSourceType { typedef FNeuralPostProcessingBuildIndirectDispatchArgsCS ARGSETUPSHADER; ARGSETUPSHADER::FParameters* PassParameters = GraphBuilder.AllocParameters(); PassParameters->TargetDimension = IntermediateTexture.Extent; PassParameters->SourceType = GraphBuilder.CreateSRV(InputSourceType, EPixelFormat::PF_R32_UINT); PassParameters->RWIndirectDispatchArgsBuffer = GraphBuilder.CreateUAV(IndirectDispatchBuffer, EPixelFormat::PF_R32_UINT); TShaderMapRef ComputeShader(GlobalShaderMap); FComputeShaderUtils::AddPass(GraphBuilder, FRDGEventName(TEXT("NeuralPostProcessing::BuildIndirectArgs(Dispatch)")), ComputeShader, PassParameters, FIntVector(1, 1, 1)); } typedef FDownScaleTextureCS SHADER; SHADER::FParameters* PassParameters = GraphBuilder.AllocParameters(); { PassParameters->Source = GetNeuralPostProcessInput(NeuralTexture, InputViewportParameters); PassParameters->SourceTextureSampler = TStaticSamplerState::GetRHI(); PassParameters->TargetWidth = IntermediateTexture.Extent.X; PassParameters->TargetHeight = IntermediateTexture.Extent.Y; PassParameters->TargetTexture = GraphBuilder.CreateUAV(FRDGTextureUAVDesc(IntermediateTexture.Texture)); PassParameters->IndirectDispatchArgsBuffer = IndirectDispatchBuffer; } TShaderMapRef ComputeShader(GlobalShaderMap); FComputeShaderUtils::AddPass( GraphBuilder, RDG_EVENT_NAME("NeuralPostProcessing::DownScale (%dx%d -> %dx%d)", (int32)InputViewportParameters.ViewportSize.X, (int32)InputViewportParameters.ViewportSize.Y, IntermediateTexture.Extent.X, IntermediateTexture.Extent.Y), ERDGPassFlags::Compute, //| ERDGPassFlags::NeverCull ComputeShader, PassParameters, PassParameters->IndirectDispatchArgsBuffer, 0); } // Copy from Texture to Overlapped tile buffer { int32 NumOfChannel = (int32)InputShape.GetData()[1]; // Build the indirect dispatch parameters with InputSourceType { typedef FNeuralPostProcessingBuildIndirectDispatchArgsCS ARGSETUPSHADER; ARGSETUPSHADER::FParameters* PassParameters = GraphBuilder.AllocParameters(); PassParameters->TargetDimension = NeuralNetworkInputSize * TileDimension; PassParameters->SourceType = GraphBuilder.CreateSRV(InputSourceType, EPixelFormat::PF_R32_UINT); PassParameters->RWIndirectDispatchArgsBuffer = GraphBuilder.CreateUAV(IndirectDispatchBuffer, EPixelFormat::PF_R32_UINT); TShaderMapRef ComputeShader(GlobalShaderMap); FComputeShaderUtils::AddPass(GraphBuilder, FRDGEventName(TEXT("NeuralPostProcessing::BuildIndirectArgs(Dispatch)")), ComputeShader, PassParameters, FIntVector(1, 1, 1)); } typedef FCopyBetweenTextureAndOverlappedTileBufferCS SHADER; SHADER::FParameters* PassParameters = GraphBuilder.AllocParameters(); { PassParameters->SourceWidth = IntermediateTexture.Extent.X; PassParameters->SourceHeight = IntermediateTexture.Extent.Y; PassParameters->RWSourceTexture = GraphBuilder.CreateUAV(FRDGTextureUAVDesc(IntermediateTexture.Texture)); PassParameters->SourceTextureSampler = TStaticSamplerState::GetRHI(); PassParameters->TargetOverlappedTileWidth = NeuralNetworkInputSize.X; PassParameters->TargetOverlappedTileHeight = NeuralNetworkInputSize.Y; PassParameters->ViewTileDimension = TileDimension; PassParameters->TileOverlap = FVector2f(TileOverlap.X,TileOverlap.Y); PassParameters->NumOfChannel = NumOfChannel; PassParameters->TargetBuffer = GraphBuilder.CreateUAV(FRDGBufferUAVDesc(Model->GetTiledInputBuffer(), PF_R32_FLOAT)); PassParameters->IndirectDispatchArgsBuffer = IndirectDispatchBuffer; } SHADER::FPermutationDomain ComputeShaderPermutationVector; ComputeShaderPermutationVector.Set(SHADER::EDirection::ToOverlappedTiles); ComputeShaderPermutationVector.Set(SHADER::EOverlapResolveType::Ignore); TShaderMapRef ComputeShader(GlobalShaderMap,ComputeShaderPermutationVector); FComputeShaderUtils::AddPass( GraphBuilder, RDG_EVENT_NAME("NeuralPostProcessing::CopyToTiledBuffer (TileOverlap=%d,%d)",TileOverlap.X,TileOverlap.Y), ERDGPassFlags::Compute, ComputeShader, PassParameters, PassParameters->IndirectDispatchArgsBuffer, 0); } } // 2. Run the network const bool bNeuralPostProcessApply = CVarNeuralPostProcessApply.GetValueOnRenderThread() > 0; if (bNeuralPostProcessApply) { Model->Execute(GraphBuilder); } // 3. Pass the output from the neural network and fill the scene color texture // when batch and channel dimension matches between the input and output buffer. auto ShouldUpdateNeuralTexture = [&]()->bool { UE::NNE::FTensorShape OutputShape = Model->GetResolvedOutputTensorShape(); UE::NNE::FTensorShape InputShape = Model->GetResolvedInputTensorShape(); return OutputShape.GetData()[0] == InputShape.GetData()[0] && OutputShape.GetData()[1] == InputShape.GetData()[1] && Model->GetDispatchSize() >= 1; }; if (ShouldUpdateNeuralTexture()) { FRDGBufferRef OutputBuffer = bNeuralPostProcessApply ? Model->GetTiledOutputBuffer() : Model->GetTiledInputBuffer(); FRDGBufferUAVRef OutputBufferUAV = GraphBuilder.CreateUAV(FRDGBufferUAVDesc(OutputBuffer, PF_R32_FLOAT)); UE::NNE::FTensorShape OutputShape = bNeuralPostProcessApply ? Model->GetResolvedOutputTensorShape() : Model->GetResolvedInputTensorShape(); FIntPoint NeuralNetworkOutputSize = { (int32)OutputShape.GetData()[3], (int32)OutputShape.GetData()[2] }; FIntPoint ScaledTileOverlap = GetScaledTileOverlap(TileOverlap, NeuralNetworkInputSize, NeuralNetworkOutputSize); // Need to scale output buffer if the output network size is larger than that of the input of the network. const bool bRequireOutputBufferScale = ModelTileType != ENeuralModelTileType::Auto || (NeuralNetworkOutputSize.X != NeuralNetworkInputSize.X || NeuralNetworkOutputSize.Y != NeuralNetworkInputSize.Y); FIntermediateTexture IntermediateTexture = GetIntermediateTexture(NeuralNetworkOutputSize, ScaledTileOverlap, bRequireOutputBufferScale, TEXT("NeuralPostProcessing.IntermediateOutputBufferTexture")); // Copy from overlapped tile buffer to texture { int32 NumOfChannel = (int32)InputShape.GetData()[1]; typedef FCopyBetweenTextureAndOverlappedTileBufferCS SHADER; SHADER::FParameters* PassParameters = GraphBuilder.AllocParameters(); { PassParameters->SourceWidth = IntermediateTexture.Extent.X; PassParameters->SourceHeight = IntermediateTexture.Extent.Y; PassParameters->RWSourceTexture = GraphBuilder.CreateUAV(FRDGTextureUAVDesc(IntermediateTexture.Texture)); PassParameters->SourceTextureSampler = TStaticSamplerState::GetRHI(); PassParameters->TargetOverlappedTileWidth = NeuralNetworkOutputSize.X; PassParameters->TargetOverlappedTileHeight = NeuralNetworkOutputSize.Y; PassParameters->ViewTileDimension = TileDimension; PassParameters->TileOverlap = FVector2f(ScaledTileOverlap.X, ScaledTileOverlap.Y); PassParameters->NumOfChannel = NumOfChannel; PassParameters->bVisualizeOverlap = CVarNeuralPostProcessTileOverlapVisualize.GetValueOnRenderThread() != 0; PassParameters->OverlapVisualizeIntensity = FMath::Max(0.0001, CVarNeuralPostProcessTileOverlapVisualizeIntensity.GetValueOnRenderThread()); PassParameters->TargetBuffer = OutputBufferUAV; } SHADER::FPermutationDomain ComputeShaderPermutationVector; ComputeShaderPermutationVector.Set(SHADER::EDirection::FromOverlappedTiles); ComputeShaderPermutationVector.Set(GetOverlapResolveType(TileOverlap, TileOverlapResolveType)); TShaderMapRef ComputeShader(GlobalShaderMap, ComputeShaderPermutationVector); FComputeShaderUtils::AddPass( GraphBuilder, RDG_EVENT_NAME("NeuralPostProcessing::CopyFromTiledBuffer (TileOverlap=%d,%d)", TileOverlap.X,TileOverlap.Y), ERDGPassFlags::Compute, ComputeShader, PassParameters, FComputeShaderUtils::GetGroupCount(NeuralNetworkOutputSize * TileDimension, NEURAL_POST_PROCESSING_THREAD_GROUP_SIZE)); } // Scale and write to the target texture if (bRequireOutputBufferScale) { FIntPoint SourceExtent = IntermediateTexture.Extent; // Update the effective source width and height if we use Auto. if (ModelTileType == ENeuralModelTileType::Auto) { int ModX = InputViewportParameters.ViewportSize.IntPoint().X % (NeuralNetworkInputSize.X - 2 * TileOverlap.X); int ModY = InputViewportParameters.ViewportSize.IntPoint().Y % (NeuralNetworkInputSize.Y - 2 * TileOverlap.Y); FIntPoint ScaledSize = ScaleValueFromSourceToTargetSize(FIntPoint(ModX, ModY), NeuralNetworkInputSize, NeuralNetworkOutputSize); SourceExtent = (NeuralNetworkOutputSize - ScaledTileOverlap * 2 )* (TileDimension - 1) + ScaledSize; } const bool bUpScale = SourceExtent[0] < ViewRect.Size()[0] && SourceExtent[1] < ViewRect.Size()[1]; if (bUpScale) { typedef FUpscaleTexture SHADER; SHADER::FParameters* ProcessOutputParameters = GraphBuilder.AllocParameters(); ProcessOutputParameters->Source_Texture = IntermediateTexture.Texture; ProcessOutputParameters->ViewportSize = InputViewportParameters.ViewportSize.IntPoint(); ProcessOutputParameters->SourceWidth = SourceExtent.X; ProcessOutputParameters->SourceHeight = SourceExtent.Y; ProcessOutputParameters->RenderTargets[0] = FRenderTargetBinding(NeuralTexture, ERenderTargetLoadAction::ELoad); TShaderMapRef WriteOutputShader(GlobalShaderMap); FPixelShaderUtils::AddFullscreenPass( GraphBuilder, GlobalShaderMap, RDG_EVENT_NAME("NeuralPostProcessing::Upscale (%dx%d -> %dx%d)", SourceExtent.X, SourceExtent.Y, (int32)InputViewportParameters.ViewportSize.X, (int32)InputViewportParameters.ViewportSize.Y), WriteOutputShader, ProcessOutputParameters, ViewRect); } else // Down Scale { FScreenPassTextureViewport IntermediateTextureViewport(IntermediateTexture.Texture, FIntRect(0, 0, SourceExtent.X, SourceExtent.Y)); FScreenPassTextureViewportParameters IntermediateTextureViewportParameters = GetScreenPassTextureViewportParameters(IntermediateTextureViewport); typedef FDownScaleTexture SHADER; SHADER::FParameters* PassParameters = GraphBuilder.AllocParameters(); { PassParameters->Source = GetNeuralPostProcessInput(IntermediateTexture.Texture, IntermediateTextureViewportParameters); PassParameters->SourceTextureSampler = TStaticSamplerState::GetRHI(); PassParameters->TargetWidth = InputViewportParameters.ViewportSize.X; PassParameters->TargetHeight = InputViewportParameters.ViewportSize.Y; PassParameters->RenderTargets[0] = FRenderTargetBinding(NeuralTexture, ERenderTargetLoadAction::ELoad); } TShaderMapRef WriteOutputShader(GlobalShaderMap); FPixelShaderUtils::AddFullscreenPass( GraphBuilder, GlobalShaderMap, RDG_EVENT_NAME("NeuralPostProcessing::Downscale (%dx%d -> %dx%d)", SourceExtent.X, SourceExtent.Y, (int32)InputViewportParameters.ViewportSize.X, (int32)InputViewportParameters.ViewportSize.Y), WriteOutputShader, PassParameters, ViewRect); } } } // 4. Read back the buffer, so the user can directly decode in the post process material. OutputNeuralBuffer = bNeuralPostProcessApply ? Model->GetTiledOutputBuffer() : Model->GetTiledInputBuffer(); UE::NNE::FTensorShape OutputShape = bNeuralPostProcessApply ? Model->GetResolvedOutputTensorShape() : Model->GetResolvedInputTensorShape(); BufferDimension = FVector4f(OutputShape.GetData()[0], OutputShape.GetData()[1], OutputShape.GetData()[2], OutputShape.GetData()[3]); BufferDimension.X *= Model->GetDispatchSize(); } class FNeuralPostProcess : public INeuralPostProcessInterface { public: virtual void Apply(FRDGBuilder& GraphBuilder, int32 NeuralProfileId, FRDGTexture* NeuralTexture, FIntRect ViewRect, FRDGBufferRef InputSourceType, FRDGBufferRef& OutputNeuralBuffer, FVector4f& BufferDimension) override { ApplyNeuralNetworks_RenderingThread( GraphBuilder, NeuralProfileId, NeuralTexture, ViewRect, InputSourceType, OutputNeuralBuffer, BufferDimension); } virtual void AllocateBuffer(FRDGBuilder& GraphBuilder, const FScreenPassTextureViewport& Viewport, int32 NeuralProfileId, FRDGBufferRef& InputNeuralBuffer, FVector4f& InputBufferDimension) override { AllocateInputBuffer_RenderingThread( GraphBuilder, Viewport, NeuralProfileId, InputNeuralBuffer, InputBufferDimension); } }; void FNeuralPostProcessingModule::StartupModule() { #if WITH_EDITOR UE_LOG(LogNeuralPostProcessing, Log, TEXT("NeuralPostProcessing starting up")); #endif const TSharedPtr Plugin = IPluginManager::Get().FindPlugin(TEXT("NeuralRendering")); if (Plugin.IsValid()) { FString ModuleDir = Plugin->GetBaseDir() + TEXT("/Source/NeuralPostProcessing"); AddShaderSourceDirectoryMapping(TEXT("/NeuralRendering"), FPaths::Combine(ModuleDir, TEXT("Shaders"))); } else { #if WITH_EDITOR UE_LOG(LogNeuralPostProcessing, Error, TEXT("Shaders directory not added. Failed to find NeuralPostProcessing plugin")); #endif } GNeuralProfileManager.Reset(new FNeuralProfileManager()); GNeuralPostProcess.Reset(new FNeuralPostProcess()); } void FNeuralPostProcessingModule::ShutdownModule() { #if WITH_EDITOR UE_LOG(LogNeuralPostProcessing, Log, TEXT("NeuralPostProcessing function shutting down")); #endif GNeuralProfileManager.Reset(nullptr); GNeuralPostProcess.Reset(nullptr); } #undef LOCTEXT_NAMESPACE IMPLEMENT_MODULE(FNeuralPostProcessingModule, NeuralPostProcessing)