// Copyright Epic Games, Inc. All Rights Reserved. #pragma once #include "NNEHlslShadersBase.h" #include "RenderGraphUtils.h" #include "ShaderParameterUtils.h" namespace UE::NNEHlslShaders::Internal { enum class EConvTransposeAlgorithm : uint8 { SharedMemory = 0, MAX }; enum class EConvTransposeGroupSize : uint8 { Size128 = 0, Size256, Size512, MAX }; enum class EConvTransposeAutoPad : uint8 { NOTSET = 0,// Use pad values passed in the array SAME_UPPER,// Auto-pad to match input and output shape with potetnial extra padding at the end SAME_LOWER,// Auto-pad to match input and output shape with potetnial extra padding at the beginning VALID,// Set all paddings to zero MAX }; class FConvTransposeConstants { public: static const int32 MAX_NUM_DIMENSIONS{4}; static const int32 MIN_NUM_READS_PER_THREAD_POW2{1}; static const int32 MAX_NUM_READS_PER_THREAD_POW2{3}; }; class NNEHLSLSHADERS_API FConvTransposeCS : public FHlslShaderBase { DECLARE_GLOBAL_SHADER(FConvTransposeCS); SHADER_USE_PARAMETER_STRUCT(FConvTransposeCS, FHlslShaderBase); class FConvTransposeAlgorithm : SHADER_PERMUTATION_ENUM_CLASS("ALGORITHM", EConvTransposeAlgorithm); class FConvTransposeGroupSize : SHADER_PERMUTATION_ENUM_CLASS("GROUP_SIZE", EConvTransposeGroupSize); class FConvTransposeNumStackDimensions : SHADER_PERMUTATION_RANGE_INT("NUM_STACK_DIMENSIONS", 1, FConvTransposeConstants::MAX_NUM_DIMENSIONS); class FConvTransposeNumReadsPerThread : SHADER_PERMUTATION_RANGE_INT("NUM_READS_PER_THREAD_POW2", FConvTransposeConstants::MIN_NUM_READS_PER_THREAD_POW2, FConvTransposeConstants::MAX_NUM_READS_PER_THREAD_POW2); class FConvTransposeHasB : SHADER_PERMUTATION_BOOL("HAS_B"); using FPermutationDomain = TShaderPermutationDomain; public: BEGIN_SHADER_PARAMETER_STRUCT(FParameters, ) SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer, X) SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer, W) SHADER_PARAMETER_RDG_BUFFER_UAV(RWBuffer, Y) SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer, B) SHADER_PARAMETER_ARRAY(FIntVector4, Dilation_Stride_XBlockStartOffset_DilationXBlockStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS]) SHADER_PARAMETER_ARRAY(FIntVector4, GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS]) SHADER_PARAMETER_ARRAY(FIntVector4, YDimension_YMemoryStride_XDimension_XMemoryStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS]) SHADER_PARAMETER_ARRAY(FIntVector4, XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS]) SHADER_PARAMETER_ARRAY(FVector4f, OneDiv_GroupStride_GroupThreadStride_OneDivStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS]) SHADER_PARAMETER(int32, NumWChannels) SHADER_PARAMETER(int32, NumOutChannelsDivGroup) SHADER_PARAMETER(int32, YBatchStride) SHADER_PARAMETER(int32, YOutputKernelStride) SHADER_PARAMETER(int32, XBatchStride) SHADER_PARAMETER(int32, XChannelStride) SHADER_PARAMETER(int32, XBlockSize) SHADER_PARAMETER(int32, NumChannelBatches) SHADER_PARAMETER(int32, NumChannelsPerBatch) SHADER_PARAMETER(int32, WOutputKernelStride) SHADER_PARAMETER(int32, WChannelBatchSize) SHADER_PARAMETER(int32, WChannelSize) SHADER_PARAMETER(float, GroupsDivM) SHADER_PARAMETER(float, OneDivGroup) END_SHADER_PARAMETER_STRUCT() static void ModifyCompilationEnvironment(const FGlobalShaderPermutationParameters& InParameters, FShaderCompilerEnvironment& OutEnvironment); static TArray GetOutputShape(TArrayView XShape, TArrayView WShape, EConvTransposeAutoPad AutoPad, TArrayView Dilations, TArrayView Strides, TArrayView Pads, TArrayView OutputPadding, int32 Group); static void FillInParameters(EConvTransposeGroupSize GroupSize, TArrayView XShape, TArrayView WShape, bool HasB, EConvTransposeAutoPad AutoPad, int32 Group, TArrayView Dilations, TArrayView Strides, TArrayView Pads, TArrayView OutputPadding, FConvTransposeCS::FParameters& Parameters); static int32 GetNumReadsPerThread(EConvTransposeGroupSize GroupSize, TArrayView WShape, TArrayView Dilations, TArrayView Strides); static TArray GetGroupShape(EConvTransposeGroupSize GroupSize, int32 NumDimensions); static FIntVector GetGroupCount(TArrayView YShape, TArrayView GroupShape); static EConvTransposeGroupSize GetMinimalGroupSize(TArrayView WShape); static void LexFromString(EConvTransposeAutoPad& OutValue, const TCHAR* StringVal); private: static TArray GetXBlockShape(TArrayView GroupShape, TArrayView WShape, TArrayView Dilations, TArrayView Strides); static TArray GetPadding(TArrayView WShape, EConvTransposeAutoPad AutoPad, TArrayView Dilations, TArrayView Strides, TArrayView Pads, TArrayView OutputPadding); static int32 GetNumThreadsPerGroup(EConvTransposeGroupSize GroupSize); static TArray GetGridShape(TArrayView YShape, TArrayView GroupShape); }; } // UE::NNEHlslShaders::Internal