Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Shaders/Private/NNEHlslShaders/NNEHlslShadersScatterND.usf
2025-05-18 13:04:45 +08:00

113 lines
3.4 KiB
HLSL

// Copyright Epic Games, Inc. All Rights Reserved.
#include "/Engine/Public/Platform.ush"
#include "NNEHlslShadersBroadcastHelper.ush"
#define WORK_TYPE float
#define BUFFER_TYPE float
#define READ(x) x
#define WRITE(x) x
Buffer<int> InputIndices;
Buffer<BUFFER_TYPE> InputUpdates;
RWBuffer<BUFFER_TYPE> Output;
uint4 DataTensorInfo[NUM_DIMENSIONS];
uint Num;
uint ThreadCountX;
uint PartialIndexRank; // k
uint SliceSize; // Obtained by multiplying elements in data.shape[k:r-1]
#define STRIDE_IDX 0
#define SIZE_IDX 1
// Corresponds to EScatterNDReductionType defined in NNEHlslShadersScatterNDCS.h
#define TYPE_NONE 0
#define TYPE_ADD 1
#define TYPE_MUL 2
#define TYPE_MAX 3
#define TYPE_MIN 4
uint NormalizeIdx(int Idx, uint DimSize)
{
return Idx < 0 ? (uint) (Idx + DimSize) : (uint) Idx;
}
// int64_t is only available from SM 6.0, therefore here we need to manually cast to int32.
int CastInt64ToInt32(uint LSW, int MSW)
{
return (MSW >> 31) == 1 ? - (int) LSW : (int) LSW;
}
#define STATIC_LOOP(Var, From, To) \
[unroll] \
for(uint Var = From; Var < To; ++Var)
void FromIdxToIterator(const uint Idx, const uint4 TensorInfo[NUM_DIMENSIONS], out uint Iterator[NUM_DIMENSIONS])
{
uint Offset = Idx;
STATIC_LOOP(DimIdx, 0, NUM_DIMENSIONS)
{
uint Remainder;
DivMod(Offset, TensorInfo[DimIdx][STRIDE_IDX], Iterator[DimIdx], Remainder);
Offset = Remainder;
}
}
uint FromIteratorToIdx(const uint4 TensorInfo[NUM_DIMENSIONS], const uint Iterator[NUM_DIMENSIONS])
{
uint Offset = 0;
STATIC_LOOP(DimIdx, 0, NUM_DIMENSIONS)
{
Offset += Iterator[DimIdx] * TensorInfo[DimIdx][STRIDE_IDX];
}
return Offset;
}
void ApplyReduction(const uint OutputIdx, const uint UpdatesIdx)
{
#if REDUCE_OPERATOR_TYPE == TYPE_NONE
Output[OutputIdx] = WRITE(READ(InputUpdates[UpdatesIdx]));
#elif REDUCE_OPERATOR_TYPE == TYPE_ADD
Output[OutputIdx] = WRITE(READ(Output[OutputIdx]) + READ(InputUpdates[UpdatesIdx]));
#elif REDUCE_OPERATOR_TYPE == TYPE_MUL
Output[OutputIdx] = WRITE(READ(Output[OutputIdx]) * READ(InputUpdates[UpdatesIdx]));
#elif REDUCE_OPERATOR_TYPE == TYPE_MAX
Output[OutputIdx] = WRITE(max(READ(Output[OutputIdx]), READ(InputUpdates[UpdatesIdx])));
#elif REDUCE_OPERATOR_TYPE == TYPE_MIN
Output[OutputIdx] = WRITE(min(READ(Output[OutputIdx]), READ(InputUpdates[UpdatesIdx])));
#endif
}
[numthreads(THREADGROUP_SIZE, 1, 1)]
void ScatterND(in const uint3 DispatchThreadID : SV_DispatchThreadID)
{
// See https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND
// Parallelizing over updates, of shape: indices.shape[0:q-2] ++ data.shape[k:r-1]
const uint Idx = DispatchThreadID.y * ThreadCountX + DispatchThreadID.x;
if(Idx < Num)
{
uint ReplacementIdx = Idx / SliceSize;
uint ElemIdx = Idx % SliceSize;
uint OutputIterator[NUM_DIMENSIONS];
// Set the index of the element within the slice
FromIdxToIterator(ElemIdx, DataTensorInfo, OutputIterator);
// Then retrieve and set the index of the slice
for (uint DimIdx = 0; DimIdx < PartialIndexRank; ++DimIdx)
{
uint IndicesAccessIdx = ReplacementIdx * PartialIndexRank + DimIdx;
int RawIdx = CastInt64ToInt32( asuint(READ(InputIndices[IndicesAccessIdx * 2])), READ(InputIndices[IndicesAccessIdx * 2 + 1]) );
OutputIterator[DimIdx] = NormalizeIdx(RawIdx, DataTensorInfo[DimIdx][SIZE_IDX]);
}
uint OutputIdx = FromIteratorToIdx(DataTensorInfo, OutputIterator);
ApplyReduction(OutputIdx, Idx);
}
}