Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Source/NNEHlslShaders/Internal/NNEHlslShadersConvTransposeCS.h
2025-05-18 13:04:45 +08:00

109 lines
5.4 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "NNEHlslShadersBase.h"
#include "RenderGraphUtils.h"
#include "ShaderParameterUtils.h"
namespace UE::NNEHlslShaders::Internal
{
enum class EConvTransposeAlgorithm : uint8
{
SharedMemory = 0,
MAX
};
enum class EConvTransposeGroupSize : uint8
{
Size128 = 0,
Size256,
Size512,
MAX
};
enum class EConvTransposeAutoPad : uint8
{
NOTSET = 0,// Use pad values passed in the array
SAME_UPPER,// Auto-pad to match input and output shape with potetnial extra padding at the end
SAME_LOWER,// Auto-pad to match input and output shape with potetnial extra padding at the beginning
VALID,// Set all paddings to zero
MAX
};
class FConvTransposeConstants
{
public:
static const int32 MAX_NUM_DIMENSIONS{4};
static const int32 MIN_NUM_READS_PER_THREAD_POW2{1};
static const int32 MAX_NUM_READS_PER_THREAD_POW2{3};
};
class NNEHLSLSHADERS_API FConvTransposeCS : public FHlslShaderBase
{
DECLARE_GLOBAL_SHADER(FConvTransposeCS);
SHADER_USE_PARAMETER_STRUCT(FConvTransposeCS, FHlslShaderBase);
class FConvTransposeAlgorithm : SHADER_PERMUTATION_ENUM_CLASS("ALGORITHM", EConvTransposeAlgorithm);
class FConvTransposeGroupSize : SHADER_PERMUTATION_ENUM_CLASS("GROUP_SIZE", EConvTransposeGroupSize);
class FConvTransposeNumStackDimensions : SHADER_PERMUTATION_RANGE_INT("NUM_STACK_DIMENSIONS", 1, FConvTransposeConstants::MAX_NUM_DIMENSIONS);
class FConvTransposeNumReadsPerThread : SHADER_PERMUTATION_RANGE_INT("NUM_READS_PER_THREAD_POW2", FConvTransposeConstants::MIN_NUM_READS_PER_THREAD_POW2, FConvTransposeConstants::MAX_NUM_READS_PER_THREAD_POW2);
class FConvTransposeHasB : SHADER_PERMUTATION_BOOL("HAS_B");
using FPermutationDomain = TShaderPermutationDomain<FConvTransposeAlgorithm, FConvTransposeGroupSize, FConvTransposeNumStackDimensions, FConvTransposeNumReadsPerThread, FConvTransposeHasB>;
public:
BEGIN_SHADER_PARAMETER_STRUCT(FParameters, )
SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer<float>, X)
SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer<float>, W)
SHADER_PARAMETER_RDG_BUFFER_UAV(RWBuffer<float>, Y)
SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer<float>, B)
SHADER_PARAMETER_ARRAY(FIntVector4, Dilation_Stride_XBlockStartOffset_DilationXBlockStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS])
SHADER_PARAMETER_ARRAY(FIntVector4, GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS])
SHADER_PARAMETER_ARRAY(FIntVector4, YDimension_YMemoryStride_XDimension_XMemoryStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS])
SHADER_PARAMETER_ARRAY(FIntVector4, XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS])
SHADER_PARAMETER_ARRAY(FVector4f, OneDiv_GroupStride_GroupThreadStride_OneDivStride, [FConvTransposeConstants::MAX_NUM_DIMENSIONS])
SHADER_PARAMETER(int32, NumWChannels)
SHADER_PARAMETER(int32, NumOutChannelsDivGroup)
SHADER_PARAMETER(int32, YBatchStride)
SHADER_PARAMETER(int32, YOutputKernelStride)
SHADER_PARAMETER(int32, XBatchStride)
SHADER_PARAMETER(int32, XChannelStride)
SHADER_PARAMETER(int32, XBlockSize)
SHADER_PARAMETER(int32, NumChannelBatches)
SHADER_PARAMETER(int32, NumChannelsPerBatch)
SHADER_PARAMETER(int32, WOutputKernelStride)
SHADER_PARAMETER(int32, WChannelBatchSize)
SHADER_PARAMETER(int32, WChannelSize)
SHADER_PARAMETER(float, GroupsDivM)
SHADER_PARAMETER(float, OneDivGroup)
END_SHADER_PARAMETER_STRUCT()
static void ModifyCompilationEnvironment(const FGlobalShaderPermutationParameters& InParameters, FShaderCompilerEnvironment& OutEnvironment);
static TArray<int32> GetOutputShape(TArrayView<const uint32> XShape, TArrayView<const uint32> WShape, EConvTransposeAutoPad AutoPad, TArrayView<const int32> Dilations, TArrayView<const int32> Strides, TArrayView<const int32> Pads, TArrayView<const int32> OutputPadding, int32 Group);
static void FillInParameters(EConvTransposeGroupSize GroupSize, TArrayView<const uint32> XShape, TArrayView<const uint32> WShape, bool HasB, EConvTransposeAutoPad AutoPad, int32 Group, TArrayView<const int32> Dilations, TArrayView<const int32> Strides, TArrayView<const int32> Pads, TArrayView<const int32> OutputPadding, FConvTransposeCS::FParameters& Parameters);
static int32 GetNumReadsPerThread(EConvTransposeGroupSize GroupSize, TArrayView<const uint32> WShape, TArrayView<const int32> Dilations, TArrayView<const int32> Strides);
static TArray<int32> GetGroupShape(EConvTransposeGroupSize GroupSize, int32 NumDimensions);
static FIntVector GetGroupCount(TArrayView<const int32> YShape, TArrayView<const int32> GroupShape);
static EConvTransposeGroupSize GetMinimalGroupSize(TArrayView<const uint32> WShape);
static void LexFromString(EConvTransposeAutoPad& OutValue, const TCHAR* StringVal);
private:
static TArray<int32> GetXBlockShape(TArrayView<const int32> GroupShape, TArrayView<const uint32> WShape, TArrayView<const int32> Dilations, TArrayView<const int32> Strides);
static TArray<int32> GetPadding(TArrayView<const uint32> WShape, EConvTransposeAutoPad AutoPad, TArrayView<const int32> Dilations, TArrayView<const int32> Strides, TArrayView<const int32> Pads, TArrayView<const int32> OutputPadding);
static int32 GetNumThreadsPerGroup(EConvTransposeGroupSize GroupSize);
static TArray<int32> GetGridShape(TArrayView<const int32> YShape, TArrayView<const int32> GroupShape);
};
} // UE::NNEHlslShaders::Internal