3609 lines
120 KiB
C++
3609 lines
120 KiB
C++
// Copyright Epic Games, Inc. All Rights Reserved.
|
|
|
|
#include "LearningAction.h"
|
|
|
|
#include "LearningRandom.h"
|
|
|
|
#include "NNERuntimeBasicCpuBuilder.h"
|
|
|
|
namespace UE::Learning::Action
|
|
{
|
|
namespace Private
|
|
{
|
|
static inline bool ContainsDuplicates(const TArrayView<const FName> ElementNames)
|
|
{
|
|
TSet<FName, DefaultKeyFuncs<FName>, TInlineSetAllocator<32>> ElementNameSet;
|
|
ElementNameSet.Append(ElementNames);
|
|
return ElementNames.Num() != ElementNameSet.Num();
|
|
}
|
|
|
|
static inline bool CheckAllValid(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
if (!Schema.IsValid(SubElement)) { return false; }
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static inline int32 GetMaxActionVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size = FMath::Max(Size, Schema.GetActionVectorSize(SubElement));
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline int32 GetTotalActionVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size += Schema.GetActionVectorSize(SubElement);
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline int32 GetTotalEncodedActionVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size += Schema.GetEncodedVectorSize(SubElement);
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline int32 GetTotalActionDistributionVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size += Schema.GetActionDistributionVectorSize(SubElement);
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline int32 GetTotalActionModifierVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size += Schema.GetActionModifierVectorSize(SubElement);
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline bool CheckAllValid(const FObject& Object, const TArrayView<const FObjectElement> Elements)
|
|
{
|
|
for (const FObjectElement SubElement : Elements)
|
|
{
|
|
if (!Object.IsValid(SubElement)) { return false; }
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static inline bool CheckPriorProbabilitiesExclusive(const TArrayView<const float> PriorProbabilities, const float Epsilon = UE_KINDA_SMALL_NUMBER)
|
|
{
|
|
if (PriorProbabilities.Num() == 0) { return true; }
|
|
|
|
for (int32 Idx = 0; Idx < PriorProbabilities.Num(); Idx++)
|
|
{
|
|
if (PriorProbabilities[Idx] < 0.0f || PriorProbabilities[Idx] > 1.0f)
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
float Total = 0.0f;
|
|
for (int32 Idx = 0; Idx < PriorProbabilities.Num(); Idx++)
|
|
{
|
|
Total += PriorProbabilities[Idx];
|
|
}
|
|
|
|
return FMath::Abs(Total - 1.0f) < Epsilon;
|
|
}
|
|
|
|
static inline bool CheckPriorProbabilitiesInclusive(const TArrayView<const float> PriorProbabilities)
|
|
{
|
|
if (PriorProbabilities.Num() == 0) { return true; }
|
|
|
|
for (int32 Idx = 0; Idx < PriorProbabilities.Num(); Idx++)
|
|
{
|
|
if (PriorProbabilities[Idx] < 0.0f || PriorProbabilities[Idx] > 1.0f)
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
static inline bool CheckAllValid(const FModifier& Object, const TArrayView<const FModifierElement> Elements)
|
|
{
|
|
for (const FModifierElement SubElement : Elements)
|
|
{
|
|
if (!Object.IsValid(SubElement)) { return false; }
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static inline bool CheckExclusiveMaskValid(const TArrayView<const bool> Mask)
|
|
{
|
|
for (int32 MaskIdx = 0; MaskIdx < Mask.Num(); MaskIdx++)
|
|
{
|
|
if (!Mask[MaskIdx]) { return true; }
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static inline float Logit(const float X)
|
|
{
|
|
return FMath::Loge(FMath::Max(X / FMath::Max(1.0f - X, FLT_MIN), FLT_MIN));
|
|
}
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateNull(const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Null);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(0);
|
|
ActionVectorSizes.Add(0);
|
|
ActionDistributionVectorSizes.Add(0);
|
|
ActionModifierVectorSizes.Add(1);
|
|
TypeDataIndices.Add(INDEX_NONE);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateContinuous(const FSchemaContinuousParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Num >= 0);
|
|
check(Parameters.Scale >= 0.0f);
|
|
|
|
FContinuousData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
ElementData.Scale = Parameters.Scale;
|
|
|
|
const int32 Index = Types.Add(EType::Continuous);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(2 * Parameters.Num);
|
|
ActionVectorSizes.Add(Parameters.Num);
|
|
ActionDistributionVectorSizes.Add(2 * Parameters.Num);
|
|
ActionModifierVectorSizes.Add(1 + 2 * Parameters.Num);
|
|
TypeDataIndices.Add(ContinuousData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateDiscreteExclusive(const FSchemaDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.PriorProbabilities.Num() == Parameters.Num);
|
|
check(Private::CheckPriorProbabilitiesExclusive(Parameters.PriorProbabilities));
|
|
|
|
FDiscreteExclusiveData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
ElementData.PriorProbabilitiesOffset = PriorProbabilities.Num();
|
|
|
|
PriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
|
|
const int32 Index = Types.Add(EType::DiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(Parameters.Num);
|
|
ActionVectorSizes.Add(Parameters.Num);
|
|
ActionDistributionVectorSizes.Add(Parameters.Num);
|
|
ActionModifierVectorSizes.Add(1 + Parameters.Num);
|
|
TypeDataIndices.Add(DiscreteExclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateDiscreteInclusive(const FSchemaDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.PriorProbabilities.Num() == Parameters.Num);
|
|
check(Private::CheckPriorProbabilitiesInclusive(Parameters.PriorProbabilities));
|
|
|
|
FDiscreteInclusiveData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
ElementData.PriorProbabilitiesOffset = PriorProbabilities.Num();
|
|
|
|
PriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
|
|
const int32 Index = Types.Add(EType::DiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(Parameters.Num);
|
|
ActionVectorSizes.Add(Parameters.Num);
|
|
ActionDistributionVectorSizes.Add(Parameters.Num);
|
|
ActionModifierVectorSizes.Add(1 + Parameters.Num);
|
|
TypeDataIndices.Add(DiscreteInclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateNamedDiscreteExclusive(const FSchemaNamedDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.PriorProbabilities.Num() == Parameters.ElementNames.Num());
|
|
check(Private::CheckPriorProbabilitiesExclusive(Parameters.PriorProbabilities));
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
|
|
FNamedDiscreteExclusiveData ElementData;
|
|
ElementData.Num = Parameters.ElementNames.Num();
|
|
ElementData.PriorProbabilitiesOffset = PriorProbabilities.Num();
|
|
ElementData.ElementsOffset = SubElementNames.Num();
|
|
|
|
PriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
for (int32 Idx = 0; Idx < ElementData.Num; Idx++) { SubElementObjects.Add(FSchemaElement()); }
|
|
|
|
const int32 Index = Types.Add(EType::NamedDiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(Parameters.ElementNames.Num());
|
|
ActionVectorSizes.Add(Parameters.ElementNames.Num());
|
|
ActionDistributionVectorSizes.Add(Parameters.ElementNames.Num());
|
|
ActionModifierVectorSizes.Add(1 + Parameters.ElementNames.Num());
|
|
TypeDataIndices.Add(NamedDiscreteExclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateNamedDiscreteInclusive(const FSchemaNamedDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.PriorProbabilities.Num() == Parameters.ElementNames.Num());
|
|
check(Private::CheckPriorProbabilitiesInclusive(Parameters.PriorProbabilities));
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
|
|
FNamedDiscreteInclusiveData ElementData;
|
|
ElementData.Num = Parameters.ElementNames.Num();
|
|
ElementData.PriorProbabilitiesOffset = PriorProbabilities.Num();
|
|
ElementData.ElementsOffset = SubElementNames.Num();
|
|
|
|
PriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
for (int32 Idx = 0; Idx < ElementData.Num; Idx++) { SubElementObjects.Add(FSchemaElement()); }
|
|
|
|
const int32 Index = Types.Add(EType::NamedDiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(Parameters.ElementNames.Num());
|
|
ActionVectorSizes.Add(Parameters.ElementNames.Num());
|
|
ActionDistributionVectorSizes.Add(Parameters.ElementNames.Num());
|
|
ActionModifierVectorSizes.Add(1 + Parameters.ElementNames.Num());
|
|
TypeDataIndices.Add(NamedDiscreteInclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateAnd(const FSchemaAndParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
FAndData ElementData;
|
|
ElementData.Num = Parameters.Elements.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
const int32 Index = Types.Add(EType::And);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(Private::GetTotalEncodedActionVectorSize(*this, Parameters.Elements));
|
|
ActionVectorSizes.Add(Private::GetTotalActionVectorSize(*this, Parameters.Elements));
|
|
ActionDistributionVectorSizes.Add(Private::GetTotalActionDistributionVectorSize(*this, Parameters.Elements));
|
|
ActionModifierVectorSizes.Add(1 + Private::GetTotalActionModifierVectorSize(*this, Parameters.Elements));
|
|
TypeDataIndices.Add(AndData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateOrExclusive(const FSchemaOrExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
check(Parameters.PriorProbabilities.Num() == Parameters.Elements.Num());
|
|
check(Private::CheckPriorProbabilitiesExclusive(Parameters.PriorProbabilities));
|
|
|
|
FOrExclusiveData ElementData;
|
|
ElementData.Num = Parameters.Elements.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
ElementData.PriorProbabilitiesOffset = PriorProbabilities.Num();
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
PriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
|
|
const int32 Index = Types.Add(EType::OrExclusive);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(Private::GetTotalEncodedActionVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
ActionVectorSizes.Add(Private::GetMaxActionVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
ActionDistributionVectorSizes.Add(Private::GetTotalActionDistributionVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
ActionModifierVectorSizes.Add(1 + Parameters.Elements.Num() + Private::GetTotalActionModifierVectorSize(*this, Parameters.Elements));
|
|
TypeDataIndices.Add(OrExclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateOrInclusive(const FSchemaOrInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
check(Parameters.PriorProbabilities.Num() == Parameters.Elements.Num());
|
|
check(Private::CheckPriorProbabilitiesInclusive(Parameters.PriorProbabilities));
|
|
|
|
FOrInclusiveData ElementData;
|
|
ElementData.Num = Parameters.Elements.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
ElementData.PriorProbabilitiesOffset = PriorProbabilities.Num();
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
PriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
|
|
const int32 Index = Types.Add(EType::OrInclusive);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(Private::GetTotalEncodedActionVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
ActionVectorSizes.Add(Private::GetTotalActionVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
ActionDistributionVectorSizes.Add(Private::GetTotalActionDistributionVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
ActionModifierVectorSizes.Add(1 + Parameters.Elements.Num() + Private::GetTotalActionModifierVectorSize(*this, Parameters.Elements));
|
|
TypeDataIndices.Add(OrInclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateArray(const FSchemaArrayParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
check(Parameters.Num >= 0);
|
|
|
|
FArrayData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
ElementData.ElementIndex = SubElementObjects.Num();
|
|
|
|
SubElementNames.Add(NAME_None);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
const int32 Index = Types.Add(EType::Array);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(GetEncodedVectorSize(Parameters.Element) * Parameters.Num);
|
|
ActionVectorSizes.Add(GetActionVectorSize(Parameters.Element) * Parameters.Num);
|
|
ActionDistributionVectorSizes.Add(GetActionDistributionVectorSize(Parameters.Element) * Parameters.Num);
|
|
ActionModifierVectorSizes.Add(1 + GetActionModifierVectorSize(Parameters.Element) * Parameters.Num);
|
|
TypeDataIndices.Add(ArrayData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateEncoding(const FSchemaEncodingParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
FEncodingData ElementData;
|
|
ElementData.EncodingSize = Parameters.EncodingSize;
|
|
ElementData.LayerNum = Parameters.LayerNum;
|
|
ElementData.ActivationFunction = Parameters.ActivationFunction;
|
|
ElementData.ElementIndex = SubElementObjects.Num();
|
|
|
|
SubElementNames.Add(NAME_None);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
const int32 Index = Types.Add(EType::Encoding);
|
|
Tags.Add(Tag);
|
|
EncodedVectorSizes.Add(ElementData.EncodingSize);
|
|
ActionVectorSizes.Add(GetActionVectorSize(Parameters.Element));
|
|
ActionDistributionVectorSizes.Add(GetActionDistributionVectorSize(Parameters.Element));
|
|
ActionModifierVectorSizes.Add(1 + GetActionModifierVectorSize(Parameters.Element));
|
|
TypeDataIndices.Add(EncodingData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
bool FSchema::IsValid(const FSchemaElement Element) const
|
|
{
|
|
return Element.Generation == Generation && Element.Index != INDEX_NONE;
|
|
}
|
|
|
|
EType FSchema::GetType(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Types[Element.Index];
|
|
}
|
|
|
|
FName FSchema::GetTag(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Tags[Element.Index];
|
|
}
|
|
|
|
int32 FSchema::GetEncodedVectorSize(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return EncodedVectorSizes[Element.Index];
|
|
}
|
|
|
|
int32 FSchema::GetActionVectorSize(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return ActionVectorSizes[Element.Index];
|
|
}
|
|
|
|
int32 FSchema::GetActionDistributionVectorSize(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return ActionDistributionVectorSizes[Element.Index];
|
|
}
|
|
|
|
int32 FSchema::GetActionModifierVectorSize(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return ActionModifierVectorSizes[Element.Index];
|
|
}
|
|
|
|
FSchemaContinuousParameters FSchema::GetContinuous(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Continuous);
|
|
const FContinuousData& ElementData = ContinuousData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaContinuousParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
Parameters.Scale = ElementData.Scale;
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaDiscreteExclusiveParameters FSchema::GetDiscreteExclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteExclusive);
|
|
const FDiscreteExclusiveData& ElementData = DiscreteExclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaDiscreteExclusiveParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
Parameters.PriorProbabilities = TArrayView<const float>(PriorProbabilities.GetData() + ElementData.PriorProbabilitiesOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaDiscreteInclusiveParameters FSchema::GetDiscreteInclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteInclusive);
|
|
const FDiscreteInclusiveData& ElementData = DiscreteInclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaDiscreteInclusiveParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
Parameters.PriorProbabilities = TArrayView<const float>(PriorProbabilities.GetData() + ElementData.PriorProbabilitiesOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaNamedDiscreteExclusiveParameters FSchema::GetNamedDiscreteExclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteExclusive);
|
|
const FNamedDiscreteExclusiveData& ElementData = NamedDiscreteExclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaNamedDiscreteExclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.PriorProbabilities = TArrayView<const float>(PriorProbabilities.GetData() + ElementData.PriorProbabilitiesOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaNamedDiscreteInclusiveParameters FSchema::GetNamedDiscreteInclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteInclusive);
|
|
const FNamedDiscreteInclusiveData& ElementData = NamedDiscreteInclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaNamedDiscreteInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.PriorProbabilities = TArrayView<const float>(PriorProbabilities.GetData() + ElementData.PriorProbabilitiesOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaAndParameters FSchema::GetAnd(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::And);
|
|
const FAndData& ElementData = AndData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaAndParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.Elements = TArrayView<const FSchemaElement>(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaOrExclusiveParameters FSchema::GetOrExclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrExclusive);
|
|
const FOrExclusiveData& ElementData = OrExclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaOrExclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.Elements = TArrayView<const FSchemaElement>(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.PriorProbabilities = TArrayView<const float>(PriorProbabilities.GetData() + ElementData.PriorProbabilitiesOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaOrInclusiveParameters FSchema::GetOrInclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrInclusive);
|
|
const FOrInclusiveData& ElementData = OrInclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaOrInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.Elements = TArrayView<const FSchemaElement>(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.PriorProbabilities = TArrayView<const float>(PriorProbabilities.GetData() + ElementData.PriorProbabilitiesOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaArrayParameters FSchema::GetArray(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Array);
|
|
const FArrayData& ElementData = ArrayData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaArrayParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
Parameters.Element = SubElementObjects[ElementData.ElementIndex];
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaEncodingParameters FSchema::GetEncoding(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Encoding);
|
|
const FEncodingData& ElementData = EncodingData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaEncodingParameters Parameters;
|
|
Parameters.Element = SubElementObjects[ElementData.ElementIndex];
|
|
Parameters.EncodingSize = ElementData.EncodingSize;
|
|
Parameters.LayerNum = ElementData.LayerNum;
|
|
Parameters.ActivationFunction = ElementData.ActivationFunction;
|
|
return Parameters;
|
|
}
|
|
|
|
uint32 FSchema::GetGeneration() const
|
|
{
|
|
return Generation;
|
|
}
|
|
|
|
void FSchema::Empty()
|
|
{
|
|
Types.Empty();
|
|
Tags.Empty();
|
|
EncodedVectorSizes.Empty();
|
|
ActionVectorSizes.Empty();
|
|
ActionDistributionVectorSizes.Empty();
|
|
TypeDataIndices.Empty();
|
|
|
|
ContinuousData.Empty();
|
|
DiscreteExclusiveData.Empty();
|
|
DiscreteInclusiveData.Empty();
|
|
NamedDiscreteExclusiveData.Empty();
|
|
NamedDiscreteInclusiveData.Empty();
|
|
AndData.Empty();
|
|
OrExclusiveData.Empty();
|
|
OrInclusiveData.Empty();
|
|
ArrayData.Empty();
|
|
EncodingData.Empty();
|
|
|
|
SubElementNames.Empty();
|
|
SubElementObjects.Empty();
|
|
PriorProbabilities.Empty();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
bool FSchema::IsEmpty() const
|
|
{
|
|
return Types.IsEmpty();
|
|
}
|
|
|
|
void FSchema::Reset()
|
|
{
|
|
Types.Reset();
|
|
Tags.Reset();
|
|
EncodedVectorSizes.Reset();
|
|
ActionVectorSizes.Reset();
|
|
ActionDistributionVectorSizes.Reset();
|
|
TypeDataIndices.Reset();
|
|
|
|
ContinuousData.Reset();
|
|
DiscreteExclusiveData.Reset();
|
|
DiscreteInclusiveData.Reset();
|
|
NamedDiscreteExclusiveData.Reset();
|
|
NamedDiscreteInclusiveData.Reset();
|
|
AndData.Reset();
|
|
OrExclusiveData.Reset();
|
|
OrInclusiveData.Reset();
|
|
ArrayData.Reset();
|
|
EncodingData.Reset();
|
|
|
|
SubElementNames.Reset();
|
|
SubElementObjects.Reset();
|
|
PriorProbabilities.Reset();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
FObjectElement FObject::CreateNull(const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Null);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateContinuous(const FObjectContinuousParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Continuous);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(Parameters.Values.Num());
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
ContinuousValues.Append(Parameters.Values);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateDiscreteExclusive(const FObjectDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::DiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(1);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
DiscreteValues.Add(Parameters.DiscreteIndex);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateDiscreteInclusive(const FObjectDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::DiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(Parameters.DiscreteIndices.Num());
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
DiscreteValues.Append(Parameters.DiscreteIndices);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateNamedDiscreteExclusive(const FObjectNamedDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::NamedDiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(1);
|
|
|
|
SubElementObjects.Add(FObjectElement());
|
|
SubElementNames.Add(Parameters.ElementName);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateNamedDiscreteInclusive(const FObjectNamedDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::NamedDiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(Parameters.ElementNames.Num());
|
|
|
|
for (int32 Idx = 0; Idx < Parameters.ElementNames.Num(); Idx++) { SubElementObjects.Add(FObjectElement()); }
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateAnd(const FObjectAndParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::And);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateOrExclusive(const FObjectOrExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
const int32 Index = Types.Add(EType::OrExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(1);
|
|
|
|
SubElementObjects.Add(Parameters.Element);
|
|
SubElementNames.Add(Parameters.ElementName);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateOrInclusive(const FObjectOrInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::OrInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateArray(const FObjectArrayParameters Parameters, const FName Tag)
|
|
{
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::Array);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Elements.Num(); ElementIdx++)
|
|
{
|
|
SubElementNames.Add(NAME_None);
|
|
}
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateEncoding(const FObjectEncodingParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
const int32 Index = Types.Add(EType::Encoding);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementObjects.Num());
|
|
ElementDataNums.Add(1);
|
|
|
|
SubElementNames.Add(NAME_None);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
bool FObject::IsValid(const FObjectElement Element) const
|
|
{
|
|
return Element.Generation == Generation && Element.Index != INDEX_NONE;
|
|
}
|
|
|
|
EType FObject::GetType(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Types[Element.Index];
|
|
}
|
|
|
|
FName FObject::GetTag(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Tags[Element.Index];
|
|
}
|
|
|
|
FObjectContinuousParameters FObject::GetContinuous(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Continuous);
|
|
|
|
FObjectContinuousParameters Parameters;
|
|
Parameters.Values = TArrayView<const float>(ContinuousValues.GetData() + ContinuousDataOffsets[Element.Index], ContinuousDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectDiscreteExclusiveParameters FObject::GetDiscreteExclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteExclusive);
|
|
|
|
FObjectDiscreteExclusiveParameters Parameters;
|
|
Parameters.DiscreteIndex = DiscreteValues[DiscreteDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectDiscreteInclusiveParameters FObject::GetDiscreteInclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteInclusive);
|
|
|
|
FObjectDiscreteInclusiveParameters Parameters;
|
|
Parameters.DiscreteIndices = TArrayView<const int32>(DiscreteValues.GetData() + DiscreteDataOffsets[Element.Index], DiscreteDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectNamedDiscreteExclusiveParameters FObject::GetNamedDiscreteExclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteExclusive);
|
|
|
|
FObjectNamedDiscreteExclusiveParameters Parameters;
|
|
Parameters.ElementName = SubElementNames[ElementDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectNamedDiscreteInclusiveParameters FObject::GetNamedDiscreteInclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteInclusive);
|
|
|
|
FObjectNamedDiscreteInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectAndParameters FObject::GetAnd(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::And);
|
|
|
|
FObjectAndParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
Parameters.Elements = TArrayView<const FObjectElement>(SubElementObjects.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectOrExclusiveParameters FObject::GetOrExclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrExclusive);
|
|
|
|
FObjectOrExclusiveParameters Parameters;
|
|
Parameters.ElementName = SubElementNames[ElementDataOffsets[Element.Index]];
|
|
Parameters.Element = SubElementObjects[ElementDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectOrInclusiveParameters FObject::GetOrInclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrInclusive);
|
|
|
|
FObjectOrInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
Parameters.Elements = TArrayView<const FObjectElement>(SubElementObjects.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectArrayParameters FObject::GetArray(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Array);
|
|
|
|
FObjectArrayParameters Parameters;
|
|
Parameters.Elements = TArrayView<const FObjectElement>(SubElementObjects.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectEncodingParameters FObject::GetEncoding(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Encoding);
|
|
|
|
FObjectEncodingParameters Parameters;
|
|
Parameters.Element = SubElementObjects[ElementDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
uint32 FObject::GetGeneration() const
|
|
{
|
|
return Generation;
|
|
}
|
|
|
|
void FObject::Empty()
|
|
{
|
|
Types.Empty();
|
|
Tags.Empty();
|
|
ContinuousDataOffsets.Empty();
|
|
ContinuousDataNums.Empty();
|
|
DiscreteDataOffsets.Empty();
|
|
DiscreteDataNums.Empty();
|
|
ElementDataOffsets.Empty();
|
|
ElementDataNums.Empty();
|
|
|
|
ContinuousValues.Empty();
|
|
DiscreteValues.Empty();
|
|
SubElementObjects.Empty();
|
|
SubElementNames.Empty();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
bool FObject::IsEmpty() const
|
|
{
|
|
return Types.IsEmpty();
|
|
}
|
|
|
|
void FObject::Reset()
|
|
{
|
|
Types.Reset();
|
|
Tags.Reset();
|
|
ContinuousDataOffsets.Reset();
|
|
ContinuousDataNums.Reset();
|
|
DiscreteDataOffsets.Reset();
|
|
DiscreteDataNums.Reset();
|
|
ElementDataOffsets.Reset();
|
|
ElementDataNums.Reset();
|
|
|
|
ContinuousValues.Reset();
|
|
DiscreteValues.Reset();
|
|
SubElementObjects.Reset();
|
|
SubElementNames.Reset();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
|
|
FModifierElement FModifier::CreateNull(const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Null);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateContinuous(const FModifierContinuousParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Continuous);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(Parameters.MaskedValues.Num());
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
ContinuousMaskeds.Append(Parameters.Masked);
|
|
ContinuousMaskedValues.Append(Parameters.MaskedValues);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateDiscreteExclusive(const FModifierDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::DiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(Parameters.MaskedIndices.Num());
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
DiscreteValues.Append(Parameters.MaskedIndices);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateDiscreteInclusive(const FModifierDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::DiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(Parameters.MaskedIndices.Num());
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(0);
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
DiscreteValues.Append(Parameters.MaskedIndices);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateNamedDiscreteExclusive(const FModifierNamedDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::NamedDiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(Parameters.MaskedElementNames.Num());
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
SubElementNames.Append(Parameters.MaskedElementNames);
|
|
for (int32 Idx = 0; Idx < Parameters.MaskedElementNames.Num(); Idx++) { SubElementModifiers.Add(FModifierElement()); }
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateNamedDiscreteInclusive(const FModifierNamedDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::NamedDiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(Parameters.MaskedElementNames.Num());
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
SubElementNames.Append(Parameters.MaskedElementNames);
|
|
for (int32 Idx = 0; Idx < Parameters.MaskedElementNames.Num(); Idx++) { SubElementModifiers.Add(FModifierElement()); }
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateAnd(const FModifierAndParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::And);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
SubElementModifiers.Append(Parameters.Elements);
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateOrExclusive(const FModifierOrExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(!Private::ContainsDuplicates(Parameters.MaskedElements));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::OrExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(Parameters.MaskedElements.Num());
|
|
|
|
SubElementModifiers.Append(Parameters.Elements);
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
|
|
MaskedElementNames.Append(Parameters.MaskedElements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateOrInclusive(const FModifierOrInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(!Private::ContainsDuplicates(Parameters.MaskedElements));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::OrInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(Parameters.MaskedElements.Num());
|
|
|
|
SubElementModifiers.Append(Parameters.Elements);
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
|
|
MaskedElementNames.Append(Parameters.MaskedElements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateArray(const FModifierArrayParameters Parameters, const FName Tag)
|
|
{
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::Array);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Elements.Num(); ElementIdx++)
|
|
{
|
|
SubElementNames.Add(NAME_None);
|
|
}
|
|
SubElementModifiers.Append(Parameters.Elements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FModifierElement FModifier::CreateEncoding(const FModifierEncodingParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
const int32 Index = Types.Add(EType::Encoding);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousMaskeds.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
ElementDataOffsets.Add(SubElementModifiers.Num());
|
|
ElementDataNums.Add(1);
|
|
|
|
MaskedDataOffsets.Add(MaskedElementNames.Num());
|
|
MaskedDataNums.Add(0);
|
|
|
|
SubElementNames.Add(NAME_None);
|
|
SubElementModifiers.Add(Parameters.Element);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
bool FModifier::IsValid(const FModifierElement Element) const
|
|
{
|
|
return Element.Generation == Generation && Element.Index != INDEX_NONE;
|
|
}
|
|
|
|
EType FModifier::GetType(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Types[Element.Index];
|
|
}
|
|
|
|
FName FModifier::GetTag(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Tags[Element.Index];
|
|
}
|
|
|
|
FModifierContinuousParameters FModifier::GetContinuous(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Continuous);
|
|
|
|
FModifierContinuousParameters Parameters;
|
|
Parameters.Masked = TArrayView<const bool>(ContinuousMaskeds.GetData() + ContinuousDataOffsets[Element.Index], ContinuousDataNums[Element.Index]);
|
|
Parameters.MaskedValues = TArrayView<const float>(ContinuousMaskedValues.GetData() + ContinuousDataOffsets[Element.Index], ContinuousDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierDiscreteExclusiveParameters FModifier::GetDiscreteExclusive(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteExclusive);
|
|
|
|
FModifierDiscreteExclusiveParameters Parameters;
|
|
Parameters.MaskedIndices = TArrayView<const int32>(DiscreteValues.GetData() + DiscreteDataOffsets[Element.Index], DiscreteDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierDiscreteInclusiveParameters FModifier::GetDiscreteInclusive(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteInclusive);
|
|
|
|
FModifierDiscreteInclusiveParameters Parameters;
|
|
Parameters.MaskedIndices = TArrayView<const int32>(DiscreteValues.GetData() + DiscreteDataOffsets[Element.Index], DiscreteDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierNamedDiscreteExclusiveParameters FModifier::GetNamedDiscreteExclusive(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteExclusive);
|
|
|
|
FModifierNamedDiscreteExclusiveParameters Parameters;
|
|
Parameters.MaskedElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierNamedDiscreteInclusiveParameters FModifier::GetNamedDiscreteInclusive(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteInclusive);
|
|
|
|
FModifierNamedDiscreteInclusiveParameters Parameters;
|
|
Parameters.MaskedElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierAndParameters FModifier::GetAnd(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::And);
|
|
|
|
FModifierAndParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
Parameters.Elements = TArrayView<const FModifierElement>(SubElementModifiers.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierOrExclusiveParameters FModifier::GetOrExclusive(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrExclusive);
|
|
|
|
FModifierOrExclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
Parameters.Elements = TArrayView<const FModifierElement>(SubElementModifiers.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
Parameters.MaskedElements = TArrayView<const FName>(MaskedElementNames.GetData() + MaskedDataOffsets[Element.Index], MaskedDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierOrInclusiveParameters FModifier::GetOrInclusive(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrInclusive);
|
|
|
|
FModifierOrInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
Parameters.Elements = TArrayView<const FModifierElement>(SubElementModifiers.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
Parameters.MaskedElements = TArrayView<const FName>(MaskedElementNames.GetData() + MaskedDataOffsets[Element.Index], MaskedDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierArrayParameters FModifier::GetArray(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Array);
|
|
|
|
FModifierArrayParameters Parameters;
|
|
Parameters.Elements = TArrayView<const FModifierElement>(SubElementModifiers.GetData() + ElementDataOffsets[Element.Index], ElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FModifierEncodingParameters FModifier::GetEncoding(const FModifierElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Encoding);
|
|
|
|
FModifierEncodingParameters Parameters;
|
|
Parameters.Element = SubElementModifiers[ElementDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
uint32 FModifier::GetGeneration() const
|
|
{
|
|
return Generation;
|
|
}
|
|
|
|
void FModifier::Empty()
|
|
{
|
|
Types.Empty();
|
|
Tags.Empty();
|
|
ContinuousDataOffsets.Empty();
|
|
ContinuousDataNums.Empty();
|
|
DiscreteDataOffsets.Empty();
|
|
DiscreteDataNums.Empty();
|
|
ElementDataOffsets.Empty();
|
|
ElementDataNums.Empty();
|
|
MaskedDataOffsets.Empty();
|
|
MaskedDataNums.Empty();
|
|
|
|
ContinuousMaskeds.Empty();
|
|
ContinuousMaskedValues.Empty();
|
|
DiscreteValues.Empty();
|
|
SubElementModifiers.Empty();
|
|
SubElementNames.Empty();
|
|
MaskedElementNames.Empty();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
bool FModifier::IsEmpty() const
|
|
{
|
|
return Types.IsEmpty();
|
|
}
|
|
|
|
void FModifier::Reset()
|
|
{
|
|
Types.Reset();
|
|
Tags.Reset();
|
|
ContinuousDataOffsets.Reset();
|
|
ContinuousDataNums.Reset();
|
|
DiscreteDataOffsets.Reset();
|
|
DiscreteDataNums.Reset();
|
|
ElementDataOffsets.Reset();
|
|
ElementDataNums.Reset();
|
|
MaskedDataOffsets.Reset();
|
|
MaskedDataNums.Reset();
|
|
|
|
ContinuousMaskeds.Reset();
|
|
ContinuousMaskedValues.Reset();
|
|
DiscreteValues.Reset();
|
|
SubElementModifiers.Reset();
|
|
SubElementNames.Reset();
|
|
MaskedElementNames.Reset();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
|
|
namespace Private
|
|
{
|
|
static inline NNE::RuntimeBasic::FModelBuilder::EActivationFunction GetNNEActivationFunction(const EEncodingActivationFunction ActivationFunction)
|
|
{
|
|
switch (ActivationFunction)
|
|
{
|
|
case EEncodingActivationFunction::ReLU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ReLU;
|
|
case EEncodingActivationFunction::ELU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ELU;
|
|
case EEncodingActivationFunction::TanH: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::TanH;
|
|
case EEncodingActivationFunction::GELU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::GELU;
|
|
default: checkNoEntry(); return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ReLU;
|
|
}
|
|
}
|
|
|
|
static inline int32 HashFNameStable(const FName Name)
|
|
{
|
|
const FString NameString = Name.ToString().ToLower();
|
|
return (int32)CityHash32(
|
|
(const char*)NameString.GetCharArray().GetData(),
|
|
NameString.GetCharArray().GetTypeSize() *
|
|
NameString.GetCharArray().Num());
|
|
}
|
|
|
|
static inline int32 HashInt(const int32 Int)
|
|
{
|
|
return (int32)CityHash32((const char*)&Int, sizeof(int32));
|
|
}
|
|
|
|
static inline int32 HashCombine(const TArrayView<const int32> Hashes)
|
|
{
|
|
return (int32)CityHash32((const char*)Hashes.GetData(), Hashes.Num() * Hashes.GetTypeSize());
|
|
}
|
|
|
|
static inline int32 HashElements(
|
|
const FSchema& Schema,
|
|
const TArrayView<const FName> SchemaElementNames,
|
|
const int32 Salt)
|
|
{
|
|
// Note: Here we xor all entries together.
|
|
// This makes the hash in invariant to the ordering of names which is actually what we want
|
|
// since this array is representing a set-like structure and it is fine to pass elements in a different order.
|
|
|
|
int32 Hash = 0x9de53147;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaElementNames.Num(); SchemaElementIdx++)
|
|
{
|
|
Hash ^= HashFNameStable(SchemaElementNames[SchemaElementIdx]);
|
|
}
|
|
|
|
return Hash;
|
|
}
|
|
|
|
static inline int32 HashElements(
|
|
const FSchema& Schema,
|
|
const TArrayView<const FName> SchemaElementNames,
|
|
const TArrayView<const FSchemaElement> SchemaElements,
|
|
const int32 Salt)
|
|
{
|
|
// Note: Here we xor all entries together.
|
|
// This makes the hash in invariant to the ordering of pairs of names and elements
|
|
// which is actually what we want since these two arrays are representing a map-like
|
|
// structure and it is fine to pass keys and values in a different order.
|
|
|
|
int32 Hash = 0x5b3bbe4d;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaElements.Num(); SchemaElementIdx++)
|
|
{
|
|
Hash ^= HashCombine({ HashFNameStable(SchemaElementNames[SchemaElementIdx]), GetSchemaObjectsCompatibilityHash(Schema, SchemaElements[SchemaElementIdx], Salt) });
|
|
}
|
|
|
|
return Hash;
|
|
}
|
|
}
|
|
|
|
int32 GetSchemaObjectsCompatibilityHash(
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const int32 Salt)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
|
|
const int32 Hash = Private::HashCombine({ Salt, Private::HashInt((int32)SchemaElementType) });
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
case EType::Null: return Hash;
|
|
|
|
case EType::Continuous: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetContinuous(SchemaElement).Num) });
|
|
|
|
case EType::DiscreteExclusive: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetDiscreteExclusive(SchemaElement).Num) });
|
|
|
|
case EType::DiscreteInclusive: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetDiscreteInclusive(SchemaElement).Num) });
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const FSchemaNamedDiscreteExclusiveParameters Parameters = Schema.GetNamedDiscreteExclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Salt) });
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const FSchemaNamedDiscreteInclusiveParameters Parameters = Schema.GetNamedDiscreteInclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Salt) });
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) });
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) });
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) });
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashInt(Parameters.Num), GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt) });
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
return GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt);
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return 0;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
bool AreSchemaObjectsCompatible(
|
|
const FSchema& SchemaA,
|
|
const FSchemaElement SchemaElementA,
|
|
const FSchema& SchemaB,
|
|
const FSchemaElement SchemaElementB)
|
|
{
|
|
check(SchemaA.IsValid(SchemaElementA));
|
|
check(SchemaB.IsValid(SchemaElementB));
|
|
|
|
const EType SchemaElementTypeA = SchemaA.GetType(SchemaElementA);
|
|
const EType SchemaElementTypeB = SchemaB.GetType(SchemaElementB);
|
|
|
|
// If any element is an encoding element we forward the comparison to the sub-element since encoding elements don't affect compatibility
|
|
if (SchemaElementTypeA == EType::Encoding) { return AreSchemaObjectsCompatible(SchemaA, SchemaA.GetEncoding(SchemaElementA).Element, SchemaB, SchemaElementB); }
|
|
if (SchemaElementTypeB == EType::Encoding) { return AreSchemaObjectsCompatible(SchemaA, SchemaElementA, SchemaB, SchemaB.GetEncoding(SchemaElementB).Element); }
|
|
|
|
// Otherwise if types don't match we immediately know elements are incompatible
|
|
if (SchemaElementTypeA != SchemaElementTypeB) { return false; }
|
|
|
|
// This is an early-out since if the input sizes are different we are definitely incompatible
|
|
if (SchemaA.GetActionVectorSize(SchemaElementA) != SchemaB.GetActionVectorSize(SchemaElementB)) { return false; }
|
|
|
|
switch (SchemaElementTypeA)
|
|
{
|
|
case EType::Null: return true;
|
|
|
|
case EType::Continuous: return SchemaA.GetContinuous(SchemaElementA).Num == SchemaB.GetContinuous(SchemaElementB).Num;
|
|
|
|
case EType::DiscreteExclusive: return SchemaA.GetDiscreteExclusive(SchemaElementA).Num == SchemaB.GetDiscreteExclusive(SchemaElementB).Num;
|
|
|
|
case EType::DiscreteInclusive: return SchemaA.GetDiscreteInclusive(SchemaElementA).Num == SchemaB.GetDiscreteInclusive(SchemaElementB).Num;
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const FSchemaNamedDiscreteExclusiveParameters ParametersA = SchemaA.GetNamedDiscreteExclusive(SchemaElementA);
|
|
const FSchemaNamedDiscreteExclusiveParameters ParametersB = SchemaB.GetNamedDiscreteExclusive(SchemaElementB);
|
|
|
|
if (ParametersA.ElementNames.Num() != ParametersB.ElementNames.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.ElementNames.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const FSchemaNamedDiscreteInclusiveParameters ParametersA = SchemaA.GetNamedDiscreteInclusive(SchemaElementA);
|
|
const FSchemaNamedDiscreteInclusiveParameters ParametersB = SchemaB.GetNamedDiscreteInclusive(SchemaElementB);
|
|
|
|
if (ParametersA.ElementNames.Num() != ParametersB.ElementNames.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.ElementNames.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters ParametersA = SchemaA.GetAnd(SchemaElementA);
|
|
const FSchemaAndParameters ParametersB = SchemaB.GetAnd(SchemaElementB);
|
|
|
|
if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters ParametersA = SchemaA.GetOrExclusive(SchemaElementA);
|
|
const FSchemaOrExclusiveParameters ParametersB = SchemaB.GetOrExclusive(SchemaElementB);
|
|
|
|
if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters ParametersA = SchemaA.GetOrInclusive(SchemaElementA);
|
|
const FSchemaOrInclusiveParameters ParametersB = SchemaB.GetOrInclusive(SchemaElementB);
|
|
|
|
if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters ParametersA = SchemaA.GetArray(SchemaElementA);
|
|
const FSchemaArrayParameters ParametersB = SchemaB.GetArray(SchemaElementB);
|
|
|
|
return (ParametersA.Num == ParametersB.Num) && AreSchemaObjectsCompatible(SchemaA, ParametersA.Element, SchemaB, ParametersB.Element);
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
checkf(false, TEXT("Encoding elements should always be forwarded..."));
|
|
return false;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return false;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
void MakeDecoderNetworkModelBuilderElementFromSchema(
|
|
NNE::RuntimeBasic::FModelBuilderElement& OutElement,
|
|
NNE::RuntimeBasic::FModelBuilder& Builder,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const FNetworkSettings& NetworkSettings)
|
|
{
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null:
|
|
{
|
|
OutElement = Builder.MakeCopy(0);
|
|
break;
|
|
}
|
|
|
|
case EType::Continuous:
|
|
{
|
|
const int32 ValueNum = Schema.GetContinuous(SchemaElement).Num * 2;
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
ValueNum,
|
|
Builder.MakeValuesZero(ValueNum),
|
|
Builder.MakeValuesOne(ValueNum));
|
|
break;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
const FSchemaDiscreteExclusiveParameters Parameters = Schema.GetDiscreteExclusive(SchemaElement);
|
|
|
|
TArray<float, TInlineAllocator<16>> LogPriorProbabilities;
|
|
LogPriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
for (int32 Idx = 0; Idx < Parameters.Num; Idx++)
|
|
{
|
|
// Clamp zero probabilities to the smallest (positive) float. This is approximately equal to a probability of 1:1e38
|
|
LogPriorProbabilities[Idx] = FMath::Loge(FMath::Max(LogPriorProbabilities[Idx], FLT_MIN));
|
|
}
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
Parameters.Num,
|
|
Builder.MakeValuesCopy(LogPriorProbabilities),
|
|
Builder.MakeValuesOne(Parameters.Num));
|
|
break;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
const FSchemaDiscreteInclusiveParameters Parameters = Schema.GetDiscreteInclusive(SchemaElement);
|
|
|
|
TArray<float, TInlineAllocator<16>> LogPriorProbabilities;
|
|
LogPriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
for (int32 Idx = 0; Idx < Parameters.Num; Idx++)
|
|
{
|
|
LogPriorProbabilities[Idx] = Private::Logit(LogPriorProbabilities[Idx]);
|
|
}
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
Parameters.Num,
|
|
Builder.MakeValuesCopy(LogPriorProbabilities),
|
|
Builder.MakeValuesOne(Parameters.Num));
|
|
break;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const FSchemaNamedDiscreteExclusiveParameters Parameters = Schema.GetNamedDiscreteExclusive(SchemaElement);
|
|
|
|
TArray<float, TInlineAllocator<16>> LogPriorProbabilities;
|
|
LogPriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
for (int32 Idx = 0; Idx < Parameters.ElementNames.Num(); Idx++)
|
|
{
|
|
// Clamp zero probabilities to the smallest (positive) float. This is approximately equal to a probability of 1:1e38
|
|
LogPriorProbabilities[Idx] = FMath::Loge(FMath::Max(LogPriorProbabilities[Idx], FLT_MIN));
|
|
}
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
Parameters.ElementNames.Num(),
|
|
Builder.MakeValuesCopy(LogPriorProbabilities),
|
|
Builder.MakeValuesOne(Parameters.ElementNames.Num()));
|
|
break;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const FSchemaNamedDiscreteInclusiveParameters Parameters = Schema.GetNamedDiscreteInclusive(SchemaElement);
|
|
|
|
TArray<float, TInlineAllocator<16>> LogPriorProbabilities;
|
|
LogPriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
for (int32 Idx = 0; Idx < Parameters.ElementNames.Num(); Idx++)
|
|
{
|
|
LogPriorProbabilities[Idx] = Private::Logit(LogPriorProbabilities[Idx]);
|
|
}
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
Parameters.ElementNames.Num(),
|
|
Builder.MakeValuesCopy(LogPriorProbabilities),
|
|
Builder.MakeValuesOne(Parameters.ElementNames.Num()));
|
|
break;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderLayers;
|
|
BuilderLayers.Reserve(Parameters.Elements.Num());
|
|
for (const FSchemaElement SubElement : Parameters.Elements)
|
|
{
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeDecoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings);
|
|
BuilderLayers.Emplace(BuilderSubElement);
|
|
}
|
|
|
|
OutElement = Builder.MakeConcat(BuilderLayers);
|
|
break;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderLayers;
|
|
BuilderLayers.Reserve(Parameters.Elements.Num() + 1);
|
|
for (const FSchemaElement SubElement : Parameters.Elements)
|
|
{
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeDecoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings);
|
|
BuilderLayers.Emplace(BuilderSubElement);
|
|
}
|
|
|
|
TArray<float, TInlineAllocator<16>> LogPriorProbabilities;
|
|
LogPriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
for (int32 Idx = 0; Idx < Parameters.PriorProbabilities.Num(); Idx++)
|
|
{
|
|
// Clamp zero probabilities to the smallest (positive) float. This is approximately equal to a probability of 1:1e38
|
|
LogPriorProbabilities[Idx] = FMath::Loge(FMath::Max(LogPriorProbabilities[Idx], FLT_MIN));
|
|
}
|
|
|
|
BuilderLayers.Emplace(Builder.MakeDenormalize(
|
|
LogPriorProbabilities.Num(),
|
|
Builder.MakeValuesCopy(LogPriorProbabilities),
|
|
Builder.MakeValuesOne(LogPriorProbabilities.Num())));
|
|
|
|
OutElement = Builder.MakeConcat(BuilderLayers);
|
|
break;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderLayers;
|
|
BuilderLayers.Reserve(Parameters.Elements.Num() + 1);
|
|
for (const FSchemaElement SubElement : Parameters.Elements)
|
|
{
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeDecoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings);
|
|
BuilderLayers.Emplace(BuilderSubElement);
|
|
}
|
|
|
|
TArray<float, TInlineAllocator<16>> LogPriorProbabilities;
|
|
LogPriorProbabilities.Append(Parameters.PriorProbabilities);
|
|
for (int32 Idx = 0; Idx < Parameters.PriorProbabilities.Num(); Idx++)
|
|
{
|
|
LogPriorProbabilities[Idx] = Private::Logit(LogPriorProbabilities[Idx]);
|
|
}
|
|
|
|
BuilderLayers.Emplace(Builder.MakeDenormalize(
|
|
LogPriorProbabilities.Num(),
|
|
Builder.MakeValuesCopy(LogPriorProbabilities),
|
|
Builder.MakeValuesOne(LogPriorProbabilities.Num())));
|
|
|
|
OutElement = Builder.MakeConcat(BuilderLayers);
|
|
break;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeDecoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings);
|
|
OutElement = Builder.MakeArray(Parameters.Num, BuilderSubElement);
|
|
break;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
|
|
const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(Parameters.Element);
|
|
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeDecoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings);
|
|
|
|
NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings;
|
|
LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ?
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed :
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal;
|
|
|
|
switch (NetworkSettings.WeightInitialization)
|
|
{
|
|
case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break;
|
|
case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break;
|
|
default: checkNoEntry();
|
|
}
|
|
|
|
OutElement = Builder.MakeSequence({
|
|
Builder.MakeActivation(Parameters.EncodingSize, Private::GetNNEActivationFunction(Parameters.ActivationFunction)),
|
|
Builder.MakeMLP(
|
|
Parameters.EncodingSize,
|
|
SubElementEncodedSize,
|
|
Parameters.EncodingSize,
|
|
Parameters.LayerNum + 1, // Add 1 to account for input layer
|
|
Private::GetNNEActivationFunction(Parameters.ActivationFunction),
|
|
false,
|
|
LinearLayerSettings),
|
|
BuilderSubElement,
|
|
});
|
|
|
|
break;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
}
|
|
}
|
|
|
|
checkf(OutElement.GetInputSize() == Schema.GetEncodedVectorSize(SchemaElement),
|
|
TEXT("Decoder Network Input unexpected size. Got %i, expected %i according to Schema."),
|
|
OutElement.GetInputSize(), Schema.GetEncodedVectorSize(SchemaElement));
|
|
|
|
checkf(OutElement.GetOutputSize() == Schema.GetActionDistributionVectorSize(SchemaElement),
|
|
TEXT("Decoder Network Output unexpected size. Got %i, expected %i according to Schema."),
|
|
OutElement.GetOutputSize(), Schema.GetActionDistributionVectorSize(SchemaElement));
|
|
}
|
|
|
|
void GenerateDecoderNetworkFileDataFromSchema(
|
|
TArray<uint8>& OutFileData,
|
|
uint32& OutInputSize,
|
|
uint32& OutOutputSize,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const FNetworkSettings& NetworkSettings,
|
|
const uint32 Seed)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
|
|
NNE::RuntimeBasic::FModelBuilder Builder(Seed);
|
|
NNE::RuntimeBasic::FModelBuilderElement Element;
|
|
MakeDecoderNetworkModelBuilderElementFromSchema(Element, Builder, Schema, SchemaElement, NetworkSettings);
|
|
Builder.WriteFileDataAndReset(OutFileData, OutInputSize, OutOutputSize, Element);
|
|
}
|
|
|
|
void SampleVectorFromDistributionVector(
|
|
uint32& InOutRandomState,
|
|
TLearningArrayView<1, float> OutActionVector,
|
|
const TLearningArrayView<1, const float> ActionDistributionVector,
|
|
const TLearningArrayView<1, const float> ActionModifierVector,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const float ActionNoiseScale)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
case EType::Null: break;
|
|
|
|
case EType::Continuous:
|
|
{
|
|
const int32 ValueNum = Schema.GetContinuous(SchemaElement).Num;
|
|
check(ValueNum == OutActionVector.Num());
|
|
check(ValueNum * 2 == ActionDistributionVector.Num());
|
|
check(1 + ValueNum * 2 == ActionModifierVector.Num());
|
|
|
|
if (ActionModifierVector[0])
|
|
{
|
|
TLearningArray<1, bool, TInlineAllocator<32>> Masked;
|
|
Masked.SetNumUninitialized({ ValueNum });
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
Masked[ValueIdx] = ActionModifierVector[1 + ValueIdx] == 1.0f;
|
|
}
|
|
|
|
Random::SampleDistributionIndependantNormalMasked(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector.Slice(0, ValueNum),
|
|
ActionDistributionVector.Slice(ValueNum, ValueNum),
|
|
Masked,
|
|
ActionModifierVector.Slice(1 + ValueNum, ValueNum),
|
|
ActionNoiseScale);
|
|
}
|
|
else
|
|
{
|
|
Random::SampleDistributionIndependantNormal(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector.Slice(0, ValueNum),
|
|
ActionDistributionVector.Slice(ValueNum, ValueNum),
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetDiscreteExclusive(SchemaElement).Num;
|
|
check(ValueNum == OutActionVector.Num());
|
|
check(ValueNum == ActionDistributionVector.Num());
|
|
check(1 + ValueNum == ActionModifierVector.Num());
|
|
|
|
if (ActionModifierVector[0])
|
|
{
|
|
TLearningArray<1, bool, TInlineAllocator<32>> Masked;
|
|
Masked.SetNumUninitialized({ ValueNum });
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
Masked[ValueIdx] = ActionModifierVector[1 + ValueIdx] == 1.0f;
|
|
}
|
|
check(Private::CheckExclusiveMaskValid(Masked));
|
|
|
|
Random::SampleDistributionMultinoulliMasked(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
Masked,
|
|
ActionNoiseScale);
|
|
}
|
|
else
|
|
{
|
|
Random::SampleDistributionMultinoulli(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetDiscreteInclusive(SchemaElement).Num;
|
|
check(ValueNum == OutActionVector.Num());
|
|
check(ValueNum == ActionDistributionVector.Num());
|
|
check(1 + ValueNum == ActionModifierVector.Num());
|
|
|
|
if (ActionModifierVector[0])
|
|
{
|
|
TLearningArray<1, bool, TInlineAllocator<32>> Masked;
|
|
Masked.SetNumUninitialized({ ValueNum });
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
Masked[ValueIdx] = ActionModifierVector[1 + ValueIdx] == 1.0f;
|
|
}
|
|
|
|
Random::SampleDistributionBernoulliMasked(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
Masked,
|
|
ActionNoiseScale);
|
|
}
|
|
else
|
|
{
|
|
Random::SampleDistributionBernoulli(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Num();
|
|
check(ValueNum == OutActionVector.Num());
|
|
check(ValueNum == ActionDistributionVector.Num());
|
|
check(1 + ValueNum == ActionModifierVector.Num());
|
|
|
|
if (ActionModifierVector[0])
|
|
{
|
|
TLearningArray<1, bool, TInlineAllocator<32>> Masked;
|
|
Masked.SetNumUninitialized({ ValueNum });
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
Masked[ValueIdx] = ActionModifierVector[1 + ValueIdx] == 1.0f;
|
|
}
|
|
check(Private::CheckExclusiveMaskValid(Masked));
|
|
|
|
Random::SampleDistributionMultinoulliMasked(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
Masked,
|
|
ActionNoiseScale);
|
|
}
|
|
else
|
|
{
|
|
Random::SampleDistributionMultinoulli(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Num();
|
|
check(ValueNum == OutActionVector.Num());
|
|
check(ValueNum == ActionDistributionVector.Num());
|
|
check(1 + ValueNum == ActionModifierVector.Num());
|
|
|
|
if (ActionModifierVector[0])
|
|
{
|
|
TLearningArray<1, bool, TInlineAllocator<32>> Masked;
|
|
Masked.SetNumUninitialized({ ValueNum });
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
Masked[ValueIdx] = ActionModifierVector[1 + ValueIdx] == 1.0f;
|
|
}
|
|
|
|
Random::SampleDistributionBernoulliMasked(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
Masked,
|
|
ActionNoiseScale);
|
|
}
|
|
else
|
|
{
|
|
Random::SampleDistributionBernoulli(
|
|
OutActionVector,
|
|
InOutRandomState,
|
|
ActionDistributionVector,
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
|
|
int32 SubElementActionVectorOffset = 0;
|
|
int32 SubElementActionDistributionVectorOffset = 0;
|
|
int32 SubElementActionModifierVectorOffset = 1;
|
|
|
|
for (const FSchemaElement SubElement : Parameters.Elements)
|
|
{
|
|
const int32 SubElementActionVectorSize = Schema.GetActionVectorSize(SubElement);
|
|
const int32 SubElementActionDistributionVectorSize = Schema.GetActionDistributionVectorSize(SubElement);
|
|
const int32 SubElementActionModifierVectorSize = Schema.GetActionModifierVectorSize(SubElement);
|
|
|
|
SampleVectorFromDistributionVector(
|
|
InOutRandomState,
|
|
OutActionVector.Slice(SubElementActionVectorOffset, SubElementActionVectorSize),
|
|
ActionDistributionVector.Slice(SubElementActionDistributionVectorOffset, SubElementActionDistributionVectorSize),
|
|
ActionModifierVector.Slice(SubElementActionModifierVectorOffset, SubElementActionModifierVectorSize),
|
|
Schema,
|
|
SubElement,
|
|
ActionNoiseScale);
|
|
|
|
SubElementActionVectorOffset += SubElementActionVectorSize;
|
|
SubElementActionDistributionVectorOffset += SubElementActionDistributionVectorSize;
|
|
SubElementActionModifierVectorOffset += SubElementActionModifierVectorSize;
|
|
}
|
|
|
|
check(SubElementActionVectorOffset == OutActionVector.Num());
|
|
check(SubElementActionDistributionVectorOffset == ActionDistributionVector.Num());
|
|
check(SubElementActionModifierVectorOffset == ActionModifierVector.Num());
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
|
|
const int32 SubElementActionVectorMax = Private::GetMaxActionVectorSize(Schema, Parameters.Elements);
|
|
const int32 SubElementActionDistributionVectorTotal = Private::GetTotalActionDistributionVectorSize(Schema, Parameters.Elements);
|
|
const int32 SubElementActionModifierVectorTotal = Private::GetTotalActionModifierVectorSize(Schema, Parameters.Elements);
|
|
const int32 ElementNum = Parameters.Elements.Num();
|
|
|
|
check(SubElementActionVectorMax + ElementNum == OutActionVector.Num());
|
|
check(SubElementActionDistributionVectorTotal + ElementNum == ActionDistributionVector.Num());
|
|
check(1 + ElementNum + SubElementActionModifierVectorTotal == ActionModifierVector.Num());
|
|
|
|
// Zero main part of vector
|
|
Array::Zero(OutActionVector.Slice(0, SubElementActionVectorMax));
|
|
|
|
if (ActionModifierVector[0])
|
|
{
|
|
TLearningArray<1, bool, TInlineAllocator<32>> Masked;
|
|
Masked.SetNumUninitialized({ ElementNum });
|
|
for (int32 ElementIdx = 0; ElementIdx < ElementNum; ElementIdx++)
|
|
{
|
|
Masked[ElementIdx] = ActionModifierVector[1 + ElementIdx] == 1.0f;
|
|
}
|
|
check(Private::CheckExclusiveMaskValid(Masked));
|
|
|
|
// Sample which sub-element to generate
|
|
Random::SampleDistributionMultinoulliMasked(
|
|
OutActionVector.Slice(SubElementActionVectorMax, ElementNum),
|
|
InOutRandomState,
|
|
ActionDistributionVector.Slice(SubElementActionDistributionVectorTotal, ElementNum),
|
|
Masked,
|
|
ActionNoiseScale);
|
|
}
|
|
else
|
|
{
|
|
// Sample which sub-element to generate
|
|
Random::SampleDistributionMultinoulli(
|
|
OutActionVector.Slice(SubElementActionVectorMax, ElementNum),
|
|
InOutRandomState,
|
|
ActionDistributionVector.Slice(SubElementActionDistributionVectorTotal, ElementNum),
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
int32 SubElementsSampled = 0;
|
|
int32 SubElementActionDistributionVectorOffset = 0;
|
|
int32 SubElementActionModifierVectorOffset = 1 + ElementNum;
|
|
|
|
for (int32 SubElementIdx = 0; SubElementIdx < ElementNum; SubElementIdx++)
|
|
{
|
|
const FSchemaElement SubElement = Parameters.Elements[SubElementIdx];
|
|
const int32 SubElementActionVectorSize = Schema.GetActionVectorSize(SubElement);
|
|
const int32 SubElementActionDistributionVectorSize = Schema.GetActionDistributionVectorSize(SubElement);
|
|
const int32 SubElementActionModifierVectorSize = Schema.GetActionModifierVectorSize(SubElement);
|
|
|
|
check(SubElementActionVectorSize <= SubElementActionVectorMax);
|
|
|
|
if (OutActionVector[SubElementActionVectorMax + SubElementIdx])
|
|
{
|
|
// Sample Sub-Element
|
|
SampleVectorFromDistributionVector(
|
|
InOutRandomState,
|
|
OutActionVector.Slice(0, SubElementActionVectorSize),
|
|
ActionDistributionVector.Slice(SubElementActionDistributionVectorOffset, SubElementActionDistributionVectorSize),
|
|
ActionModifierVector.Slice(SubElementActionModifierVectorOffset, SubElementActionModifierVectorSize),
|
|
Schema,
|
|
SubElement,
|
|
ActionNoiseScale);
|
|
|
|
SubElementsSampled++;
|
|
}
|
|
|
|
SubElementActionDistributionVectorOffset += SubElementActionDistributionVectorSize;
|
|
SubElementActionModifierVectorOffset += SubElementActionModifierVectorSize;
|
|
}
|
|
|
|
check(SubElementsSampled == 1); // Exactly one sub-element should have been sampled
|
|
check(SubElementActionDistributionVectorOffset == SubElementActionDistributionVectorTotal);
|
|
check(SubElementActionModifierVectorOffset == 1 + ElementNum + SubElementActionModifierVectorTotal);
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
|
|
const int32 SubElementActionVectorTotal = Private::GetTotalActionVectorSize(Schema, Parameters.Elements);
|
|
const int32 SubElementActionDistributionVectorTotal = Private::GetTotalActionDistributionVectorSize(Schema, Parameters.Elements);
|
|
const int32 SubElementActionModifierVectorTotal = Private::GetTotalActionModifierVectorSize(Schema, Parameters.Elements);
|
|
const int32 ElementNum = Parameters.Elements.Num();
|
|
|
|
check(SubElementActionVectorTotal + ElementNum == OutActionVector.Num());
|
|
check(SubElementActionDistributionVectorTotal + ElementNum == ActionDistributionVector.Num());
|
|
check(1 + ElementNum + SubElementActionModifierVectorTotal == ActionModifierVector.Num());
|
|
|
|
// Zero main part of vector
|
|
Array::Zero(OutActionVector.Slice(0, SubElementActionVectorTotal));
|
|
|
|
if (ActionModifierVector[0])
|
|
{
|
|
TLearningArray<1, bool, TInlineAllocator<32>> Masked;
|
|
Masked.SetNumUninitialized({ ElementNum });
|
|
for (int32 ElementIdx = 0; ElementIdx < ElementNum; ElementIdx++)
|
|
{
|
|
Masked[ElementIdx] = ActionModifierVector[1 + ElementIdx] == 1.0f;
|
|
}
|
|
|
|
// Sample which sub-elements to generate
|
|
Random::SampleDistributionBernoulliMasked(
|
|
OutActionVector.Slice(SubElementActionVectorTotal, ElementNum),
|
|
InOutRandomState,
|
|
ActionDistributionVector.Slice(SubElementActionDistributionVectorTotal, ElementNum),
|
|
Masked,
|
|
ActionNoiseScale);
|
|
}
|
|
else
|
|
{
|
|
// Sample which sub-elements to generate
|
|
Random::SampleDistributionBernoulli(
|
|
OutActionVector.Slice(SubElementActionVectorTotal, ElementNum),
|
|
InOutRandomState,
|
|
ActionDistributionVector.Slice(SubElementActionDistributionVectorTotal, ElementNum),
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
int32 SubElementActionVectorOffset = 0;
|
|
int32 SubElementActionDistributionVectorOffset = 0;
|
|
int32 SubElementActionModifierVectorOffset = 1 + ElementNum;
|
|
|
|
for (int32 SubElementIdx = 0; SubElementIdx < ElementNum; SubElementIdx++)
|
|
{
|
|
const FSchemaElement SubElement = Parameters.Elements[SubElementIdx];
|
|
const int32 SubElementActionVectorSize = Schema.GetActionVectorSize(SubElement);
|
|
const int32 SubElementActionDistributionVectorSize = Schema.GetActionDistributionVectorSize(SubElement);
|
|
const int32 SubElementActionModifierVectorSize = Schema.GetActionModifierVectorSize(SubElement);
|
|
|
|
if (OutActionVector[SubElementActionVectorTotal + SubElementIdx])
|
|
{
|
|
// Sample sub-elements
|
|
SampleVectorFromDistributionVector(
|
|
InOutRandomState,
|
|
OutActionVector.Slice(SubElementActionVectorOffset, SubElementActionVectorSize),
|
|
ActionDistributionVector.Slice(SubElementActionDistributionVectorOffset, SubElementActionDistributionVectorSize),
|
|
ActionModifierVector.Slice(SubElementActionModifierVectorOffset, SubElementActionModifierVectorSize),
|
|
Schema,
|
|
SubElement,
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
SubElementActionVectorOffset += SubElementActionVectorSize;
|
|
SubElementActionDistributionVectorOffset += SubElementActionDistributionVectorSize;
|
|
SubElementActionModifierVectorOffset += SubElementActionModifierVectorSize;
|
|
}
|
|
|
|
check(SubElementActionVectorOffset == SubElementActionVectorTotal);
|
|
check(SubElementActionDistributionVectorOffset == SubElementActionDistributionVectorTotal);
|
|
check(SubElementActionModifierVectorOffset == 1 + ElementNum + SubElementActionModifierVectorTotal);
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
|
|
const int32 SubElementActionVectorSize = Schema.GetActionVectorSize(Parameters.Element);
|
|
const int32 SubElementActionDistributionVectorSize = Schema.GetActionDistributionVectorSize(Parameters.Element);
|
|
const int32 SubElementActionModifierVectorSize = Schema.GetActionModifierVectorSize(Parameters.Element);
|
|
|
|
check(SubElementActionVectorSize * Parameters.Num == OutActionVector.Num());
|
|
check(SubElementActionDistributionVectorSize * Parameters.Num == ActionDistributionVector.Num());
|
|
check(1 + SubElementActionModifierVectorSize * Parameters.Num == ActionModifierVector.Num());
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Num; ElementIdx++)
|
|
{
|
|
SampleVectorFromDistributionVector(
|
|
InOutRandomState,
|
|
OutActionVector.Slice(ElementIdx * SubElementActionVectorSize, SubElementActionVectorSize),
|
|
ActionDistributionVector.Slice(ElementIdx * SubElementActionDistributionVectorSize, SubElementActionDistributionVectorSize),
|
|
ActionModifierVector.Slice(1 + ElementIdx * SubElementActionModifierVectorSize, SubElementActionModifierVectorSize),
|
|
Schema,
|
|
Parameters.Element,
|
|
ActionNoiseScale);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
const int32 SubElementActionModifierVectorSize = Schema.GetActionModifierVectorSize(Parameters.Element);
|
|
|
|
SampleVectorFromDistributionVector(
|
|
InOutRandomState,
|
|
OutActionVector,
|
|
ActionDistributionVector,
|
|
ActionModifierVector.Slice(1, SubElementActionModifierVectorSize),
|
|
Schema,
|
|
Parameters.Element,
|
|
ActionNoiseScale);
|
|
|
|
break;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
void SetVectorFromObject(
|
|
TLearningArrayView<1, float> OutActionVector,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const FObject& Object,
|
|
const FObjectElement ObjectElement)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
check(Object.IsValid(ObjectElement));
|
|
check(OutActionVector.Num() == Schema.GetActionVectorSize(SchemaElement));
|
|
|
|
// Check that the types match
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
const EType ObjectElementType = Object.GetType(ObjectElement);
|
|
check(ObjectElementType == SchemaElementType);
|
|
|
|
// Zero Action Vector
|
|
|
|
Array::Zero(OutActionVector);
|
|
|
|
// Logic for each specific element type
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null: return;
|
|
|
|
case EType::Continuous:
|
|
{
|
|
// Check the input sizes match
|
|
|
|
const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement);
|
|
|
|
TArrayView<const float> ActionValues = Object.GetContinuous(ObjectElement).Values;
|
|
check(Schema.GetActionVectorSize(SchemaElement) == ActionValues.Num());
|
|
check(Schema.GetActionVectorSize(SchemaElement) == OutActionVector.Num());
|
|
check(Schema.GetActionVectorSize(SchemaElement) == SchemaParameters.Num);
|
|
|
|
// Copy in and scale the values from the action object
|
|
|
|
const int32 ValueNum = SchemaParameters.Num;
|
|
const float ValueScale = FMath::Max(SchemaParameters.Scale, UE_SMALL_NUMBER);
|
|
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
OutActionVector[ValueIdx] = ActionValues[ValueIdx] / ValueScale;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
const int32 ActionValue = Object.GetDiscreteExclusive(ObjectElement).DiscreteIndex;
|
|
check(Schema.GetActionVectorSize(SchemaElement) > ActionValue && ActionValue >= 0);
|
|
check(Schema.GetActionVectorSize(SchemaElement) == OutActionVector.Num());
|
|
|
|
// Set the single value in the action vector
|
|
|
|
OutActionVector[ActionValue] = 1.0f;
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
const TArrayView<const int32> ActionValues = Object.GetDiscreteInclusive(ObjectElement).DiscreteIndices;
|
|
check(Schema.GetActionVectorSize(SchemaElement) >= ActionValues.Num());
|
|
check(Schema.GetActionVectorSize(SchemaElement) == OutActionVector.Num());
|
|
|
|
// Set values in the action vector
|
|
|
|
for (int32 ActionValueIdx = 0; ActionValueIdx < ActionValues.Num(); ActionValueIdx++)
|
|
{
|
|
check(Schema.GetActionVectorSize(SchemaElement) > ActionValues[ActionValueIdx] && ActionValues[ActionValueIdx] >= 0);
|
|
OutActionVector[ActionValues[ActionValueIdx]] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames;
|
|
const FName ActionValue = Object.GetNamedDiscreteExclusive(ObjectElement).ElementName;
|
|
check(Schema.GetActionVectorSize(SchemaElement) == OutActionVector.Num());
|
|
|
|
// Set the single value in the action vector
|
|
const int32 ActionIndex = SchemaNames.Find(ActionValue);
|
|
check(ActionIndex != INDEX_NONE);
|
|
OutActionVector[ActionIndex] = 1.0f;
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames;
|
|
const TArrayView<const FName> ActionValues = Object.GetNamedDiscreteInclusive(ObjectElement).ElementNames;
|
|
check(Schema.GetActionVectorSize(SchemaElement) >= ActionValues.Num());
|
|
check(Schema.GetActionVectorSize(SchemaElement) == OutActionVector.Num());
|
|
|
|
// Set values in the action vector
|
|
for (int32 ActionValueIdx = 0; ActionValueIdx < ActionValues.Num(); ActionValueIdx++)
|
|
{
|
|
const int32 ActionIndex = SchemaNames.Find(ActionValues[ActionValueIdx]);
|
|
check(ActionIndex != INDEX_NONE);
|
|
OutActionVector[ActionIndex] = 1.0f;
|
|
}
|
|
return;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
// Check the number of sub-elements match
|
|
|
|
const FSchemaAndParameters SchemaParameters = Schema.GetAnd(SchemaElement);
|
|
const FObjectAndParameters ObjectParameters = Object.GetAnd(ObjectElement);
|
|
check(SchemaParameters.Elements.Num() == ObjectParameters.Elements.Num());
|
|
|
|
// Set the Sub-elements
|
|
|
|
int32 SubElementOffset = 0;
|
|
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 ObjectElementIndex = ObjectParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]);
|
|
check(ObjectElementIndex != INDEX_NONE);
|
|
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(SchemaParameters.Elements[SchemaElementIdx]);
|
|
|
|
SetVectorFromObject(
|
|
OutActionVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIdx],
|
|
Object,
|
|
ObjectParameters.Elements[ObjectElementIndex]);
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
check(SubElementOffset == OutActionVector.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
// Check only one sub-element is given and index is valid
|
|
|
|
const FSchemaOrExclusiveParameters SchemaParameters = Schema.GetOrExclusive(SchemaElement);
|
|
const FObjectOrExclusiveParameters ObjectParameters = Object.GetOrExclusive(ObjectElement);
|
|
|
|
const int32 SchemaElementIndex = SchemaParameters.ElementNames.Find(ObjectParameters.ElementName);
|
|
check(SchemaElementIndex != INDEX_NONE);
|
|
|
|
// Set the sub-element
|
|
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(SchemaParameters.Elements[SchemaElementIndex]);
|
|
|
|
SetVectorFromObject(
|
|
OutActionVector.Slice(0, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIndex],
|
|
Object,
|
|
ObjectParameters.Element);
|
|
|
|
// Set Mask
|
|
|
|
const int32 MaxSubElementSize = Private::GetMaxActionVectorSize(Schema, SchemaParameters.Elements);
|
|
|
|
OutActionVector[MaxSubElementSize + SchemaElementIndex] = 1.0f;
|
|
|
|
check(OutActionVector.Num() == MaxSubElementSize + SchemaParameters.Elements.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
// Check all indices are in range
|
|
|
|
const FSchemaOrInclusiveParameters SchemaParameters = Schema.GetOrInclusive(SchemaElement);
|
|
const FObjectOrInclusiveParameters ObjectParameters = Object.GetOrInclusive(ObjectElement);
|
|
check(ObjectParameters.Elements.Num() <= SchemaParameters.Elements.Num());
|
|
|
|
// Update sub-elements
|
|
|
|
int32 SubElementOffset = 0;
|
|
|
|
for (int32 ObjectElementIdx = 0; ObjectElementIdx < ObjectParameters.Elements.Num(); ObjectElementIdx++)
|
|
{
|
|
const int32 SchemaElementIdx = SchemaParameters.ElementNames.Find(ObjectParameters.ElementNames[ObjectElementIdx]);
|
|
check(SchemaElementIdx != INDEX_NONE);
|
|
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(SchemaParameters.Elements[SchemaElementIdx]);
|
|
|
|
SetVectorFromObject(
|
|
OutActionVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIdx],
|
|
Object,
|
|
ObjectParameters.Elements[ObjectElementIdx]);
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
// Set Mask
|
|
|
|
check(SubElementOffset + SchemaParameters.Elements.Num() == OutActionVector.Num());
|
|
|
|
for (int32 ObjectElementIdx = 0; ObjectElementIdx < ObjectParameters.Elements.Num(); ObjectElementIdx++)
|
|
{
|
|
const int32 SchemaElementIdx = SchemaParameters.ElementNames.Find(ObjectParameters.ElementNames[ObjectElementIdx]);
|
|
check(SchemaElementIdx != INDEX_NONE);
|
|
|
|
OutActionVector[SubElementOffset + SchemaElementIdx] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
// Check number of array elements is correct
|
|
|
|
const FSchemaArrayParameters SchemaParameters = Schema.GetArray(SchemaElement);
|
|
const FObjectArrayParameters ObjectParameters = Object.GetArray(ObjectElement);
|
|
check(SchemaParameters.Num == ObjectParameters.Elements.Num());
|
|
|
|
// Update sub-elements
|
|
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(SchemaParameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < SchemaParameters.Num; ElementIdx++)
|
|
{
|
|
SetVectorFromObject(
|
|
OutActionVector.Slice(ElementIdx * SubElementSize, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Element,
|
|
Object,
|
|
ObjectParameters.Elements[ElementIdx]);
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters SchemaParameters = Schema.GetEncoding(SchemaElement);
|
|
const FObjectEncodingParameters ObjectParameters = Object.GetEncoding(ObjectElement);
|
|
|
|
SetVectorFromObject(
|
|
OutActionVector,
|
|
Schema,
|
|
SchemaParameters.Element,
|
|
Object,
|
|
ObjectParameters.Element);
|
|
|
|
return;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
void GetObjectFromVector(
|
|
FObject& OutObject,
|
|
FObjectElement& OutObjectElement,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const TLearningArrayView<1, const float> ActionVector)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
|
|
// Check that the types match
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
const FName SchemaElementTag = Schema.GetTag(SchemaElement);
|
|
|
|
// Get Action Vector Size
|
|
|
|
const int32 ActionVectorSize = ActionVector.Num();
|
|
check(ActionVectorSize == Schema.GetActionVectorSize(SchemaElement));
|
|
|
|
// Logic for each specific element type
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null:
|
|
{
|
|
OutObjectElement = OutObject.CreateNull(SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Continuous:
|
|
{
|
|
const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement);
|
|
check(ActionVectorSize == SchemaParameters.Num);
|
|
|
|
const int32 ValueNum = SchemaParameters.Num;
|
|
const float ValueScale = FMath::Max(SchemaParameters.Scale, UE_SMALL_NUMBER);
|
|
|
|
TLearningArray<1, float, TInlineAllocator<32>> ActionValues;
|
|
ActionValues.SetNumUninitialized({ ValueNum });
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
ActionValues[ValueIdx] = ValueScale * ActionVector[ValueIdx];
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateContinuous({ MakeArrayView(ActionValues.GetData(), ActionValues.Num()) }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
check(ActionVectorSize == Schema.GetDiscreteExclusive(SchemaElement).Num);
|
|
|
|
// Find Index
|
|
int32 ExclusiveIndex = INDEX_NONE;
|
|
for (int32 Idx = 0; Idx < ActionVectorSize; Idx++)
|
|
{
|
|
check(ActionVector[Idx] == 0.0f || ActionVector[Idx] == 1.0f);
|
|
if (ActionVector[Idx])
|
|
{
|
|
ExclusiveIndex = Idx;
|
|
break;
|
|
}
|
|
}
|
|
check(ExclusiveIndex != INDEX_NONE);
|
|
|
|
OutObjectElement = OutObject.CreateDiscreteExclusive({ ExclusiveIndex }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
check(ActionVectorSize == Schema.GetDiscreteInclusive(SchemaElement).Num);
|
|
|
|
// Find Indices
|
|
TArray<int32, TInlineAllocator<8>> InclusiveIndices;
|
|
InclusiveIndices.Reserve(ActionVectorSize);
|
|
for (int32 Idx = 0; Idx < ActionVectorSize; Idx++)
|
|
{
|
|
check(ActionVector[Idx] == 0.0f || ActionVector[Idx] == 1.0f);
|
|
if (ActionVector[Idx])
|
|
{
|
|
InclusiveIndices.Add(Idx);
|
|
}
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateDiscreteInclusive({ InclusiveIndices }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames;
|
|
check(ActionVectorSize == Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Num());
|
|
|
|
// Find Name
|
|
FName ExclusiveName = NAME_None;
|
|
for (int32 Idx = 0; Idx < ActionVectorSize; Idx++)
|
|
{
|
|
check(ActionVector[Idx] == 0.0f || ActionVector[Idx] == 1.0f);
|
|
if (ActionVector[Idx])
|
|
{
|
|
ExclusiveName = SchemaNames[Idx];
|
|
break;
|
|
}
|
|
}
|
|
check(ExclusiveName != NAME_None);
|
|
|
|
OutObjectElement = OutObject.CreateNamedDiscreteExclusive({ ExclusiveName }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames;
|
|
check(ActionVectorSize == Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Num());
|
|
|
|
// Find Names
|
|
TArray<FName, TInlineAllocator<8>> InclusiveNames;
|
|
InclusiveNames.Reserve(ActionVectorSize);
|
|
for (int32 Idx = 0; Idx < ActionVectorSize; Idx++)
|
|
{
|
|
check(ActionVector[Idx] == 0.0f || ActionVector[Idx] == 1.0f);
|
|
if (ActionVector[Idx])
|
|
{
|
|
InclusiveNames.Add(SchemaNames[Idx]);
|
|
}
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateNamedDiscreteInclusive({ InclusiveNames }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
|
|
// Create Sub-elements
|
|
|
|
TArray<FObjectElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(Parameters.Elements.Num());
|
|
|
|
int32 SubElementOffset = 0;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < Parameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(Parameters.Elements[SchemaElementIdx]);
|
|
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElements[SchemaElementIdx],
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIdx],
|
|
ActionVector.Slice(SubElementOffset, SubElementSize));
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset == ActionVectorSize);
|
|
|
|
OutObjectElement = OutObject.CreateAnd({ Parameters.ElementNames, SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
|
|
// Find active element
|
|
|
|
const int32 MaxSubElementSize = Private::GetMaxActionVectorSize(Schema, Parameters.Elements);
|
|
|
|
int32 SchemaElementIndex = INDEX_NONE;
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
|
|
{
|
|
check(ActionVector[MaxSubElementSize + SubElementIdx] == 0.0f || ActionVector[MaxSubElementSize + SubElementIdx] == 1.0f);
|
|
if (ActionVector[MaxSubElementSize + SubElementIdx])
|
|
{
|
|
SchemaElementIndex = SubElementIdx;
|
|
break;
|
|
}
|
|
}
|
|
check(SchemaElementIndex != INDEX_NONE);
|
|
|
|
// Create sub-element
|
|
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(Parameters.Elements[SchemaElementIndex]);
|
|
|
|
FObjectElement SubElement;
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIndex],
|
|
ActionVector.Slice(0, SubElementSize));
|
|
|
|
OutObjectElement = OutObject.CreateOrExclusive({ Parameters.ElementNames[SchemaElementIndex], SubElement }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
|
|
// Find total sub-element size
|
|
|
|
const int32 TotalSubElementSize = Private::GetTotalActionVectorSize(Schema, Parameters.Elements);
|
|
|
|
// Create sub-elements
|
|
|
|
TArray<FName, TInlineAllocator<8>> SubElementNames;
|
|
TArray<FObjectElement, TInlineAllocator<8>> SubElements;
|
|
SubElementNames.Reserve(Parameters.Elements.Num());
|
|
SubElements.Reserve(Parameters.Elements.Num());
|
|
|
|
int32 SubElementOffset = 0;
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(Parameters.Elements[SubElementIdx]);
|
|
|
|
check(ActionVector[TotalSubElementSize + SubElementIdx] == 0.0f || ActionVector[TotalSubElementSize + SubElementIdx] == 1.0f);
|
|
if (ActionVector[TotalSubElementSize + SubElementIdx])
|
|
{
|
|
FObjectElement SubElement;
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Elements[SubElementIdx],
|
|
ActionVector.Slice(SubElementOffset, SubElementSize));
|
|
|
|
SubElementNames.Add(Parameters.ElementNames[SubElementIdx]);
|
|
SubElements.Add(SubElement);
|
|
}
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset + Parameters.Elements.Num() == ActionVectorSize);
|
|
|
|
OutObjectElement = OutObject.CreateOrInclusive({ SubElementNames, SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
|
|
TArray<FObjectElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(Parameters.Num);
|
|
|
|
// Create sub-elements
|
|
|
|
const int32 SubElementSize = Schema.GetActionVectorSize(Parameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Num; ElementIdx++)
|
|
{
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElements[ElementIdx],
|
|
Schema,
|
|
Parameters.Element,
|
|
ActionVector.Slice(ElementIdx * SubElementSize, SubElementSize));
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateArray({ SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
|
|
FObjectElement SubElement;
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Element,
|
|
ActionVector);
|
|
|
|
OutObjectElement = OutObject.CreateEncoding({ SubElement }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
OutObjectElement = FObjectElement();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
void SetVectorFromModifier(
|
|
TLearningArrayView<1, float> OutActionModifierVector,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const FModifier& Modifier,
|
|
const FModifierElement ModifierElement)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
check(Modifier.IsValid(ModifierElement));
|
|
check(OutActionModifierVector.Num() == Schema.GetActionModifierVectorSize(SchemaElement));
|
|
|
|
// Check that the types match
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
const EType ModifierElementType = Modifier.GetType(ModifierElement);
|
|
check(ModifierElementType == EType::Null || ModifierElementType == SchemaElementType);
|
|
|
|
// Zero Action Modifier Vector and return early if we have a null Type
|
|
|
|
Array::Zero(OutActionModifierVector);
|
|
|
|
if (ModifierElementType == EType::Null)
|
|
{
|
|
return;
|
|
}
|
|
|
|
// Indicate we have a modifier by setting the first element in the vector to 1.0f
|
|
|
|
OutActionModifierVector[0] = 1.0f;
|
|
|
|
// Logic for each specific modifier type
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null:
|
|
{
|
|
// This should never be reached
|
|
checkNoEntry();
|
|
break;
|
|
}
|
|
|
|
case EType::Continuous:
|
|
{
|
|
// Check the input sizes match
|
|
|
|
const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement);
|
|
const int32 ValueNum = SchemaParameters.Num;
|
|
|
|
const TArrayView<const bool> Masked = Modifier.GetContinuous(ModifierElement).Masked;
|
|
const TArrayView<const float> MaskedValues = Modifier.GetContinuous(ModifierElement).MaskedValues;
|
|
check(Masked.Num() == ValueNum);
|
|
check(MaskedValues.Num() == ValueNum);
|
|
check(Schema.GetActionModifierVectorSize(SchemaElement) == 1 + Masked.Num() + MaskedValues.Num());
|
|
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
OutActionModifierVector[1 + ValueIdx] = Masked[ValueIdx] ? 1.0f : 0.0f;
|
|
OutActionModifierVector[1 + ValueNum + ValueIdx] = MaskedValues[ValueIdx];
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
const TArrayView<const int32> MaskIndices = Modifier.GetDiscreteExclusive(ModifierElement).MaskedIndices;
|
|
check(Schema.GetDiscreteExclusive(SchemaElement).Num >= MaskIndices.Num());
|
|
|
|
for (int32 MaskIndicesIdx = 0; MaskIndicesIdx < MaskIndices.Num(); MaskIndicesIdx++)
|
|
{
|
|
check(Schema.GetDiscreteExclusive(SchemaElement).Num > MaskIndices[MaskIndicesIdx] && MaskIndices[MaskIndicesIdx] >= 0);
|
|
OutActionModifierVector[1 + MaskIndices[MaskIndicesIdx]] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
const TArrayView<const int32> MaskIndices = Modifier.GetDiscreteInclusive(ModifierElement).MaskedIndices;
|
|
check(Schema.GetDiscreteInclusive(SchemaElement).Num >= MaskIndices.Num());
|
|
|
|
for (int32 MaskIndicesIdx = 0; MaskIndicesIdx < MaskIndices.Num(); MaskIndicesIdx++)
|
|
{
|
|
check(Schema.GetDiscreteInclusive(SchemaElement).Num > MaskIndices[MaskIndicesIdx] && MaskIndices[MaskIndicesIdx] >= 0);
|
|
OutActionModifierVector[1 + MaskIndices[MaskIndicesIdx]] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const TArrayView<const FName> MaskNames = Modifier.GetNamedDiscreteExclusive(ModifierElement).MaskedElementNames;
|
|
check(Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Num() >= MaskNames.Num());
|
|
|
|
for (int32 MaskNameIdx = 0; MaskNameIdx < MaskNames.Num(); MaskNameIdx++)
|
|
{
|
|
const int32 MaskIdx = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Find(MaskNames[MaskNameIdx]);
|
|
check(MaskIdx != INDEX_NONE);
|
|
OutActionModifierVector[1 + MaskIdx] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const TArrayView<const FName> MaskNames = Modifier.GetNamedDiscreteInclusive(ModifierElement).MaskedElementNames;
|
|
check(Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Num() >= MaskNames.Num());
|
|
|
|
for (int32 MaskNameIdx = 0; MaskNameIdx < MaskNames.Num(); MaskNameIdx++)
|
|
{
|
|
const int32 MaskIdx = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Find(MaskNames[MaskNameIdx]);
|
|
check(MaskIdx != INDEX_NONE);
|
|
OutActionModifierVector[1 + MaskIdx] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters SchemaParameters = Schema.GetAnd(SchemaElement);
|
|
const FModifierAndParameters ModifierParameters = Modifier.GetAnd(ModifierElement);
|
|
|
|
check(OutActionModifierVector.Num() == 1 + Private::GetTotalActionModifierVectorSize(Schema, SchemaParameters.Elements));
|
|
|
|
// Set the Sub-elements
|
|
|
|
int32 SubElementOffset = 1;
|
|
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(SchemaParameters.Elements[SchemaElementIdx]);
|
|
|
|
const int32 ModifierElementIndex = ModifierParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]);
|
|
if (ModifierElementIndex != INDEX_NONE)
|
|
{
|
|
SetVectorFromModifier(
|
|
OutActionModifierVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIdx],
|
|
Modifier,
|
|
ModifierParameters.Elements[ModifierElementIndex]);
|
|
}
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
check(SubElementOffset == OutActionModifierVector.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters SchemaParameters = Schema.GetOrExclusive(SchemaElement);
|
|
const FModifierOrExclusiveParameters ModifierParameters = Modifier.GetOrExclusive(ModifierElement);
|
|
|
|
check(OutActionModifierVector.Num() == 1 + SchemaParameters.Elements.Num() + Private::GetTotalActionModifierVectorSize(Schema, SchemaParameters.Elements));
|
|
|
|
// Set the Mask
|
|
|
|
for (int32 MaskElementIdx = 0; MaskElementIdx < ModifierParameters.MaskedElements.Num(); MaskElementIdx++)
|
|
{
|
|
const int32 SchemaMaskElementIdx = SchemaParameters.ElementNames.Find(ModifierParameters.MaskedElements[MaskElementIdx]);
|
|
check(SchemaMaskElementIdx != INDEX_NONE);
|
|
OutActionModifierVector[1 + SchemaMaskElementIdx] = 1.0f;
|
|
}
|
|
|
|
// Set the Sub-elements
|
|
|
|
int32 SubElementOffset = 1 + SchemaParameters.Elements.Num();
|
|
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(SchemaParameters.Elements[SchemaElementIdx]);
|
|
|
|
const int32 ModifierElementIndex = ModifierParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]);
|
|
if (ModifierElementIndex != INDEX_NONE)
|
|
{
|
|
SetVectorFromModifier(
|
|
OutActionModifierVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIdx],
|
|
Modifier,
|
|
ModifierParameters.Elements[ModifierElementIndex]);
|
|
}
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
check(SubElementOffset == OutActionModifierVector.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters SchemaParameters = Schema.GetOrInclusive(SchemaElement);
|
|
const FModifierOrInclusiveParameters ModifierParameters = Modifier.GetOrInclusive(ModifierElement);
|
|
|
|
check(OutActionModifierVector.Num() == 1 + SchemaParameters.Elements.Num() + Private::GetTotalActionModifierVectorSize(Schema, SchemaParameters.Elements));
|
|
|
|
// Set the Mask
|
|
|
|
for (int32 MaskElementIdx = 0; MaskElementIdx < ModifierParameters.MaskedElements.Num(); MaskElementIdx++)
|
|
{
|
|
const int32 SchemaMaskElementIdx = SchemaParameters.ElementNames.Find(ModifierParameters.MaskedElements[MaskElementIdx]);
|
|
check(SchemaMaskElementIdx != INDEX_NONE);
|
|
OutActionModifierVector[1 + SchemaMaskElementIdx] = 1.0f;
|
|
}
|
|
|
|
// Set the Sub-elements
|
|
|
|
int32 SubElementOffset = 1 + SchemaParameters.Elements.Num();
|
|
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(SchemaParameters.Elements[SchemaElementIdx]);
|
|
|
|
const int32 ModifierElementIndex = ModifierParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]);
|
|
if (ModifierElementIndex != INDEX_NONE)
|
|
{
|
|
SetVectorFromModifier(
|
|
OutActionModifierVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIdx],
|
|
Modifier,
|
|
ModifierParameters.Elements[ModifierElementIndex]);
|
|
}
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
check(SubElementOffset == OutActionModifierVector.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
// Check number of array elements is correct
|
|
|
|
const FSchemaArrayParameters SchemaParameters = Schema.GetArray(SchemaElement);
|
|
const FModifierArrayParameters ModifierParameters = Modifier.GetArray(ModifierElement);
|
|
check(SchemaParameters.Num == ModifierParameters.Elements.Num());
|
|
|
|
// Update sub-elements
|
|
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(SchemaParameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < SchemaParameters.Num; ElementIdx++)
|
|
{
|
|
SetVectorFromModifier(
|
|
OutActionModifierVector.Slice(1 + ElementIdx * SubElementSize, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Element,
|
|
Modifier,
|
|
ModifierParameters.Elements[ElementIdx]);
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters SchemaParameters = Schema.GetEncoding(SchemaElement);
|
|
const FModifierEncodingParameters ModifierParameters = Modifier.GetEncoding(ModifierElement);
|
|
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(SchemaParameters.Element);
|
|
|
|
SetVectorFromModifier(
|
|
OutActionModifierVector.Slice(1, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Element,
|
|
Modifier,
|
|
ModifierParameters.Element);
|
|
|
|
return;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
void GetModifierFromVector(
|
|
FModifier& OutModifier,
|
|
FModifierElement& OutModifierElement,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const TLearningArrayView<1, const float> ActionModifierVector)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
|
|
// Get Type and Tag
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
const FName SchemaElementTag = Schema.GetTag(SchemaElement);
|
|
|
|
// Get Action Modifier Vector Size
|
|
|
|
const int32 ActionModifierVectorSize = ActionModifierVector.Num();
|
|
check(ActionModifierVectorSize == Schema.GetActionModifierVectorSize(SchemaElement));
|
|
|
|
// We always have at least one element in the ActionModifierVector which says if the element is provided
|
|
// if this first value is zero then it means nothing below is masked and we always just return the null element
|
|
|
|
check(ActionModifierVectorSize > 0);
|
|
|
|
if (ActionModifierVector[0] == 0.0f)
|
|
{
|
|
OutModifierElement = OutModifier.CreateNull(SchemaElementTag);
|
|
return;
|
|
}
|
|
else
|
|
{
|
|
check(ActionModifierVector[0] == 1.0f);
|
|
}
|
|
|
|
// Logic for each specific element type
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null:
|
|
{
|
|
OutModifierElement = OutModifier.CreateNull(SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Continuous:
|
|
{
|
|
const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement);
|
|
check(ActionModifierVectorSize == 1 + 2 * SchemaParameters.Num);
|
|
|
|
const int32 ValueNum = SchemaParameters.Num;
|
|
|
|
TLearningArray<1, bool, TInlineAllocator<32>> ActionMasked;
|
|
TLearningArray<1, float, TInlineAllocator<32>> ActionMaskedValues;
|
|
ActionMasked.SetNumUninitialized({ ValueNum });
|
|
ActionMaskedValues.SetNumUninitialized({ ValueNum });
|
|
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
ActionMasked[ValueIdx] = ActionModifierVector[1 + ValueIdx] == 1.0f;
|
|
ActionMaskedValues[ValueIdx] = ActionModifierVector[1 + ValueNum + ValueIdx];
|
|
}
|
|
|
|
OutModifierElement = OutModifier.CreateContinuous({
|
|
MakeArrayView(ActionMasked.GetData(), ActionMasked.Num()),
|
|
MakeArrayView(ActionMaskedValues.GetData(), ActionMaskedValues.Num()) }, SchemaElementTag);
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetDiscreteExclusive(SchemaElement).Num;
|
|
check(ActionModifierVectorSize == 1 + ValueNum);
|
|
|
|
// Find Indices
|
|
TArray<int32, TInlineAllocator<8>> MaskedIndices;
|
|
MaskedIndices.Reserve(ValueNum);
|
|
for (int32 Idx = 0; Idx < ValueNum; Idx++)
|
|
{
|
|
check(ActionModifierVector[1 + Idx] == 0.0f || ActionModifierVector[1 + Idx] == 1.0f);
|
|
if (ActionModifierVector[1 + Idx] == 1.0f)
|
|
{
|
|
MaskedIndices.Add(Idx);
|
|
}
|
|
}
|
|
|
|
OutModifierElement = OutModifier.CreateDiscreteExclusive({ MaskedIndices }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetDiscreteInclusive(SchemaElement).Num;
|
|
check(ActionModifierVectorSize == 1 + ValueNum);
|
|
|
|
// Find Indices
|
|
TArray<int32, TInlineAllocator<8>> MaskedIndices;
|
|
MaskedIndices.Reserve(ValueNum);
|
|
for (int32 Idx = 0; Idx < ValueNum; Idx++)
|
|
{
|
|
check(ActionModifierVector[1 + Idx] == 0.0f || ActionModifierVector[1 + Idx] == 1.0f);
|
|
if (ActionModifierVector[1 + Idx] == 1.0f)
|
|
{
|
|
MaskedIndices.Add(Idx);
|
|
}
|
|
}
|
|
|
|
OutModifierElement = OutModifier.CreateDiscreteInclusive({ MaskedIndices }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const TArrayView<const FName> ElementNames = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames;
|
|
check(ActionModifierVectorSize == 1 + ElementNames.Num());
|
|
|
|
// Find Names
|
|
TArray<FName, TInlineAllocator<8>> MaskedNames;
|
|
MaskedNames.Reserve(ElementNames.Num());
|
|
for (int32 Idx = 0; Idx < ElementNames.Num(); Idx++)
|
|
{
|
|
check(ActionModifierVector[1 + Idx] == 0.0f || ActionModifierVector[1 + Idx] == 1.0f);
|
|
if (ActionModifierVector[1 + Idx] == 1.0f)
|
|
{
|
|
MaskedNames.Add(ElementNames[Idx]);
|
|
}
|
|
}
|
|
|
|
OutModifierElement = OutModifier.CreateNamedDiscreteExclusive({ MaskedNames }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const TArrayView<const FName> ElementNames = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames;
|
|
check(ActionModifierVectorSize == 1 + ElementNames.Num());
|
|
|
|
// Find Names
|
|
TArray<FName, TInlineAllocator<8>> MaskedNames;
|
|
MaskedNames.Reserve(ElementNames.Num());
|
|
for (int32 Idx = 0; Idx < ElementNames.Num(); Idx++)
|
|
{
|
|
check(ActionModifierVector[1 + Idx] == 0.0f || ActionModifierVector[1 + Idx] == 1.0f);
|
|
if (ActionModifierVector[1 + Idx] == 1.0f)
|
|
{
|
|
MaskedNames.Add(ElementNames[Idx]);
|
|
}
|
|
}
|
|
|
|
OutModifierElement = OutModifier.CreateNamedDiscreteInclusive({ MaskedNames }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
|
|
// Create Sub-elements
|
|
|
|
TArray<FModifierElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(Parameters.Elements.Num());
|
|
|
|
int32 SubElementOffset = 1;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < Parameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(Parameters.Elements[SchemaElementIdx]);
|
|
|
|
GetModifierFromVector(
|
|
OutModifier,
|
|
SubElements[SchemaElementIdx],
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIdx],
|
|
ActionModifierVector.Slice(SubElementOffset, SubElementSize));
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset == ActionModifierVectorSize);
|
|
|
|
OutModifierElement = OutModifier.CreateAnd({ Parameters.ElementNames, SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
const int32 SubElementNum = Parameters.Elements.Num();
|
|
|
|
// Extract Mask Elements
|
|
|
|
TArray<FName, TInlineAllocator<8>> MaskedElements;
|
|
MaskedElements.Reserve(SubElementNum);
|
|
|
|
for (int32 Idx = 0; Idx < SubElementNum; Idx++)
|
|
{
|
|
check(ActionModifierVector[1 + Idx] == 0.0f || ActionModifierVector[1 + Idx] == 1.0f);
|
|
if (ActionModifierVector[1 + Idx] == 1.0f)
|
|
{
|
|
MaskedElements.Add(Parameters.ElementNames[Idx]);
|
|
}
|
|
}
|
|
|
|
// Create Sub-elements
|
|
|
|
TArray<FModifierElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(SubElementNum);
|
|
|
|
int32 SubElementOffset = 1 + SubElementNum;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SubElementNum; SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(Parameters.Elements[SchemaElementIdx]);
|
|
|
|
GetModifierFromVector(
|
|
OutModifier,
|
|
SubElements[SchemaElementIdx],
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIdx],
|
|
ActionModifierVector.Slice(SubElementOffset, SubElementSize));
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset == ActionModifierVectorSize);
|
|
|
|
OutModifierElement = OutModifier.CreateOrExclusive({ Parameters.ElementNames, SubElements, MaskedElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
const int32 SubElementNum = Parameters.Elements.Num();
|
|
|
|
// Extract Mask Elements
|
|
|
|
TArray<FName, TInlineAllocator<8>> MaskedElements;
|
|
MaskedElements.Reserve(SubElementNum);
|
|
|
|
for (int32 Idx = 0; Idx < SubElementNum; Idx++)
|
|
{
|
|
check(ActionModifierVector[1 + Idx] == 0.0f || ActionModifierVector[1 + Idx] == 1.0f);
|
|
if (ActionModifierVector[1 + Idx] == 1.0f)
|
|
{
|
|
MaskedElements.Add(Parameters.ElementNames[Idx]);
|
|
}
|
|
}
|
|
|
|
// Create Sub-elements
|
|
|
|
TArray<FModifierElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(SubElementNum);
|
|
|
|
int32 SubElementOffset = 1 + SubElementNum;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SubElementNum; SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(Parameters.Elements[SchemaElementIdx]);
|
|
|
|
GetModifierFromVector(
|
|
OutModifier,
|
|
SubElements[SchemaElementIdx],
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIdx],
|
|
ActionModifierVector.Slice(SubElementOffset, SubElementSize));
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset == ActionModifierVectorSize);
|
|
|
|
OutModifierElement = OutModifier.CreateOrInclusive({ Parameters.ElementNames, SubElements, MaskedElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
|
|
TArray<FModifierElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(Parameters.Num);
|
|
|
|
// Create sub-elements
|
|
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(Parameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Num; ElementIdx++)
|
|
{
|
|
GetModifierFromVector(
|
|
OutModifier,
|
|
SubElements[ElementIdx],
|
|
Schema,
|
|
Parameters.Element,
|
|
ActionModifierVector.Slice(1 + ElementIdx * SubElementSize, SubElementSize));
|
|
}
|
|
|
|
OutModifierElement = OutModifier.CreateArray({ SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
const int32 SubElementSize = Schema.GetActionModifierVectorSize(Parameters.Element);
|
|
|
|
FModifierElement SubElement;
|
|
GetModifierFromVector(
|
|
OutModifier,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Element,
|
|
ActionModifierVector.Slice(1, SubElementSize));
|
|
|
|
OutModifierElement = OutModifier.CreateEncoding({ SubElement }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
OutModifierElement = FModifierElement();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
} |