// Copyright Epic Games, Inc. All Rights Reserved. /*============================================================================== RadixSortShaders.usf: Compute shader implementation of radix sort. ==============================================================================*/ #include "Common.ush" /*------------------------------------------------------------------------------ Compile time parameters: RADIX_BITS - The number of bits to inspect during each pass. THREAD_COUNT - The number of threads to launch per workgroup. KEYS_PER_LOOP - The number of keys to process simultaneously in a single thread. MAX_GROUP_COUNT - The maximum number of groups that will be used for the upsweep and downsweep. Notes: The upsweep and downsweep kernels must never be invoked with more than MAX_GROUP_COUNT workgroups. The spine kernel does not respect THREAD_COUNT and will launch the number of threads needed to efficiently process MAX_GROUP_COUNT * (1 << RADIX_BITS) counters. The spine kernel should always be invoked with a single workgroup. The tile size is defined as THREAD_COUNT * KEYS_PER_LOOP. When computing the number of tiles and groups needed to execute the sort, CPU code must respect the tile size with which these shaders have been compiled. ------------------------------------------------------------------------------*/ /** Full group sync as some GPUs will hang without it, see UE-282658. */ #define PPS_BARRIER GroupMemoryBarrierWithGroupSync /** The number of digits per radix. */ #define DIGIT_COUNT (1 << RADIX_BITS) /** Bitmask to retrieve the digit of a key. */ #define DIGIT_MASK (DIGIT_COUNT - 1) /** The number of words per bank. This is used to compute padding to avoid bank conflicts. */ #define WORDS_PER_BANK (32) /** The padded size of a bank. */ #define PADDED_BANK_SIZE (WORDS_PER_BANK + 1) /*------------------------------------------------------------------------------ Common functionality. ------------------------------------------------------------------------------*/ /** Specifies which shaders use the raking technique for computing parallel prefix sums. */ #define USES_RAKING (RADIX_SORT_UPSWEEP || RADIX_SORT_SPINE || RADIX_SORT_DOWNSWEEP) #if (RADIX_SORT_UPSWEEP || RADIX_SORT_DOWNSWEEP) /** The size of a single tile. */ #define TILE_SIZE (THREAD_COUNT * KEYS_PER_LOOP) /** The number of counters required per digit. */ #define COUNTERS_PER_DIGIT (THREAD_COUNT) /** The number of banks per digit. Counters are segregated in to banks to avoid bank conflicts. */ #define BANKS_PER_DIGIT (COUNTERS_PER_DIGIT / WORDS_PER_BANK) /** The number of raking threads. */ #define RAKING_THREAD_COUNT (64) /** The number of raking threads used per digit. */ #define RAKING_THREADS_PER_DIGIT (RAKING_THREAD_COUNT / DIGIT_COUNT) /** The required padding for raking thread counters. */ #define RAKING_COUNTER_PADDING (RAKING_THREADS_PER_DIGIT >> 1) /** The number of raking counters required per digit. */ #define RAKING_COUNTERS_PER_DIGIT (RAKING_THREADS_PER_DIGIT + RAKING_COUNTER_PADDING) /** The number of counters processed by each raking thread. */ #define COUNTERS_PER_RAKING_THREAD (COUNTERS_PER_DIGIT / RAKING_THREADS_PER_DIGIT) /** The number of raking threads per bank. */ #define RAKING_THREADS_PER_BANK (WORDS_PER_BANK / COUNTERS_PER_RAKING_THREAD) #elif RADIX_SORT_SPINE // #if (RADIX_SORT_UPSWEEP || RADIX_SORT_DOWNSWEEP) /** The number of counters per digit. Note that this shader uses twice the number of required counters to work around a bug. */ #define COUNTERS_PER_DIGIT (MAX_GROUP_COUNT) /** The number of banks required per digit. */ #define BANKS_PER_DIGIT (COUNTERS_PER_DIGIT / WORDS_PER_BANK) /** The number of raking threads. */ #define RAKING_THREAD_COUNT (BANKS_PER_DIGIT * DIGIT_COUNT) /** The number of raking threads used per digit. */ #define RAKING_THREADS_PER_DIGIT (RAKING_THREAD_COUNT / DIGIT_COUNT) /** The amount of padding required for raking counters. */ #define RAKING_COUNTER_PADDING (RAKING_THREADS_PER_DIGIT >> 1) /** The number of raking counters required per digit. */ #define RAKING_COUNTERS_PER_DIGIT (RAKING_THREADS_PER_DIGIT + RAKING_COUNTER_PADDING) /** The number of counters processed by each raking thread. */ #define COUNTERS_PER_RAKING_THREAD (COUNTERS_PER_DIGIT / RAKING_THREADS_PER_DIGIT) /** The number of raking threads per bank. */ #define RAKING_THREADS_PER_BANK (WORDS_PER_BANK / COUNTERS_PER_RAKING_THREAD) #endif // #elif RADIX_SORT_SPINE #if USES_RAKING /** * Computes the bank offset to rake for a given thread. */ uint GetBankToRakeOffset( const uint ThreadId ) { return (ThreadId / RAKING_THREADS_PER_BANK) * PADDED_BANK_SIZE // Which bank to rake. + (ThreadId & (RAKING_THREADS_PER_BANK - 1)) // Which counters within the bank to rake. * COUNTERS_PER_RAKING_THREAD; } /** * Computes the digit to rake for a given thread. */ uint GetRakingDigit( const uint ThreadId ) { if ( ThreadId < RAKING_THREAD_COUNT ) { return ThreadId / RAKING_THREADS_PER_DIGIT; } else { return DIGIT_COUNT; } } /** * Computes the index to rake for a given thread. */ uint GetRakingIndex( const uint ThreadId ) { if ( ThreadId < RAKING_THREAD_COUNT ) { return GetRakingDigit( ThreadId ) * RAKING_COUNTERS_PER_DIGIT // Which digit. + (ThreadId & (RAKING_THREADS_PER_DIGIT - 1)) // Thread offset within digit. + (RAKING_COUNTER_PADDING); // Add in padding. } else { return RAKING_COUNTERS_PER_DIGIT * DIGIT_COUNT; } } #endif // #if USES_RAKING /*------------------------------------------------------------------------------ The offset clearing kernel. This kernel just zeroes out the offsets buffer. Note that MAX_GROUP_COUNT * DIGIT_COUNT must be a multiple of THREAD_COUNT. ------------------------------------------------------------------------------*/ #if RADIX_SORT_CLEAR_OFFSETS #define OFFSET_COUNT (MAX_GROUP_COUNT * DIGIT_COUNT) #define OFFSETS_PER_THREAD (OFFSET_COUNT / THREAD_COUNT) /** Output buffer to clear. */ RWBuffer OutOffsets; [numthreads(THREAD_COUNT,1,1)] void RadixSort_ClearOffsets( uint3 GroupThreadId : SV_GroupThreadID ) { // Determine group and thread IDs. const uint ThreadId = GroupThreadId.x; // Clear all offsets. for ( uint OffsetIndex = 0; OffsetIndex < OFFSETS_PER_THREAD; ++OffsetIndex ) { OutOffsets[OffsetIndex * THREAD_COUNT + ThreadId] = 0; } } #endif // #if RADIX_SORT_CLEAR_OFFSETS /*------------------------------------------------------------------------------ The upsweep sorting kernel. This kernel performs an upsweep scan on all tiles allocated to this group and computes per-digit totals. These totals are output to the offsets buffer. ------------------------------------------------------------------------------*/ #if RADIX_SORT_UPSWEEP /** Input keys to be sorted. */ Buffer InKeys; /** Output buffer for offsets. */ RWBuffer OutOffsets; /** Local storage for the digit counters. */ groupshared uint LocalCounters[BANKS_PER_DIGIT * DIGIT_COUNT * PADDED_BANK_SIZE]; /** Local storage for raking totals. */ groupshared uint LocalRakingTotals[RAKING_COUNTERS_PER_DIGIT * DIGIT_COUNT + 1]; /** * The upsweep sorting kernel computing a scan upsweep on tiles per thread group. */ [numthreads(THREAD_COUNT,1,1)] void RadixSort_Upsweep( uint3 GroupThreadId : SV_GroupThreadID, uint3 GroupIdXYZ : SV_GroupID ) { uint i; uint FirstTileIndex; uint TileCountForGroup; uint ExtraKeysForGroup; // Determine group and thread IDs. const uint ThreadId = GroupThreadId.x; const uint GroupId = GroupIdXYZ.x; const uint BankOffset = ThreadId / (WORDS_PER_BANK); const uint CounterOffset = ThreadId & (WORDS_PER_BANK - 1); const uint BankToRakeOffset = GetBankToRakeOffset( ThreadId ); const uint RakingIndex = GetRakingIndex( ThreadId ); #if 0 if ( RadixSortUB.ExtraKeyCount == 16 && RadixSortUB.TilesPerGroup != 0 ) { if ( ThreadId == 0 ) { OutOffsets[DIGIT_COUNT*1] = RadixSortUB.RadixShift; OutOffsets[DIGIT_COUNT*2] = RadixSortUB.TilesPerGroup; OutOffsets[DIGIT_COUNT*3] = RadixSortUB.ExtraTileCount; OutOffsets[DIGIT_COUNT*4] = RadixSortUB.ExtraKeyCount; OutOffsets[DIGIT_COUNT*5] = RadixSortUB.GroupCount; } } #endif // Clear local counters. for ( uint DigitIndex = 0; DigitIndex < DIGIT_COUNT; ++DigitIndex ) { const uint BankIndex = DigitIndex * BANKS_PER_DIGIT + BankOffset; LocalCounters[BankIndex * PADDED_BANK_SIZE + CounterOffset] = 0; } // Clear the raking counter padding. if ( ThreadId < RAKING_THREAD_COUNT ) { [unroll] for ( i = 1; i <= RAKING_COUNTER_PADDING; ++i ) { LocalRakingTotals[RakingIndex - i] = 0; } } // The number of tiles to process in this group. if ( GroupId < RadixSortUB.ExtraTileCount ) { FirstTileIndex = GroupId * (RadixSortUB.TilesPerGroup + 1); TileCountForGroup = RadixSortUB.TilesPerGroup + 1; } else { FirstTileIndex = GroupId * RadixSortUB.TilesPerGroup + RadixSortUB.ExtraTileCount; TileCountForGroup = RadixSortUB.TilesPerGroup; } // The last group has to process any keys after the last tile. if ( GroupId == (RadixSortUB.GroupCount - 1) ) { ExtraKeysForGroup = RadixSortUB.ExtraKeyCount; } else { ExtraKeysForGroup = 0; } // Key range for this group. uint GroupKeyBegin = FirstTileIndex * TILE_SIZE; uint GroupKeyEnd = GroupKeyBegin + TileCountForGroup * TILE_SIZE; // Acquire LocalCounters. GroupMemoryBarrierWithGroupSync(); // Accumulate digit counters for the tiles assigned to this group. while ( GroupKeyBegin < GroupKeyEnd ) { const uint Key = InKeys[GroupKeyBegin + ThreadId]; const uint Digit = (Key >> RadixSortUB.RadixShift) & DIGIT_MASK; const uint BankIndex = Digit * BANKS_PER_DIGIT + BankOffset; LocalCounters[BankIndex * PADDED_BANK_SIZE + CounterOffset] += 1; //LocalCounters[BankOffset * PADDED_BANK_SIZE + CounterOffset] += 1; GroupKeyBegin += THREAD_COUNT; } // Accumulate digit counters for any additional keys assigned to this group. GroupKeyBegin = GroupKeyEnd; GroupKeyEnd += ExtraKeysForGroup; while ( GroupKeyBegin < GroupKeyEnd ) { if ( GroupKeyBegin + ThreadId < GroupKeyEnd ) { const uint Key = InKeys[GroupKeyBegin + ThreadId]; const uint Digit = (Key >> RadixSortUB.RadixShift) & DIGIT_MASK; const uint BankIndex = Digit * BANKS_PER_DIGIT + BankOffset; LocalCounters[BankIndex * PADDED_BANK_SIZE + CounterOffset] += 1; //LocalCounters[BankOffset * PADDED_BANK_SIZE + CounterOffset] += 100; } GroupKeyBegin += THREAD_COUNT; } // Acquire LocalCounters. GroupMemoryBarrierWithGroupSync(); // Reduce. uint Total = 0; if ( ThreadId < RAKING_THREAD_COUNT ) { // Perform a serial reduction on this raking thread's counters. Total = LocalCounters[BankToRakeOffset]; [unroll] for ( i = 1; i < COUNTERS_PER_RAKING_THREAD; ++i ) { Total += LocalCounters[BankToRakeOffset + i]; } // Place the total in the raking counter. LocalRakingTotals[RakingIndex] = Total; } for ( uint RakingOffset = 1; RakingOffset < RAKING_THREADS_PER_DIGIT; RakingOffset <<= 1 ) { // Acquire LocalRakingTotals. PPS_BARRIER(); Total += LocalRakingTotals[RakingIndex - RakingOffset]; LocalRakingTotals[RakingIndex] = Total; } // Acquire LocalRakingTotals. GroupMemoryBarrierWithGroupSync(); // Store digit totals for this group. if ( ThreadId < DIGIT_COUNT ) { OutOffsets[GroupId * DIGIT_COUNT + ThreadId] = LocalRakingTotals[RAKING_COUNTERS_PER_DIGIT * (ThreadId+1) - 1]; } } #endif // #if RADIX_SORT_UPSWEEP /*------------------------------------------------------------------------------ The spine sorting kernel. This kernel performs a parallel prefix sum on the offsets computed by each work group in the upsweep. The outputs will be used by individual work groups during the downsweep to compute the final location of keys. ------------------------------------------------------------------------------*/ #if RADIX_SORT_SPINE /** The amount of padding to add for scanning digits. */ #define DIGIT_SCAN_PADDING (DIGIT_COUNT >> 1) /** The number of threads required for this phase. */ #if COUNTERS_PER_DIGIT > RAKING_THREAD_COUNT #define SPINE_THREAD_COUNT COUNTERS_PER_DIGIT #else #define SPINE_THREAD_COUNT RAKING_THREAD_COUNT #endif /** Input offsets to be scanned. */ Buffer InOffsets; /** Output buffer for scanned offsets. */ RWBuffer OutOffsets; /** Local storage for the digit counters. */ groupshared uint LocalOffsets[BANKS_PER_DIGIT * DIGIT_COUNT * PADDED_BANK_SIZE]; /** Local storage for raking totals. */ groupshared uint LocalRakingTotals[RAKING_COUNTERS_PER_DIGIT * DIGIT_COUNT + 1]; /** Local storage for the per-digit totals. */ groupshared uint LocalTotals[128 + DIGIT_SCAN_PADDING]; /** * Spine parallel prefix sum kernel. */ [numthreads(SPINE_THREAD_COUNT,1,1)] void RadixSort_Spine( uint3 GroupThreadId : SV_GroupThreadID ) { uint i; // Determine the index for this thread. const uint ThreadId = GroupThreadId.x; const uint BankOffset = ThreadId / WORDS_PER_BANK; const uint CounterOffset = ThreadId & (WORDS_PER_BANK - 1); const uint BankToRakeOffset = GetBankToRakeOffset( ThreadId ); const uint RakingDigit = GetRakingDigit( ThreadId ); const uint RakingIndex = GetRakingIndex( ThreadId ); // Load offsets. if ( ThreadId < MAX_GROUP_COUNT ) { for ( uint DigitIndex = 0; DigitIndex < DIGIT_COUNT; ++DigitIndex ) { const uint BankIndex = DigitIndex * BANKS_PER_DIGIT + BankOffset; LocalOffsets[BankIndex * PADDED_BANK_SIZE + CounterOffset] = InOffsets[ThreadId * DIGIT_COUNT + DigitIndex]; } } else { for ( uint DigitIndex = 0; DigitIndex < DIGIT_COUNT; ++DigitIndex ) { const uint BankIndex = DigitIndex * BANKS_PER_DIGIT + BankOffset; LocalOffsets[BankIndex * PADDED_BANK_SIZE + CounterOffset] = 0; } } // Clear the raking counter padding. if ( ThreadId < RAKING_THREAD_COUNT ) { [unroll] for ( i = 1; i <= RAKING_COUNTER_PADDING; ++i ) { LocalRakingTotals[RakingIndex - i] = 0; } } // Clear LocalTotals padding. if ( ThreadId < DIGIT_SCAN_PADDING ) { LocalTotals[ThreadId] = 0; } // Acquire LocalOffsets, LocalTotals, and LocalRakingTotals. GroupMemoryBarrierWithGroupSync(); // Reduce. uint Total = 0; if ( ThreadId < RAKING_THREAD_COUNT ) { // Perform a serial reduction on this raking thread's counters. Total = LocalOffsets[BankToRakeOffset]; [unroll] for ( i = 1; i < COUNTERS_PER_RAKING_THREAD; ++i ) { Total += LocalOffsets[BankToRakeOffset + i]; } // Place the total in the raking counter. LocalRakingTotals[RakingIndex] = Total; } // Scan raking totals. { [unroll] for ( uint RakingOffset = 1; RakingOffset < RAKING_THREADS_PER_DIGIT; RakingOffset <<= 1 ) { // Acquire LocalRakingTotals. PPS_BARRIER(); Total = Total + LocalRakingTotals[RakingIndex - RakingOffset]; LocalRakingTotals[RakingIndex] = Total; } } // Acquire LocalRakingTotals. GroupMemoryBarrierWithGroupSync(); uint DigitSum = 0; if(ThreadId < DIGIT_COUNT) { DigitSum = LocalRakingTotals[RAKING_COUNTERS_PER_DIGIT * (ThreadId+1) - 1]; } LocalTotals[ThreadId + DIGIT_SCAN_PADDING] = DigitSum; // Scan local totals. { [unroll] for ( uint RakingOffset = 1; RakingOffset < DIGIT_COUNT; RakingOffset <<= 1 ) { // Acquire LocalTotals. PPS_BARRIER(); DigitSum += LocalTotals[ThreadId + DIGIT_SCAN_PADDING - RakingOffset]; GroupMemoryBarrierWithGroupSync(); LocalTotals[ThreadId + DIGIT_SCAN_PADDING] = DigitSum; } } // Acquire LocalTotals. GroupMemoryBarrierWithGroupSync(); // Serial scan. if ( ThreadId < RAKING_THREAD_COUNT ) { uint Offset = LocalTotals[RakingDigit + DIGIT_SCAN_PADDING - 1] + LocalRakingTotals[RakingIndex - 1]; [unroll] for ( i = 0; i < COUNTERS_PER_RAKING_THREAD; ++i ) { uint Counter = LocalOffsets[BankToRakeOffset + i]; LocalOffsets[BankToRakeOffset + i] = Offset; Offset += Counter; } } // Acquire LocalOffsets. GroupMemoryBarrierWithGroupSync(); // Store offsets. if ( ThreadId < MAX_GROUP_COUNT ) { for ( uint DigitIndex = 0; DigitIndex < DIGIT_COUNT; ++DigitIndex ) { const uint BankIndex = DigitIndex * BANKS_PER_DIGIT + BankOffset; OutOffsets[ThreadId * DIGIT_COUNT + DigitIndex] = LocalOffsets[BankIndex * PADDED_BANK_SIZE + CounterOffset]; } } } #endif // #if RADIX_SORT_SPINE /*------------------------------------------------------------------------------ The downsweep sorting kernel. This kernel reads the per-work group partial sums in to LocalTotals. The kernel then recomputes much of the work done during the upsweep, this time computing a full set of prefix sums so that keys can be scattered in to global memory. ------------------------------------------------------------------------------*/ #if RADIX_SORT_DOWNSWEEP /** The amount of padding to add for scanning digits. */ #define DIGIT_SCAN_PADDING (DIGIT_COUNT >> 1) /** The amount of storage required to store counters. */ #define MAX_COUNTER_STORAGE (BANKS_PER_DIGIT * DIGIT_COUNT * PADDED_BANK_SIZE) /** The maximum amount of storage required for temporary keys. */ #define MAX_KEY_STORAGE (THREAD_COUNT * KEYS_PER_LOOP) /** The amount of scratch storage required, aliased by counters and keys. */ #define SCRATCH_STORAGE (MAX_COUNTER_STORAGE > MAX_KEY_STORAGE ? MAX_COUNTER_STORAGE : MAX_KEY_STORAGE) /** Input keys to be sorted. */ Buffer InKeys; /** Input values to be transferred with keys. */ Buffer InValues; /** Input offsets. */ Buffer InOffsets; /** Output buffer for sorted keys. */ RWBuffer OutKeys; /** Output buffer for sorted values. */ RWBuffer OutValues; /** Local scratch storage for scattering. Should be SCRATCH_STORAGE elements, but DirectCompute doesn't like the ternary operator being used for array sizes. */ groupshared uint LocalScratch[MAX_COUNTER_STORAGE]; /** Local storage for raking totals. */ groupshared uint LocalRakingTotals[RAKING_COUNTERS_PER_DIGIT * DIGIT_COUNT + 1]; /** Local storage for the per-digit totals. */ groupshared uint LocalTotals[128 + DIGIT_SCAN_PADDING]; /** Local storage for the input offsets. */ groupshared uint LocalOffsets[DIGIT_COUNT]; /** * Obtain digits for the set of keys. */ void GetDigits( uint Keys[KEYS_PER_LOOP], out uint Digits[KEYS_PER_LOOP] ) { uint KeyIndex; [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { Digits[KeyIndex] = (Keys[KeyIndex] >> RadixSortUB.RadixShift) & DIGIT_MASK; } } /** * Scan keys to compute ranks. Ranks are stored in LocalScratch. */ void ScanKeys( const uint ThreadId, uint Digits[KEYS_PER_LOOP], const uint BankToRakeOffset, const uint RakingDigit, const uint RakingIndex, out uint DigitTotalThisTile ) { uint i; uint KeyIndex; uint DigitIndex; uint RakingOffset; DigitTotalThisTile = 0; const uint BankOffset = ThreadId / WORDS_PER_BANK; const uint CounterOffset = ThreadId & (WORDS_PER_BANK - 1); // Clear local counters. for ( DigitIndex = 0; DigitIndex < DIGIT_COUNT; ++DigitIndex ) { const uint BankIndex = DigitIndex * BANKS_PER_DIGIT + BankOffset; LocalScratch[BankIndex * PADDED_BANK_SIZE + CounterOffset] = 0; } // Increment counters for each key. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { const uint BankIndex = Digits[KeyIndex] * BANKS_PER_DIGIT + BankOffset; LocalScratch[BankIndex * PADDED_BANK_SIZE + CounterOffset] += 1; } // Acquire counters. GroupMemoryBarrierWithGroupSync(); // Reduce. uint Total = 0; if ( ThreadId < RAKING_THREAD_COUNT ) { // Perform a serial reduction on this raking thread's counters. Total = LocalScratch[BankToRakeOffset]; [unroll] for ( i = 1; i < COUNTERS_PER_RAKING_THREAD; ++i ) { Total += LocalScratch[BankToRakeOffset + i]; } // Place the total in the raking counter. LocalRakingTotals[RakingIndex] = Total; } // Scan raking totals. [unroll] for ( RakingOffset = 1; RakingOffset < RAKING_THREADS_PER_DIGIT; RakingOffset <<= 1 ) { // Acquire LocalRakingTotals. PPS_BARRIER(); Total += LocalRakingTotals[RakingIndex - RakingOffset]; LocalRakingTotals[RakingIndex] = Total; } // Acquire LocalRakingTotals. GroupMemoryBarrierWithGroupSync(); uint DigitSum = 0; if(ThreadId < DIGIT_COUNT) { DigitSum = LocalRakingTotals[RAKING_COUNTERS_PER_DIGIT * (ThreadId+1) - 1]; } DigitTotalThisTile = DigitSum; LocalTotals[ThreadId + DIGIT_SCAN_PADDING] = DigitSum; // Scan local totals. [unroll] for ( RakingOffset = 1; RakingOffset < DIGIT_COUNT; RakingOffset <<= 1 ) { // Acquire LocalTotals. PPS_BARRIER(); DigitSum += LocalTotals[ThreadId + DIGIT_SCAN_PADDING - RakingOffset]; GroupMemoryBarrierWithGroupSync(); LocalTotals[ThreadId + DIGIT_SCAN_PADDING] = DigitSum; } // Acquire LocalTotals. GroupMemoryBarrierWithGroupSync(); // Serial scan. if ( ThreadId < RAKING_THREAD_COUNT ) { uint Offset = LocalTotals[RakingDigit + DIGIT_SCAN_PADDING - 1] + LocalRakingTotals[RakingIndex - 1]; [unroll] for ( i = 0; i < COUNTERS_PER_RAKING_THREAD; ++i ) { uint Counter = LocalScratch[BankToRakeOffset + i]; LocalScratch[BankToRakeOffset + i] = Offset; Offset += Counter; } } // Acquire ranks in LocalScratch. GroupMemoryBarrierWithGroupSync(); } /** * Compute offsets for the given digits. */ void ComputeOffsets( const uint ThreadId, uint Digits[KEYS_PER_LOOP], out uint KeyOffsets[KEYS_PER_LOOP] ) { uint KeyIndex; const uint BankOffset = ThreadId / WORDS_PER_BANK; const uint CounterOffset = ThreadId & (WORDS_PER_BANK - 1); // Compute key offsets. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { const uint BankIndex = Digits[KeyIndex] * BANKS_PER_DIGIT + BankOffset; const uint Offset = LocalScratch[BankIndex * PADDED_BANK_SIZE + CounterOffset]; KeyOffsets[KeyIndex] = Offset; LocalScratch[BankIndex * PADDED_BANK_SIZE + CounterOffset] = Offset + 1; } // Release LocalScratch so we can scatter keys in to that memory. GroupMemoryBarrierWithGroupSync(); } /** * Scatter to local memory. */ void ScatterLocal( const uint ThreadId, uint Values[KEYS_PER_LOOP], uint Offsets[KEYS_PER_LOOP] ) { uint KeyIndex; // Scatter keys to local memory. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { LocalScratch[Offsets[KeyIndex]] = Values[KeyIndex]; } // Acquire values in LocalScratch. GroupMemoryBarrierWithGroupSync(); } /** * Downsweep scan and scatter kernel. */ [numthreads(THREAD_COUNT,1,1)] void RadixSort_Downsweep( uint3 GroupThreadId : SV_GroupThreadID, uint3 GroupIdXYZ : SV_GroupID ) { uint Keys[KEYS_PER_LOOP]; uint KeyOffsets[KEYS_PER_LOOP]; uint Digits[KEYS_PER_LOOP]; uint DigitTotalThisTile = 0; uint FirstTileIndex; uint TileCountForGroup; uint KeyIndex; uint i; // Determine global and local thread IDs. const uint ThreadId = GroupThreadId.x; const uint GroupId = GroupIdXYZ.x; const uint BankOffset = ThreadId / WORDS_PER_BANK; const uint CounterOffset = ThreadId & (WORDS_PER_BANK - 1); const uint BankToRakeOffset = GetBankToRakeOffset( ThreadId ); const uint RakingDigit = GetRakingDigit( ThreadId ); const uint RakingIndex = GetRakingIndex( ThreadId ); // Clear the raking counter padding. if ( ThreadId < RAKING_THREAD_COUNT ) { [unroll] for ( i = 1; i <= RAKING_COUNTER_PADDING; ++i ) { LocalRakingTotals[RakingIndex - i] = 0; } } // Read work group prefix sums in to local memory. if ( ThreadId < DIGIT_COUNT ) { LocalOffsets[ThreadId] = InOffsets[GroupId * DIGIT_COUNT + ThreadId]; } // Clear the padding on local totals. if ( ThreadId < DIGIT_SCAN_PADDING ) { LocalTotals[ThreadId] = 0; } // The number of tiles to process in this group. if ( GroupId < RadixSortUB.ExtraTileCount ) { FirstTileIndex = GroupId * (RadixSortUB.TilesPerGroup + 1); TileCountForGroup = RadixSortUB.TilesPerGroup + 1; } else { FirstTileIndex = GroupId * RadixSortUB.TilesPerGroup + RadixSortUB.ExtraTileCount; TileCountForGroup = RadixSortUB.TilesPerGroup; } // Key range for this group. uint GroupKeyBegin = FirstTileIndex * TILE_SIZE; uint GroupKeyEnd = GroupKeyBegin + TileCountForGroup * TILE_SIZE; // Process each tile sequentially. while ( GroupKeyBegin < GroupKeyEnd ) { // Load keys. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { Keys[KeyIndex] = InKeys[GroupKeyBegin + ThreadId * KEYS_PER_LOOP + KeyIndex]; } // Scan keys and compute offsets. GetDigits( Keys, Digits ); ScanKeys( ThreadId, Digits, BankToRakeOffset, RakingDigit, RakingIndex, DigitTotalThisTile ); ComputeOffsets( ThreadId, Digits, KeyOffsets ); // Scatter to local memory. ScatterLocal( ThreadId, Keys, KeyOffsets ); // Obtain coalesced keys. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { Keys[KeyIndex] = LocalScratch[ThreadId * KEYS_PER_LOOP + KeyIndex]; } GetDigits( Keys, Digits ); // Scatter keys to global memory. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { const uint Digit = Digits[KeyIndex]; const uint GlobalScatterIndex = LocalOffsets[Digit] + ThreadId * KEYS_PER_LOOP + KeyIndex - LocalTotals[Digit + DIGIT_SCAN_PADDING - 1]; OutKeys[GlobalScatterIndex] = Keys[KeyIndex]; } // Load values. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { Keys[KeyIndex] = InValues[GroupKeyBegin + ThreadId * KEYS_PER_LOOP + KeyIndex]; } // Release LocalScratch so values can be scattered to local memory. GroupMemoryBarrierWithGroupSync(); // Scatter to local memory. ScatterLocal( ThreadId, Keys, KeyOffsets ); // Obtain coalesced values. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { Keys[KeyIndex] = LocalScratch[ThreadId * KEYS_PER_LOOP + KeyIndex]; } // Scatter values to global memory. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { const uint Digit = Digits[KeyIndex]; const uint GlobalScatterIndex = LocalOffsets[Digit] + ThreadId * KEYS_PER_LOOP + KeyIndex - LocalTotals[Digit + DIGIT_SCAN_PADDING - 1]; OutValues[GlobalScatterIndex] = Keys[KeyIndex]; } // Release LocalOffsets so it can be updated. GroupMemoryBarrierWithGroupSync(); // Increment local offsets so the next tile has correct offsets. if ( ThreadId < DIGIT_COUNT ) { LocalOffsets[ThreadId] += DigitTotalThisTile; } GroupKeyBegin += TILE_SIZE; } // The last group has to process any keys after the last tile. if ( GroupId == (RadixSortUB.GroupCount - 1) && RadixSortUB.ExtraKeyCount > 0 ) { // Load keys. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { if ( ThreadId * KEYS_PER_LOOP + KeyIndex < RadixSortUB.ExtraKeyCount ) { Keys[KeyIndex] = InKeys[GroupKeyBegin + ThreadId * KEYS_PER_LOOP + KeyIndex]; } else { Keys[KeyIndex] = 0xFFFFFFFF; } } // Scan keys and compute offsets. GetDigits( Keys, Digits ); ScanKeys( ThreadId, Digits, BankToRakeOffset, RakingDigit, RakingIndex, DigitTotalThisTile ); ComputeOffsets( ThreadId, Digits, KeyOffsets ); // Scatter to local memory. ScatterLocal( ThreadId, Keys, KeyOffsets ); // Obtain coalesced keys. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { Keys[KeyIndex] = LocalScratch[ThreadId * KEYS_PER_LOOP + KeyIndex]; } GetDigits( Keys, Digits ); // Scatter keys to global memory. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { if ( ThreadId * KEYS_PER_LOOP + KeyIndex < RadixSortUB.ExtraKeyCount ) { const uint Digit = Digits[KeyIndex]; const uint GlobalScatterIndex = LocalOffsets[Digit] + ThreadId * KEYS_PER_LOOP + KeyIndex - LocalTotals[Digit + DIGIT_SCAN_PADDING - 1]; OutKeys[GlobalScatterIndex] = Keys[KeyIndex]; } } // Load values. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { if ( ThreadId * KEYS_PER_LOOP + KeyIndex < RadixSortUB.ExtraKeyCount ) { Keys[KeyIndex] = InValues[GroupKeyBegin + ThreadId * KEYS_PER_LOOP + KeyIndex]; } else { Keys[KeyIndex] = 0x0BADF00D; } } // Release LocalScratch so values can be scattered to local memory. GroupMemoryBarrierWithGroupSync(); // Scatter to local memory. ScatterLocal( ThreadId, Keys, KeyOffsets ); // Obtain coalesced values. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { Keys[KeyIndex] = LocalScratch[ThreadId * KEYS_PER_LOOP + KeyIndex]; } // Scatter values to global memory. [unroll] for ( KeyIndex = 0; KeyIndex < KEYS_PER_LOOP; ++KeyIndex ) { if ( ThreadId * KEYS_PER_LOOP + KeyIndex < RadixSortUB.ExtraKeyCount ) { const uint Digit = Digits[KeyIndex]; const uint GlobalScatterIndex = LocalOffsets[Digit] + ThreadId * KEYS_PER_LOOP + KeyIndex - LocalTotals[Digit + DIGIT_SCAN_PADDING - 1]; OutValues[GlobalScatterIndex] = Keys[KeyIndex]; } } } } #endif // #if RADIX_SORT_DOWNSWEEP