Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Shaders/Private/NNEHlslShaders/NNEHlslShadersConvMatmul.usf
2025-05-18 13:04:45 +08:00

236 lines
8.1 KiB
HLSL

// Copyright Epic Games, Inc. All Rights Reserved.
#include "/Engine/Public/Platform.ush"
#define WORK_TYPE float
#define BUFFER_TYPE float
#define BUFFER_TO_WORK_TYPE(x) x
#define WORK_TO_BUFFER_TYPE(x) x
#define LDS_IC 16 //Input channel cached in LDS
#define TG_OC 32 //Number of output channel covered by a thread group
#define TG_OP 32 //Number of output pixel covered by a thread group
//To avoid bank conflict when writing out to DDR
#define LDS_PADX 1
#define LDS_PADY 2
#define LDS_X_PADDEDSIZE ((4 + LDS_PADX) * 8 + LDS_PADY) //42
groupshared WORK_TYPE LDSBuffer[32 * LDS_X_PADDEDSIZE]; //max (16*32+32,32*42) = 1344 WORK_TYPE
Buffer<BUFFER_TYPE> Input; // Ni x Ci x Hi x Wi
Buffer<BUFFER_TYPE> Weight; // Cw x Ci x Hw x Ww OR Hw x Ww x Ci x Cw if WEIGHTS_TRANSPOSED
Buffer<BUFFER_TYPE> Bias; // Cw
RWBuffer<BUFFER_TYPE> Output; // Ni x Cw x Ho x Wo
int Ci;
int Hi;
int Wi;
int Ho;
int Wo;
int Cw;
int Hw;
int Ww;
int PadLeft;
int PadTop;
int StrideH;
int StrideW;
WORK_TYPE ReadLDSWeights(int inputChannel, int outputChannel)
{
return LDSBuffer[LDS_IC * TG_OP + inputChannel * TG_OC + outputChannel];
}
void WriteLDSWeights(int inputChannel, int outputChannel, WORK_TYPE value)
{
LDSBuffer[LDS_IC * TG_OP + inputChannel * TG_OC + outputChannel] = value;
}
WORK_TYPE ReadLDSInputs(int inputChannel, int outputPixel)
{
return LDSBuffer[inputChannel * TG_OP + outputPixel];
}
void WriteLDSInputs(int inputChannel, int outputPixel, WORK_TYPE value)
{
LDSBuffer[inputChannel * TG_OP + outputPixel] = value;
}
[numthreads(8, 8, 1)]
void ConvMatmul(
in const uint3 DispatchThreadID : SV_DispatchThreadID,
in const uint3 GroupID : SV_GroupID,
in const uint3 GroupThreadID : SV_GroupThreadID,
in const uint GroupIndex : SV_GroupIndex)
{
const int DispatchThreadOutputPixelOffset = 4 * DispatchThreadID.x;
const int DispatchThreadOutputChannelOffset = 4 * DispatchThreadID.y;
const int GroupThreadOutputPixelOffset = 4 * GroupThreadID.x;
const int GroupThreadOutputChannelOffset = 4 * GroupThreadID.y;
const int GroupThreadIdx = GroupIndex;
const int Scalar_GroupOutputPixelOffset = TG_OP * GroupID.x;
const int Scalar_GroupOutputChannelOffset = TG_OC * GroupID.y;
const int Scalar_BatchOutputOffset = Ho * Wo * Cw * GroupID.z;
const int Scalar_BatchInputOffset = Hi * Wi * Ci * GroupID.z;
const int Scalar_GroupOutputPixelH = (Scalar_GroupOutputPixelOffset) / Wo * StrideH;
const int Scalar_GroupOutputPixelWBase = Scalar_GroupOutputPixelOffset % Wo * StrideW;
const int Scalar_GroupInputPixelKernelTopLeftH = Scalar_GroupOutputPixelH - PadTop;
const int Scalar_GroupInputPixelKernelTopLeftW = Scalar_GroupOutputPixelWBase - PadLeft;
const int cw = DispatchThreadOutputChannelOffset;
const int pi = DispatchThreadOutputPixelOffset;
#if HAS_BIAS
WORK_TYPE BiasC0 = BUFFER_TO_WORK_TYPE(Bias[cw+0]);
WORK_TYPE BiasC1 = BUFFER_TO_WORK_TYPE(Bias[cw+1]);
WORK_TYPE BiasC2 = BUFFER_TO_WORK_TYPE(Bias[cw+2]);
WORK_TYPE BiasC3 = BUFFER_TO_WORK_TYPE(Bias[cw+3]);
#else
WORK_TYPE BiasC0 = 0.0f;
WORK_TYPE BiasC1 = 0.0f;
WORK_TYPE BiasC2 = 0.0f;
WORK_TYPE BiasC3 = 0.0f;
#endif
int i;
// First index is thread output channel offset (on cw)
// 2nd index is thread pixels offset (on pi)
WORK_TYPE Values[4][4];
UNROLL
for (i = 0; i < 4; ++i)
{
Values[0][i] = BiasC0;
Values[1][i] = BiasC1;
Values[2][i] = BiasC2;
Values[3][i] = BiasC3;
}
int Scalar_KernelIdx = 0;
const int InputChannelOffset = GroupThreadIdx >= 32 ? LDS_IC/2 : 0;
const int OutputPixelOffset = GroupThreadIdx % 32;
const int OutputChannelOffset = GroupThreadIdx % 32;
int OutputChannel = Scalar_GroupOutputChannelOffset + OutputChannelOffset;
const bool IsValidOutputChannel = OutputChannel < Cw;
OutputChannel = min(OutputChannel, Cw - 1);
for (int hw = 0; hw < Hw; ++hw)
{
const int Scalar_InputPixelH = Scalar_GroupInputPixelKernelTopLeftH + hw;
const bool Scalar_IsValidInputH = (Scalar_InputPixelH >= 0) && (Scalar_InputPixelH < Hi);
for (int ww = 0; ww < Ww; ++ww)
{
const int InputPixelW = Scalar_GroupInputPixelKernelTopLeftW + ww + OutputPixelOffset * StrideW;
const bool IsValidInputPixel = Scalar_IsValidInputH && (InputPixelW >= 0) && (InputPixelW < Wi);
int InputPixelOffset = Scalar_InputPixelH * Wi + InputPixelW;
InputPixelOffset = clamp(InputPixelOffset, 0, Hi * Wi - 1);
const int ReadOffsetI = Scalar_BatchInputOffset + InputPixelOffset;
#if WEIGHTS_TRANSPOSED
const int Scalar_InputChannelStride = Cw;
const int ReadOffsetW = Scalar_KernelIdx + OutputChannel;
#else
const int Scalar_InputChannelStride = Hw * Ww;
const int ReadOffsetW = OutputChannel * Ci * Scalar_InputChannelStride + Scalar_KernelIdx;
#endif
for (int ci = 0; ci < Ci; ci += LDS_IC)
{
///— DDR → LDS
// We need to load 1024 floats
// 512 weights (16 inputs channel, 32 output channels) and
// 512 inputs (16 inputs channel, 32 inputs pixels)
// We have 64 threads thus each will read 8 inputs and 8 weights.
UNROLL
for (i = 0; i < (LDS_IC / 2); ++i)
{
int InputChannel = ci + InputChannelOffset + i;
const bool IsValidInputChannel = InputChannel < Ci;
InputChannel = min(InputChannel, Ci - 1);
WORK_TYPE ValueI = BUFFER_TO_WORK_TYPE(Input[Hi * Wi * InputChannel + ReadOffsetI]);
ValueI = IsValidInputChannel && IsValidInputPixel ? ValueI : 0.0f;
WriteLDSInputs(InputChannelOffset + i, OutputPixelOffset, ValueI);
WORK_TYPE ValueW = BUFFER_TO_WORK_TYPE(Weight[Scalar_InputChannelStride * InputChannel + ReadOffsetW]);
ValueW = IsValidInputChannel && IsValidOutputChannel ? ValueW : 0.0f;
WriteLDSWeights(InputChannelOffset + i, OutputChannelOffset, ValueW);
}
GroupMemoryBarrierWithGroupSync();
/// LDS to register + inner loop
// Loop on cached input channels
#define INNERLOOPUNROLLCOUNT 4
for (int cachedCiBase = 0; cachedCiBase < LDS_IC; cachedCiBase += INNERLOOPUNROLLCOUNT)
{
WORK_TYPE RegWeights[INNERLOOPUNROLLCOUNT][4];
WORK_TYPE RegInputs[INNERLOOPUNROLLCOUNT][4];
UNROLL
for (int unrolledIdx = 0; unrolledIdx < INNERLOOPUNROLLCOUNT; ++unrolledIdx)
{
// 8 load from LDS
UNROLL
for (i = 0; i < 4; ++i)
{
RegWeights[unrolledIdx][i] = ReadLDSWeights(cachedCiBase + unrolledIdx, GroupThreadOutputChannelOffset + i);
RegInputs[unrolledIdx][i] = ReadLDSInputs(cachedCiBase + unrolledIdx, GroupThreadOutputPixelOffset + i);
}
// Inner loop (16 mads)
UNROLL
for (i = 0; i < 4; ++i)
{
Values[i][0] += RegWeights[unrolledIdx][i] * RegInputs[unrolledIdx][0];
Values[i][1] += RegWeights[unrolledIdx][i] * RegInputs[unrolledIdx][1];
Values[i][2] += RegWeights[unrolledIdx][i] * RegInputs[unrolledIdx][2];
Values[i][3] += RegWeights[unrolledIdx][i] * RegInputs[unrolledIdx][3];
}
}
}
GroupMemoryBarrierWithGroupSync();
}
#if WEIGHTS_TRANSPOSED
Scalar_KernelIdx += Ci * Cw;
#else
++Scalar_KernelIdx;
#endif
}
}
/// Write results back to DDR (via LDS)
UNROLL
for (i = 0; i < 4; ++i)
{
const int LDSWriteOffsetY = 4 * GroupThreadID.y + i;
const int LDSWriteOffsetX = (LDS_PADX + 4) * GroupThreadID.x;
LDSBuffer[LDSWriteOffsetY * LDS_X_PADDEDSIZE + LDSWriteOffsetX + 0] = Values[i][0];
LDSBuffer[LDSWriteOffsetY * LDS_X_PADDEDSIZE + LDSWriteOffsetX + 1] = Values[i][1];
LDSBuffer[LDSWriteOffsetY * LDS_X_PADDEDSIZE + LDSWriteOffsetX + 2] = Values[i][2];
LDSBuffer[LDSWriteOffsetY * LDS_X_PADDEDSIZE + LDSWriteOffsetX + 3] = Values[i][3];
}
GroupMemoryBarrierWithGroupSync();
const int WriteOutChannelOffset = GroupThreadIdx >= 32 ? LDS_IC : 0;
const int WriteOutPixelOffset = GroupThreadIdx % 32;
const int LDSPixelOffset = (4 + LDS_PADX) * (WriteOutPixelOffset / 4) + WriteOutPixelOffset % 4;
UNROLL
for (i = 0; i < LDS_IC; ++i)
{
const int WriteOutChannel = Scalar_GroupOutputChannelOffset + i + WriteOutChannelOffset;
if (WriteOutChannel < Cw)
{
const int WriteOffset = Scalar_BatchOutputOffset + WriteOutChannel * Ho * Wo + Scalar_GroupOutputPixelOffset + WriteOutPixelOffset;
Output[WriteOffset] = LDSBuffer[(i + WriteOutChannelOffset) * LDS_X_PADDEDSIZE + LDSPixelOffset];
}
}
}