Files
UnrealEngine/Engine/Shaders/Private/ThreadGroupPrefixSum.ush
2025-05-18 13:04:45 +08:00

163 lines
5.7 KiB
HLSL

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
/**
* ThreadGroupPrefixSum(Value, GroupThreadIndex [, GroupSumOut])
*
* Calculates the exclusive prefix sum (and optionally group sum) of a given value across threads of a thread group.
*
* EXAMPLE USES:
* float GroupSum;
* float PrefixSum = ThreadGroupPrefixSum(ValueFloat, GroupThreadIndex, GroupSum);
*
* int PrefixSum = ThreadGroupPrefixSum(ValueInt, GroupThreadIndex);
*
* NOTES:
* - (Exclusive) Prefix Sum means the sum of the value for all threads whose group thread index is LESS than the current
* - Only scalar types are currently supported
* - All threads in the group must be active when calling this method; it cannot be used in a branch (or after an early return)
* that doesn't include the entire group.
*/
#ifndef NUM_THREADS_PER_GROUP
#error NUM_THREADS_PER_GROUP must be defined, and must be equal to the thread group size of the caller of ThreadGroupPrefixSum
#endif
// re-use the same groupshared memory, in case the caller utilizes multiple overloads
groupshared uint ThreadGroupPrefixSumWorkspace[2][NUM_THREADS_PER_GROUP];
#define DECLARE_THREAD_GROUP_PREFIX_SUM(ValType, CastToUint, CastFromUint) \
ValType ThreadGroupPrefixSum(ValType Value, uint GroupThreadIndex, inout ValType GroupSum) \
{ \
uint Curr = 0; \
ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex] = CastToUint(Value); \
GroupMemoryBarrierWithGroupSync(); \
for (uint i = 1U; i <= (NUM_THREADS_PER_GROUP / 2U); i *= 2U) \
{ \
const uint Next = 1U - Curr; \
if (GroupThreadIndex < i) \
{ \
ThreadGroupPrefixSumWorkspace[Next][GroupThreadIndex] = \
ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex]; \
} \
else \
{ \
ThreadGroupPrefixSumWorkspace[Next][GroupThreadIndex] = CastToUint( \
CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex]) + \
CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex - i]) \
); \
} \
Curr = Next; \
GroupMemoryBarrierWithGroupSync(); \
} \
GroupSum = CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][NUM_THREADS_PER_GROUP - 1]); \
if (GroupThreadIndex == 0U) \
{ \
return CastFromUint(0); \
} \
else \
{ \
return CastFromUint(ThreadGroupPrefixSumWorkspace[Curr][GroupThreadIndex - 1]); \
} \
} \
ValType ThreadGroupPrefixSum(ValType Value, uint GroupThreadIndex) \
{ \
ValType GroupSum; \
return ThreadGroupPrefixSum(Value, GroupThreadIndex, GroupSum); \
}
DECLARE_THREAD_GROUP_PREFIX_SUM(float, asuint, asfloat)
DECLARE_THREAD_GROUP_PREFIX_SUM(int, uint, int)
DECLARE_THREAD_GROUP_PREFIX_SUM(uint, uint, uint)
#define DECLARE_THREAD_GROUP_REDUCE_SUM(ValType, CastToUint, CastFromUint) \
ValType ThreadGroupReduceSum(ValType Value, uint GroupThreadIndex) \
{ \
ValType Result = Value; \
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = CastToUint(Value); \
GroupMemoryBarrierWithGroupSync(); \
for (uint Index = NUM_THREADS_PER_GROUP / 2U; Index > 0U; Index >>= 1U) \
{ \
if (GroupThreadIndex < Index) \
{ \
Result += CastFromUint(ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex + Index]); \
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = CastToUint(Result); \
} \
GroupMemoryBarrierWithGroupSync(); \
} \
return CastFromUint(ThreadGroupPrefixSumWorkspace[0][0]); \
}
DECLARE_THREAD_GROUP_REDUCE_SUM(float, asuint, asfloat)
DECLARE_THREAD_GROUP_REDUCE_SUM(int, uint, int)
DECLARE_THREAD_GROUP_REDUCE_SUM(uint, uint, uint)
#undef DECLARE_THREAD_GROUP_REDUCE_SUM
#define DECLARE_THREAD_GROUP_REDUCE_MAX(ValType, CastToUint, CastFromUint) \
ValType ThreadGroupReduceMax(ValType Value, uint GroupThreadIndex) \
{ \
ValType Result = Value; \
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = CastToUint(Value); \
GroupMemoryBarrierWithGroupSync(); \
for (uint Index = NUM_THREADS_PER_GROUP / 2U; Index > 0U; Index >>= 1U) \
{ \
if (GroupThreadIndex < Index) \
{ \
Result = max(Result, CastFromUint(ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex + Index])); \
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = CastToUint(Result); \
} \
GroupMemoryBarrierWithGroupSync(); \
} \
return CastFromUint(ThreadGroupPrefixSumWorkspace[0][0]); \
}
DECLARE_THREAD_GROUP_REDUCE_MAX(float, asuint, asfloat)
DECLARE_THREAD_GROUP_REDUCE_MAX(int, uint, int)
DECLARE_THREAD_GROUP_REDUCE_MAX(uint, uint, uint)
#undef DECLARE_THREAD_GROUP_REDUCE_MAX
#define DECLARE_THREAD_GROUP_REDUCE_MIN(ValType, CastToUint, CastFromUint) \
ValType ThreadGroupReduceMin(ValType Value, uint GroupThreadIndex) \
{ \
ValType Result = Value; \
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = CastToUint(Value); \
GroupMemoryBarrierWithGroupSync(); \
for (uint Index = NUM_THREADS_PER_GROUP / 2U; Index > 0U; Index >>= 1U) \
{ \
if (GroupThreadIndex < Index) \
{ \
Result = min(Result, CastFromUint(ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex + Index])); \
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = CastToUint(Result); \
} \
GroupMemoryBarrierWithGroupSync(); \
} \
return CastFromUint(ThreadGroupPrefixSumWorkspace[0][0]); \
}
DECLARE_THREAD_GROUP_REDUCE_MIN(float, asuint, asfloat)
DECLARE_THREAD_GROUP_REDUCE_MIN(int, uint, int)
DECLARE_THREAD_GROUP_REDUCE_MIN(uint, uint, uint)
uint ThreadGroupReduceOr(uint Value, uint GroupThreadIndex)
{
uint Result = Value;
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = Value;
GroupMemoryBarrierWithGroupSync();
for (uint Index = NUM_THREADS_PER_GROUP / 2U; Index > 0U; Index >>= 1U)
{
if (GroupThreadIndex < Index)
{
Result = Result | ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex + Index];
ThreadGroupPrefixSumWorkspace[0][GroupThreadIndex] = Result;
}
GroupMemoryBarrierWithGroupSync();
}
return ThreadGroupPrefixSumWorkspace[0][0];
}