Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Shaders/Private/NNEHlslShaders/NNEHlslShadersCumSum.usf
2025-05-18 13:04:45 +08:00

322 lines
11 KiB
HLSL

// Copyright Epic Games, Inc. All Rights Reserved.
//
// Implementation based on Single-pass Parallel Prefix Scan with Decoupled Look-back
// Original research paper: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
#include "/Engine/Public/Platform.ush"
#define WORK_TYPE float
#define BUFFER_TYPE float
#define READ(x) x
#define WRITE(x) x
uint NumThreadGroupsPerScan; // Each group's x-axis refers to a single prefix scan (cumulative sum)
uint NumThreadGroupsY; // Each group's y-index refers to pre-axis dimensions
uint NumThreadGroupsZ; // Each group's z-index refers to post-axis dimensions
#define STATUS_INVALID 0
#define STATUS_AGGREGATE_AVAILABLE 1
#define STATUS_PREFIX_AVAILABLE 2
/*
For each scan there is a different global partition index.
Each group in the same scan lock-step increments the global partition index.
The value before the increment is the partition index assigned to the thread group that did the increment.
*/
globallycoherent RWStructuredBuffer<uint> GlobalPartitionIndex;
/*
Elements to scan are partitioned. Each thread group computes the prefix scan of a partition.
Results to be propagated to the next groups are written in the partition descriptor in device memory.
It needs to be `globallycoherent` so that memory fences flush updates across groups.
There are as many descriptors as thread groups.
*/
struct FPartitionDescriptor
{
int StatusFlag;
WORK_TYPE Aggregate;
WORK_TYPE InclusivePrefix;
int PadToQWord;
};
globallycoherent RWStructuredBuffer<FPartitionDescriptor> PartitionDescriptor;
#ifdef INIT_SHADER
uint NumInitThreadGroups;
[numthreads(INIT_THREADGROUP_SIZE, 1, 1)]
void InitCumSum(uint3 DispatchThreadID : SV_DispatchThreadID)
{
const uint Increment = INIT_THREADGROUP_SIZE * NumInitThreadGroups;
for (uint Idx = DispatchThreadID.x; Idx < NumThreadGroupsPerScan * NumThreadGroupsY * NumThreadGroupsZ; Idx += Increment)
{
#if METAL_SM6_PROFILE
PartitionDescriptor[Idx].StatusFlag = STATUS_INVALID;
PartitionDescriptor[Idx].InclusivePrefix = (WORK_TYPE) 0;
PartitionDescriptor[Idx].Aggregate = (WORK_TYPE) 0;
PartitionDescriptor[Idx].PadToQWord = 0;
#else
FPartitionDescriptor InitPD;
InitPD.StatusFlag = STATUS_INVALID;
InitPD.InclusivePrefix = (WORK_TYPE) 0;
InitPD.Aggregate = (WORK_TYPE) 0;
InitPD.PadToQWord = 0;
PartitionDescriptor[Idx] = InitPD;
#endif
if (Idx % NumThreadGroupsPerScan == 0)
{
GlobalPartitionIndex[Idx / NumThreadGroupsPerScan] = 0u;
}
}
}
#else // !INIT_SHADER
Buffer<BUFFER_TYPE> Input;
RWBuffer<BUFFER_TYPE> Output;
#define STRIDE_IDX 0
#define MIN_WAVE_LANES 4U
#define PARTITION_SIZE (THREADGROUP_SIZE * VALUES_PER_THREAD)
#define NUM_WAVEAGGREGATES (THREADGROUP_SIZE / WaveGetLaneCount())
#define LAST_WAVEAGGREGATE_ID (NUM_WAVEAGGREGATES - 1)
uint NumScanValues;
uint Axis;
uint AxisStride;
/* Group shared memory variables */
groupshared uint GroupPartitionIndex;
groupshared WORK_TYPE GroupExclusivePrefix;
groupshared WORK_TYPE LocalPrefix[PARTITION_SIZE]; // Wave-local prefixes for each element of the partition (before adding the group's exclusive prefix)
groupshared WORK_TYPE WaveAggregate[THREADGROUP_SIZE / MIN_WAVE_LANES]; // Wave-wide aggregates (also group-local) to speed up reduction phase
inline WORK_TYPE GetScanInput(uint IdxWithinAxis, uint ScanStartGlobalIdx)
{
return Input[ ScanStartGlobalIdx + IdxWithinAxis * AxisStride ];
}
inline void SetScanOutput(uint IdxWithinAxis, WORK_TYPE Value, uint ScanStartGlobalIdx)
{
Output[ ScanStartGlobalIdx + IdxWithinAxis * AxisStride ] = Value;
}
inline uint GetWaveIndex(uint GroupThreadID)
{
return GroupThreadID / WaveGetLaneCount();
}
inline uint PartitionOffset(uint PartitionIndex)
{
return PartitionIndex * PARTITION_SIZE;
}
inline uint WavePartitionOffset(uint GroupThreadID)
{
return GetWaveIndex(GroupThreadID) * VALUES_PER_THREAD * WaveGetLaneCount();
}
/*
Performs a wave-local prefix scan (and reduction) of VALUES_PER_THREAD elements per thread, for each wave.
Scan results are stored in groupshared memory array LocalPrefix.
The aggregated value for each wave is stored in groupshared memory array WaveAggregate.
*/
inline void WaveScanReduce(uint GroupThreadID, uint PartitionIndex, uint ScanStartGlobalIdx, bool bPartial)
{
const uint ShiftLaneIdx = WaveGetLaneIndex() != 0u ? (WaveGetLaneIndex() - 1u) : (WaveGetLaneCount() - 1u);
const uint PartialSize = NumScanValues - PartitionOffset(PartitionIndex);
WORK_TYPE Aggregate = 0;
[unroll]
for (uint Idx = WaveGetLaneIndex() + WavePartitionOffset(GroupThreadID), WordIdx = 0; WordIdx < VALUES_PER_THREAD; Idx += WaveGetLaneCount(), ++WordIdx)
{
WORK_TYPE Value = !bPartial || Idx < PartialSize ?
GetScanInput(Idx + PartitionOffset(PartitionIndex), ScanStartGlobalIdx)
: 0;
const WORK_TYPE ShiftLaneValue = WaveReadLaneAt(Value + WavePrefixSum(Value), ShiftLaneIdx); // Wave prefix sum and circular shuffle
LocalPrefix[Idx] = Value + (WaveGetLaneIndex() != 0u ? ShiftLaneValue : 0) + Aggregate;
Aggregate += WaveReadLaneAt(ShiftLaneValue, 0);
}
if (WaveGetLaneIndex() == 0u)
{
WaveAggregate[GetWaveIndex(GroupThreadID)] = Aggregate;
}
}
/*
Performs a group-local prefix scan of the wave aggregates in groupshared memory.
*/
inline void ScanWaveAggregate(uint GroupThreadID)
{
const uint ScanSize = NUM_WAVEAGGREGATES;
if (GroupThreadID < WaveGetLaneCount())
{
WaveAggregate[GroupThreadID] += WavePrefixSum(WaveAggregate[GroupThreadID]);
// The following won't be needed if WaveGetLaneCount() * WaveGetLaneCount() >= THREADGROUP_SIZE (most common case).
for(uint AggregateOffset = WaveGetLaneCount(); AggregateOffset + GroupThreadID < ScanSize; AggregateOffset += WaveGetLaneCount())
{
GroupMemoryBarrierWithGroupSync();
if (WaveGetLaneIndex() == 0u)
{
WaveAggregate[AggregateOffset] += WaveAggregate[AggregateOffset - 1u];
}
GroupMemoryBarrierWithGroupSync();
WaveAggregate[GroupThreadID + AggregateOffset] += WavePrefixSum(WaveAggregate[GroupThreadID + AggregateOffset]);
}
}
}
inline void ReservePartitionIndex(uint GroupThreadID, uint ScanID)
{
if (GroupThreadID == 0)
{
InterlockedAdd(GlobalPartitionIndex[ScanID], 1u, GroupPartitionIndex);
}
}
inline void SetPartitionInclusivePrefix(uint PDIdx, WORK_TYPE Value)
{
PartitionDescriptor[PDIdx].InclusivePrefix = Value;
DeviceMemoryBarrier(); // To guarantee coherency avoid that status gets written before inclusive prefix does
PartitionDescriptor[PDIdx].StatusFlag = STATUS_PREFIX_AVAILABLE;
}
inline void SetPartitionAggregate(uint PDIdx, WORK_TYPE Value)
{
PartitionDescriptor[PDIdx].Aggregate = Value;
DeviceMemoryBarrier(); // To guarantee coherency avoid that status gets written before aggregate does
PartitionDescriptor[PDIdx].StatusFlag = STATUS_AGGREGATE_AVAILABLE;
}
inline uint GetPDIdx(uint PartitionIndex, uint ScanID)
{
return ScanID * NumThreadGroupsPerScan + PartitionIndex;
}
/*
Updates the current group's partition descriptor in device memory once the partition aggregate has been computed.
*/
inline void BroadcastGroupAggregate(uint GroupThreadID, uint PartitionIndex, uint ScanID)
{
// Last wave aggregate corresponds to group (partition) aggregate
if (GroupThreadID == LAST_WAVEAGGREGATE_ID)
{
const uint PDIdx = GetPDIdx(PartitionIndex, ScanID);
if(PartitionIndex == 0)
{
SetPartitionInclusivePrefix(PDIdx, WaveAggregate[GroupThreadID]);
}
else
{
SetPartitionAggregate(PDIdx, WaveAggregate[GroupThreadID]);
}
}
}
/*
Computes the partition exclusive prefix by recursively summing up prefixes/aggregates of previous partitions in backward order.
Also computes and sets (broadcasts) the inclusive prefix of the partition once the exclusive prefix is available.
*/
inline void DecoupledLookback(uint PartitionIndex, uint ScanID)
{
WORK_TYPE PreviousReduction = (WORK_TYPE) 0;
int LookBackIndex = (int) PartitionIndex - 1;
while (LookBackIndex >= 0)
{
const FPartitionDescriptor LookBackPD = PartitionDescriptor[GetPDIdx(LookBackIndex, ScanID)];
if (LookBackPD.StatusFlag == STATUS_PREFIX_AVAILABLE)
{
DeviceMemoryBarrier(); // Required between reading lookback's StatusFlag and InclusivePrefix for coherency
PreviousReduction += LookBackPD.InclusivePrefix;
GroupExclusivePrefix = PreviousReduction;
SetPartitionInclusivePrefix(
GetPDIdx(PartitionIndex, ScanID),
PreviousReduction + WaveAggregate[LAST_WAVEAGGREGATE_ID]
);
return;
}
else if(LookBackPD.StatusFlag == STATUS_AGGREGATE_AVAILABLE)
{
DeviceMemoryBarrier(); // Required between reading lookback's StatusFlag and Aggregate for coherency
PreviousReduction += LookBackPD.Aggregate;
LookBackIndex--;
}
}
GroupExclusivePrefix = PreviousReduction;
SetPartitionInclusivePrefix(
GetPDIdx(PartitionIndex, ScanID),
PreviousReduction + WaveAggregate[LAST_WAVEAGGREGATE_ID]
);
}
/*
Computes the final prefix scan by adding wave-local prefixes to wave-wise exclusive prefixes.
*/
void SeededWaveScan(uint GroupThreadID, uint PartitionIndex, WORK_TYPE ExclusivePrefix, uint ScanStartGlobalIdx, bool bPartial)
{
const uint PartialSize = NumScanValues - PartitionOffset(PartitionIndex);
[unroll]
for (uint Idx = WaveGetLaneIndex() + WavePartitionOffset(GroupThreadID), WordIdx = 0; WordIdx < VALUES_PER_THREAD && (!bPartial || Idx < PartialSize); Idx += WaveGetLaneCount(), ++WordIdx)
{
SetScanOutput(Idx + PartitionOffset(PartitionIndex), LocalPrefix[Idx] + ExclusivePrefix, ScanStartGlobalIdx);
}
}
[numthreads(THREADGROUP_SIZE, 1, 1)]
void CumSum(in const uint3 GroupID : SV_GroupID, in const uint3 GroupThreadIDVec : SV_GroupThreadID)
{
const uint PreAxisGroupID = GroupID.y;
const uint PostAxisGroupID = GroupID.z;
const uint GroupThreadID = GroupThreadIDVec.x;
const uint ScanID = PreAxisGroupID * NumThreadGroupsZ + PostAxisGroupID;
// Global index to first element of the scan
uint ScanStartGlobalIdx = PreAxisGroupID * (NumScanValues * NumThreadGroupsZ) + PostAxisGroupID;
ReservePartitionIndex(GroupThreadID, ScanID);
GroupMemoryBarrierWithGroupSync();
const uint PartitionIndex = GroupPartitionIndex;
WaveScanReduce(GroupThreadID, PartitionIndex, ScanStartGlobalIdx, /* bPartial */ PartitionIndex == NumThreadGroupsPerScan - 1);
GroupMemoryBarrierWithGroupSync();
ScanWaveAggregate(GroupThreadID);
BroadcastGroupAggregate(GroupThreadID, PartitionIndex, ScanID);
GroupExclusivePrefix = 0.0f;
if (PartitionIndex != 0 && GroupThreadID == 0)
{
DecoupledLookback(PartitionIndex, ScanID);
}
GroupMemoryBarrierWithGroupSync(); // Required to sync all threads in the group with respect to GroupExclusivePrefix
// Compute the wave-wise exclusive prefix
const WORK_TYPE ExclusivePrefix = GroupExclusivePrefix +
(GroupThreadID >= WaveGetLaneCount() ? WaveAggregate[GetWaveIndex(GroupThreadID) - 1] : (WORK_TYPE) 0)
;
SeededWaveScan(GroupThreadID, PartitionIndex, ExclusivePrefix, ScanStartGlobalIdx, /* bPartial */ PartitionIndex == NumThreadGroupsPerScan - 1);
}
#endif //ifdef INIT_SHADER