// Copyright Epic Games, Inc. All Rights Reserved. #include "/Engine/Public/Platform.ush" #define WORK_TYPE float #define BUFFER_TYPE float #define BUFFER_TO_WORK_TYPE(x) x #define WORK_TO_BUFFER_TYPE(x) x // Must correspond to EConvGroupSize defined in NNEHlslShadersConvCS.h #if GROUP_SIZE == 0 #define NUM_GROUP_THREADS 128 #elif GROUP_SIZE == 1 #define NUM_GROUP_THREADS 256 #elif GROUP_SIZE == 2 #define NUM_GROUP_THREADS 512 #endif Buffer X; // N x C x XD0 x ... x XDi Buffer W; // M x C/GROUPS x WD0 x ... x WDi OR WD0 x WD1 x C X M if WEIGHTS_TRANSPOSED (only supported for 4D) RWBuffer Y; // N x M x YD0 x ... x YDi Buffer B; // B0, ..., BM // x: Dilation0, ...., Dilationi // y: Strides0, ...., Stridesi // z: Pad0Begin, ...., PadiBegin int4 Dilation_Stride_XBlockStartOffset_DilationXBlockStride[MAX_NUM_DIMENSIONS]; // x: Element j contains PROD(GDj+1,..,GDi) with the last element containing 1, with GDi being the number of thread groups contained by a dimension i // y: Number of threads in each dimension of a group // z: Element j contains PROD(GTDj+1,..,GTDi) with the last element containing 1, with GTDi being the number of threads contained by a dimension i inside a group int4 GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[MAX_NUM_DIMENSIONS]; // x: YD0, ..., YDi // y: Element j contains PROD(YDj+1,..,YDi) with the last element containing 1 // z: XD0, ..., XDi // w: Element j contains PROD(XDj+1,..,XDi) with the last element containing 1 int4 YDimension_YMemoryStride_XDimension_XMemoryStride[MAX_NUM_DIMENSIONS]; // x: Number of elements in each dimension of a X block // y: The strides of the X block to be loaded // z: WD0, ..., WDi int4 XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[MAX_NUM_DIMENSIONS]; // x: 1/GroupStride // y: 1/GroupThreadStride // z: 1/XBlockStride float4 OneDiv_GroupStride_GroupThreadStride_XBlockStride[MAX_NUM_DIMENSIONS]; int NumWFeatures; // Number of features (aka output channels) int NumWChannels; // Number of input channels int YBatchStride; int YOutputKernelStride; int XBatchStride; // The number of elements in each bacth of X int XChannelStride; // The number of elements in each channel of X int XBlockSize; // The total number of elements in each block to be loaded int NumChannelBatches; // The number of iteration needed to cover all channels that needs to be processed. Equals to ceil(NumWChannels/NumChannelsPerBatch) int NumChannelsPerBatch; // The number of complete kernel channels that can be loaded in one batch int WOutputKernelStride; // The total number of W elements per output kernel and thus PROD(C/GROUPS, WD0, .., WDi) int WChannelBatchSize; // The total number of W elements inside a channel batch and thus the number of channels per batch times PROD(WD0, .., WDi) int WChannelSize; // Number of elements in a single channel of W and thus PROD(WDi) float GroupsDivM; // The number of groups divided by the number of output kernels used to compute the start channel offset groupshared WORK_TYPE SharedMemoryX[NUM_GROUP_THREADS<= YDimension_YMemoryStride_XDimension_XMemoryStride[i].x || YIndex < 0) { YIndex = -1; } else { YIndex += ThreadDimensionIndex * YDimension_YMemoryStride_XDimension_XMemoryStride[i].y; } } // Compute the memory indices in global X which the current thread has to load as well as the indices in shared X where to write int GlobalXIndex[0x1<= XBlockSize) { // Assign -1 for both to indicate that the element is not used and thus neither has to be read from global nor written to shared memory GlobalXIndex[j] = -1; SharedXIndex[j] = -1; continue; } // Turn the flat group index multiple into thread dimension indices inside the volume to be loaded by this group int TmpGlobalXIndex = BatchIndex * XBatchStride + InputChannelStartIndex * XChannelStride; int TmpSharedXIndex = 0; int TempReadIdx = ReadIdx; UNROLL for (int k = 0; k < NUM_DIMENSIONS; k++) { int DimensionIndex = (int)(OneDiv_GroupStride_GroupThreadStride_XBlockStride[k].z * (float)TempReadIdx); int DimensionIndexOffset = DimensionIndex * XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[k].y; TempReadIdx -= DimensionIndexOffset; // Update TmpGlobalXIndex only if positive (it is already outside the volume if negative) if (TmpGlobalXIndex >= 0) { // Check if the global X index is in range int XIndex = XBlockStartDimensionIndex[k] + DimensionIndex; if (XIndex < 0 || XIndex >= YDimension_YMemoryStride_XDimension_XMemoryStride[k].z) { TmpGlobalXIndex = -1; } else { TmpGlobalXIndex += XIndex * YDimension_YMemoryStride_XDimension_XMemoryStride[k].w; } } // Update TmpSharedXIndex in any case as a zero has to be written TmpSharedXIndex += DimensionIndexOffset; } // Assign the result, which will be -1 for the global and positive for the shared index if the element lies in the padding area // This indicates that the global element should not be loaded but a zero has to be written to shared memory GlobalXIndex[j] = TmpGlobalXIndex; SharedXIndex[j] = TmpSharedXIndex; // Increment the read idx by the number of threads inside the group to cover all elements ReadIdx += NUM_GROUP_THREADS; } // Write the zeros once to avoid rewriting the same value later UNROLL for (int l = 0; l < (int)(0x1<= 0) { SharedMemoryX[SharedXIndex[l]] = 0; } } // Iterate over all channel batches int WChannelOutputKernelOffset = OutputKernelIndex * WOutputKernelStride; int WChannelBatchOffset = 0; int WChannelIndex = 0; #if HAS_B == 0 WORK_TYPE Result = 0; #else WORK_TYPE Result = BUFFER_TO_WORK_TYPE(B[OutputKernelIndex]); #endif for (int m = 0; m < NumChannelBatches; m++) { // Load all channels of W belonging to this batch int WIndex = WChannelBatchOffset + GroupThreadID.x; if (GroupThreadID.x < WChannelBatchSize && WIndex < WOutputKernelStride) { #if WEIGHTS_TRANSPOSED int MCHWIndex = WChannelOutputKernelOffset + WIndex; int Whw = MCHWIndex % WChannelSize; int Wc = (MCHWIndex / WChannelSize) % NumWChannels; int Wm = (MCHWIndex / (NumWChannels * WChannelSize)); int HWCMIndex = Whw * NumWChannels * NumWFeatures + Wc * NumWFeatures + Wm; SharedMemoryW[GroupThreadID.x] = BUFFER_TO_WORK_TYPE(W[HWCMIndex]); #else SharedMemoryW[GroupThreadID.x] = BUFFER_TO_WORK_TYPE(W[WChannelOutputKernelOffset + WIndex]); #endif WChannelBatchOffset += WChannelBatchSize; } // Each channel batch iterates over all consecutive elements of the loaded kernel channels // This results in bank conflict free lookups, as all threads of the same warp read exactly the same element which is then broadcasted int SharedWIndex = 0; // Iterate over all channels in this batch for (int n = 0; n < NumChannelsPerBatch; n++) { // Check if the channel index is valid as channel batches may not be aligned with the effective number of channels if (WChannelIndex >= NumWChannels) { break; } WChannelIndex++; // Load the shared X UNROLL for (int o = 0; o < (int)(0x1<= 0 && SharedXIndex[o] >= 0) { // Load the global data into shared memory SharedMemoryX[SharedXIndex[o]] = BUFFER_TO_WORK_TYPE(X[GlobalXIndex[o]]); // Update the indices for the next iteration by adding the channel stride GlobalXIndex[o] += XChannelStride; } } // Sync the group before doing any computation GroupMemoryBarrierWithGroupSync(); // Initialize a dimension index array used to iterate through the kernel dimensions // Also initialize the start index in shared memory based on the result position int WDimensionIndex[NUM_DIMENSIONS]; int SharedXIndex = 0; UNROLL for (int p = 0; p < NUM_DIMENSIONS; p++) { WDimensionIndex[p] = 0; // The index of the first element (e.g. upper left) required by each output is offsetted by the kernel stride and then multiplied by the stride in memory SharedXIndex += GroupThreadDimensionIndex[p] * GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[p].w; } // Compute and accumulate the result for (int q = 0; q < WChannelSize; q++) { Result += SharedMemoryX[SharedXIndex] * SharedMemoryW[SharedWIndex]; // Advance the inner-most dimension SharedWIndex++; SharedXIndex += Dilation_Stride_XBlockStartOffset_DilationXBlockStride[NUM_DIMENSIONS-1].x; WDimensionIndex[NUM_DIMENSIONS-1]++; // Advance the other dimensions if necessary UNROLL for (int r = NUM_DIMENSIONS-1; r > 0; r--) { if(WDimensionIndex[r] >= XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[r].z) { // Adjust the SharedXIndex by resetting the current dimension and increase the next outer one SharedXIndex -= XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[r].w; SharedXIndex += Dilation_Stride_XBlockStartOffset_DilationXBlockStride[r-1].w; WDimensionIndex[r] = 0; WDimensionIndex[r-1]++; } else { break; } } } // Sync the group before continuing any load GroupMemoryBarrierWithGroupSync(); } } // Write the final result if (YIndex >= 0) { Y[YIndex] = WORK_TO_BUFFER_TYPE(Result); } } #endif