60 lines
2.1 KiB
C++
60 lines
2.1 KiB
C++
// Copyright Epic Games, Inc. All Rights Reserved.
|
|
|
|
#include "NNEHlslShadersElementWiseBinaryCS.h"
|
|
#include "NNEHlslShadersLog.h"
|
|
|
|
namespace UE::NNEHlslShaders::Internal
|
|
{
|
|
void TElementWiseBinaryCS::ModifyCompilationEnvironment(const FGlobalShaderPermutationParameters& InParameters, FShaderCompilerEnvironment& OutEnvironment)
|
|
{
|
|
FGlobalShader::ModifyCompilationEnvironment(InParameters, OutEnvironment);
|
|
|
|
OutEnvironment.SetDefine(TEXT("THREADGROUP_SIZE_X"), FElementWiseBinaryConstants::NUM_GROUP_THREADS);
|
|
|
|
FPermutationDomain PermutationVector(InParameters.PermutationId);
|
|
|
|
const FString OpFunc = GetOpFunc(PermutationVector.Get<FOperatorType>());
|
|
|
|
OutEnvironment.SetDefine(TEXT("ELEMENTWISE_OP(X,Y)"), *OpFunc);
|
|
}
|
|
|
|
const FString TElementWiseBinaryCS::GetOpFunc(EElementWiseBinaryOperatorType OpType)
|
|
{
|
|
FString OpTable[(int32) EElementWiseBinaryOperatorType::MAX];
|
|
|
|
for (int32 Idx = 0; Idx < (int32) EElementWiseBinaryOperatorType::MAX; ++Idx)
|
|
{
|
|
OpTable[Idx] = FString("");
|
|
}
|
|
|
|
#define OP(OpName, OpFunc) OpTable[(int32) EElementWiseBinaryOperatorType::OpName] = OpFunc
|
|
OP(Add, TEXT("((X)+(Y))"));
|
|
//OP(And, TEXT("((X)&&(Y))"));
|
|
OP(Div, TEXT("((X)/(Y))"));
|
|
//OP(Equal, TEXT("((X)==(Y))"));
|
|
//OP(Greater, TEXT("((X)>(Y))"));
|
|
//OP(GreaterOrEqual, TEXT("((X)>=(Y))"));
|
|
//OP(Less, TEXT("((X)<(Y))"));
|
|
//OP(LessOrEqual, TEXT("((X)<(Y))"));
|
|
OP(Mod, TEXT("((X)%(Y))"));
|
|
OP(Mul, TEXT("((X)*(Y))"));
|
|
//OP(Or, TEXT("((X)||(Y))"));
|
|
OP(Prelu, TEXT("prelu(X,Y)"));
|
|
OP(Pow, TEXT("safe_pow(X,Y)"));
|
|
OP(Sub, TEXT("((X)-(Y))"));
|
|
//OP(Or, TEXT("((X)^=(Y))"));
|
|
#undef OP
|
|
|
|
FString OpFunc = OpTable[(int32) OpType];
|
|
|
|
if (OpFunc == "")
|
|
{
|
|
UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("Undefined ElementWise Binary operator name for operator:%d"), int(OpType));
|
|
}
|
|
|
|
return OpFunc;
|
|
}
|
|
|
|
IMPLEMENT_GLOBAL_SHADER(TElementWiseBinaryCS, "/NNEHlslShaders/NNEHlslShadersElementWiseBinary.usf", "ElementWiseBinary", SF_Compute);
|
|
} // UE::NNEHlslShaders::Internal
|