// Copyright Epic Games, Inc. All Rights Reserved. #include "NNEModelData.h" #include "NNE.h" #include "Serialization/CustomVersion.h" #include "UObject/AssetRegistryTagsContext.h" #include "UObject/WeakInterfacePtr.h" namespace UE::NNE::ModelData { enum Version : uint32 { V0 = 0, // Initial V1 = 1, // TargetRuntimes and AssetImportData V2 = 2, // Re-arrange fields and store only ModelData in cooked assets V3 = 3, // Adding AdditionalFileData V4 = 4, // Support for > 2GB models // New versions can be added above this line VersionPlusOne, Latest = VersionPlusOne - 1 }; const FGuid GUID(0x9513202e, 0xeba1b279, 0xf17fe5ba, 0xab90c3f2); FCustomVersionRegistration NNEModelDataVersion(GUID, Version::Latest, TEXT("NNEModelDataVersion"));// Always save with the latest version FString GetRuntimesAsString(TArrayView Runtimes) { if (Runtimes.Num() == 0) { return TEXT("All"); } FString RuntimesAsOneString; bool bIsFirstRuntime = true; for (const FString& Runtime : Runtimes) { if (!bIsFirstRuntime) { RuntimesAsOneString += TEXT(", "); } RuntimesAsOneString += Runtime; bIsFirstRuntime = false; } return RuntimesAsOneString; } inline TSharedPtr CreateModelData(const FString& RuntimeName, const FString& FileType, const TArray64& FileData, const TMap>& AdditionalFileData, const FGuid& FileId, const ITargetPlatform* TargetPlatform) { TWeakInterfacePtr NNERuntime = UE::NNE::GetRuntime(RuntimeName); if (NNERuntime.IsValid()) { INNERuntime::ECanCreateModelDataStatus CanCreateModelDataStatus = NNERuntime->CanCreateModelData(FileType, FileData, AdditionalFileData, FileId, TargetPlatform); if (CanCreateModelDataStatus == INNERuntime::ECanCreateModelDataStatus::Ok) { return NNERuntime->CreateModelData(FileType, FileData, AdditionalFileData, FileId, TargetPlatform); } else if (CanCreateModelDataStatus == INNERuntime::ECanCreateModelDataStatus::FailFileIdNotSupported) { UE_LOG(LogNNE, Display, TEXT("Runtime %s does not support Filetype: %s, skipping the model data creation for model with id %s "), *RuntimeName, *FileType, *FileId.ToString(EGuidFormats::Digits).ToLower()); } else { UE_LOG(LogNNE, Warning, TEXT("Runtime %s cannot create the model data with id %s (Filetype: %s)"), *RuntimeName, *FileId.ToString(EGuidFormats::Digits).ToLower(), *FileType); } } else { UE_LOG(LogNNE, Error, TEXT("No runtime '%s' found. Valid runtimes are: "), *RuntimeName); TArray Runtimes = UE::NNE::GetAllRuntimeNames(); for (int32 i = 0; i < Runtimes.Num(); i++) { UE_LOG(LogNNE, Error, TEXT("- %s"), *Runtimes[i]); } } return TSharedPtr(); } } // UE::NNE::ModelData namespace UE::NNE { FSharedModelData::FSharedModelData(FSharedBuffer InData, uint32 InMemoryAlignment) : Data(InData), MemoryAlignment(InMemoryAlignment) { checkf(Data.IsOwned(), TEXT("InData data must be ownned!")); checkf(MemoryAlignment <= 1 || (((uintptr_t)(const void*)(InData.GetData())) % MemoryAlignment == 0), TEXT("InData must be aligned with InMemoryAlignment!")) } FSharedModelData::FSharedModelData() { MemoryAlignment = 0; } TConstArrayView64 FSharedModelData::GetView() const { return TConstArrayView64(static_cast(Data.GetData()), Data.GetSize()); } uint32 FSharedModelData::GetMemoryAlignment() const { return MemoryAlignment; } } // UE::NNE void UNNEModelData::GetAssetRegistryTags(FAssetRegistryTagsContext Context) const { Context.AddTag(FAssetRegistryTag("TargetRuntimes", UE::NNE::ModelData::GetRuntimesAsString(GetTargetRuntimes()), FAssetRegistryTag::TT_Alphabetical)); Super::GetAssetRegistryTags(Context); } void UNNEModelData::Serialize(FArchive& Ar) { // Store the asset version (no effect in load) Ar.UsingCustomVersion(UE::NNE::ModelData::GUID); if (Ar.IsSaving() || Ar.IsCountingMemory()) { bool bWriteModelData = true; if (Ar.IsCooking()) { // Optimize storage: FileData and AdditionalFileData are not required anymore because we have the model and can cook it for every runtime TArray TmpTargetRuntimes; Ar << TmpTargetRuntimes; FString TmpFileType; Ar << TmpFileType; TArray64 TmpFileData; Ar << TmpFileData; int32 NumAdditionalFileDataItems = 0; Ar << NumAdditionalFileDataItems; FGuid TmpGuid; Ar << TmpGuid; // Cooking must recreate all model data but only if file data is still available if (FileData.Num() > 0) { ModelData.Reset(); // No target runtime means all currently registered ones TArray> CookRuntimeNames; if (GetTargetRuntimes().IsEmpty()) { CookRuntimeNames.Append(UE::NNE::GetAllRuntimeNames()); } else { CookRuntimeNames.Append(GetTargetRuntimes()); } TMap> AdditionalFileDataView; TSet Keys; AdditionalFileData.GetKeys(Keys); for (auto& Key : Keys) { AdditionalFileDataView.Emplace(Key, AdditionalFileData[Key]); } for (const FString& RuntimeName : CookRuntimeNames) { TSharedPtr SharedModelData = UE::NNE::ModelData::CreateModelData(RuntimeName, FileType, FileData, AdditionalFileDataView, FileId, Ar.GetArchiveState().CookingTarget()); if (SharedModelData.IsValid() && SharedModelData->GetView().Num() > 0) { ModelData.Add(RuntimeName, SharedModelData); } } } } else { // Only cooked assets optimize storage Ar << TargetRuntimes; Ar << FileType; Ar << FileData; int32 NumAdditionalFileDataItems = AdditionalFileData.Num(); Ar << NumAdditionalFileDataItems; TSet Keys; AdditionalFileData.GetKeys(Keys); for (auto& Key : Keys) { Ar << Key; Ar << AdditionalFileData[Key]; } Ar << FileId; #if WITH_EDITOR // In editor (when not cooking), no model data is stored as model data can always be recreated and unnecessary data in subversion control should be avoided bWriteModelData = false; #endif //WITH_EDITOR } if (bWriteModelData) { TArray RuntimeNames; ModelData.GetKeys(RuntimeNames); int32 NumItems = RuntimeNames.Num(); Ar << NumItems; for (int32 i = 0; i < NumItems; i++) { Ar << RuntimeNames[i]; uint32 MemoryAlignment = ModelData[RuntimeNames[i]]->GetMemoryAlignment(); Ar << MemoryAlignment; uint64 DataSize = ModelData[RuntimeNames[i]]->GetView().Num(); Ar << DataSize; Ar.Serialize((void*)ModelData[RuntimeNames[i]]->GetView().GetData(), DataSize); } } else { int32 NumItems = 0; Ar << NumItems; } } else if (Ar.IsLoading()) { TObjectPtr AssetImportData; int32 NumItems = 0; int32 NumAdditionalFileDataItems = 0; FString Name; uint32 MemoryAlignment = 0; uint64 DataSize = 0; TArray64 Data; TArray Data_Size32Bits; void* RawData = nullptr; int32 Index = 0; switch (Ar.CustomVer(UE::NNE::ModelData::GUID)) { case UE::NNE::ModelData::Version::V0: TargetRuntimes.Empty(); Ar << FileType; Ar << Data_Size32Bits; FileData = MoveTemp(Data_Size32Bits); Ar << FileId; Ar << NumItems; for (Index = 0; Index < NumItems; Index++) { Ar << Name; Ar << Data_Size32Bits; ModelData.Add(Name, MakeShared(MakeSharedBufferFromArray(MoveTemp(Data_Size32Bits)), 0)); } UE_LOG(LogNNE, Warning, TEXT("[DEPRECATION] The asset %s (v0) is deprecated. Please right-click the asset and select 'Save' to update it to the latest version."), *this->GetName()); break; case UE::NNE::ModelData::Version::V1: TargetRuntimes.Empty(); if (!Ar.IsLoadingFromCookedPackage()) { Ar << TargetRuntimes; Ar << AssetImportData; } Ar << FileType; Ar << Data_Size32Bits; FileData = MoveTemp(Data_Size32Bits); Ar << FileId; Ar << NumItems; for (Index = 0; Index < NumItems; Index++) { Ar << Name; Ar << Data_Size32Bits; ModelData.Add(Name, MakeShared(MakeSharedBufferFromArray(MoveTemp(Data_Size32Bits)), 0)); } UE_LOG(LogNNE, Warning, TEXT("[DEPRECATION] The asset %s (v1) is deprecated. Please right-click the asset and select 'Save' to update it to the latest version."), *this->GetName()); break; case UE::NNE::ModelData::Version::V2: Ar << TargetRuntimes; Ar << FileType; Ar << Data_Size32Bits; FileData = MoveTemp(Data_Size32Bits); Ar << FileId; Ar << NumItems; for (Index = 0; Index < NumItems; Index++) { Ar << Name; Ar << MemoryAlignment; Ar << DataSize; RawData = FMemory::Malloc(DataSize, MemoryAlignment); Ar.Serialize(RawData, DataSize); ModelData.Add(Name, MakeShared(FSharedBuffer::TakeOwnership(RawData, DataSize, FMemory::Free), MemoryAlignment)); } break; case UE::NNE::ModelData::Version::V3: Ar << TargetRuntimes; Ar << FileType; Ar << Data_Size32Bits; FileData = MoveTemp(Data_Size32Bits); Ar << NumAdditionalFileDataItems; AdditionalFileData.Empty(); for (Index = 0; Index < NumAdditionalFileDataItems; Index++) { Ar << Name; Ar << Data_Size32Bits; Data = MoveTemp(Data_Size32Bits); AdditionalFileData.Add(Name, Data); } Ar << FileId; Ar << NumItems; for (Index = 0; Index < NumItems; Index++) { Ar << Name; Ar << MemoryAlignment; Ar << DataSize; RawData = FMemory::Malloc(DataSize, MemoryAlignment); Ar.Serialize(RawData, DataSize); ModelData.Add(Name, MakeShared(FSharedBuffer::TakeOwnership(RawData, DataSize, FMemory::Free), MemoryAlignment)); } break; case UE::NNE::ModelData::Version::V4: Ar << TargetRuntimes; Ar << FileType; Ar << FileData; Ar << NumAdditionalFileDataItems; AdditionalFileData.Empty(); for (Index = 0; Index < NumAdditionalFileDataItems; Index++) { Ar << Name; Ar << Data; AdditionalFileData.Add(Name, Data); } Ar << FileId; Ar << NumItems; for (Index = 0; Index < NumItems; Index++) { Ar << Name; Ar << MemoryAlignment; Ar << DataSize; RawData = FMemory::Malloc(DataSize, MemoryAlignment); Ar.Serialize(RawData, DataSize); ModelData.Add(Name, MakeShared(FSharedBuffer::TakeOwnership(RawData, DataSize, FMemory::Free), MemoryAlignment)); } break; default: UE_LOG(LogNNE, Error, TEXT("Unknown asset version %d: Deserialisation failed, please reimport the original model."), Ar.CustomVer(UE::NNE::ModelData::GUID)); break; } } } void UNNEModelData::Init(const FString& Type, TConstArrayView64 Buffer, const TMap>& AdditionalBuffers) { TargetRuntimes.Empty(); FileType = Type; FileData = Buffer; AdditionalFileData.Empty(); for (auto& Element : AdditionalBuffers) { AdditionalFileData.Emplace(Element.Key, AdditionalBuffers[Element.Key]); } FPlatformMisc::CreateGuid(FileId); ModelData.Empty(); } TArrayView UNNEModelData::GetTargetRuntimes() const { return TargetRuntimes; } void UNNEModelData::SetTargetRuntimes(TArrayView RuntimeNames) { TargetRuntimes = RuntimeNames; if (RuntimeNames.Num() > 0) { TArray> CookedRuntimes; ModelData.GetKeys(CookedRuntimes); for (const FString& Runtime : CookedRuntimes) { if (!TargetRuntimes.Contains(Runtime)) { ModelData.Remove(Runtime); } } ModelData.Compact(); } } FString UNNEModelData::GetFileType() const { return FileType; } TConstArrayView64 UNNEModelData::GetFileData() const { return FileData; } TConstArrayView64 UNNEModelData::GetAdditionalFileData(const FString& Key) const { return AdditionalFileData.Contains(Key) ? AdditionalFileData[Key] : TConstArrayView(); } void UNNEModelData::ClearFileDataAndFileType() { FileType = ""; FileData.Empty(); AdditionalFileData.Empty(); } FGuid UNNEModelData::GetFileId() const { return FileId; } TSharedPtr UNNEModelData::GetModelData(const FString& RuntimeName) { // Check model data is supporting the requested target runtime TArrayView TargetRuntimesNames = GetTargetRuntimes(); if (!TargetRuntimesNames.IsEmpty() && !TargetRuntimesNames.Contains(RuntimeName)) { UE_LOG(LogNNE, Error, TEXT("Runtime '%s' is not among the target runtimes. Target runtimes are: "), *RuntimeName); for (const FString& TargetRuntimesName : TargetRuntimesNames) { UE_LOG(LogNNE, Error, TEXT("- %s"), *TargetRuntimesName); } return TSharedPtr(); } // Check if we have a local cache hit TSharedPtr* LocalDataPtr = ModelData.Find(RuntimeName); if (LocalDataPtr) { return *LocalDataPtr; } // After this point FileData is required to either get the cache id or recreate it from scratch if (FileData.Num() < 1) { UE_LOG(LogNNE, Error, TEXT("Cannot create model data from empty file data.")); return TSharedPtr(); } TMap> AdditionalFileDataView; TSet Keys; AdditionalFileData.GetKeys(Keys); for (auto& Key : Keys) { AdditionalFileDataView.Emplace(Key, AdditionalFileData[Key]); } // Try to create the model TSharedPtr CreatedData = UE::NNE::ModelData::CreateModelData(RuntimeName, FileType, FileData, AdditionalFileDataView, FileId, nullptr); if (!CreatedData.IsValid() || CreatedData->GetView().Num() < 1) { return TSharedPtr(); } // Cache the model ModelData.Add(RuntimeName, CreatedData); return CreatedData; } void UNNEModelData::ClearModelData() { ModelData.Empty(); }