Files
UnrealEngine/Engine/Shaders/Private/WorkDistribution.ush
2025-05-18 13:04:45 +08:00

579 lines
15 KiB
HLSL

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "/Engine/Public/Platform.ush"
#include "WaveOpUtil.ush"
/*
Provides functions to distribute uneven amounts of work uniformly across a wave.
Work won't be distributed wider than the same wave.
The following must be defined:
void DoWork( FWorkContext Context, FWorkSourceType WorkSource, uint LocalItemIndex );
*/
#ifdef GENERATE_WORK
/*
This version can continuously generate work using the function:
uint GenerateWork( FWorkContext Context, uint GroupIndex, inout FWorkSourceType WorkSource, inout bool bDone )
{
Set WorkSource if there is a valid source of work.
if( No more work left from this thread )
{
bDone = true;
}
return NumWorkItems;
}
Once it has a full wave worth of work it consumes it.
*/
groupshared FWorkSourceType WorkQueueSource[ THREADGROUP_SIZE * 2 ];
groupshared uint WorkQueueAccum[ THREADGROUP_SIZE * 2 ];
groupshared uint WorkBoundary[ THREADGROUP_SIZE ];
void DistributeWork( FWorkContext Context, uint GroupIndex )
{
const uint LaneCount = WaveGetLaneCount();
const uint LaneIndex = GroupIndex & ( LaneCount - 1 );
const uint QueueOffset = GroupIndex & ~( LaneCount - 1 );
const uint QueueSize = LaneCount * 2;
const uint QueueMask = QueueSize - 1;
#define QUEUE_INDEX(i) ( QueueOffset*2 + ( (i) & QueueMask ) )
bool bDone = false;
int WorkRead = 0;
int WorkWrite = 0;
int SourceRead = 0;
int SourceWrite = 0;
WorkQueueAccum[ QueueOffset*2 + QueueMask ] = 0;
while( true )
{
// Need to queue more work to fill wave?
while( WorkWrite - WorkRead < LaneCount && WaveActiveAnyTrue( !bDone ) )
{
FWorkSourceType WorkSource;
// Generate work and record the source.
// When sources run out set bDone = true.
uint NumWorkItems = GenerateWork( Context, GroupIndex, WorkSource, bDone );
// Queue work
uint FirstWorkItem = WorkWrite + WavePrefixSum( NumWorkItems );
uint WorkAccum = FirstWorkItem + NumWorkItems; // Could use Inclusive sum instead.
WorkWrite = WaveReadLaneAt( WorkAccum, LaneCount - 1 );
bool bHasWork = NumWorkItems != 0;
uint QueueIndex = SourceWrite + WavePrefixCountBits( bHasWork );
if( bHasWork )
{
WorkQueueSource[ QUEUE_INDEX( QueueIndex ) ] = WorkSource;
WorkQueueAccum[ QUEUE_INDEX( QueueIndex ) ] = WorkAccum;
}
SourceWrite += WaveActiveCountBits( bHasWork );
}
// Any work left?
if( WorkWrite == WorkRead )
break;
// TODO read and write bytes instead (ds_write_b8, ds_read_u8_d16)
WorkBoundary[ GroupIndex ] = 0;
GroupMemoryBarrier();
if( SourceRead + LaneIndex < SourceWrite )
{
// Mark the last work item of each source
uint LastItemIndex = WorkQueueAccum[ QUEUE_INDEX( SourceRead + LaneIndex ) ] - WorkRead - 1;
if( LastItemIndex < LaneCount )
WorkBoundary[ QueueOffset + LastItemIndex ] = 1;
}
GroupMemoryBarrier();
bool bIsBoundary = WorkBoundary[ GroupIndex ];
uint QueueIndex = SourceRead + WavePrefixCountBits( bIsBoundary );
// Distribute work
if( WorkRead + LaneIndex < WorkWrite )
{
uint FirstWorkItem = WorkQueueAccum[ QUEUE_INDEX( QueueIndex - 1 ) ];
uint LocalItemIndex = WorkRead + LaneIndex - FirstWorkItem;
FWorkSourceType WorkSource = WorkQueueSource[ QUEUE_INDEX( QueueIndex ) ];
DoWork( Context, WorkSource, LocalItemIndex );
}
// Did 1 wave of work
WorkRead = min( WorkRead + LaneCount, WorkWrite );
SourceRead += WaveActiveCountBits( bIsBoundary );
}
#undef QUEUE_INDEX
}
#elif 0
bool WaveFlagLaneAt( uint DstIndex, uint SrcIndex )
{
const uint LaneCount = WaveGetLaneCount();
uint DstMask = 1 << ( DstIndex & 31 );
uint SrcMask = 1 << ( SrcIndex & 31 );
DstMask = DstIndex < LaneCount ? DstMask : 0;
if( LaneCount > 32 )
{
bool bDstLow = DstIndex < 32;
bool bSrcLow = SrcIndex < 32;
uint2 WaveBits = 0;
WaveBits.x = WaveActiveBitOr( bDstLow ? DstMask : 0 );
WaveBits.y = WaveActiveBitOr( bDstLow ? 0 : DstMask );
return SrcMask & ( bSrcLow ? WaveBits.x : WaveBits.y );
}
else
{
return WaveActiveBitOr( DstMask ) & SrcMask;
}
}
// Simpler version where threads can only generate work once.
// This is done before calling DistributeWork so a GenerateWork function doesn't need to be defined.
groupshared FWorkSourceType WorkQueueSource[ THREADGROUP_SIZE ];
groupshared uint WorkQueueAccum[ THREADGROUP_SIZE ];
groupshared uint WorkBoundary[ THREADGROUP_SIZE ];
void DistributeWork( FWorkContext Context, uint GroupIndex, FWorkSourceType WorkSource, uint NumWorkItems )
{
const uint LaneCount = WaveGetLaneCount();
const uint LaneIndex = GroupIndex & ( LaneCount - 1 );
const uint QueueOffset = GroupIndex & ~( LaneCount - 1 );
int WorkRead = 0;
int WorkWrite = 0;
int SourceRead = 0;
uint WorkAccum = 0;
if( WaveActiveAnyTrue( NumWorkItems != 0 ) )
{
// Queue work
uint FirstWorkItem = WavePrefixSum( NumWorkItems );
WorkAccum = FirstWorkItem + NumWorkItems; // Could use Inclusive sum instead.
WorkWrite = WaveReadLaneAt( WorkAccum, LaneCount - 1 );
bool bHasWork = NumWorkItems != 0;
uint QueueIndex = WavePrefixCountBits( bHasWork );
if( bHasWork )
{
WorkQueueSource[ QueueOffset + QueueIndex ] = WorkSource;
WorkQueueAccum[ QueueOffset + QueueIndex ] = WorkAccum;
}
}
// Pull work from queue
while( WorkRead < WorkWrite )
{
// TODO read and write bytes instead (ds_write_b8, ds_read_u8_d16)
WorkBoundary[ GroupIndex ] = 0;
GroupMemoryBarrier();
// Mark the last work item of each source
uint LastItemIndex = WorkAccum - WorkRead - 1;
if( LastItemIndex < LaneCount )
WorkBoundary[ QueueOffset + LastItemIndex ] = 1;
GroupMemoryBarrier();
bool bIsBoundary = WorkBoundary[ GroupIndex ];
uint QueueIndex = SourceRead + WavePrefixCountBits( bIsBoundary );
if( WorkRead + LaneIndex < WorkWrite )
{
uint FirstWorkItem = QueueIndex > 0 ? WorkQueueAccum[ QueueOffset + QueueIndex - 1 ] : 0;
uint LocalItemIndex = WorkRead + LaneIndex - FirstWorkItem;
FWorkSourceType WorkSource = WorkQueueSource[ QueueOffset + QueueIndex ];
DoWork( Context, WorkSource, LocalItemIndex );
}
// Did 1 wave of work
WorkRead += LaneCount;
SourceRead += WaveActiveCountBits( bIsBoundary );
}
}
#else
groupshared uint WorkBoundary[ THREADGROUP_SIZE ];
template< typename FTask >
void DistributeWork( FTask Task, uint GroupIndex, uint NumWorkItems )
{
const uint LaneCount = WaveGetLaneCount();
const uint LaneIndex = GroupIndex & ( LaneCount - 1 );
const uint QueueOffset = GroupIndex & ~( LaneCount - 1 );
int WorkHead = 0;
int WorkTail = 0;
int SourceHead = 0;
uint WorkSource = LaneIndex;
uint WorkAccum = 0;
if( WaveActiveAnyTrue( NumWorkItems != 0 ) )
{
// Queue work
uint FirstWorkItem = WavePrefixSum( NumWorkItems );
WorkAccum = FirstWorkItem + NumWorkItems; // Could use Inclusive sum instead.
WorkTail = WaveReadLaneAt( WorkAccum, LaneCount - 1 );
bool bHasWork = NumWorkItems != 0;
uint QueueIndex = WavePrefixCountBits( bHasWork );
// Compact
if( WaveActiveAnyTrue( NumWorkItems == 0 ) ) // Might know this is impossible
{
// Compact LaneIndex
#if 0//COMPILER_SUPPORTS_WAVE_PERMUTE
QueueIndex = bHasWork ? QueueIndex : LaneCount - 1;
WorkSource = WaveWriteLaneAt( QueueIndex, LaneIndex );
#else
if( bHasWork )
WorkBoundary[ QueueOffset + QueueIndex ] = LaneIndex;
GroupMemoryBarrier();
WorkSource = WorkBoundary[ GroupIndex ];
#endif
WorkAccum = WaveReadLaneAt( WorkAccum, WorkSource );
// Push invalid lanes off the end to prevent writes to WorkBoundary and bank conflicts.
if( LaneIndex >= WaveActiveCountBits( bHasWork ) )
WorkAccum = WorkTail + LaneCount;
}
}
// Pull work from queue
while( WorkHead < WorkTail )
{
// TODO read and write bytes instead (ds_write_b8, ds_read_u8_d16)
WorkBoundary[ GroupIndex ] = 0;
GroupMemoryBarrier();
// Mark the last work item of each source
uint LastItemIndex = WorkAccum - WorkHead - 1;
if( LastItemIndex < LaneCount )
WorkBoundary[ QueueOffset + LastItemIndex ] = 1;
GroupMemoryBarrier();
bool bIsBoundary = WorkBoundary[ GroupIndex ];
uint QueueIndex = SourceHead + WavePrefixCountBits( bIsBoundary );
uint SourceIndex = WaveReadLaneAt( WorkSource, QueueIndex );
uint FirstWorkItem = select( QueueIndex > 0, WaveReadLaneAt( WorkAccum, QueueIndex - 1 ), 0 );
uint LocalItemIndex = WorkHead + LaneIndex - FirstWorkItem;
FTask ChildTask = Task.CreateChild( SourceIndex );
bool bActive = ( WorkHead + LaneIndex < WorkTail );
ChildTask.RunChild( Task, bActive, LocalItemIndex );
// Did 1 wave of work
WorkHead += LaneCount;
SourceHead += WaveActiveCountBits( bIsBoundary );
}
}
#endif
#if 1
struct FWorkQueueState
{
uint ReadOffset;
uint WriteOffset;
int TaskCount; // Can temporarily be conservatively higher
};
struct FOutputQueue
{
RWByteAddressBuffer DataBuffer;
RWStructuredBuffer< FWorkQueueState > StateBuffer; // Ideally this was GDS but we don't have that API control.
uint StateIndex;
uint Size;
uint Add()
{
uint WriteOffset;
WaveInterlockedAddScalar_( StateBuffer[ StateIndex ].WriteOffset, 1, WriteOffset );
// TODO Copy WriteOffset to TaskCount
WaveInterlockedAddScalar( StateBuffer[ StateIndex ].TaskCount, 1 );
return WriteOffset;
}
uint DataBuffer_Load(uint Address)
{
return DataBuffer.Load(Address);
}
uint4 DataBuffer_Load4(uint Address)
{
return DataBuffer.Load4(Address);
}
void DataBuffer_Store4(uint Address, uint4 Values)
{
DataBuffer.Store4(Address, Values);
}
FWorkQueueState GetState(uint Index)
{
return StateBuffer[Index];
}
};
struct FInputQueue
{
ByteAddressBuffer DataBuffer;
RWStructuredBuffer< FWorkQueueState > StateBuffer; // Ideally this was GDS but we don't have that API control.
uint StateIndex;
uint Size;
uint Remove()
{
uint ReadOffset;
WaveInterlockedAddScalar_( StateBuffer[ StateIndex ].ReadOffset, 1, ReadOffset );
return ReadOffset;
}
uint Num()
{
return StateBuffer[ StateIndex ].WriteOffset;
}
};
struct FGlobalWorkQueue
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentByteAddressBufferRef DataBuffer_Private;
RWCoherentStructuredBufferRef( FWorkQueueState ) StateBuffer_Private; // Ideally this was GDS but we don't have that API control.
RWCoherentByteAddressBuffer GetDataBuffer() { return (RWCoherentByteAddressBuffer)DataBuffer_Private; }
RWCoherentStructuredBuffer( FWorkQueueState ) GetStateBuffer() { return (RWCoherentStructuredBuffer( FWorkQueueState ))StateBuffer_Private; }
#else
RWCoherentByteAddressBufferRef DataBuffer;
RWCoherentStructuredBufferRef( FWorkQueueState ) StateBuffer; // Ideally this was GDS but we don't have that API control.
#endif
uint StateIndex;
uint Size;
uint Add()
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentStructuredBuffer(FWorkQueueState) StateBuffer = GetStateBuffer();
#endif
uint WriteCount = WaveActiveCountBits( true );
uint WriteOffset = 0;
if( WaveIsFirstLane() )
{
InterlockedAdd( StateBuffer[ StateIndex ].WriteOffset, WriteCount, WriteOffset );
InterlockedAdd( StateBuffer[ StateIndex ].TaskCount, (int)WriteCount );
}
return WaveReadLaneFirst( WriteOffset ) + WavePrefixCountBits( true );
}
uint Remove()
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentStructuredBuffer(FWorkQueueState) StateBuffer = GetStateBuffer();
#endif
uint ReadOffset;
WaveInterlockedAddScalar_( StateBuffer[ StateIndex ].ReadOffset, 1, ReadOffset );
return ReadOffset;
}
// Only call after current task has completely finished adding work!
void ReleaseTask()
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentStructuredBuffer(FWorkQueueState) StateBuffer = GetStateBuffer();
#endif
WaveInterlockedAddScalar( StateBuffer[ StateIndex ].TaskCount, -1 );
}
bool IsEmpty()
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentStructuredBuffer(FWorkQueueState) StateBuffer = GetStateBuffer();
#endif
#if 1
return StateBuffer[ StateIndex ].TaskCount == 0;
#else
uint Count = 0;
if( WaveIsFirstLane() )
{
InterlockedAdd( StateBuffer[ StateIndex ].TaskCount, 0, Count );
}
return WaveReadLaneFirst( Count ) == 0;
#endif
}
uint DataBuffer_Load(uint Address)
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentByteAddressBuffer DataBuffer = GetDataBuffer();
#endif
return DataBuffer.Load(Address);
}
uint4 DataBuffer_Load4(uint Address)
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentByteAddressBuffer DataBuffer = GetDataBuffer();
#endif
return DataBuffer.Load4(Address);
}
void DataBuffer_Store4(uint Address, uint4 Values)
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentByteAddressBuffer DataBuffer = GetDataBuffer();
#endif
DataBuffer.Store4(Address, Values);
}
FWorkQueueState GetState(uint Index)
{
#if COMPILER_NEEDS_GLOBALLYCOHERENT_LOCALS
RWCoherentStructuredBuffer(FWorkQueueState) StateBuffer = GetStateBuffer();
#endif
return StateBuffer[Index];
}
};
template< typename FTask >
void GlobalTaskLoop( FGlobalWorkQueue GlobalWorkQueue )
{
bool bTaskComplete = true;
uint TaskReadOffset = 0;
while( true )
{
if( WaveActiveAllTrue( bTaskComplete ) )
{
TaskReadOffset = GlobalWorkQueue.Remove();
bTaskComplete = TaskReadOffset >= GlobalWorkQueue.Size;
if( WaveActiveAllTrue( bTaskComplete ) )
break;
}
FTask Task;
bool bTaskReady = false;
if( !bTaskComplete )
{
bTaskReady = Task.Load( GlobalWorkQueue, TaskReadOffset );
}
if( WaveActiveAnyTrue( bTaskReady ) )
{
if( bTaskReady )
{
Task.Run( GlobalWorkQueue );
// Clear processed element so we leave the buffer cleared for next pass.
Task.Clear( GlobalWorkQueue, TaskReadOffset );
bTaskComplete = true;
}
}
else
{
if( GlobalWorkQueue.IsEmpty() )
break;
else
{
DeviceMemoryBarrier();
ShaderYield();
}
}
}
}
template< typename FTask >
void GlobalTaskLoopVariable( FGlobalWorkQueue GlobalWorkQueue, uint GroupIndex )
{
bool bTaskComplete = true;
uint TaskReadOffset = 0;
while( true )
{
if( WaveActiveAllTrue( bTaskComplete ) )
{
TaskReadOffset = GlobalWorkQueue.Remove();
bTaskComplete = TaskReadOffset >= GlobalWorkQueue.Size;
if( WaveActiveAllTrue( bTaskComplete ) )
break;
}
FTask Task = (FTask)0;
bool bTaskReady = false;
if( !bTaskComplete )
{
bTaskReady = Task.Load( GlobalWorkQueue, TaskReadOffset );
}
if( WaveActiveAnyTrue( bTaskReady ) )
{
uint NumChildren = 0;
if( bTaskReady )
{
NumChildren = Task.Run();
}
DistributeWork( Task, GroupIndex, NumChildren );
if( bTaskReady )
{
// Clear processed element so we leave the buffer cleared for next pass.
Task.Clear( GlobalWorkQueue, TaskReadOffset );
bTaskComplete = true;
}
}
else
{
if( GlobalWorkQueue.IsEmpty() )
break;
else
{
DeviceMemoryBarrier();
ShaderYield();
}
}
}
}
#endif