Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Shaders/Private/NNEHlslShaders/NNEHlslShadersConv.usf
2025-05-18 13:04:45 +08:00

304 lines
14 KiB
HLSL

// Copyright Epic Games, Inc. All Rights Reserved.
#include "/Engine/Public/Platform.ush"
#define WORK_TYPE float
#define BUFFER_TYPE float
#define BUFFER_TO_WORK_TYPE(x) x
#define WORK_TO_BUFFER_TYPE(x) x
// Must correspond to EConvGroupSize defined in NNEHlslShadersConvCS.h
#if GROUP_SIZE == 0
#define NUM_GROUP_THREADS 128
#elif GROUP_SIZE == 1
#define NUM_GROUP_THREADS 256
#elif GROUP_SIZE == 2
#define NUM_GROUP_THREADS 512
#endif
Buffer<BUFFER_TYPE> X; // N x C x XD0 x ... x XDi
Buffer<BUFFER_TYPE> W; // M x C/GROUPS x WD0 x ... x WDi OR WD0 x WD1 x C X M if WEIGHTS_TRANSPOSED (only supported for 4D)
RWBuffer<BUFFER_TYPE> Y; // N x M x YD0 x ... x YDi
Buffer<BUFFER_TYPE> B; // B0, ..., BM
// x: Dilation0, ...., Dilationi
// y: Strides0, ...., Stridesi
// z: Pad0Begin, ...., PadiBegin
int4 Dilation_Stride_XBlockStartOffset_DilationXBlockStride[MAX_NUM_DIMENSIONS];
// x: Element j contains PROD(GDj+1,..,GDi) with the last element containing 1, with GDi being the number of thread groups contained by a dimension i
// y: Number of threads in each dimension of a group
// z: Element j contains PROD(GTDj+1,..,GTDi) with the last element containing 1, with GTDi being the number of threads contained by a dimension i inside a group
int4 GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[MAX_NUM_DIMENSIONS];
// x: YD0, ..., YDi
// y: Element j contains PROD(YDj+1,..,YDi) with the last element containing 1
// z: XD0, ..., XDi
// w: Element j contains PROD(XDj+1,..,XDi) with the last element containing 1
int4 YDimension_YMemoryStride_XDimension_XMemoryStride[MAX_NUM_DIMENSIONS];
// x: Number of elements in each dimension of a X block
// y: The strides of the X block to be loaded
// z: WD0, ..., WDi
int4 XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[MAX_NUM_DIMENSIONS];
// x: 1/GroupStride
// y: 1/GroupThreadStride
// z: 1/XBlockStride
float4 OneDiv_GroupStride_GroupThreadStride_XBlockStride[MAX_NUM_DIMENSIONS];
int NumWFeatures; // Number of features (aka output channels)
int NumWChannels; // Number of input channels
int YBatchStride;
int YOutputKernelStride;
int XBatchStride; // The number of elements in each bacth of X
int XChannelStride; // The number of elements in each channel of X
int XBlockSize; // The total number of elements in each block to be loaded
int NumChannelBatches; // The number of iteration needed to cover all channels that needs to be processed. Equals to ceil(NumWChannels/NumChannelsPerBatch)
int NumChannelsPerBatch; // The number of complete kernel channels that can be loaded in one batch
int WOutputKernelStride; // The total number of W elements per output kernel and thus PROD(C/GROUPS, WD0, .., WDi)
int WChannelBatchSize; // The total number of W elements inside a channel batch and thus the number of channels per batch times PROD(WD0, .., WDi)
int WChannelSize; // Number of elements in a single channel of W and thus PROD(WDi)
float GroupsDivM; // The number of groups divided by the number of output kernels used to compute the start channel offset
groupshared WORK_TYPE SharedMemoryX[NUM_GROUP_THREADS<<NUM_READS_PER_THREAD_POW2];
groupshared WORK_TYPE SharedMemoryW[NUM_GROUP_THREADS];
// Must correspond to EConvAlgorithm defined in NNEHlslShadersConvCS.h
#if ALGORITHM < 1
[numthreads(NUM_GROUP_THREADS, 1, 1)]
void Conv(in const uint3 DispatchThreadID : SV_DispatchThreadID, in const uint3 GroupID : SV_GroupID, in const uint3 GroupThreadID : SV_GroupThreadID)
{
// Group Z dimension and Y dimension are used for batch and output kernel
// This works as each group only works on the inner dimensions (num threads in y and z set to one, group count to N and M respectively)
int BatchIndex = GroupID.z;
int OutputKernelIndex = GroupID.y;
// Input channels and output kernels are divided into groups
// Thus the channel offset for this thread's output kernel can be computed as follows:
// The output kernel index is divided by the number of output kernels per group giving the group index
// This is multiplied by the number of channels in W to get the start index
int InputChannelStartIndex = NumWChannels * (int)(GroupsDivM * (float)OutputKernelIndex);
// Compute the output memory index and dimension indices of this thread inside its group as well as the start dimension indices of the X block to load
int YIndex = BatchIndex * YBatchStride + OutputKernelIndex * YOutputKernelStride;
int GroupThreadDimensionIndex[NUM_DIMENSIONS];// This will be used later to lookup the shared W and X to compute the final result
int XBlockStartDimensionIndex[NUM_DIMENSIONS];
int GroupIndex = GroupID.x;
int GroupThreadIndex = GroupThreadID.x;
UNROLL
for (int i = 0; i < NUM_DIMENSIONS; i++)
{
// Divide by the number of groups contained by one index of this dimension to get the dimension index n being the nth group in this dimension
int GroupDimensionIndex = (int)(OneDiv_GroupStride_GroupThreadStride_XBlockStride[i].x * (float)GroupIndex);
GroupIndex -= GroupDimensionIndex * GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[i].x;
// Same here with number of threads contained inside the group giving the nth thread inside the group in this dimension
GroupThreadDimensionIndex[i] = (int)(OneDiv_GroupStride_GroupThreadStride_XBlockStride[i].y * (float)GroupThreadIndex);
GroupThreadIndex -= GroupThreadDimensionIndex[i] * GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[i].z;
// Compose the global index by the group index times group shape plus thread index inside the group
int ThreadDimensionIndex = GroupDimensionIndex * GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[i].y + GroupThreadDimensionIndex[i];
// Subtract the padding from the index of the first element of the group to get the start index of the block needed by this group
XBlockStartDimensionIndex[i] = GroupDimensionIndex * XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[i].x + Dilation_Stride_XBlockStartOffset_DilationXBlockStride[i].z;
// ThreadDimensionIndex is always positive but needs to be tested if it does not overflow
if (ThreadDimensionIndex >= YDimension_YMemoryStride_XDimension_XMemoryStride[i].x || YIndex < 0)
{
YIndex = -1;
}
else
{
YIndex += ThreadDimensionIndex * YDimension_YMemoryStride_XDimension_XMemoryStride[i].y;
}
}
// Compute the memory indices in global X which the current thread has to load as well as the indices in shared X where to write
int GlobalXIndex[0x1<<NUM_READS_PER_THREAD_POW2];
int SharedXIndex[0x1<<NUM_READS_PER_THREAD_POW2];
int ReadIdx = GroupThreadID.x;
UNROLL
for (int j = 0; j < (int)(0x1<<NUM_READS_PER_THREAD_POW2); j++)
{
// If the read index exceeded the number of elements in the X block further computation can be skipped
if (ReadIdx >= XBlockSize)
{
// Assign -1 for both to indicate that the element is not used and thus neither has to be read from global nor written to shared memory
GlobalXIndex[j] = -1;
SharedXIndex[j] = -1;
continue;
}
// Turn the flat group index multiple into thread dimension indices inside the volume to be loaded by this group
int TmpGlobalXIndex = BatchIndex * XBatchStride + InputChannelStartIndex * XChannelStride;
int TmpSharedXIndex = 0;
int TempReadIdx = ReadIdx;
UNROLL
for (int k = 0; k < NUM_DIMENSIONS; k++)
{
int DimensionIndex = (int)(OneDiv_GroupStride_GroupThreadStride_XBlockStride[k].z * (float)TempReadIdx);
int DimensionIndexOffset = DimensionIndex * XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[k].y;
TempReadIdx -= DimensionIndexOffset;
// Update TmpGlobalXIndex only if positive (it is already outside the volume if negative)
if (TmpGlobalXIndex >= 0)
{
// Check if the global X index is in range
int XIndex = XBlockStartDimensionIndex[k] + DimensionIndex;
if (XIndex < 0 || XIndex >= YDimension_YMemoryStride_XDimension_XMemoryStride[k].z)
{
TmpGlobalXIndex = -1;
}
else
{
TmpGlobalXIndex += XIndex * YDimension_YMemoryStride_XDimension_XMemoryStride[k].w;
}
}
// Update TmpSharedXIndex in any case as a zero has to be written
TmpSharedXIndex += DimensionIndexOffset;
}
// Assign the result, which will be -1 for the global and positive for the shared index if the element lies in the padding area
// This indicates that the global element should not be loaded but a zero has to be written to shared memory
GlobalXIndex[j] = TmpGlobalXIndex;
SharedXIndex[j] = TmpSharedXIndex;
// Increment the read idx by the number of threads inside the group to cover all elements
ReadIdx += NUM_GROUP_THREADS;
}
// Write the zeros once to avoid rewriting the same value later
UNROLL
for (int l = 0; l < (int)(0x1<<NUM_READS_PER_THREAD_POW2); l++)
{
// If the global index is out of bound (padding area) but the shared index is valid (inside the X block) the shared element needs initialization to zero
if (GlobalXIndex[l] < 0 && SharedXIndex[l] >= 0)
{
SharedMemoryX[SharedXIndex[l]] = 0;
}
}
// Iterate over all channel batches
int WChannelOutputKernelOffset = OutputKernelIndex * WOutputKernelStride;
int WChannelBatchOffset = 0;
int WChannelIndex = 0;
#if HAS_B == 0
WORK_TYPE Result = 0;
#else
WORK_TYPE Result = BUFFER_TO_WORK_TYPE(B[OutputKernelIndex]);
#endif
for (int m = 0; m < NumChannelBatches; m++)
{
// Load all channels of W belonging to this batch
int WIndex = WChannelBatchOffset + GroupThreadID.x;
if (GroupThreadID.x < WChannelBatchSize && WIndex < WOutputKernelStride)
{
#if WEIGHTS_TRANSPOSED
int MCHWIndex = WChannelOutputKernelOffset + WIndex;
int Whw = MCHWIndex % WChannelSize;
int Wc = (MCHWIndex / WChannelSize) % NumWChannels;
int Wm = (MCHWIndex / (NumWChannels * WChannelSize));
int HWCMIndex = Whw * NumWChannels * NumWFeatures + Wc * NumWFeatures + Wm;
SharedMemoryW[GroupThreadID.x] = BUFFER_TO_WORK_TYPE(W[HWCMIndex]);
#else
SharedMemoryW[GroupThreadID.x] = BUFFER_TO_WORK_TYPE(W[WChannelOutputKernelOffset + WIndex]);
#endif
WChannelBatchOffset += WChannelBatchSize;
}
// Each channel batch iterates over all consecutive elements of the loaded kernel channels
// This results in bank conflict free lookups, as all threads of the same warp read exactly the same element which is then broadcasted
int SharedWIndex = 0;
// Iterate over all channels in this batch
for (int n = 0; n < NumChannelsPerBatch; n++)
{
// Check if the channel index is valid as channel batches may not be aligned with the effective number of channels
if (WChannelIndex >= NumWChannels)
{
break;
}
WChannelIndex++;
// Load the shared X
UNROLL
for (int o = 0; o < (int)(0x1<<NUM_READS_PER_THREAD_POW2); o++)
{
// If the global index is negative, a zero has already been written, if shared index is negative there is no need to load anything
if (GlobalXIndex[o] >= 0 && SharedXIndex[o] >= 0)
{
// Load the global data into shared memory
SharedMemoryX[SharedXIndex[o]] = BUFFER_TO_WORK_TYPE(X[GlobalXIndex[o]]);
// Update the indices for the next iteration by adding the channel stride
GlobalXIndex[o] += XChannelStride;
}
}
// Sync the group before doing any computation
GroupMemoryBarrierWithGroupSync();
// Initialize a dimension index array used to iterate through the kernel dimensions
// Also initialize the start index in shared memory based on the result position
int WDimensionIndex[NUM_DIMENSIONS];
int SharedXIndex = 0;
UNROLL
for (int p = 0; p < NUM_DIMENSIONS; p++)
{
WDimensionIndex[p] = 0;
// The index of the first element (e.g. upper left) required by each output is offsetted by the kernel stride and then multiplied by the stride in memory
SharedXIndex += GroupThreadDimensionIndex[p] * GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[p].w;
}
// Compute and accumulate the result
for (int q = 0; q < WChannelSize; q++)
{
Result += SharedMemoryX[SharedXIndex] * SharedMemoryW[SharedWIndex];
// Advance the inner-most dimension
SharedWIndex++;
SharedXIndex += Dilation_Stride_XBlockStartOffset_DilationXBlockStride[NUM_DIMENSIONS-1].x;
WDimensionIndex[NUM_DIMENSIONS-1]++;
// Advance the other dimensions if necessary
UNROLL
for (int r = NUM_DIMENSIONS-1; r > 0; r--)
{
if(WDimensionIndex[r] >= XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[r].z)
{
// Adjust the SharedXIndex by resetting the current dimension and increase the next outer one
SharedXIndex -= XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[r].w;
SharedXIndex += Dilation_Stride_XBlockStartOffset_DilationXBlockStride[r-1].w;
WDimensionIndex[r] = 0;
WDimensionIndex[r-1]++;
}
else
{
break;
}
}
}
// Sync the group before continuing any load
GroupMemoryBarrierWithGroupSync();
}
}
// Write the final result
if (YIndex >= 0)
{
Y[YIndex] = WORK_TO_BUFFER_TYPE(Result);
}
}
#endif