263 lines
10 KiB
C++
263 lines
10 KiB
C++
//-------------------------------------------------------------------------------------------------------------------------------------------------------------
|
|
//
|
|
// Copyright 2023-2024 Apple Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
//-------------------------------------------------------------------------------------------------------------------------------------------------------------
|
|
|
|
#pragma once
|
|
|
|
#ifndef __METAL_VERSION__
|
|
#ifdef __cplusplus
|
|
#include <cstdint>
|
|
extern "C" {
|
|
#else
|
|
#include <stdint.h>
|
|
#endif // __cplusplus
|
|
#else
|
|
#include <metal_stdlib>
|
|
using metal::visible_function_table;
|
|
using metal::MTLDispatchThreadgroupsIndirectArguments;
|
|
#endif // __METAL_VERSION__
|
|
|
|
|
|
#ifdef __METAL_VERSION__
|
|
#define IR_CONSTANT_PTR(ptr) constant ptr*
|
|
#define IR_DEVICE_PTR(ptr) device ptr*
|
|
#else
|
|
#define IR_CONSTANT_PTR(ptr) uint64_t
|
|
#define IR_DEVICE_PTR(ptr) uint64_t
|
|
#endif // __METAL_VERSION__
|
|
|
|
typedef struct IRShaderIdentifier
|
|
{
|
|
// For HitGroups, index into visible function table containing a converted
|
|
// intersection function.
|
|
uint64_t intersectionShaderHandle;
|
|
// For ray generation, miss, callable shaders, index into visible function
|
|
// table containing the translated function. For HitGroups, index to the
|
|
// converted closest-hit shader.
|
|
uint64_t shaderHandle;
|
|
// GPU address to a buffer containing static samplers for shader records
|
|
uint64_t localRootSignatureSamplersBuffer;
|
|
// Unused
|
|
uint64_t pad0;
|
|
|
|
#if !defined(__METAL_VERSION__) && (__cplusplus)
|
|
IRShaderIdentifier()
|
|
: intersectionShaderHandle(0)
|
|
, shaderHandle(0)
|
|
, localRootSignatureSamplersBuffer(0)
|
|
, pad0(0)
|
|
{
|
|
}
|
|
#endif
|
|
} IRShaderIdentifier;
|
|
|
|
typedef struct IRVirtualAddressRange
|
|
{
|
|
IR_CONSTANT_PTR(IRShaderIdentifier) StartAddress;
|
|
uint64_t SizeInBytes;
|
|
} IRVirtualAddressRange;
|
|
|
|
typedef struct IRVirtualAddressRangeAndStride
|
|
{
|
|
IR_CONSTANT_PTR(IRShaderIdentifier) StartAddress;
|
|
uint64_t SizeInBytes;
|
|
uint64_t StrideInBytes;
|
|
} IRVirtualAddressRangeAndStride;
|
|
|
|
typedef struct IRDispatchRaysDescriptor
|
|
{
|
|
IRVirtualAddressRange RayGenerationShaderRecord;
|
|
IRVirtualAddressRangeAndStride MissShaderTable;
|
|
IRVirtualAddressRangeAndStride HitGroupTable;
|
|
IRVirtualAddressRangeAndStride CallableShaderTable;
|
|
uint Width;
|
|
uint Height;
|
|
uint Depth;
|
|
} IRDispatchRaysDescriptor;
|
|
|
|
#ifdef __METAL_VERSION__
|
|
struct IRDispatchRaysArgument;
|
|
struct top_level_global_ab;
|
|
using top_level_local_ab = uint8_t;
|
|
struct res_desc_heap_ab;
|
|
struct smp_desc_heap_ab;
|
|
using RaygenFunctionType = void(constant top_level_global_ab*, constant top_level_local_ab*, constant res_desc_heap_ab*, constant smp_desc_heap_ab*, constant IRDispatchRaysArgument*, uint3);
|
|
#define RaygenFunctionPointerTable metal::visible_function_table<RaygenFunctionType>
|
|
#define IFT metal::raytracing::intersection_function_table<>
|
|
#define MSLAccelerationStructure metal::raytracing::instance_acceleration_structure
|
|
#else
|
|
#define RaygenFunctionPointerTable resourceid_t
|
|
#define IFT resourceid_t
|
|
#define MSLAccelerationStructure uint64_t
|
|
#endif
|
|
|
|
typedef struct IRDispatchRaysArgument
|
|
{
|
|
IRDispatchRaysDescriptor DispatchRaysDesc;
|
|
IR_CONSTANT_PTR(top_level_global_ab) GRS;
|
|
IR_CONSTANT_PTR(res_desc_heap_ab) ResDescHeap;
|
|
IR_CONSTANT_PTR(smp_desc_heap_ab) SmpDescHeap;
|
|
RaygenFunctionPointerTable VisibleFunctionTable;
|
|
IFT IntersectionFunctionTable;
|
|
IR_CONSTANT_PTR(IFT) IntersectionFunctionTables;
|
|
} IRDispatchRaysArgument;
|
|
|
|
#ifdef IR_RUNTIME_METALCPP
|
|
typedef MTL::DispatchThreadgroupsIndirectArguments dispatchthreadgroupsindirectargs_t;
|
|
#else
|
|
typedef MTLDispatchThreadgroupsIndirectArguments dispatchthreadgroupsindirectargs_t;
|
|
#endif // IR_RUNTIME_METAL_CPP
|
|
|
|
typedef struct IRRaytracingAccelerationStructureGPUHeader
|
|
{
|
|
MSLAccelerationStructure accelerationStructureID;
|
|
IR_DEVICE_PTR(uint32_t) addressOfInstanceContributions;
|
|
uint64_t pad0[4];
|
|
dispatchthreadgroupsindirectargs_t pad1;
|
|
} IRRaytracingAccelerationStructureGPUHeader;
|
|
|
|
typedef struct IRRaytracingInstanceDescriptor
|
|
{
|
|
float Transform[3][4];
|
|
uint32_t InstanceID : 24;
|
|
uint32_t InstanceMask : 8;
|
|
uint32_t InstanceContributionToHitGroupIndex : 24;
|
|
uint32_t Flags : 8;
|
|
#ifndef __METAL_VERSION__
|
|
uint64_t AccelerationStructure;
|
|
#else
|
|
metal::raytracing::instance_acceleration_structure AccelerationStructure;
|
|
#endif // __METAL_VERSION__
|
|
} IRRaytracingInstanceDescriptor;
|
|
|
|
#ifdef __METAL_VERSION__
|
|
void IRRaytracingUpdateInstanceContributions(IRRaytracingAccelerationStructureGPUHeader header,
|
|
device IRRaytracingInstanceDescriptor* instanceDescriptor,
|
|
uint32_t index);
|
|
|
|
#ifdef IR_PRIVATE_IMPLEMENTATION
|
|
void IRRaytracingUpdateInstanceContributions(IRRaytracingAccelerationStructureGPUHeader header,
|
|
device IRRaytracingInstanceDescriptor* instanceDescriptor,
|
|
uint32_t index)
|
|
{
|
|
header.addressOfInstanceContributions[index] = instanceDescriptor[index].InstanceContributionToHitGroupIndex;
|
|
}
|
|
#endif // IR_PRIVATE_IMPLEMENTATION
|
|
#endif // __METAL_VERSION__
|
|
|
|
#ifndef __METAL_VERSION__
|
|
|
|
extern const char* kIRRayDispatchIndirectionKernelName;
|
|
extern const uint64_t kIRRayDispatchArgumentsBindPoint;
|
|
|
|
/**
|
|
* Encode an acceleration structure into the argument buffer.
|
|
* @param entry the pointer to the descriptor table entry to encode the acceleration structure reference into.
|
|
* @param gpu_va the GPU address of the acceleration structure to encode.
|
|
**/
|
|
void IRDescriptorTableSetAccelerationStructure(IRDescriptorTableEntry* entry, uint64_t gpu_va);
|
|
|
|
|
|
/**
|
|
* Encode an instance acceleration structure into a buffer and instance contributions into a separate one.
|
|
* @param headerBuffer pointer to an address where to encode the acceleration structure.
|
|
* @param accelerationStructure resource ID of the instance acceleration structure to encode.
|
|
* @param instanceContributionArrayBuffer pointer to an address where to encode the acceleration structure instance contributions.
|
|
* @param instanceContributions array of instance contributions to hit group index.
|
|
* @param instanceCount number of elements in the instanceContributions array.
|
|
*/
|
|
void IRRaytracingSetAccelerationStructure(uint8_t* headerBuffer,
|
|
resourceid_t accelerationStructure,
|
|
uint8_t* instanceContributionArrayBuffer,
|
|
const uint32_t* instanceContributions,
|
|
uinteger_t instanceCount) IR_OVERLOADABLE;
|
|
|
|
/**
|
|
* Initialize a shader identifier to reference a ray generation, closest-hit, any-hit, miss, or callable shader without a
|
|
* custom intersection function.
|
|
* @param identifier shader identifier to initialize.
|
|
* @param shaderHandle shader handle, corresponding to the index into a visible function table of converted functions.
|
|
*/
|
|
void IRShaderIdentifierInit(IRShaderIdentifier* identifier, uint64_t shaderHandle) IR_OVERLOADABLE;
|
|
|
|
/**
|
|
* Initialize a shader identifier for a HitGroup, providing the closest-hit shader handle and a custom intersection shader handle.
|
|
* @param identifier shader identifier to initialize.
|
|
* @param shaderHandle handle to closest-hit shader, corresponding to the index into a visible function table of converted functions.
|
|
* @param intersectionShaderHandle handle to a custom any-hit, intersection, or combined any-hit and intersection function, corresponding to the index into a visible function table of converted functions.
|
|
*/
|
|
void IRShaderIdentifierInitWithCustomIntersection(IRShaderIdentifier* identifier, uint64_t shaderHandle, uint64_t intersectionShaderHandle) IR_OVERLOADABLE;
|
|
|
|
#ifdef IR_PRIVATE_IMPLEMENTATION
|
|
|
|
const char* kIRRayDispatchIndirectionKernelName = "RaygenIndirection";
|
|
const uint64_t kIRRayDispatchArgumentsBindPoint = 3;
|
|
|
|
IR_INLINE
|
|
void IRShaderIdentifierInit(IRShaderIdentifier* identifier, uint64_t shaderHandle) IR_OVERLOADABLE
|
|
{
|
|
memset(identifier, 0x0, sizeof(IRShaderIdentifier));
|
|
identifier->shaderHandle = shaderHandle;
|
|
}
|
|
|
|
IR_INLINE
|
|
void IRShaderIdentifierInitWithCustomIntersection(IRShaderIdentifier* identifier, uint64_t shaderHandle, uint64_t intersectionShaderHandle) IR_OVERLOADABLE
|
|
{
|
|
memset(identifier, 0x0, sizeof(IRShaderIdentifier));
|
|
identifier->intersectionShaderHandle = intersectionShaderHandle;
|
|
identifier->shaderHandle = shaderHandle;
|
|
}
|
|
|
|
IR_INLINE
|
|
void IRDescriptorTableSetAccelerationStructure(IRDescriptorTableEntry* entry, uint64_t gpu_va)
|
|
{
|
|
entry->gpuVA = gpu_va;
|
|
entry->textureViewID = 0;
|
|
entry->metadata = 0;
|
|
}
|
|
|
|
IR_INLINE
|
|
void IRRaytracingSetAccelerationStructure(uint8_t* headerBuffer,
|
|
resourceid_t accelerationStructure,
|
|
uint8_t* instanceContributionArrayBuffer,
|
|
const uint32_t* instanceContributions,
|
|
uinteger_t instanceCount) IR_OVERLOADABLE
|
|
{
|
|
IRRaytracingAccelerationStructureGPUHeader* header = (IRRaytracingAccelerationStructureGPUHeader*)headerBuffer;
|
|
header->accelerationStructureID = accelerationStructure._impl;
|
|
header->addressOfInstanceContributions = (uint64_t)instanceContributionArrayBuffer;
|
|
uint32_t* bufferInstanceContributions = (uint32_t*)instanceContributionArrayBuffer;
|
|
for (uinteger_t i = 0; i < instanceCount; ++i)
|
|
{
|
|
bufferInstanceContributions[i] = instanceContributions[i];
|
|
}
|
|
}
|
|
|
|
#endif // IR_PRIVATE_IMPLEMENTATION
|
|
|
|
#endif // __METAL_VERSION__
|
|
|
|
#ifndef __METAL_VERSION__
|
|
#ifdef __cplusplus
|
|
}
|
|
#endif //__cplusplus
|
|
#endif // __METAL_VERSION
|
|
|
|
#undef RaygenFunctionPointerTable
|
|
#undef IFT
|
|
#undef IR_CONSTANT_PTR
|