// Copyright Epic Games, Inc. All Rights Reserved. #include "D3D12WorkGraph.h" #include "Async/ParallelFor.h" #include "Containers/DynamicRHIResourceArray.h" #include "D3D12ExplicitDescriptorCache.h" #include "D3D12RHIPrivate.h" #include "D3D12ResourceCollection.h" #include "D3D12Shader.h" #include "D3D12TextureReference.h" #include "PipelineStateCache.h" #include "ShaderBundles.h" #include "RHIUniformBufferUtilities.h" static bool GShaderBundleSkipDispatch = false; static FAutoConsoleVariableRef CVarShaderBundleSkipDispatch( TEXT("wg.ShaderBundle.SkipDispatch"), GShaderBundleSkipDispatch, TEXT("Whether to dispatch the built shader bundle pipeline (for debugging)"), ECVF_RenderThreadSafe ); FD3D12WorkGraphPipelineState::FD3D12WorkGraphPipelineState(FD3D12Device* Device, const FWorkGraphPipelineStateInitializer& Initializer) : Device(Device) { #if D3D12_RHI_WORKGRAPHS RootArgStrideInBytes = 0; // Use global root signature from specified node, or if non-specified use a fixed root signature. ID3D12RootSignature* GlobalRootSignature = nullptr; const int32 RootShaderIndex = Initializer.GetRootShaderIndex(); if (RootShaderIndex != -1) { if (Initializer.GetShaderTable().IsValidIndex(RootShaderIndex)) { FRHIWorkGraphShader* Shader = Initializer.GetShaderTable()[RootShaderIndex]; if (Shader != nullptr && Shader->GetFrequency() == SF_WorkGraphRoot) { GlobalRootSignature = FD3D12DynamicRHI::ResourceCast(Shader)->RootSignature->GetRootSignature(); } } } if (GlobalRootSignature == nullptr) { FD3D12Adapter* Adapter = Device->GetParentAdapter(); FRHIShaderBindingLayout ShaderBindingLayout; // todo pass ShaderBindingLayout down GlobalRootSignature = Adapter->GetGlobalWorkGraphRootSignature(ShaderBindingLayout)->GetRootSignature(); } RootSignature = GlobalRootSignature; CD3DX12_STATE_OBJECT_DESC StateObjectDesc(D3D12_STATE_OBJECT_TYPE_EXECUTABLE); CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT* GlobalRootSignatureSubobject = StateObjectDesc.CreateSubobject(); GlobalRootSignatureSubobject->SetRootSignature(RootSignature); CD3DX12_WORK_GRAPH_SUBOBJECT* WorkGraphSubobject = StateObjectDesc.CreateSubobject(); TCHAR const* ProgramName = Initializer.GetProgramName().IsEmpty() ? TEXT("WorkGraphProgram") : *Initializer.GetProgramName(); WorkGraphSubobject->SetProgramName(ProgramName); D3D12_NODE_ID EntryPoint; { FD3D12WorkGraphShader* EntryShader = FD3D12DynamicRHI::ResourceCast(Initializer.GetShaderTable()[0]); EntryPoint.Name = *EntryShader->EntryPoint; EntryPoint.ArrayIndex = 0; } WorkGraphSubobject->AddEntrypoint(EntryPoint); // Compute Shader Table const int32 ShaderTableNum = Initializer.GetShaderTable().Num(); for (int32 Index = 0; Index < ShaderTableNum; ++Index) { FD3D12WorkGraphShader* NodeShader = FD3D12DynamicRHI::ResourceCast(Initializer.GetShaderTable()[Index]); static const FString EmptyString; FString const& ExportName = NodeShader == nullptr ? EmptyString : NodeShader->EntryPoint; const int32 NodeNameIndex = Initializer.GetNameTable().IndexOfByPredicate([ExportName](FWorkGraphPipelineStateInitializer::FNameMap const& NameMap) { return NameMap.ExportName == ExportName; }); if (NodeNameIndex != INDEX_NONE) { FString const& NodeName = Initializer.GetNameTable()[NodeNameIndex].NodeName; const uint32 NodeArrayIndex = (NodeCountPerName.FindOrAdd(NodeName))++; if (NodeShader != nullptr) { const FString NodePathName = FString::Printf(TEXT("%s_%d"), *NodeName, NodeArrayIndex); CD3DX12_DXIL_LIBRARY_SUBOBJECT* Lib = StateObjectDesc.CreateSubobject(); CD3DX12_SHADER_BYTECODE LibCode(NodeShader->Code.GetData(), NodeShader->Code.Num()); Lib->SetDXILLibrary(&LibCode); Lib->DefineExport(*NodePathName, *NodeShader->EntryPoint); CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT* LocalRootSignature = StateObjectDesc.CreateSubobject(); LocalRootSignature->SetRootSignature(NodeShader->RootSignature->GetRootSignature()); CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT* AssociationSubobject = StateObjectDesc.CreateSubobject(); AssociationSubobject->SetSubobjectToAssociate(*LocalRootSignature); AssociationSubobject->AddExport(*NodePathName); CD3DX12_COMMON_COMPUTE_NODE_OVERRIDES* Override = WorkGraphSubobject->CreateCommonComputeNodeOverrides(*NodePathName); Override->NewName(D3D12_NODE_ID{ *NodeName, NodeArrayIndex }); const int32 LocalRootArgumentsTableIndex = RootArgOffsets.Num(); RootArgOffsets.Add(LocalRootArgumentsTableIndex); Override->LocalRootArgumentsTableIndex(LocalRootArgumentsTableIndex); RootArgStrideInBytes = FMath::Max(RootArgStrideInBytes, NodeShader->RootSignature->GetTotalRootSignatureSizeInBytes()); } } } #if D3D12_RHI_WORKGRAPHS_GRAPHICS // Graphics Shader Table const int32 PSOTableNum = Initializer.GetGraphicsPSOTable().Num(); if (PSOTableNum > 0) { CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT* ConfigSubobject = StateObjectDesc.CreateSubobject(); ConfigSubobject->SetFlags(D3D12_STATE_OBJECT_FLAG_WORK_GRAPHS_USE_GRAPHICS_STATE_FOR_GLOBAL_ROOT_SIGNATURE); } for (int32 Index = 0; Index < PSOTableNum; ++Index) { FGraphicsPipelineStateInitializer const* NodePSO = Initializer.GetGraphicsPSOTable()[Index]; FRHIWorkGraphShader* MeshShader = NodePSO == nullptr ? nullptr : NodePSO->BoundShaderState.GetWorkGraphShader(); FRHIPixelShader* PixelShader = NodePSO == nullptr ? nullptr : NodePSO->BoundShaderState.GetPixelShader(); static const FString EmptyString; FString const& ExportName = MeshShader == nullptr ? EmptyString : FD3D12DynamicRHI::ResourceCast(MeshShader)->EntryPoint; const int32 NodeNameIndex = Initializer.GetNameTable().IndexOfByPredicate([ExportName](FWorkGraphPipelineStateInitializer::FNameMap const& NameMap) { return NameMap.ExportName == ExportName; }); if (NodeNameIndex == INDEX_NONE) { continue; } FString const& NodeName = Initializer.GetNameTable()[NodeNameIndex].NodeName; const uint32 NodeArrayIndex = (NodeCountPerName.FindOrAdd(NodeName))++; if (MeshShader == nullptr || PixelShader == nullptr) { continue; } const FString NodePathName = FString::Printf(TEXT("%s_%d"), *NodeName, NodeArrayIndex); FString MeshShaderName = FString::Printf(TEXT("MeshShader_%d"), NodeArrayIndex); FString PixelShaderName = FString::Printf(TEXT("PixelShader_%d"), NodeArrayIndex); FD3D12RootSignature const* LocalRootSignature = Device->GetParentAdapter()->GetWorkGraphGraphicsRootSignature(NodePSO->BoundShaderState); CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT* LocalRootSignatureSubobject = StateObjectDesc.CreateSubobject(); LocalRootSignatureSubobject->SetRootSignature(LocalRootSignature->GetRootSignature()); RootArgStrideInBytes = FMath::Max(RootArgStrideInBytes, LocalRootSignature->GetTotalRootSignatureSizeInBytes()); { FD3D12WorkGraphShader* D3D12Shader = FD3D12DynamicRHI::ResourceCast(MeshShader); CD3DX12_DXIL_LIBRARY_SUBOBJECT* LibrarySubobject = StateObjectDesc.CreateSubobject(); CD3DX12_SHADER_BYTECODE ByteCode(D3D12Shader->Code.GetData(), D3D12Shader->Code.Num()); LibrarySubobject->SetDXILLibrary(&ByteCode); LibrarySubobject->DefineExport(*MeshShaderName, *D3D12Shader->EntryPoint); CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT* ExportAssociationSubobject = StateObjectDesc.CreateSubobject(); ExportAssociationSubobject->SetSubobjectToAssociate(*LocalRootSignatureSubobject); ExportAssociationSubobject->AddExport(*MeshShaderName); } { FD3D12PixelShader* D3D12Shader = FD3D12DynamicRHI::ResourceCast(PixelShader); CD3DX12_DXIL_LIBRARY_SUBOBJECT* LibrarySubobject = StateObjectDesc.CreateSubobject(); CD3DX12_SHADER_BYTECODE ByteCode(D3D12Shader->Code.GetData(), D3D12Shader->Code.Num()); LibrarySubobject->SetDXILLibrary(&ByteCode); LibrarySubobject->DefineExport(*PixelShaderName, *D3D12Shader->EntryPoint); CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT* ExportAssociationSubobject = StateObjectDesc.CreateSubobject(); ExportAssociationSubobject->SetSubobjectToAssociate(*LocalRootSignatureSubobject); ExportAssociationSubobject->AddExport(*PixelShaderName); } TArray> StateSubobjects; CD3DX12_PRIMITIVE_TOPOLOGY_SUBOBJECT* PrimitiveTopologyState = StateObjectDesc.CreateSubobject(); const EPrimitiveType PrimitiveType = NodePSO->PrimitiveType; PrimitiveTopologyState->SetPrimitiveTopologyType(D3D12PrimitiveTypeToTopologyType(TranslatePrimitiveType(PrimitiveType))); FD3D12RasterizerState* D3D12RasterizerState = FD3D12DynamicRHI::ResourceCast(NodePSO->RasterizerState); CD3DX12_RASTERIZER_SUBOBJECT* RasterizerSubobject = StateObjectDesc.CreateSubobject(); RasterizerSubobject->SetFrontCounterClockwise(D3D12RasterizerState->Desc.FrontCounterClockwise); RasterizerSubobject->SetFillMode(D3D12RasterizerState->Desc.FillMode); RasterizerSubobject->SetCullMode(D3D12RasterizerState->Desc.CullMode); StateSubobjects.Add(RasterizerSubobject); FD3D12DepthStencilState* D3D12DepthStencilState = FD3D12DynamicRHI::ResourceCast(NodePSO->DepthStencilState); CD3DX12_DEPTH_STENCIL_SUBOBJECT* DepthStencilSubobject = StateObjectDesc.CreateSubobject(); DepthStencilSubobject->SetDepthEnable(D3D12DepthStencilState->Desc.DepthEnable); DepthStencilSubobject->SetDepthFunc(D3D12DepthStencilState->Desc.DepthFunc); DepthStencilSubobject->SetDepthWriteMask(D3D12DepthStencilState->Desc.DepthWriteMask); DepthStencilSubobject->SetStencilEnable(D3D12DepthStencilState->Desc.StencilEnable); DepthStencilSubobject->SetStencilReadMask(D3D12DepthStencilState->Desc.StencilReadMask); DepthStencilSubobject->SetStencilWriteMask(D3D12DepthStencilState->Desc.StencilWriteMask); StateSubobjects.Add(DepthStencilSubobject); if (NodePSO->DepthStencilTargetFormat != PF_Unknown) { CD3DX12_DEPTH_STENCIL_FORMAT_SUBOBJECT* DepthStencilFormatSubobject = StateObjectDesc.CreateSubobject(); DepthStencilFormatSubobject->SetDepthStencilFormat((DXGI_FORMAT)GPixelFormats[NodePSO->DepthStencilTargetFormat].PlatformFormat); StateSubobjects.Add(DepthStencilFormatSubobject); } CD3DX12_RENDER_TARGET_FORMATS_SUBOBJECT* RenderTargetFormatSubobject = StateObjectDesc.CreateSubobject(); const int32 NumRenderTargets = NodePSO->ComputeNumValidRenderTargets(); RenderTargetFormatSubobject->SetNumRenderTargets(NumRenderTargets); if (NumRenderTargets > 0) { for (int32 RenderTargetIndex = 0; RenderTargetIndex < NumRenderTargets; ++RenderTargetIndex) { RenderTargetFormatSubobject->SetRenderTargetFormat(RenderTargetIndex, (DXGI_FORMAT)GPixelFormats[NodePSO->RenderTargetFormats[RenderTargetIndex]].PlatformFormat); } } StateSubobjects.Add(RenderTargetFormatSubobject); CD3DX12_GENERIC_PROGRAM_SUBOBJECT* ProgramSubobject = StateObjectDesc.CreateSubobject(); ProgramSubobject->SetProgramName(*NodePathName); ProgramSubobject->AddExport(*MeshShaderName); ProgramSubobject->AddExport(*PixelShaderName); for (CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE* StateSubobject : StateSubobjects) { ProgramSubobject->AddSubobject(*StateSubobject); } ProgramSubobject->Finalize(); CD3DX12_MESH_LAUNCH_NODE_OVERRIDES* NodeOverrides = WorkGraphSubobject->CreateMeshLaunchNodeOverrides(*NodePathName); NodeOverrides->NewName(D3D12_NODE_ID{ *NodeName, NodeArrayIndex }); NodeOverrides->MaxInputRecordsPerGraphEntryRecord(1, false); const int32 LocalRootArgumentsTableIndex = RootArgOffsets.Num(); RootArgOffsets.Add(LocalRootArgumentsTableIndex); NodeOverrides->LocalRootArgumentsTableIndex(LocalRootArgumentsTableIndex); } #endif // D3D12_RHI_WORKGRAPHS_GRAPHICS RootArgStrideInBytes = ((RootArgStrideInBytes + 15) & ~15); MaxRootArgOffset = RootArgOffsets.Last(); WorkGraphSubobject->Finalize(); ID3D12Device9* Device9 = (ID3D12Device9*)Device->GetDevice(); HRESULT HResult = Device9->CreateStateObject(StateObjectDesc, IID_PPV_ARGS(StateObject.GetInitReference())); checkf(SUCCEEDED(HResult), TEXT("Failed to create work graph state object. Result=%08x"), uint32(HResult)); TRefCountPtr PipelineProperties; HResult = StateObject->QueryInterface(IID_PPV_ARGS(PipelineProperties.GetInitReference())); checkf(SUCCEEDED(HResult), TEXT("Failed to query pipeline properties from the work graph pipeline state object. Result=%08x"), uint32(HResult)); ProgramIdentifier = PipelineProperties->GetProgramIdentifier(ProgramName); #if D3D12_RHI_WORKGRAPHS_GRAPHICS TRefCountPtr WorkGraphProperties; #else TRefCountPtr WorkGraphProperties; #endif HResult = StateObject->QueryInterface(IID_PPV_ARGS(WorkGraphProperties.GetInitReference())); checkf(SUCCEEDED(HResult), TEXT("Failed to query work graph properties from the work graph pipeline state object. Result=%08x"), uint32(HResult)); UINT WorkGraphIndex = WorkGraphProperties->GetWorkGraphIndex(ProgramName); #if D3D12_RHI_WORKGRAPHS_GRAPHICS WorkGraphProperties->SetMaximumInputRecords(WorkGraphIndex, 1, 1); #endif D3D12_WORK_GRAPH_MEMORY_REQUIREMENTS MemoryRequirements = {}; WorkGraphProperties->GetWorkGraphMemoryRequirements(WorkGraphIndex, &MemoryRequirements); ID3D12Resource* BackingMemoryBufferResource = nullptr; { CD3DX12_RESOURCE_DESC BufferDesc = CD3DX12_RESOURCE_DESC::Buffer(MemoryRequirements.MaxSizeInBytes, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, 65536ull); CD3DX12_HEAP_PROPERTIES HeapProperties(D3D12_HEAP_TYPE_DEFAULT); HResult = Device->GetDevice()->CreateCommittedResource( //TODO: don't use raw device? &HeapProperties, D3D12_HEAP_FLAG_NONE, &BufferDesc, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, NULL, __uuidof(ID3D12Resource), (void**)&BackingMemoryBufferResource); checkf(SUCCEEDED(HResult), TEXT("Failed to allocate backing memory for work graph. Result=%08x"), uint32(HResult)); } BackingMemoryAddressRange.StartAddress = BackingMemoryBufferResource->GetGPUVirtualAddress(); BackingMemoryAddressRange.SizeInBytes = MemoryRequirements.MaxSizeInBytes; #if NV_AFTERMATH if (UE::RHICore::Nvidia::Aftermath::IsShaderRegistrationEnabled()) { // Copy shader table for late association for (FRHIWorkGraphShader* Shader : Initializer.GetShaderTable()) { Shaders.Add(TRefCountPtr(Shader)); } } #endif // NV_AFTERMATH #endif // D3D12_RHI_WORKGRAPHS } FWorkGraphPipelineStateRHIRef FD3D12DynamicRHI::RHICreateWorkGraphPipelineState(const FWorkGraphPipelineStateInitializer& Initializer) { FD3D12Device* Device = GetAdapter().GetDevice(0); // All pipelines are created on the first node, as they may be used on any other linked GPU. return new FD3D12WorkGraphPipelineState(Device, Initializer); } #if D3D12_RHI_WORKGRAPHS /** Struct to collect transitions for all shader bundle dispatches. */ struct FShaderBundleBinderOps { Experimental::TSherwoodSet TransitionViewSet; Experimental::TSherwoodSet TransitionClearSet; TArray TransitionSRVs; TArray TransitionUAVs; TArray ClearUAVs; inline void AddResourceTransition(FD3D12ShaderResourceView* SRV) { if (SRV->GetResource()->RequiresResourceStateTracking()) { bool bAlreadyInSet = false; TransitionViewSet.Add(SRV, &bAlreadyInSet); if (!bAlreadyInSet) { TransitionSRVs.Add(SRV); } } } inline void AddResourceTransition(FD3D12UnorderedAccessView* UAV) { if (UAV->GetResource()->RequiresResourceStateTracking()) { bool bAlreadyInSet = false; TransitionViewSet.Add(UAV, &bAlreadyInSet); if (!bAlreadyInSet) { TransitionUAVs.Add(UAV); } } } inline void AddResourceClear(FD3D12UnorderedAccessView* UAV) { bool bAlreadyInSet = false; TransitionClearSet.Add(UAV, &bAlreadyInSet); if (!bAlreadyInSet) { ClearUAVs.Add(UAV); } } }; /** Struct to collect shader bundle bindings. */ struct FWorkGraphShaderBundleBinder { FD3D12CommandContext& Context; FShaderBundleBinderOps& BinderOps; const uint32 GPUIndex; const EShaderFrequency Frequency; const bool bBindlessResources; const bool bBindlessSamplers; uint32 CBVVersions[MAX_CBS]; uint32 SRVVersions[MAX_SRVS]; uint32 UAVVersions[MAX_UAVS]; uint32 SamplerVersions[MAX_SAMPLERS]; uint64 BoundCBVMask = 0; uint64 BoundSRVMask = 0; uint64 BoundUAVMask = 0; uint64 BoundSamplerMask = 0; D3D12_CPU_DESCRIPTOR_HANDLE LocalCBVs[MAX_CBS]; D3D12_CPU_DESCRIPTOR_HANDLE LocalSRVs[MAX_SRVS]; D3D12_CPU_DESCRIPTOR_HANDLE LocalUAVs[MAX_UAVS]; D3D12_CPU_DESCRIPTOR_HANDLE LocalSamplers[MAX_SAMPLERS]; FWorkGraphShaderBundleBinder(FD3D12CommandContext& InContext, EShaderFrequency InFrequency, FShaderBundleBinderOps& InBinderOps, FD3D12ShaderData const* ShaderData) : Context(InContext) , BinderOps(InBinderOps) , GPUIndex(InContext.GetGPUIndex()) , Frequency(InFrequency) , bBindlessResources(ShaderData->UsesBindlessResources()) , bBindlessSamplers(ShaderData->UsesBindlessSamplers()) { } void SetUAV(FRHIUnorderedAccessView* InUnorderedAccessView, uint32 Index, bool bClearResources = false) { FD3D12UnorderedAccessView* UAV = FD3D12CommandContext::RetrieveObject(InUnorderedAccessView, GPUIndex); check(UAV != nullptr); if (bClearResources) { BinderOps.AddResourceClear(UAV); } if (!bBindlessResources) { FD3D12OfflineDescriptor Descriptor = UAV->GetOfflineCpuHandle(); LocalUAVs[Index] = Descriptor; UAVVersions[Index] = Descriptor.GetVersion(); BoundUAVMask |= 1ull << Index; } BinderOps.AddResourceTransition(UAV); } void SetSRV(FRHIShaderResourceView* InShaderResourceView, uint32 Index) { FD3D12ShaderResourceView_RHI* SRV = FD3D12CommandContext::RetrieveObject(InShaderResourceView, GPUIndex); check(SRV != nullptr); if (!bBindlessResources) { FD3D12OfflineDescriptor Descriptor = SRV->GetOfflineCpuHandle(); LocalSRVs[Index] = Descriptor; SRVVersions[Index] = Descriptor.GetVersion(); BoundSRVMask |= 1ull << Index; } BinderOps.AddResourceTransition(SRV); } void SetTexture(FRHITexture* InTexture, uint32 Index) { FD3D12ShaderResourceView* SRV = FD3D12CommandContext::RetrieveTexture(InTexture, GPUIndex)->GetShaderResourceView(); check(SRV != nullptr); if (!bBindlessResources) { FD3D12OfflineDescriptor Descriptor = SRV->GetOfflineCpuHandle(); LocalSRVs[Index] = Descriptor; SRVVersions[Index] = Descriptor.GetVersion(); BoundSRVMask |= 1ull << Index; } BinderOps.AddResourceTransition(SRV); } void SetSampler(FRHISamplerState* InSampler, uint32 Index) { FD3D12SamplerState* Sampler = FD3D12CommandContext::RetrieveObject(InSampler, GPUIndex); check(Sampler != nullptr); if (!bBindlessSamplers) { FD3D12OfflineDescriptor Descriptor = Sampler->OfflineDescriptor; LocalSamplers[Index] = Descriptor; SamplerVersions[Index] = Descriptor.GetVersion(); BoundSamplerMask |= 1ull << Index; } } void SetResourceCollection(FRHIResourceCollection* ResourceCollection, uint32 Index) { FD3D12ResourceCollection* D3D12ResourceCollection = FD3D12CommandContext::RetrieveObject(ResourceCollection, GPUIndex); FD3D12ShaderResourceView* SRV = D3D12ResourceCollection ? D3D12ResourceCollection->GetShaderResourceView() : nullptr; check(bBindlessResources); if (bBindlessResources) { Context.StateCache.QueueBindlessSRV(Frequency, SRV); } } }; /** Wrapper for constant buffer and it's underlying resource allocation. */ struct FAllocatedConstantBuffer { FAllocatedConstantBuffer(FD3D12CommandContext& Context) : ResourceLocation(Context.GetParentDevice()) { } FD3D12ConstantBuffer* ConstantBuffer = nullptr; FD3D12ResourceLocation ResourceLocation; }; // Record bindings from shader bundle parameters. static void RecordBindings( FD3D12CommandContext& Context, EShaderFrequency Frequency, FD3D12ExplicitDescriptorCache& TransientDescriptorCache, FShaderBundleBinderOps& BinderOps, uint32 WorkerIndex, FRHIShader* ShaderRHI, FD3D12ShaderData const* D3D12ShaderData, FRHIBatchedShaderParameters const& Parameters, FD3D12RootSignature const* LocalRootSignature, FAllocatedConstantBuffer const& SharedConstantBuffer, FUint32Vector4 const& Constants, TArrayView RootArgs ) { const uint32 NumSMPs = D3D12ShaderData->ResourceCounts.NumSamplers; const uint32 NumSRVs = D3D12ShaderData->ResourceCounts.NumSRVs; const uint32 NumCBVs = D3D12ShaderData->ResourceCounts.NumCBs; const uint32 NumUAVs = D3D12ShaderData->ResourceCounts.NumUAVs; // With shader root constants, we should never hit this expensive path! // If we hit this, check if the shaders in the bundle had loose // uniform parameters added to it recently, falling into this path. check(!D3D12ShaderData->UsesGlobalUniformBuffer() || SharedConstantBuffer.ConstantBuffer != nullptr); FWorkGraphShaderBundleBinder BundleBinder(Context, Frequency, BinderOps, D3D12ShaderData); FD3D12UniformBuffer* BundleUniformBuffers[MAX_CBS] = { nullptr }; uint32 UniformBufferMask = 0u; const bool bClearUAVResources = false; for (const FRHIShaderParameterResource& Parameter : Parameters.ResourceParameters) { switch (Parameter.Type) { case FRHIShaderParameterResource::EType::Texture: BundleBinder.SetTexture(static_cast(Parameter.Resource), Parameter.Index); break; case FRHIShaderParameterResource::EType::ResourceView: BundleBinder.SetSRV(static_cast(Parameter.Resource), Parameter.Index); break; case FRHIShaderParameterResource::EType::UnorderedAccessView: BundleBinder.SetUAV(static_cast(Parameter.Resource), Parameter.Index, bClearUAVResources); break; case FRHIShaderParameterResource::EType::Sampler: BundleBinder.SetSampler(static_cast(Parameter.Resource), Parameter.Index); break; case FRHIShaderParameterResource::EType::UniformBuffer: BundleUniformBuffers[Parameter.Index] = FD3D12CommandContext::RetrieveObject(Parameter.Resource, 0 /*GpuIndex*/); UniformBufferMask |= (1 << Parameter.Index); break; case FRHIShaderParameterResource::EType::ResourceCollection: BundleBinder.SetResourceCollection(static_cast(Parameter.Resource), Parameter.Index); break; default: checkf(false, TEXT("Unhandled resource type?")); break; } } UE::RHICore::ApplyStaticUniformBuffers(ShaderRHI, Context.GetStaticUniformBuffers(), [&](int32 BufferIndex, FRHIUniformBuffer* Buffer) { BundleUniformBuffers[BufferIndex] = Context.RetrieveObject(Buffer); }); uint32 FakeDirtyUniformBuffers = ~(0u); UE::RHI::Private::SetUniformBufferResourcesFromTables(BundleBinder, *ShaderRHI, FakeDirtyUniformBuffers, BundleUniformBuffers #if ENABLE_RHI_VALIDATION , Context.Tracker #endif ); if (SharedConstantBuffer.ConstantBuffer != nullptr) { check(BundleUniformBuffers[0] == nullptr); BundleBinder.BoundCBVMask |= 1ull << 0; } for (uint32 CBVIndex = 0; CBVIndex < MAX_CBS; ++CBVIndex) { FD3D12UniformBuffer* UniformBuffer = BundleUniformBuffers[CBVIndex]; if (UniformBuffer) { BundleBinder.BoundCBVMask |= 1ull << CBVIndex; } } // Validate that all resources required by the shader are set auto IsCompleteBinding = [](uint32 ExpectedCount, uint64 BoundMask) { if (ExpectedCount > 64) return false; // Bound resource mask can't be represented by uint64 // All bits of the mask [0..ExpectedCount) are expected to be set uint64 ExpectedMask = ExpectedCount == 64 ? ~0ull : ((1ull << ExpectedCount) - 1); return (ExpectedMask & BoundMask) == ExpectedMask; }; check(IsCompleteBinding(D3D12ShaderData->ResourceCounts.NumSRVs, BundleBinder.BoundSRVMask)); check(IsCompleteBinding(D3D12ShaderData->ResourceCounts.NumUAVs, BundleBinder.BoundUAVMask)); check(IsCompleteBinding(D3D12ShaderData->ResourceCounts.NumCBs, BundleBinder.BoundCBVMask)); check(IsCompleteBinding(D3D12ShaderData->ResourceCounts.NumSamplers, BundleBinder.BoundSamplerMask)); if (NumSRVs > 0) { const int32 DescriptorTableBaseIndex = TransientDescriptorCache.Allocate(BundleBinder.LocalSRVs, NumSRVs, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, WorkerIndex); check(DescriptorTableBaseIndex != INDEX_NONE); const uint32 BindSlot = LocalRootSignature->SRVRDTBindSlot(Frequency); const uint32 BindSlotOffset = LocalRootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; const D3D12_GPU_DESCRIPTOR_HANDLE ResourceDescriptorTableBaseGPU = TransientDescriptorCache.ViewHeap.GetDescriptorGPU(DescriptorTableBaseIndex); FMemory::Memcpy(&RootArgs[BindSlotOffset], &ResourceDescriptorTableBaseGPU, sizeof(ResourceDescriptorTableBaseGPU)); } if (NumSMPs > 0) { const int32 DescriptorTableBaseIndex = TransientDescriptorCache.AllocateDeduplicated(BundleBinder.SamplerVersions, BundleBinder.LocalSamplers, NumSMPs, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, WorkerIndex); check(DescriptorTableBaseIndex != INDEX_NONE); const uint32 BindSlot = LocalRootSignature->SamplerRDTBindSlot(Frequency); const uint32 BindSlotOffset = LocalRootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; const D3D12_GPU_DESCRIPTOR_HANDLE ResourceDescriptorTableBaseGPU = TransientDescriptorCache.SamplerHeap.GetDescriptorGPU(DescriptorTableBaseIndex); FMemory::Memcpy(&RootArgs[BindSlotOffset], &ResourceDescriptorTableBaseGPU, sizeof(ResourceDescriptorTableBaseGPU)); } if (NumUAVs > 0) { const int32 DescriptorTableBaseIndex = TransientDescriptorCache.AllocateDeduplicated(BundleBinder.UAVVersions, BundleBinder.LocalUAVs, NumUAVs, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, WorkerIndex); //const int32 DescriptorTableBaseIndex = TransientDescriptorCache.Allocate(BundleBinder.LocalUAVs, NumUAVs, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, WorkerIndex); check(DescriptorTableBaseIndex != INDEX_NONE); const uint32 BindSlot = LocalRootSignature->UAVRDTBindSlot(Frequency); const uint32 BindSlotOffset = LocalRootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; const D3D12_GPU_DESCRIPTOR_HANDLE ResourceDescriptorTableBaseGPU = TransientDescriptorCache.ViewHeap.GetDescriptorGPU(DescriptorTableBaseIndex); FMemory::Memcpy(&RootArgs[BindSlotOffset], &ResourceDescriptorTableBaseGPU, sizeof(ResourceDescriptorTableBaseGPU)); } if (SharedConstantBuffer.ConstantBuffer != nullptr) { const uint32 BindSlot = LocalRootSignature->CBVRDBindSlot(Frequency, 0); const uint32 BindSlotOffset = LocalRootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; D3D12_GPU_VIRTUAL_ADDRESS Address = SharedConstantBuffer.ResourceLocation.GetGPUVirtualAddress(); FMemory::Memcpy(&RootArgs[BindSlotOffset], &Address, sizeof(Address)); } for (uint32 CBVIndex = 0; CBVIndex < NumCBVs; ++CBVIndex) { const uint32 BindSlot = LocalRootSignature->CBVRDBindSlot(Frequency, CBVIndex); const uint32 BindSlotOffset = LocalRootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; FD3D12UniformBuffer* UniformBuffer = BundleUniformBuffers[CBVIndex]; if (UniformBuffer) { D3D12_GPU_VIRTUAL_ADDRESS Address = UniformBuffer->ResourceLocation.GetGPUVirtualAddress(); FMemory::Memcpy(&RootArgs[BindSlotOffset], &Address, sizeof(Address)); } } const int8 BindSlot = LocalRootSignature->GetRootConstantsSlot(); if (BindSlot != -1) { const uint32 BindSlotOffset = LocalRootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; RootArgs[BindSlotOffset] = Constants.X; RootArgs[BindSlotOffset + 1] = Constants.Y; RootArgs[BindSlotOffset + 2] = Constants.Z; RootArgs[BindSlotOffset + 3] = Constants.W; } } struct FD3D12BindlessConstantSetter { FD3D12CommandContext& Context; FD3D12ConstantBuffer& ConstantBuffer; const uint32 GpuIndex; const EShaderFrequency Frequency; FD3D12BindlessConstantSetter(FD3D12CommandContext& InContext, EShaderFrequency InFrequency) : Context(InContext) , ConstantBuffer(InContext.StageConstantBuffers[SF_Compute]) , GpuIndex(InContext.GetGPUIndex()) , Frequency(InFrequency) { } void SetBindlessHandle(const FRHIDescriptorHandle& Handle, uint32 Offset) { if (Handle.IsValid()) { const uint32 BindlessIndex = Handle.GetIndex(); ConstantBuffer.UpdateConstant(reinterpret_cast(&BindlessIndex), Offset, 4); } } void SetUAV(FD3D12UnorderedAccessView* D3D12UnorderedAccessView, uint32 Offset) { SetBindlessHandle(D3D12UnorderedAccessView->GetBindlessHandle(), Offset); Context.StateCache.QueueBindlessUAV(Frequency, D3D12UnorderedAccessView); } void SetUAV(FRHIUnorderedAccessView* InUnorderedAccessView, uint32 Offset) { FD3D12UnorderedAccessView_RHI* D3D12UnorderedAccessView = FD3D12CommandContext::RetrieveObject(InUnorderedAccessView, GpuIndex); SetUAV(static_cast(D3D12UnorderedAccessView), Offset); } void SetSRV(FD3D12ShaderResourceView* D3D12ShaderResourceView, uint32 Offset) { SetBindlessHandle(D3D12ShaderResourceView->GetBindlessHandle(), Offset); Context.StateCache.QueueBindlessSRV(Frequency, D3D12ShaderResourceView); } void SetSRV(FRHIShaderResourceView* InShaderResourceView, uint32 Offset) { FD3D12ShaderResourceView_RHI* D3D12ShaderResourceView = FD3D12CommandContext::RetrieveObject(InShaderResourceView, GpuIndex); SetSRV(static_cast(D3D12ShaderResourceView), Offset); } void SetTexture(FRHITexture* InTexture, uint32 Offset) { FD3D12Texture* D3D12Texture = FD3D12CommandContext::RetrieveTexture(InTexture, GpuIndex); FD3D12ShaderResourceView* D3D12ShaderResourceView = D3D12Texture ? D3D12Texture->GetShaderResourceView() : nullptr; SetBindlessHandle(InTexture->GetDefaultBindlessHandle(), Offset); Context.StateCache.QueueBindlessSRV(Frequency, D3D12ShaderResourceView); } void SetSampler(FRHISamplerState* InSampler, uint32 Offset) { FD3D12SamplerState* D3D12SamplerState = FD3D12CommandContext::RetrieveObject(InSampler, GpuIndex); SetBindlessHandle(D3D12SamplerState->GetBindlessHandle(), Offset); } void SetResourceCollection(FRHIResourceCollection* ResourceCollection, uint32 Offset) { if (FD3D12ResourceCollection* D3D12ResourceCollection = FD3D12CommandContext::RetrieveObject(ResourceCollection, GpuIndex)) { FD3D12ShaderResourceView* D3D12ShaderResourceView = D3D12ResourceCollection->GetShaderResourceView(); Context.StateCache.QueueBindlessSRV(Frequency, D3D12ShaderResourceView); Context.StateCache.QueueBindlessSRVs(Frequency, D3D12ResourceCollection->AllSrvs); // We have to go through each TextureReference to get the most recent version. for (FD3D12RHITextureReference* TextureReference : D3D12ResourceCollection->AllTextureReferences) { if (FD3D12Texture* Texture = GetD3D12TextureFromRHITexture(TextureReference)) { Context.StateCache.QueueBindlessSRV(Frequency, Texture->GetShaderResourceView()); } } } } [[nodiscard]] void Finalize(FAllocatedConstantBuffer& OutConstantBuffer) { OutConstantBuffer.ConstantBuffer = &ConstantBuffer; ConstantBuffer.Version(OutConstantBuffer.ResourceLocation, false); } }; void SetShaderBundleSharedBindlessConstants(FD3D12CommandContext& Context, TConstArrayView SharedBindlessParameters, FAllocatedConstantBuffer& OutConstantBuffer) { if (SharedBindlessParameters.Num()) { FD3D12BindlessConstantSetter Setter(Context, SF_Compute); for (const FRHIShaderParameterResource& Parameter : SharedBindlessParameters) { if (FRHIResource* Resource = Parameter.Resource) { switch (Parameter.Type) { case FRHIShaderParameterResource::EType::Texture: Setter.SetTexture(static_cast(Resource), Parameter.Index); break; case FRHIShaderParameterResource::EType::ResourceView: Setter.SetSRV(static_cast(Resource), Parameter.Index); break; case FRHIShaderParameterResource::EType::UnorderedAccessView: Setter.SetUAV(static_cast(Resource), Parameter.Index); break; case FRHIShaderParameterResource::EType::Sampler: break; case FRHIShaderParameterResource::EType::ResourceCollection: Setter.SetResourceCollection(static_cast(Resource), Parameter.Index); break; } } } Setter.Finalize(OutConstantBuffer); } } #endif // D3D12_RHI_WORKGRAPHS void FD3D12CommandContext::DispatchWorkGraphShaderBundle( FRHIShaderBundle* ShaderBundle, FRHIBuffer* RecordArgBuffer, TConstArrayView SharedBindlessParameters, TConstArrayView Dispatches) { #if D3D12_RHI_WORKGRAPHS TRHICommandList_RecursiveHazardous RHICmdList(this); FD3D12ShaderBundle* D3D12ShaderBundle = static_cast(FD3D12DynamicRHI::ResourceCast(ShaderBundle)); TShaderRef WorkGraphGlobalShader = GetGlobalShaderMap(GMaxRHIFeatureLevel)->GetShader(); FD3D12WorkGraphShader* D3D12EntryShader = FD3D12DynamicRHI::ResourceCast(WorkGraphGlobalShader.GetWorkGraphShader()); const bool bBindlessResources = D3D12EntryShader->UsesBindlessResources(); uint32 ViewDescriptorCount = D3D12EntryShader->ResourceCounts.NumSRVs + D3D12EntryShader->ResourceCounts.NumCBs + D3D12EntryShader->ResourceCounts.NumUAVs; uint32 SamplerDescriptorCount = D3D12EntryShader->ResourceCounts.NumSamplers; const int32 NumRecords = Dispatches.Num(); checkf(NumRecords <= FDispatchShaderBundleWorkGraph::GetMaxShaderBundleSize(), TEXT("Too many entries in a shader bundle (%d). Try increasing 'r.ShaderBundle.MaxSize'"), NumRecords); TArray ValidRecords; ValidRecords.Reserve(NumRecords); TArray LocalNodeShaders; LocalNodeShaders.Reserve(NumRecords + 1); LocalNodeShaders.Add(D3D12EntryShader); for (int32 DispatchIndex = 0; DispatchIndex < NumRecords; ++DispatchIndex) { const FRHIShaderBundleComputeDispatch& Dispatch = Dispatches[DispatchIndex]; FRHIWorkGraphShader* Shader = Dispatch.IsValid() ? Dispatch.WorkGraphShader : nullptr; LocalNodeShaders.Add(Shader); if (Shader != nullptr) { ValidRecords.Add(uint32(DispatchIndex)); if (FD3D12WorkGraphShader* D3D12Shader = FD3D12DynamicRHI::ResourceCast(Shader)) { ViewDescriptorCount += D3D12Shader->ResourceCounts.NumSRVs + D3D12Shader->ResourceCounts.NumCBs + D3D12Shader->ResourceCounts.NumUAVs; SamplerDescriptorCount += D3D12Shader->ResourceCounts.NumSamplers; } } } const int32 NumValidRecords = ValidRecords.Num(); FWorkGraphPipelineStateInitializer Initializer; Initializer.SetProgramName(TEXT("ShaderBundle")); TArray NameTable; NameTable.Add(FWorkGraphPipelineStateInitializer::FNameMap(TEXT("WorkGraphMainCS"), TEXT("WorkGraphMainCS"))); // Entry node. NameTable.Add(FWorkGraphPipelineStateInitializer::FNameMap(TEXT(""), TEXT("ShaderBundleNode"))); // Empty shader slots still increment bundle node index. NameTable.Add(FWorkGraphPipelineStateInitializer::FNameMap(TEXT("MainCS"), TEXT("ShaderBundleNode"))); // Nanite compute materials. NameTable.Add(FWorkGraphPipelineStateInitializer::FNameMap(TEXT("MicropolyRasterize"), TEXT("ShaderBundleNode"))); // Nanite software rasterize. Initializer.SetNameTable(NameTable); Initializer.SetShaderTable(LocalNodeShaders); FWorkGraphPipelineState* WorkGraphPipelineState = PipelineStateCache::GetAndOrCreateWorkGraphPipelineState(RHICmdList, Initializer); FD3D12WorkGraphPipelineState* Pipeline = static_cast(GetRHIWorkGraphPipelineState(WorkGraphPipelineState)); const uint32 NumViewDescriptors = ViewDescriptorCount; const uint32 NumSamplerDescriptors = SamplerDescriptorCount; const uint32 MaxWorkers = 4u; const uint32 NumWorkerThreads = FTaskGraphInterface::Get().GetNumWorkerThreads(); const uint32 MaxTasks = FApp::ShouldUseThreadingForPerformance() ? FMath::Min(NumWorkerThreads, MaxWorkers) : 1u; struct FTaskContext { uint32 WorkerIndex = 0; }; TArray> TaskContexts; for (uint32 WorkerIndex = 0; WorkerIndex < MaxTasks; ++WorkerIndex) { TaskContexts.Add(FTaskContext{ WorkerIndex }); } FD3D12ExplicitDescriptorCache TransientDescriptorCache(GetParentDevice(), MaxTasks /* Worker Count */); TransientDescriptorCache.Init(0, NumViewDescriptors, NumSamplerDescriptors, ERHIBindlessConfiguration::Minimal); TArray> BinderOps; BinderOps.SetNum(MaxTasks); TResourceArray LocalRootArgs; uint32 MinRootArgBufferSizeInDWords = (Pipeline->RootArgStrideInBytes / 4) * (Pipeline->MaxRootArgOffset + 1); LocalRootArgs.AddZeroed(MinRootArgBufferSizeInDWords); FAllocatedConstantBuffer SharedConstantBuffer(*this); SetShaderBundleSharedBindlessConstants(*this, SharedBindlessParameters, SharedConstantBuffer); auto RecordTask = [this, &LocalRootArgs, Pipeline, &TransientDescriptorCache, &ValidRecords, &Dispatches, &BinderOps, &SharedConstantBuffer](FTaskContext& Context, int32 RecordIndex) { const uint32 DispatchIndex = ValidRecords[RecordIndex]; const FRHIShaderBundleComputeDispatch& Dispatch = Dispatches[DispatchIndex]; check(Dispatch.IsValid()); const uint32 ShaderTableIndex = RecordIndex + 1; check(Pipeline->RootArgOffsets.IsValidIndex(ShaderTableIndex)); uint32 RootArgOffset = Pipeline->RootArgOffsets[ShaderTableIndex]; check((Pipeline->RootArgStrideInBytes / 4) * (RootArgOffset + 1) <= (uint32)LocalRootArgs.Num()); FD3D12WorkGraphShader* D3D12WorkGraphShader = FD3D12DynamicRHI::ResourceCast(Dispatch.WorkGraphShader); FD3D12RootSignature const* LocalRootSignature = D3D12WorkGraphShader->RootSignature; RecordBindings( *this, SF_Compute, TransientDescriptorCache, BinderOps[Context.WorkerIndex], Context.WorkerIndex, Dispatch.WorkGraphShader, D3D12WorkGraphShader, *Dispatch.Parameters, LocalRootSignature, SharedConstantBuffer, Dispatch.Constants, MakeArrayView(&LocalRootArgs[RootArgOffset * Pipeline->RootArgStrideInBytes / 4], Pipeline->RootArgStrideInBytes / 4) ); }; // One helper worker task will be created at most per this many work items, plus one worker for current thread (unless running on a task thread), // up to a hard maximum of MaxWorkers. // Internally, parallel for tasks still subdivide the work into smaller chunks and perform fine-grained load-balancing. const int32 ItemsPerTask = 1024; ParallelForWithExistingTaskContext(TEXT("DispatchShaderBundle"), MakeArrayView(TaskContexts), ValidRecords.Num(), ItemsPerTask, RecordTask); // Bind RecordArgBuffer FD3D12Buffer* RecordArgBufferPtr = FD3D12DynamicRHI::ResourceCast(RecordArgBuffer); D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = {}; SRVDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; SRVDesc.Format = DXGI_FORMAT_R32_TYPELESS; SRVDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; SRVDesc.Buffer.FirstElement = 0; SRVDesc.Buffer.NumElements = RecordArgBufferPtr->GetSize() >> 2u; SRVDesc.Buffer.StructureByteStride = 0; SRVDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_RAW; TSharedPtr RecordArgBufferSRV; RecordArgBufferSRV = MakeShared(GetParentDevice(), nullptr); // Always single GPU object, so FirstLinkedObject is nullptr RecordArgBufferSRV->CreateView(RecordArgBufferPtr, SRVDesc, FD3D12ShaderResourceView::EFlags::None); uint32 RecordArgBufferBindlessHandle = 0; if (bBindlessResources) { RecordArgBufferBindlessHandle = RecordArgBufferSRV->GetBindlessHandle().GetIndex(); check(RecordArgBufferBindlessHandle != INDEX_NONE); } else { D3D12_CPU_DESCRIPTOR_HANDLE LocalSRVs[MAX_SRVS]; LocalSRVs[WorkGraphGlobalShader->RecordArgBufferParam.GetBaseIndex()] = RecordArgBufferSRV->GetOfflineCpuHandle(); const int32 DescriptorTableBaseIndex = TransientDescriptorCache.Allocate(LocalSRVs, 1, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, 0); const D3D12_GPU_DESCRIPTOR_HANDLE ResourceDescriptorTableBaseGPU = TransientDescriptorCache.ViewHeap.GetDescriptorGPU(DescriptorTableBaseIndex); const uint32 BindSlot = D3D12EntryShader->RootSignature->SRVRDTBindSlot(SF_Compute); const uint32 BindSlotOffset = D3D12EntryShader->RootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; TArrayView RootArgSlice = MakeArrayView(&LocalRootArgs[Pipeline->RootArgOffsets[0] * Pipeline->RootArgStrideInBytes / 4], Pipeline->RootArgStrideInBytes / 4); FMemory::Memcpy(&RootArgSlice[BindSlotOffset], &ResourceDescriptorTableBaseGPU, sizeof(ResourceDescriptorTableBaseGPU)); } // Upload local root arguments table. D3D12_GPU_VIRTUAL_ADDRESS_RANGE_AND_STRIDE NodeLocalRootArgumentsTable{ 0, 0, 0 }; if (ValidRecords.Num() && LocalRootArgs.Num()) { const uint32 DataSize = LocalRootArgs.GetResourceDataSize(); // todo: Check if copy queue is the optimal way to upload the root args. // todo: Use a single buffer owned by the shader bundle RHI object (needs a copy operation that doesn't complain about multiple uploads). const D3D12_RESOURCE_DESC ResourceDesc = CD3DX12_RESOURCE_DESC::Buffer(DataSize, D3D12_RESOURCE_FLAG_NONE); const FRHIBufferCreateDesc CreateDesc = FRHIBufferCreateDesc::Create(TEXT("BundleRecordBuffer"), DataSize, 0, EBufferUsageFlags::Static) .SetInitialState(ERHIAccess::CopyDest) .SetGPUMask(FRHIGPUMask::FromIndex(GetParentDevice()->GetGPUIndex())); FD3D12Buffer* RootArgBuffer = GetParentDevice()->GetParentAdapter()->CreateRHIBuffer( ResourceDesc, 16, CreateDesc, ED3D12ResourceStateMode::MultiState, D3D12_RESOURCE_STATE_COPY_DEST, true ); BatchedSyncPoints.ToWait.Emplace(RootArgBuffer->UploadResourceDataViaCopyQueue(*this, &LocalRootArgs)); TransitionResource(RootArgBuffer->GetResource(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_COMMON, 0); NodeLocalRootArgumentsTable.StartAddress = RootArgBuffer->ResourceLocation.GetGPUVirtualAddress(); NodeLocalRootArgumentsTable.SizeInBytes = RootArgBuffer->ResourceLocation.GetSize(); NodeLocalRootArgumentsTable.StrideInBytes = Pipeline->RootArgStrideInBytes; } // Apply Binder Ops { for (int32 WorkerIndex = 1; WorkerIndex < BinderOps.Num(); ++WorkerIndex) { for (FD3D12ShaderResourceView* SRV : BinderOps[WorkerIndex].TransitionSRVs) { BinderOps[0].AddResourceTransition(SRV); } for (FD3D12UnorderedAccessView* UAV : BinderOps[WorkerIndex].TransitionUAVs) { BinderOps[0].AddResourceTransition(UAV); } BinderOps[WorkerIndex].TransitionSRVs.Empty(); BinderOps[WorkerIndex].TransitionUAVs.Empty(); BinderOps[WorkerIndex].TransitionViewSet.Empty(); BinderOps[WorkerIndex].ClearUAVs.Empty(); BinderOps[WorkerIndex].TransitionClearSet.Empty(); } for (FD3D12UnorderedAccessView* UAV : BinderOps[0].ClearUAVs) { ClearShaderResources(UAV, EShaderParameterTypeMask::SRVMask); } //TODO: check resource view // for (FD3D12ShaderResourceView* SRV : BinderOps[0].TransitionSRVs) // { // TransitionResource(SRV, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE); // } // // for (FD3D12UnorderedAccessView* UAV : BinderOps[0].TransitionUAVs) // { // TransitionResource(UAV, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); // } } FlushResourceBarriers(); // Apply the transient descriptor heaps. // Note this only uses transient heap for non-bindless. SetExplicitDescriptorCache(TransientDescriptorCache); if (bBindlessResources) { StateCache.ApplyBindlessResources(nullptr, SF_Compute, SF_NumStandardFrequencies); } GraphicsCommandList()->SetComputeRootSignature(Pipeline->RootSignature); // Kick off the work graph D3D12_SET_PROGRAM_DESC SetProgramDesc = {}; SetProgramDesc.Type = D3D12_PROGRAM_TYPE_WORK_GRAPH; SetProgramDesc.WorkGraph.ProgramIdentifier = Pipeline->ProgramIdentifier; SetProgramDesc.WorkGraph.Flags = !Pipeline->bInitialized ? D3D12_SET_WORK_GRAPH_FLAG_INITIALIZE : D3D12_SET_WORK_GRAPH_FLAG_NONE; SetProgramDesc.WorkGraph.BackingMemory = Pipeline->BackingMemoryAddressRange; SetProgramDesc.WorkGraph.NodeLocalRootArgumentsTable = NodeLocalRootArgumentsTable; GraphicsCommandList10()->SetProgram(&SetProgramDesc); FDispatchShaderBundleWorkGraph::FEntryNodeRecord InputRecord = FDispatchShaderBundleWorkGraph::MakeInputRecord(NumRecords, ShaderBundle->ArgOffset, ShaderBundle->ArgStride, RecordArgBufferBindlessHandle); if (!GShaderBundleSkipDispatch) { D3D12_DISPATCH_GRAPH_DESC DispatchGraphDesc = {}; DispatchGraphDesc.Mode = D3D12_DISPATCH_MODE_NODE_CPU_INPUT; DispatchGraphDesc.NodeCPUInput.EntrypointIndex = 0; DispatchGraphDesc.NodeCPUInput.NumRecords = 1; DispatchGraphDesc.NodeCPUInput.RecordStrideInBytes = sizeof(InputRecord); DispatchGraphDesc.NodeCPUInput.pRecords = &InputRecord; GraphicsCommandList10()->DispatchGraph(&DispatchGraphDesc); } // Pipeline state memory should now be initialized. Pipeline->bInitialized = true; // Restore global descriptor heaps if necessary. UnsetExplicitDescriptorCache(); // We did not write through the state cache, so we need to invalidate it so subsequent workloads correctly re-bind state StateCache.DirtyState(); ConditionalSplitCommandList(); #endif // D3D12_RHI_WORKGRAPHS } void FD3D12CommandContext::DispatchWorkGraphShaderBundle( FRHIShaderBundle* ShaderBundle, FRHIBuffer* RecordArgBuffer, const FRHIShaderBundleGraphicsState& BundleState, TConstArrayView SharedBindlessParameters, TConstArrayView Dispatches) { #if D3D12_RHI_WORKGRAPHS TRHICommandList_RecursiveHazardous RHICmdList(this); FD3D12ShaderBundle* D3D12ShaderBundle = static_cast(FD3D12DynamicRHI::ResourceCast(ShaderBundle)); TShaderRef WorkGraphGlobalShader = GetGlobalShaderMap(GMaxRHIFeatureLevel)->GetShader(); FD3D12WorkGraphShader* D3D12EntryShader = FD3D12DynamicRHI::ResourceCast(WorkGraphGlobalShader.GetWorkGraphShader()); const bool bBindlessResources = D3D12EntryShader->UsesBindlessResources(); uint32 ViewDescriptorCount = D3D12EntryShader->ResourceCounts.NumSRVs + D3D12EntryShader->ResourceCounts.NumCBs + D3D12EntryShader->ResourceCounts.NumUAVs; uint32 SamplerDescriptorCount = D3D12EntryShader->ResourceCounts.NumSamplers; TArray LocalNodeShaders; LocalNodeShaders.Add(D3D12EntryShader); const int32 NumRecords = Dispatches.Num(); checkf(NumRecords <= FDispatchShaderBundleWorkGraph::GetMaxShaderBundleSize(), TEXT("Too many entries in a shader bundle (%d). Try increasing 'r.ShaderBundle.MaxSize'"), NumRecords); TArray ValidRecords; ValidRecords.Reserve(NumRecords); TArray LocalPSOs; LocalPSOs.Reserve(NumRecords); for (int32 DispatchIndex = 0; DispatchIndex < NumRecords; ++DispatchIndex) { const FRHIShaderBundleGraphicsDispatch& Dispatch = Dispatches[DispatchIndex]; FGraphicsPipelineStateInitializer const* PSO = Dispatch.IsValid() ? &Dispatch.PipelineInitializer : nullptr; LocalPSOs.Add(PSO); if (PSO != nullptr) { FRHIWorkGraphShader* MeshShader = PSO->BoundShaderState.GetWorkGraphShader(); FRHIPixelShader* PixelShader = PSO->BoundShaderState.GetPixelShader(); if (MeshShader != nullptr && PixelShader != nullptr) { ValidRecords.Add(uint32(DispatchIndex)); if (FD3D12WorkGraphShader* D3D12Shader = FD3D12DynamicRHI::ResourceCast(MeshShader)) { ViewDescriptorCount += D3D12Shader->ResourceCounts.NumSRVs + D3D12Shader->ResourceCounts.NumCBs + D3D12Shader->ResourceCounts.NumUAVs; SamplerDescriptorCount += D3D12Shader->ResourceCounts.NumSamplers; } if (FD3D12PixelShader* D3D12Shader = FD3D12DynamicRHI::ResourceCast(PixelShader)) { ViewDescriptorCount += D3D12Shader->ResourceCounts.NumSRVs + D3D12Shader->ResourceCounts.NumCBs + D3D12Shader->ResourceCounts.NumUAVs; SamplerDescriptorCount += D3D12Shader->ResourceCounts.NumSamplers; } } } } const int32 NumValidRecords = ValidRecords.Num(); FWorkGraphPipelineStateInitializer Initializer; Initializer.SetProgramName(TEXT("ShaderBundle")); TArray NameTable; NameTable.Add(FWorkGraphPipelineStateInitializer::FNameMap(TEXT("WorkGraphMainCS"), TEXT("WorkGraphMainCS"))); // Entry node. NameTable.Add(FWorkGraphPipelineStateInitializer::FNameMap(TEXT(""), TEXT("ShaderBundleNode"))); // Empty shader slots still increment bundle node index. NameTable.Add(FWorkGraphPipelineStateInitializer::FNameMap(TEXT("HWRasterizeMS"), TEXT("ShaderBundleNode"))); // Nanite software rasterize. Initializer.SetNameTable(NameTable); Initializer.SetShaderTable(LocalNodeShaders); Initializer.SetGraphicsPSOTable(LocalPSOs); FWorkGraphPipelineState* WorkGraphPipelineState = PipelineStateCache::GetAndOrCreateWorkGraphPipelineState(RHICmdList, Initializer); FD3D12WorkGraphPipelineState* Pipeline = static_cast(GetRHIWorkGraphPipelineState(WorkGraphPipelineState)); const uint32 NumViewDescriptors = ViewDescriptorCount; const uint32 NumSamplerDescriptors = SamplerDescriptorCount; const uint32 MaxWorkers = 4u; const uint32 NumWorkerThreads = FTaskGraphInterface::Get().GetNumWorkerThreads(); const uint32 MaxTasks = FApp::ShouldUseThreadingForPerformance() ? FMath::Min(NumWorkerThreads, MaxWorkers) : 1u; struct FTaskContext { uint32 WorkerIndex = 0; }; TArray> TaskContexts; for (uint32 WorkerIndex = 0; WorkerIndex < MaxTasks; ++WorkerIndex) { TaskContexts.Add(FTaskContext{ WorkerIndex }); } FD3D12ExplicitDescriptorCache TransientDescriptorCache(GetParentDevice(), MaxTasks /* Worker Count */); TransientDescriptorCache.Init(0, NumViewDescriptors, NumSamplerDescriptors, ERHIBindlessConfiguration::Minimal); TArray> BinderOps; BinderOps.SetNum(MaxTasks); TResourceArray LocalRootArgs; uint32 MinRootArgBufferSizeInDWords = (Pipeline->RootArgStrideInBytes / 4) * (Pipeline->MaxRootArgOffset + 1); LocalRootArgs.AddZeroed(MinRootArgBufferSizeInDWords); FAllocatedConstantBuffer SharedConstantBuffer(*this); SetShaderBundleSharedBindlessConstants(*this, SharedBindlessParameters, SharedConstantBuffer); auto RecordTask = [this, &LocalRootArgs, Pipeline, &TransientDescriptorCache, &ValidRecords, &Dispatches, &BinderOps, &SharedConstantBuffer](FTaskContext& Context, int32 RecordIndex) { uint32 DispatchIndex = ValidRecords[RecordIndex]; const FRHIShaderBundleGraphicsDispatch& Dispatch = Dispatches[DispatchIndex]; check(Dispatch.IsValid()); const uint32 ShaderTableIndex = RecordIndex + 1; check(Pipeline->RootArgOffsets.IsValidIndex(ShaderTableIndex)); const uint32 RootArgOffset = Pipeline->RootArgOffsets[ShaderTableIndex]; check((Pipeline->RootArgStrideInBytes / 4) * (RootArgOffset + 1) <= (uint32)LocalRootArgs.Num()); FRHIWorkGraphShader* MeshShader = Dispatch.PipelineInitializer.BoundShaderState.GetWorkGraphShader(); FD3D12WorkGraphShader* D3D12MeshShader = FD3D12DynamicRHI::ResourceCast(MeshShader); FRHIPixelShader* PixelShader = Dispatch.PipelineInitializer.BoundShaderState.GetPixelShader(); FD3D12PixelShader* D3D12PixelShader = FD3D12DynamicRHI::ResourceCast(PixelShader); FD3D12RootSignature const* LocalRootSignature = GetParentAdapter()->GetWorkGraphGraphicsRootSignature(Dispatch.PipelineInitializer.BoundShaderState); RecordBindings( *this, SF_Pixel, TransientDescriptorCache, BinderOps[Context.WorkerIndex], Context.WorkerIndex, PixelShader, D3D12PixelShader, *Dispatch.Parameters_PS, LocalRootSignature, SharedConstantBuffer, Dispatch.Constants, MakeArrayView(&LocalRootArgs[RootArgOffset * Pipeline->RootArgStrideInBytes / 4], Pipeline->RootArgStrideInBytes / 4) ); RecordBindings( *this, SF_Mesh, TransientDescriptorCache, BinderOps[Context.WorkerIndex], Context.WorkerIndex, MeshShader, D3D12MeshShader, *Dispatch.Parameters_MSVS, LocalRootSignature, SharedConstantBuffer, Dispatch.Constants, MakeArrayView(&LocalRootArgs[RootArgOffset * Pipeline->RootArgStrideInBytes / 4], Pipeline->RootArgStrideInBytes / 4) ); }; const int32 ItemsPerTask = 1024; ParallelForWithExistingTaskContext(TEXT("DispatchShaderBundle"), MakeArrayView(TaskContexts), ValidRecords.Num(), ItemsPerTask, RecordTask); // Apply Binder Ops { for (int32 WorkerIndex = 1; WorkerIndex < BinderOps.Num(); ++WorkerIndex) { for (FD3D12ShaderResourceView* SRV : BinderOps[WorkerIndex].TransitionSRVs) { BinderOps[0].AddResourceTransition(SRV); } for (FD3D12UnorderedAccessView* UAV : BinderOps[WorkerIndex].TransitionUAVs) { BinderOps[0].AddResourceTransition(UAV); } BinderOps[WorkerIndex].TransitionSRVs.Empty(); BinderOps[WorkerIndex].TransitionUAVs.Empty(); BinderOps[WorkerIndex].TransitionViewSet.Empty(); BinderOps[WorkerIndex].ClearUAVs.Empty(); BinderOps[WorkerIndex].TransitionClearSet.Empty(); } for (FD3D12UnorderedAccessView* UAV : BinderOps[0].ClearUAVs) { ClearShaderResources(UAV, EShaderParameterTypeMask::SRVMask); } //TODO: check resource view // for (FD3D12ShaderResourceView* SRV : BinderOps[0].TransitionSRVs) // { // TransitionResource(SRV, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE); // } // // for (FD3D12UnorderedAccessView* UAV : BinderOps[0].TransitionUAVs) // { // TransitionResource(UAV, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); // } } FlushResourceBarriers(); // Create SRV for RecordArgsBuffer FD3D12Buffer* RecordArgBufferPtr = FD3D12DynamicRHI::ResourceCast(RecordArgBuffer); D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = {}; SRVDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; SRVDesc.Format = DXGI_FORMAT_R32_TYPELESS; SRVDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; SRVDesc.Buffer.FirstElement = 0; SRVDesc.Buffer.NumElements = RecordArgBufferPtr->GetSize() >> 2u; SRVDesc.Buffer.StructureByteStride = 0; SRVDesc.Buffer.Flags = D3D12_BUFFER_SRV_FLAG_RAW; TSharedPtr RecordArgBufferSRV; RecordArgBufferSRV = MakeShared(GetParentDevice(), nullptr); // Always single GPU object, so FirstLinkedObject is nullptr RecordArgBufferSRV->CreateView(RecordArgBufferPtr, SRVDesc, FD3D12ShaderResourceView::EFlags::None); // Gather root arguments for shader bundle entry node. uint32 RecordArgBufferBindlessHandle = 0; uint32 BindSlot = 0; D3D12_GPU_DESCRIPTOR_HANDLE ResourceDescriptorTableBaseGPU = {}; if (bBindlessResources) { RecordArgBufferBindlessHandle = RecordArgBufferSRV->GetBindlessHandle().GetIndex(); check(RecordArgBufferBindlessHandle != INDEX_NONE); StateCache.ApplyBindlessResources(nullptr, SF_Compute, SF_NumStandardFrequencies); } else { D3D12_CPU_DESCRIPTOR_HANDLE LocalSRVs[MAX_SRVS]; LocalSRVs[WorkGraphGlobalShader->RecordArgBufferParam.GetBaseIndex()] = RecordArgBufferSRV->GetOfflineCpuHandle(); const int32 DescriptorTableBaseIndex = TransientDescriptorCache.Allocate(LocalSRVs, 1, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, 0); ResourceDescriptorTableBaseGPU = TransientDescriptorCache.ViewHeap.GetDescriptorGPU(DescriptorTableBaseIndex); BindSlot = D3D12EntryShader->RootSignature->SRVRDTBindSlot(SF_Compute); const uint32 BindSlotOffset = D3D12EntryShader->RootSignature->GetBindSlotOffsetInBytes(BindSlot) / 4; TArrayView RootArgSlice = MakeArrayView(&LocalRootArgs[Pipeline->RootArgOffsets[0] * Pipeline->RootArgStrideInBytes / 4], Pipeline->RootArgStrideInBytes / 4); FMemory::Memcpy(&RootArgSlice[BindSlotOffset], &ResourceDescriptorTableBaseGPU, sizeof(ResourceDescriptorTableBaseGPU)); } // Upload local root arguments table. D3D12_GPU_VIRTUAL_ADDRESS_RANGE_AND_STRIDE NodeLocalRootArgumentsTable{ 0, 0, 0 }; if (LocalRootArgs.Num()) { const uint32 DataSize = LocalRootArgs.GetResourceDataSize(); // todo: Check if copy queue is the optimal way to upload the root args. // todo: Use a single buffer owned by the shader bundle RHI object (needs a copy operation that doesn't complain about multiple uploads). const D3D12_RESOURCE_DESC ResourceDesc = CD3DX12_RESOURCE_DESC::Buffer(DataSize, D3D12_RESOURCE_FLAG_NONE); const FRHIBufferCreateDesc CreateDesc = FRHIBufferCreateDesc::Create(TEXT("BundleRecordBuffer"), DataSize, 0, EBufferUsageFlags::Static) .SetInitialState(ERHIAccess::CopyDest) .SetGPUMask(FRHIGPUMask::FromIndex(GetParentDevice()->GetGPUIndex())); FD3D12Buffer* RootArgBuffer = GetParentDevice()->GetParentAdapter()->CreateRHIBuffer( ResourceDesc, 16, CreateDesc, ED3D12ResourceStateMode::MultiState, D3D12_RESOURCE_STATE_COPY_DEST, true ); BatchedSyncPoints.ToWait.Emplace(RootArgBuffer->UploadResourceDataViaCopyQueue(*this, &LocalRootArgs)); TransitionResource(RootArgBuffer->GetResource(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_COMMON, 0); NodeLocalRootArgumentsTable.StartAddress = RootArgBuffer->ResourceLocation.GetGPUVirtualAddress(); NodeLocalRootArgumentsTable.SizeInBytes = RootArgBuffer->ResourceLocation.GetSize(); NodeLocalRootArgumentsTable.StrideInBytes = Pipeline->RootArgStrideInBytes; } // Apply the transient descriptor heaps. // Note this only uses transient heap for non-bindless. SetExplicitDescriptorCache(TransientDescriptorCache); // Set graphics state GraphicsCommandList()->SetGraphicsRootSignature(Pipeline->RootSignature); if (!bBindlessResources) { GraphicsCommandList()->SetGraphicsRootDescriptorTable(BindSlot, ResourceDescriptorTableBaseGPU); } GraphicsCommandList()->IASetVertexBuffers(0, D3D12_IA_VERTEX_INPUT_RESOURCE_SLOT_COUNT, nullptr); GraphicsCommandList()->IASetIndexBuffer(nullptr); GraphicsCommandList()->OMSetRenderTargets(0, nullptr, 0, nullptr); D3D12_VIEWPORT Viewport{}; { Viewport.TopLeftX = BundleState.ViewRect.Min.X; Viewport.TopLeftY = BundleState.ViewRect.Min.Y; Viewport.Width = BundleState.ViewRect.Width(); Viewport.Height = BundleState.ViewRect.Height(); Viewport.MinDepth = BundleState.DepthMin; Viewport.MaxDepth = BundleState.DepthMax; } GraphicsCommandList()->RSSetViewports(1, &Viewport); D3D12_RECT Rect{}; { Rect.left = BundleState.ViewRect.Min.X; Rect.top = BundleState.ViewRect.Min.Y; Rect.right = BundleState.ViewRect.Max.X; Rect.bottom = BundleState.ViewRect.Max.Y; } GraphicsCommandList()->RSSetScissorRects(1, &Rect); const D3D_PRIMITIVE_TOPOLOGY PrimitiveTopology = TranslatePrimitiveType(BundleState.PrimitiveType); GraphicsCommandList()->IASetPrimitiveTopology(PrimitiveTopology); GraphicsCommandList()->OMSetStencilRef(BundleState.StencilRef); GraphicsCommandList()->OMSetBlendFactor(BundleState.BlendFactor); // GraphicsCommandList()->OMSetDepthBounds(BundleState.DepthMin, BundleState.DepthMax); // Kick off the work graph D3D12_SET_PROGRAM_DESC SetProgramDesc = {}; SetProgramDesc.Type = D3D12_PROGRAM_TYPE_WORK_GRAPH; SetProgramDesc.WorkGraph.ProgramIdentifier = Pipeline->ProgramIdentifier; SetProgramDesc.WorkGraph.Flags = D3D12_SET_WORK_GRAPH_FLAG_INITIALIZE; SetProgramDesc.WorkGraph.BackingMemory = Pipeline->BackingMemoryAddressRange; SetProgramDesc.WorkGraph.NodeLocalRootArgumentsTable = NodeLocalRootArgumentsTable; GraphicsCommandList10()->SetProgram(&SetProgramDesc); Pipeline->FrameCounter.Set(GetFrameFenceCounter()); const uint32 RecordCount = Dispatches.Num(); FDispatchShaderBundleWorkGraph::FEntryNodeRecord InputRecord = FDispatchShaderBundleWorkGraph::MakeInputRecord(RecordCount, ShaderBundle->ArgOffset, ShaderBundle->ArgStride, RecordArgBufferBindlessHandle); if (!GShaderBundleSkipDispatch) { D3D12_DISPATCH_GRAPH_DESC DispatchGraphDesc = {}; DispatchGraphDesc.Mode = D3D12_DISPATCH_MODE_NODE_CPU_INPUT; DispatchGraphDesc.NodeCPUInput.EntrypointIndex = 0; DispatchGraphDesc.NodeCPUInput.NumRecords = 1; DispatchGraphDesc.NodeCPUInput.RecordStrideInBytes = sizeof(InputRecord); DispatchGraphDesc.NodeCPUInput.pRecords = &InputRecord; GraphicsCommandList10()->DispatchGraph(&DispatchGraphDesc); } // Restore global descriptor heaps if necessary. UnsetExplicitDescriptorCache(); // We did not write through the state cache, so we need to invalidate it so subsequent workloads correctly re-bind state StateCache.DirtyState(); ConditionalSplitCommandList(); #endif // D3D12_RHI_WORKGRAPHS }