// Copyright Epic Games, Inc. All Rights Reserved. #include "MassEntityManager.h" #include "MassProcessingContext.h" #include "MassEntityTestTypes.h" #include "MassEntitySettings.h" #include "MassExecutor.h" #include "MassEntityView.h" #include "MassExecutionContext.h" #include "MassProcessingContext.h" #define LOCTEXT_NAMESPACE "MassTest" UE_DISABLE_OPTIMIZATION_SHIP //----------------------------------------------------------------------// // tests //----------------------------------------------------------------------// namespace FMassMultiThreadingTest { template static FName GetProcessorName() { return T::StaticClass()->GetFName(); } struct FMTTestBase : FEntityTestBase { using Super = FExecutionTestBase; UMassCompositeProcessor* CompositeProcessor = nullptr; TArray Processors; TArray Result; FGraphEventRef FinishEvent; TArray Entities; virtual bool Update() override { // cannot do this in SetUp without adding a new virtual function for subtests to override if (CompositeProcessor == nullptr) { CompositeProcessor = NewObject(); CompositeProcessor->SetGroupName(TEXT("Test")); CompositeProcessor->SetProcessors(MakeArrayView((UMassProcessor**)Processors.GetData(), Processors.Num())); FMassProcessingContext Context(*EntityManager); FinishEvent = UE::Mass::Executor::TriggerParallelTasks(*CompositeProcessor, MoveTemp(Context), []() {}); } if (FinishEvent->IsComplete()) { // signal that we're done with this test return true; } return false; } }; struct FMTTrivial : FMTTestBase { using Super = FMTTestBase; const int32 NumToCreate = 200; int32 NumProcessed = 0; virtual bool SetUp() override { if (!Super::SetUp()) { return false; } EntityManager->BatchCreateEntities(IntsArchetype, NumToCreate, Entities); Processors.Reset(); { UMassTestProcessorBase* Proc = Processors.Add_GetRef(NewTestProcessor(EntityManager)); Proc->EntityQuery.AddRequirement(EMassFragmentAccess::ReadOnly); Proc->ForEachEntityChunkExecutionFunction = [this](FMassExecutionContext& Context) { NumProcessed += Context.GetNumEntities(); }; } return true; } virtual void VerifyLatentResults() override { AITEST_EQUAL_LATENT("Expected to process all the created entities.", NumToCreate, NumProcessed); } }; IMPLEMENT_AI_LATENT_TEST(FMTTrivial, "System.Mass.Multithreading.Trivial"); struct FMTBasic : FMTTestBase { using Super = FMTTestBase; const int32 NumToCreate = 200; int32 NumProcessed = 0; virtual bool SetUp() override { if (!Super::SetUp()) { return false; } EntityManager->BatchCreateEntities(FloatsIntsArchetype, NumToCreate, Entities); Processors.Reset(); { UMassTestProcessorBase* Proc = Processors.Add_GetRef(NewTestProcessor(EntityManager)); Proc->GetMutableExecutionOrder().ExecuteAfter.Add(GetProcessorName()); Proc->EntityQuery.AddRequirement(EMassFragmentAccess::ReadOnly); Proc->EntityQuery.AddRequirement(EMassFragmentAccess::ReadWrite); Proc->ForEachEntityChunkExecutionFunction = [this](FMassExecutionContext& Context) { const TArrayView IntsList = Context.GetMutableFragmentView(); const TConstArrayView FloatsList = Context.GetFragmentView(); for (int32 i = 0; i < Context.GetNumEntities(); ++i) { IntsList[i].Value = int(FloatsList[i].Value) + IntsList[i].Value; } }; } { UMassTestProcessorBase* Proc = Processors.Add_GetRef(NewTestProcessor(EntityManager)); Proc->GetMutableExecutionOrder().ExecuteAfter.Add(GetProcessorName()); Proc->EntityQuery.AddRequirement(EMassFragmentAccess::ReadOnly); Proc->EntityQuery.AddRequirement(EMassFragmentAccess::ReadWrite); Proc->ForEachEntityChunkExecutionFunction = [this](FMassExecutionContext& Context) { const TConstArrayView IntsList = Context.GetFragmentView(); const TArrayView FloatsList = Context.GetMutableFragmentView(); for (int32 i = 0; i < Context.GetNumEntities(); ++i) { FloatsList[i].Value = float(IntsList[i].Value * IntsList[i].Value); } }; } { UMassTestProcessorBase* Proc = Processors.Add_GetRef(NewTestProcessor(EntityManager)); Proc->EntityQuery.AddRequirement(EMassFragmentAccess::ReadWrite); Proc->ForEachEntityChunkExecutionFunction = [this](FMassExecutionContext& Context) { int Index = 0; const TArrayView IntsList = Context.GetMutableFragmentView(); for (int32 i = 0; i < Context.GetNumEntities(); ++i) { IntsList[i].Value = Index++; } }; } return true; } virtual void VerifyLatentResults() override { for (int i = 0; i < Entities.Num(); ++i) { FMassEntityView View(FloatsIntsArchetype, Entities[i]); AITEST_EQUAL_LATENT(TEXT("Should have predicted values"), View.GetFragmentData().Value, i*i + i); } } }; IMPLEMENT_AI_LATENT_TEST(FMTBasic, "System.Mass.Multithreading.Basic"); } // FMassMultiThreadingTest UE_ENABLE_OPTIMIZATION_SHIP #undef LOCTEXT_NAMESPACE