Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Shaders/Private/NNEHlslShaders/NNEHlslShadersConvWinogradMMM.usf
2025-05-18 13:04:45 +08:00

148 lines
5.2 KiB
HLSL

// Copyright Epic Games, Inc. All Rights Reserved.
#include "/Engine/Public/Platform.ush"
#define K_STEP 16
#define BLOCK_ELEM_COUNT_M 64
#define VECTOR_SIZE 2
#define BLOCK_VECTOR_ELEM_COUNT_M (BLOCK_ELEM_COUNT_M / VECTOR_SIZE)
#define BLOCK_VECTOR_ELEM_COUNT_N (BLOCK_ELEM_COUNT_N / VECTOR_SIZE)
#define BLOCK_COUNT_M (BLOCK_ELEM_COUNT_M / 4)
#define BLOCK_COUNT_N (BLOCK_ELEM_COUNT_N / 4)
#define BLOCK_COUNT (BLOCK_COUNT_M * BLOCK_COUNT_N)
#define THREADGROUP_SIZE_X BLOCK_VECTOR_ELEM_COUNT_M
#define THREADGROUP_SIZE_Y (BLOCK_COUNT / THREADGROUP_SIZE_X)
#define THREADGROUP_SIZE (THREADGROUP_SIZE_X * THREADGROUP_SIZE_Y)
// Winograd does not compile on D3D12 SM6 + bindless if using vector<x,y> atm
// hence the define is directly set to the vectorized type in ModifyCompilationEnvironment()
// rather than defined as a templated vector here
//#define WORK_TYPE_VECTOR vector <WORK_TYPE, VECTOR_SIZE>
Buffer<WORK_TYPE_VECTOR> Input; // Ni x 36 x K x M
Buffer<WORK_TYPE_VECTOR> Weight; // 36 x K x N
RWBuffer<WORK_TYPE_VECTOR> Output; // Ni x 36 x N x M
groupshared WORK_TYPE_VECTOR LDS_InputOutput[max(K_STEP, BLOCK_ELEM_COUNT_N)][BLOCK_VECTOR_ELEM_COUNT_M];
groupshared WORK_TYPE_VECTOR LDS_Weight[K_STEP][BLOCK_VECTOR_ELEM_COUNT_N];
int M;
int N;
int K;
int MatrixInputStride;
int KInputStride;
int MatrixWeightStride;
int KWeightStride;
int MatrixOutputStride;
int NOutputStride;
[numthreads(THREADGROUP_SIZE_X, THREADGROUP_SIZE_Y, 1)]
void ConvWinogradMMM(
in const uint3 DispatchThreadID : SV_DispatchThreadID,
in const uint3 GroupID : SV_GroupID,
in const uint3 GroupThreadID : SV_GroupThreadID,
in const uint GroupIndex : SV_GroupIndex)
{
const int Scalar_NBase = GroupID.y * BLOCK_VECTOR_ELEM_COUNT_N;
const int Scalar_InputMatrix = GroupID.z * MatrixInputStride;
const int Scalar_WeightMatrix = (GroupID.z % 36) * MatrixWeightStride;
const int Scalar_OutputMatrix = GroupID.z * MatrixOutputStride;
const int Scalar_InputIndexMax = Scalar_InputMatrix + MatrixInputStride - 1;
const int Scalar_WeightIndexMax = Scalar_WeightMatrix + MatrixWeightStride - 1;
const int Scalar_OutputIndexMax = Scalar_OutputMatrix + MatrixOutputStride - 1;
const int MBlock = GroupIndex % BLOCK_COUNT_M;
const int MBlockBase = MBlock * (4 / VECTOR_SIZE);
const int NBlock = GroupIndex / BLOCK_COUNT_M;
const int NBlockBase = NBlock * (4 / VECTOR_SIZE);
matrix < WORK_TYPE, 4, 4 > Values = 0;
const int MInputIndex = DispatchThreadID.x;
int InputIndex = Scalar_InputMatrix;
InputIndex += GroupThreadID.y * KInputStride;
InputIndex += MInputIndex < M ? MInputIndex : MatrixInputStride;
const int NWeightIndex = Scalar_NBase + GroupIndex % BLOCK_VECTOR_ELEM_COUNT_N;
int WeightIndex = Scalar_WeightMatrix;
WeightIndex += (GroupIndex / BLOCK_VECTOR_ELEM_COUNT_N) * KWeightStride;
WeightIndex += NWeightIndex < N ? NWeightIndex : MatrixWeightStride;
for (int KBase = 0; KBase < K; KBase += K_STEP)
{
//Load Input and Weights
UNROLL
for (int KOffset = 0; KOffset < K_STEP; KOffset += THREADGROUP_SIZE_Y)
{
const int InputIndexClamped = min(InputIndex, Scalar_InputIndexMax);
WORK_TYPE_VECTOR InputValue = Input[InputIndexClamped];
InputValue = InputIndex <= Scalar_InputIndexMax ? InputValue : 0;
LDS_InputOutput[KOffset + GroupThreadID.y][GroupThreadID.x] = InputValue;
InputIndex += THREADGROUP_SIZE_Y * KInputStride;
}
UNROLL
for (int KOffset = 0; KOffset < K_STEP; KOffset += (THREADGROUP_SIZE / BLOCK_VECTOR_ELEM_COUNT_N))
{
const int WeightIndexClamped = min(WeightIndex, Scalar_WeightIndexMax);
WORK_TYPE_VECTOR WeightValue = Weight[WeightIndexClamped];
WeightValue = WeightIndex <= Scalar_WeightIndexMax ? WeightValue : 0;
LDS_Weight[KOffset + GroupIndex / BLOCK_VECTOR_ELEM_COUNT_N][GroupIndex % BLOCK_VECTOR_ELEM_COUNT_N] = WeightValue;
WeightIndex += (THREADGROUP_SIZE / BLOCK_VECTOR_ELEM_COUNT_N) * KWeightStride;
}
GroupMemoryBarrierWithGroupSync();
ISOLATE
{
//Multiply
UNROLL
for (int KIndex = 0; KIndex < K_STEP; KIndex++)
{
matrix < WORK_TYPE, 1, 4 > Reg_Input = { LDS_InputOutput[KIndex][MBlockBase], LDS_InputOutput[KIndex][MBlockBase + 1] };
matrix < WORK_TYPE, 4, 1 > Reg_Weight = { LDS_Weight[KIndex][NBlockBase], LDS_Weight[KIndex][NBlockBase + 1] };
Values += mul(Reg_Weight, Reg_Input);
}
}
GroupMemoryBarrierWithGroupSync();
}
// Move values from Reg to LDS
UNROLL
for (int NOffset = 0; NOffset < 4; NOffset++)
{
const int NLDSIndex = NBlock * 4 + NOffset;
LDS_InputOutput[NLDSIndex][MBlockBase + 0] = Values[NOffset].xy;
LDS_InputOutput[NLDSIndex][MBlockBase + 1] = Values[NOffset].zw;
}
GroupMemoryBarrierWithGroupSync();
// Move values from LDS to VRAM
const int MOutputIndex = DispatchThreadID.x;
int OutputIndex = Scalar_OutputMatrix;
OutputIndex += Scalar_NBase * VECTOR_SIZE * NOutputStride;
OutputIndex += GroupThreadID.y * NOutputStride;
OutputIndex += MOutputIndex < M ? MOutputIndex : MatrixOutputStride;
UNROLL
for (int NOffset = 0; NOffset < BLOCK_ELEM_COUNT_N; NOffset += THREADGROUP_SIZE_Y)
{
if (OutputIndex <= Scalar_OutputIndexMax)
{
Output[OutputIndex] = LDS_InputOutput[NOffset + GroupThreadID.y][GroupThreadID.x];
}
OutputIndex += THREADGROUP_SIZE_Y * NOutputStride;
}
}