448 lines
13 KiB
HLSL
448 lines
13 KiB
HLSL
// Copyright Epic Games, Inc. All Rights Reserved.
|
|
|
|
#include "/Engine/Public/Platform.ush"
|
|
|
|
#define WORK_TYPE float
|
|
#define BUFFER_TYPE float
|
|
#define READ(x) x
|
|
#define WRITE(x) x
|
|
|
|
WORK_TYPE Alpha;
|
|
WORK_TYPE Beta;
|
|
int TransA;
|
|
int TransB;
|
|
uint M;
|
|
uint N;
|
|
uint K;
|
|
uint MxK;
|
|
uint KxN;
|
|
uint MxN;
|
|
uint CWidth;
|
|
uint CHeight;
|
|
|
|
WORK_TYPE CScalar;
|
|
|
|
Buffer<BUFFER_TYPE> A;
|
|
Buffer<BUFFER_TYPE> B;
|
|
Buffer<BUFFER_TYPE> C;
|
|
RWBuffer<BUFFER_TYPE> Y;
|
|
|
|
// x: StackShapeA[0], ..., StackShapeA[NUM_STACK_DIMENSIONS - 1]
|
|
// y: StackShapeB[0], ..., StackShapeB[NUM_STACK_DIMENSIONS - 1]
|
|
// z: StackStrideA[0], ..., StackStrideA[NUM_STACK_DIMENSIONS - 1]
|
|
// w: StackStrideB[0], ..., StackStrideB[NUM_STACK_DIMENSIONS - 1]
|
|
uint4 StackShapeA_StackShapeB_StackStrideA_StackStrideB[MAX_NUM_STACK_DIMENSIONS];
|
|
|
|
uint3 GetMatrixStackOffsets(const uint3 GroupID)
|
|
{
|
|
// Must correspond to FGemmNumDimensions defined in NNEHlslShadersGemmCS.h
|
|
#if NUM_STACK_DIMENSIONS == 0
|
|
return uint3(0, 0, 0);
|
|
#else
|
|
uint stackCoordA[NUM_STACK_DIMENSIONS];
|
|
uint stackCoordB[NUM_STACK_DIMENSIONS];
|
|
|
|
uint stackStrideOutput[NUM_STACK_DIMENSIONS];
|
|
stackStrideOutput[NUM_STACK_DIMENSIONS - 1] = 1;
|
|
for (int i = NUM_STACK_DIMENSIONS - 2; i >= 0; i--)
|
|
{
|
|
uint maxDim = max(StackShapeA_StackShapeB_StackStrideA_StackStrideB[i + 1][0], StackShapeA_StackShapeB_StackStrideA_StackStrideB[i + 1][1]);
|
|
stackStrideOutput[i] = maxDim * stackStrideOutput[i + 1];
|
|
}
|
|
|
|
uint idx = GroupID.z;
|
|
for (int i = 0; i < NUM_STACK_DIMENSIONS; i++)
|
|
{
|
|
uint coord = idx / stackStrideOutput[i];
|
|
stackCoordA[i] = coord % StackShapeA_StackShapeB_StackStrideA_StackStrideB[i][0];
|
|
stackCoordB[i] = coord % StackShapeA_StackShapeB_StackStrideA_StackStrideB[i][1];
|
|
|
|
idx = idx % stackStrideOutput[i];
|
|
}
|
|
|
|
uint stackOffsetA = 0;
|
|
uint stackOffsetB = 0;
|
|
for (int i = 0; i < NUM_STACK_DIMENSIONS; i++)
|
|
{
|
|
stackOffsetA += stackCoordA[i] * StackShapeA_StackShapeB_StackStrideA_StackStrideB[i][2];
|
|
stackOffsetB += stackCoordB[i] * StackShapeA_StackShapeB_StackStrideA_StackStrideB[i][3];
|
|
}
|
|
|
|
return uint3(
|
|
stackOffsetA * MxK,
|
|
stackOffsetB * KxN,
|
|
GroupID.z * MxN);
|
|
#endif
|
|
}
|
|
|
|
WORK_TYPE GetBetaTimesC(const uint3 DispatchThreadID)
|
|
{
|
|
// Must correspond to EGemmCScalar defined in NNEHlslShadersGemmCS.h
|
|
#if C_SCALAR == 0
|
|
return Beta * READ(C[(DispatchThreadID.y % CHeight) * N + (DispatchThreadID.x % CWidth)]);
|
|
#elif C_SCALAR == 1
|
|
return Beta * CScalar;
|
|
#else
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
// Must correspond to EGemmAlgorithm defined in NNEHlslShadersGemmCS.h
|
|
#if ALGORITHM < 4
|
|
|
|
#if ALGORITHM == 0
|
|
[numthreads(8, 8, 1)]
|
|
#elif ALGORITHM == 1
|
|
[numthreads(16, 16, 1)]
|
|
#elif ALGORITHM == 2
|
|
[numthreads(32, 32, 1)]
|
|
#elif ALGORITHM == 3
|
|
[numthreads(256, 1, 1)]
|
|
#endif
|
|
void Gemm(in const uint3 DispatchThreadID : SV_DispatchThreadID, in const uint3 GroupID : SV_GroupID)
|
|
{
|
|
if (DispatchThreadID.y >= M || DispatchThreadID.x >= N)
|
|
{
|
|
return;
|
|
}
|
|
|
|
uint AIdx;
|
|
uint AIdxIncrement;
|
|
if (TransA == 0)
|
|
{
|
|
AIdx = DispatchThreadID.y * K; // Index of the first element in the row
|
|
AIdxIncrement = 1; // The loop goes over the whole row
|
|
}
|
|
else
|
|
{
|
|
AIdx = DispatchThreadID.y; // Index of the first element in the column
|
|
AIdxIncrement = M; // The loop goes over the column, thus skip a full row
|
|
}
|
|
|
|
uint BIdx;
|
|
uint BIdxIncrement;
|
|
if (TransB == 0)
|
|
{
|
|
BIdx = DispatchThreadID.x; // Index of the first element in the column
|
|
BIdxIncrement = N; // The loop goes over the column, thus skip a full row
|
|
}
|
|
else
|
|
{
|
|
BIdx = DispatchThreadID.x * K; // Index of the first element in the row
|
|
BIdxIncrement = 1; // The loop goes over the whole row
|
|
}
|
|
|
|
const uint3 StackOffsets = GetMatrixStackOffsets(GroupID);
|
|
AIdx += StackOffsets.x;
|
|
BIdx += StackOffsets.y;
|
|
|
|
WORK_TYPE Result = 0.0;
|
|
for (int k = 0; k < K; k++)
|
|
{
|
|
Result += READ(A[AIdx]) * READ(B[BIdx]);
|
|
AIdx += AIdxIncrement;
|
|
BIdx += BIdxIncrement;
|
|
}
|
|
|
|
Y[StackOffsets.z + DispatchThreadID.y * N + DispatchThreadID.x] = WRITE(Alpha * Result + GetBetaTimesC(DispatchThreadID));
|
|
}
|
|
|
|
#elif ALGORITHM < 7
|
|
|
|
#if ALGORITHM == 4
|
|
#define GROUP_SIZE 8
|
|
#elif ALGORITHM == 5
|
|
#define GROUP_SIZE 16
|
|
#elif ALGORITHM == 6
|
|
#define GROUP_SIZE 32
|
|
#endif
|
|
|
|
groupshared WORK_TYPE SharedMemoryA[GROUP_SIZE * GROUP_SIZE];
|
|
groupshared WORK_TYPE SharedMemoryB[GROUP_SIZE * GROUP_SIZE];
|
|
|
|
[numthreads(GROUP_SIZE, GROUP_SIZE, 1)]
|
|
void Gemm(in const uint3 DispatchThreadID : SV_DispatchThreadID, in const uint3 GroupThreadID : SV_GroupThreadID, in const uint3 GroupID : SV_GroupID)
|
|
{
|
|
// Calculate the number of steps each group needs to make to cover the full length of the matrix
|
|
uint NumGroupSteps = (uint)ceil((float)K / (float)GROUP_SIZE);
|
|
|
|
// Compute some variables required at multiple places
|
|
uint SharedMemoryLineStartIdx = GroupThreadID.y * GROUP_SIZE;
|
|
uint SharedMemoryIdx = SharedMemoryLineStartIdx + GroupThreadID.x;
|
|
|
|
// Compute the lookup start indices and the increment which needs to be applied for each sub block
|
|
uint AIdx;
|
|
uint AIdxIncrement;
|
|
if (TransA == 0)
|
|
{
|
|
AIdx = DispatchThreadID.y * K + GroupThreadID.x;
|
|
AIdxIncrement = GROUP_SIZE;
|
|
}
|
|
else
|
|
{
|
|
AIdx = GroupThreadID.x * M + DispatchThreadID.y;
|
|
AIdxIncrement = GROUP_SIZE * M;
|
|
}
|
|
|
|
uint BIdx;
|
|
uint BIdxIncrement;
|
|
if (TransB == 0)
|
|
{
|
|
BIdx = GroupThreadID.y * N + DispatchThreadID.x;
|
|
BIdxIncrement = GROUP_SIZE * N;
|
|
}
|
|
else
|
|
{
|
|
BIdx = DispatchThreadID.x * K + GroupThreadID.y;
|
|
BIdxIncrement = GROUP_SIZE;
|
|
}
|
|
|
|
const uint3 StackOffsets = GetMatrixStackOffsets(GroupID);
|
|
AIdx += StackOffsets.x;
|
|
BIdx += StackOffsets.y;
|
|
|
|
// Multiply and accumulate the blocks along one dimension of the matrices
|
|
WORK_TYPE Result = 0.0;
|
|
uint GroupIdx = 0;
|
|
for (int GroupStep = 0; GroupStep < NumGroupSteps; GroupStep++)
|
|
{
|
|
// Only sync and update inidces starting with the second iteration
|
|
if(GroupStep > 0)
|
|
{
|
|
// Increment the indices
|
|
GroupIdx += GROUP_SIZE;
|
|
AIdx += AIdxIncrement;
|
|
BIdx += BIdxIncrement;
|
|
|
|
// Make sure computation is done in the whole group before overwriting shared memory
|
|
GroupMemoryBarrierWithGroupSync();
|
|
}
|
|
|
|
// Load the shared memory after applying the transpose and set overflow elements to zero
|
|
WORK_TYPE Temp = 0;
|
|
if (DispatchThreadID.y < M && GroupIdx + GroupThreadID.x < K)
|
|
{
|
|
Temp = READ(A[AIdx]);
|
|
}
|
|
SharedMemoryA[SharedMemoryIdx] = Temp;
|
|
|
|
Temp = 0;
|
|
if (DispatchThreadID.x < N && GroupIdx + GroupThreadID.y < K)
|
|
{
|
|
Temp = READ(B[BIdx]);
|
|
}
|
|
SharedMemoryB[SharedMemoryIdx] = Temp;
|
|
|
|
// Make sure all memory has been loaded before running the multiplication
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Multiply the groups and accumulate the result
|
|
// This works also at matrix borders, as overflow elements are set to zero
|
|
uint SharedMemoryBIdx = GroupThreadID.x;
|
|
for (int kGroup = 0; kGroup < GROUP_SIZE; kGroup++)
|
|
{
|
|
Result += SharedMemoryA[SharedMemoryLineStartIdx + kGroup] * SharedMemoryB[SharedMemoryBIdx];
|
|
SharedMemoryBIdx += GROUP_SIZE;
|
|
}
|
|
}
|
|
|
|
// Only now return out of bound threads as they were needed to load shared memory above
|
|
if (DispatchThreadID.y >= M || DispatchThreadID.x >= N)
|
|
{
|
|
return;
|
|
}
|
|
|
|
Y[StackOffsets.z + DispatchThreadID.y * N + DispatchThreadID.x] = WRITE(Alpha * Result + GetBetaTimesC(DispatchThreadID));
|
|
}
|
|
|
|
#else
|
|
|
|
#if ALGORITHM == 7
|
|
#define GROUP_SIZE_X 16
|
|
#define GROUP_SIZE_Y 1
|
|
#define LOAD_PER_THREAD 16
|
|
#elif ALGORITHM == 8
|
|
#define GROUP_SIZE_X 16
|
|
#define GROUP_SIZE_Y 2
|
|
#define LOAD_PER_THREAD 8
|
|
#elif ALGORITHM == 9
|
|
#define GROUP_SIZE_X 32
|
|
#define GROUP_SIZE_Y 1
|
|
#define LOAD_PER_THREAD 32
|
|
#elif ALGORITHM == 10
|
|
#define GROUP_SIZE_X 32
|
|
#define GROUP_SIZE_Y 2
|
|
#define LOAD_PER_THREAD 16
|
|
#elif ALGORITHM == 11
|
|
#define GROUP_SIZE_X 32
|
|
#define GROUP_SIZE_Y 4
|
|
#define LOAD_PER_THREAD 8
|
|
#elif ALGORITHM == 12
|
|
#define GROUP_SIZE_X 64
|
|
#define GROUP_SIZE_Y 2
|
|
#define LOAD_PER_THREAD 32
|
|
#elif ALGORITHM == 13
|
|
#define GROUP_SIZE_X 64
|
|
#define GROUP_SIZE_Y 4
|
|
#define LOAD_PER_THREAD 16
|
|
#elif ALGORITHM == 14
|
|
#define GROUP_SIZE_X 64
|
|
#define GROUP_SIZE_Y 8
|
|
#define LOAD_PER_THREAD 8
|
|
#endif
|
|
|
|
// Use the larger of group sizes to define the shared memory size
|
|
// Threads will do multiple loads over the smaller dimension to fill the memory completely
|
|
groupshared WORK_TYPE SharedMemoryA[GROUP_SIZE_X * GROUP_SIZE_X];
|
|
groupshared WORK_TYPE SharedMemoryB[GROUP_SIZE_X * GROUP_SIZE_X];
|
|
|
|
[numthreads(GROUP_SIZE_X, GROUP_SIZE_Y, 1)]
|
|
void Gemm(in const uint3 DispatchThreadID : SV_DispatchThreadID, in const uint3 GroupThreadID : SV_GroupThreadID, in const uint3 GroupID : SV_GroupID)
|
|
{
|
|
// Initialize the result
|
|
WORK_TYPE Results[LOAD_PER_THREAD];
|
|
for (int LoadIdx = 0; LoadIdx < LOAD_PER_THREAD; LoadIdx++)
|
|
{
|
|
Results[LoadIdx] = 0.0f;
|
|
}
|
|
|
|
// Calculate the number of steps each group needs to make to cover the full length of the matrix
|
|
uint NumGroupSteps = (uint)ceil((float)K / (float)GROUP_SIZE_X);
|
|
|
|
// Compute some variables required at multiple places
|
|
uint SharedMemoryTransposedLineIdx = GroupThreadID.x * GROUP_SIZE_X;
|
|
uint SharedMemoryStartIdx = GroupThreadID.y * GROUP_SIZE_X + GroupThreadID.x;
|
|
uint SharedIdxIncrement = GROUP_SIZE_Y * GROUP_SIZE_X;
|
|
uint GroupStartRow = GroupID.y * GROUP_SIZE_X;
|
|
uint GroupLoadOffset = GroupThreadID.y * LOAD_PER_THREAD;
|
|
|
|
// Compute the lookup start indices and the increment which needs to be applied for each sub block
|
|
uint AIdx;
|
|
uint AIdxOuterIncrement; // Increment added in the group loop
|
|
uint AIdxInnerIncrement; // Increment added in the load loop
|
|
if (TransA == 0)
|
|
{
|
|
AIdx = (GroupStartRow + GroupThreadID.y) * K + GroupThreadID.x;
|
|
AIdxOuterIncrement = GROUP_SIZE_X;
|
|
AIdxInnerIncrement = GROUP_SIZE_Y * K;
|
|
}
|
|
else
|
|
{
|
|
AIdx = GroupThreadID.x * M + (GroupStartRow + GroupThreadID.y);
|
|
AIdxOuterIncrement = GROUP_SIZE_X * M;
|
|
AIdxInnerIncrement = GROUP_SIZE_Y;
|
|
}
|
|
|
|
uint BIdx;
|
|
uint BIdxOuterIncrement; // Increment added in the group loop
|
|
uint BIdxInnerIncrement; // Increment added in the load loop
|
|
if (TransB == 0)
|
|
{
|
|
BIdx = GroupThreadID.y * N + DispatchThreadID.x;
|
|
BIdxOuterIncrement = GROUP_SIZE_X * N;
|
|
BIdxInnerIncrement = GROUP_SIZE_Y * N;
|
|
}
|
|
else
|
|
{
|
|
BIdx = DispatchThreadID.x * K + GroupThreadID.y;
|
|
BIdxOuterIncrement = GROUP_SIZE_X;
|
|
BIdxInnerIncrement = GROUP_SIZE_Y;
|
|
}
|
|
|
|
const uint3 StackOffsets = GetMatrixStackOffsets(GroupID);
|
|
AIdx += StackOffsets.x;
|
|
BIdx += StackOffsets.y;
|
|
|
|
// Multiply and accumulate the blocks
|
|
uint GroupIdx = 0;
|
|
for (uint GroupStep = 0; GroupStep < NumGroupSteps; GroupStep++)
|
|
{
|
|
// Prepare some shared variables needed inside this loop
|
|
uint SharedIdx = SharedMemoryStartIdx;
|
|
uint AIdxOffset = 0;
|
|
uint BIdxOffset = 0;
|
|
|
|
uint GroupRow = GroupThreadID.y;
|
|
uint ThreadIDY = GroupStartRow + GroupRow;
|
|
|
|
// Only sync and update inidces starting with the second iteration
|
|
if (GroupStep > 0)
|
|
{
|
|
// increment indices
|
|
GroupIdx += GROUP_SIZE_X;
|
|
AIdx += AIdxOuterIncrement;
|
|
BIdx += BIdxOuterIncrement;
|
|
|
|
// Make sure computation is done in the whole group before overwriting shared memory
|
|
GroupMemoryBarrierWithGroupSync();
|
|
}
|
|
|
|
for (int LoadIdx = 0; LoadIdx < LOAD_PER_THREAD; LoadIdx++)
|
|
{
|
|
// Load the shared memory after applying the transpose and set overflow elements to zero
|
|
WORK_TYPE Temp = 0;
|
|
if (ThreadIDY < M && GroupIdx + GroupThreadID.x < K)
|
|
{
|
|
Temp = READ(A[AIdx + AIdxOffset]);
|
|
}
|
|
SharedMemoryA[SharedIdx] = Temp;
|
|
|
|
Temp = 0;
|
|
if (GroupIdx + GroupRow < K && DispatchThreadID.x < N)
|
|
{
|
|
Temp = READ(B[BIdx + BIdxOffset]);
|
|
}
|
|
SharedMemoryB[SharedIdx] = Temp;
|
|
|
|
// Increment the thread y coordinate virtually by one group size
|
|
GroupRow += GROUP_SIZE_Y;
|
|
ThreadIDY += GROUP_SIZE_Y;
|
|
|
|
// Increment the shared index pointer to advance to the next subblock
|
|
SharedIdx += SharedIdxIncrement;
|
|
AIdxOffset += AIdxInnerIncrement;
|
|
BIdxOffset += BIdxInnerIncrement;
|
|
}
|
|
|
|
// Make sure all memory has been loaded before running the multiplication
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// multiply and accumulate
|
|
uint StartB = GroupLoadOffset;
|
|
for (int kGroup = 0; kGroup < GROUP_SIZE_X; kGroup++)
|
|
{
|
|
WORK_TYPE AVal = SharedMemoryA[SharedMemoryTransposedLineIdx + kGroup];
|
|
for (int LoadIdx = 0; LoadIdx < LOAD_PER_THREAD; LoadIdx++)
|
|
{
|
|
Results[LoadIdx] += AVal * SharedMemoryB[StartB + LoadIdx];
|
|
}
|
|
StartB += GROUP_SIZE_X;
|
|
}
|
|
}
|
|
|
|
// Only now return out of bound threads as they were needed to load shared memory above
|
|
uint Row = GroupStartRow + GroupThreadID.x;
|
|
if (Row >= M)
|
|
{
|
|
return;
|
|
}
|
|
|
|
// Save back the results
|
|
uint RowIdx = Row * N;
|
|
uint ColumnStart = GroupID.x * GROUP_SIZE_X + GroupLoadOffset;
|
|
for (int LoadIdx = 0; LoadIdx < LOAD_PER_THREAD; LoadIdx++)
|
|
{
|
|
uint Column = ColumnStart + LoadIdx;
|
|
|
|
// Return as soon as the column is out of bound as further loops will be outside too
|
|
if (Column >= N)
|
|
{
|
|
return;
|
|
}
|
|
|
|
Y[StackOffsets.z + RowIdx + Column] = WRITE(Alpha * Results[LoadIdx] + GetBetaTimesC(uint3(Column, Row, 0)));
|
|
}
|
|
}
|
|
|
|
#endif
|