Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Source/NNERuntimeRDGData/Internal/NNERuntimeRDGDataAttributeTensor.h
2025-05-18 13:04:45 +08:00

61 lines
1.3 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
#include "Containers/ArrayView.h"
#include "NNETypes.h"
#include "Serialization/Archive.h"
namespace UE::NNERuntimeRDGData::Internal
{
class FAttributeTensor
{
protected:
ENNETensorDataType DataType;
TArray<uint32, TInlineAllocator<NNE::FTensorShape::MaxRank>> Shape;
TArray<uint8> Data;
public:
ENNETensorDataType GetDataType() const
{
return DataType;
}
TConstArrayView<uint8> GetData() const
{
return Data;
}
NNE::FTensorShape GetShape() const
{
return NNE::FTensorShape::Make(Shape);
}
static FAttributeTensor Make(const NNE::FTensorShape& Shape,
ENNETensorDataType DataType,
TConstArrayView<uint8> Data)
{
uint64 Volume = Shape.Volume();
check(Volume <= TNumericLimits<uint32>::Max());
check(Data.NumBytes() == NNE::GetTensorDataTypeSizeInBytes(DataType) * Volume);
FAttributeTensor Tensor;
Tensor.DataType = DataType;
Tensor.Shape.Reset();
Tensor.Shape.Append(Shape.GetData());
Tensor.Data.Reset();
Tensor.Data.Append(Data);
return Tensor;
}
friend FArchive& operator<<(FArchive& Ar, FAttributeTensor& Tensor)
{
Ar << Tensor.DataType;
Ar << Tensor.Shape;
Ar << Tensor.Data;
return Ar;
}
};
} // namespace UE::RuntimeRDGData::Internal