2376 lines
78 KiB
C++
2376 lines
78 KiB
C++
// Copyright Epic Games, Inc. All Rights Reserved.
|
|
|
|
#include "LearningObservation.h"
|
|
#include "LearningRandom.h"
|
|
|
|
#include "NNERuntimeBasicCpuBuilder.h"
|
|
|
|
namespace UE::Learning::Observation
|
|
{
|
|
namespace Private
|
|
{
|
|
static inline bool ContainsDuplicates(const TArrayView<const FName> ElementNames)
|
|
{
|
|
TSet<FName, DefaultKeyFuncs<FName>, TInlineSetAllocator<32>> ElementNameSet;
|
|
ElementNameSet.Append(ElementNames);
|
|
return ElementNames.Num() != ElementNameSet.Num();
|
|
}
|
|
|
|
static inline bool CheckAllValid(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
if (!Schema.IsValid(SubElement)) { return false; }
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static inline int32 GetMaxObservationVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size = FMath::Max(Size, Schema.GetObservationVectorSize(SubElement));
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline int32 GetTotalObservationVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size += Schema.GetObservationVectorSize(SubElement);
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline int32 GetTotalEncodedObservationVectorSize(const FSchema& Schema, const TArrayView<const FSchemaElement> Elements)
|
|
{
|
|
int32 Size = 0;
|
|
for (const FSchemaElement SubElement : Elements)
|
|
{
|
|
Size += Schema.GetEncodedVectorSize(SubElement);
|
|
}
|
|
return Size;
|
|
}
|
|
|
|
static inline bool CheckAllValid(const FObject& Object, const TArrayView<const FObjectElement> Elements)
|
|
{
|
|
for (const FObjectElement SubElement : Elements)
|
|
{
|
|
if (!Object.IsValid(SubElement)) { return false; }
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateNull(const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Null);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(0);
|
|
EncodedVectorSizes.Add(0);
|
|
TypeDataIndices.Add(INDEX_NONE);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateContinuous(const FSchemaContinuousParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Num >= 0);
|
|
check(Parameters.Scale >= 0.0f);
|
|
|
|
FContinuousData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
ElementData.Scale = Parameters.Scale;
|
|
|
|
const int32 Index = Types.Add(EType::Continuous);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Parameters.Num);
|
|
EncodedVectorSizes.Add(Parameters.Num);
|
|
TypeDataIndices.Add(ContinuousData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateDiscreteExclusive(const FSchemaDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
FDiscreteExclusiveData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
|
|
const int32 Index = Types.Add(EType::DiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Parameters.Num);
|
|
EncodedVectorSizes.Add(Parameters.Num);
|
|
TypeDataIndices.Add(DiscreteExclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateDiscreteInclusive(const FSchemaDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
FDiscreteInclusiveData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
|
|
const int32 Index = Types.Add(EType::DiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Parameters.Num);
|
|
EncodedVectorSizes.Add(Parameters.Num);
|
|
TypeDataIndices.Add(DiscreteInclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateNamedDiscreteExclusive(const FSchemaNamedDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
|
|
FNamedDiscreteExclusiveData ElementData;
|
|
ElementData.Num = Parameters.ElementNames.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
for (int32 Idx = 0; Idx < ElementData.Num; Idx++) { SubElementObjects.Add(FSchemaElement()); }
|
|
|
|
const int32 Index = Types.Add(EType::NamedDiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Parameters.ElementNames.Num());
|
|
EncodedVectorSizes.Add(Parameters.ElementNames.Num());
|
|
TypeDataIndices.Add(NamedDiscreteExclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateNamedDiscreteInclusive(const FSchemaNamedDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
|
|
FNamedDiscreteInclusiveData ElementData;
|
|
ElementData.Num = Parameters.ElementNames.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
for (int32 Idx = 0; Idx < ElementData.Num; Idx++) { SubElementObjects.Add(FSchemaElement()); }
|
|
|
|
const int32 Index = Types.Add(EType::NamedDiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Parameters.ElementNames.Num());
|
|
EncodedVectorSizes.Add(Parameters.ElementNames.Num());
|
|
TypeDataIndices.Add(NamedDiscreteInclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateAnd(const FSchemaAndParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
FAndData ElementData;
|
|
ElementData.Num = Parameters.Elements.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
const int32 Index = Types.Add(EType::And);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Private::GetTotalObservationVectorSize(*this, Parameters.Elements));
|
|
EncodedVectorSizes.Add(Private::GetTotalEncodedObservationVectorSize(*this, Parameters.Elements));
|
|
TypeDataIndices.Add(AndData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateOrExclusive(const FSchemaOrExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
FOrExclusiveData ElementData;
|
|
ElementData.Num = Parameters.Elements.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
ElementData.EncodingSize = Parameters.EncodingSize;
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
const int32 Index = Types.Add(EType::OrExclusive);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Private::GetMaxObservationVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
EncodedVectorSizes.Add(Parameters.EncodingSize + Parameters.Elements.Num());
|
|
TypeDataIndices.Add(OrExclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateOrInclusive(const FSchemaOrInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
FOrInclusiveData ElementData;
|
|
ElementData.Num = Parameters.Elements.Num();
|
|
ElementData.ElementsOffset = SubElementObjects.Num();
|
|
ElementData.AttentionEncodingSize = Parameters.AttentionEncodingSize;
|
|
ElementData.AttentionHeadNum = Parameters.AttentionHeadNum;
|
|
ElementData.ValueEncodingSize = Parameters.ValueEncodingSize;
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
const int32 Index = Types.Add(EType::OrInclusive);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(Private::GetTotalObservationVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num());
|
|
EncodedVectorSizes.Add(Parameters.AttentionHeadNum * Parameters.ValueEncodingSize + Parameters.Elements.Num());
|
|
TypeDataIndices.Add(OrInclusiveData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateArray(const FSchemaArrayParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Num >= 0);
|
|
check(IsValid(Parameters.Element));
|
|
|
|
FArrayData ElementData;
|
|
ElementData.Num = Parameters.Num;
|
|
ElementData.ElementIndex = SubElementObjects.Num();
|
|
|
|
SubElementNames.Add(NAME_None);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
const int32 Index = Types.Add(EType::Array);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(GetObservationVectorSize(Parameters.Element) * Parameters.Num);
|
|
EncodedVectorSizes.Add(GetEncodedVectorSize(Parameters.Element) * Parameters.Num);
|
|
TypeDataIndices.Add(ArrayData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateSet(const FSchemaSetParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
FSetData ElementData;
|
|
ElementData.MaxNum = Parameters.MaxNum;
|
|
ElementData.ElementIndex = SubElementObjects.Num();
|
|
ElementData.AttentionEncodingSize = Parameters.AttentionEncodingSize;
|
|
ElementData.AttentionHeadNum = Parameters.AttentionHeadNum;
|
|
ElementData.ValueEncodingSize = Parameters.ValueEncodingSize;
|
|
|
|
SubElementNames.Add(NAME_None);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
const int32 Index = Types.Add(EType::Set);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(GetObservationVectorSize(Parameters.Element) * Parameters.MaxNum + Parameters.MaxNum);
|
|
EncodedVectorSizes.Add(Parameters.ValueEncodingSize * Parameters.AttentionHeadNum + 1);
|
|
TypeDataIndices.Add(SetData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FSchemaElement FSchema::CreateEncoding(const FSchemaEncodingParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
FEncodingData ElementData;
|
|
ElementData.ElementIndex = SubElementObjects.Num();
|
|
ElementData.EncodingSize = Parameters.EncodingSize;
|
|
ElementData.LayerNum = Parameters.LayerNum;
|
|
ElementData.ActivationFunction = Parameters.ActivationFunction;
|
|
|
|
SubElementNames.Add(NAME_None);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
const int32 Index = Types.Add(EType::Encoding);
|
|
Tags.Add(Tag);
|
|
ObservationVectorSizes.Add(GetObservationVectorSize(Parameters.Element));
|
|
EncodedVectorSizes.Add(Parameters.EncodingSize);
|
|
TypeDataIndices.Add(EncodingData.Add(ElementData));
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
bool FSchema::IsValid(const FSchemaElement Element) const
|
|
{
|
|
return Element.Generation == Generation && Element.Index != INDEX_NONE;
|
|
}
|
|
|
|
EType FSchema::GetType(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Types[Element.Index];
|
|
}
|
|
|
|
FName FSchema::GetTag(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Tags[Element.Index];
|
|
}
|
|
|
|
int32 FSchema::GetObservationVectorSize(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return ObservationVectorSizes[Element.Index];
|
|
}
|
|
|
|
int32 FSchema::GetEncodedVectorSize(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return EncodedVectorSizes[Element.Index];
|
|
}
|
|
|
|
FSchemaContinuousParameters FSchema::GetContinuous(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Continuous);
|
|
const FContinuousData& ElementData = ContinuousData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaContinuousParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
Parameters.Scale = ElementData.Scale;
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaDiscreteExclusiveParameters FSchema::GetDiscreteExclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteExclusive);
|
|
const FDiscreteExclusiveData& ElementData = DiscreteExclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaDiscreteExclusiveParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaDiscreteInclusiveParameters FSchema::GetDiscreteInclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteInclusive);
|
|
const FDiscreteInclusiveData& ElementData = DiscreteInclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaDiscreteInclusiveParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaNamedDiscreteExclusiveParameters FSchema::GetNamedDiscreteExclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteExclusive);
|
|
const FNamedDiscreteExclusiveData& ElementData = NamedDiscreteExclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaNamedDiscreteExclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaNamedDiscreteInclusiveParameters FSchema::GetNamedDiscreteInclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteInclusive);
|
|
const FNamedDiscreteInclusiveData& ElementData = NamedDiscreteInclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaNamedDiscreteInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaAndParameters FSchema::GetAnd(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::And);
|
|
const FAndData& ElementData = AndData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaAndParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.Elements = TArrayView<const FSchemaElement>(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaOrExclusiveParameters FSchema::GetOrExclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrExclusive);
|
|
const FOrExclusiveData& ElementData = OrExclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaOrExclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.Elements = TArrayView<const FSchemaElement>(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.EncodingSize = ElementData.EncodingSize;
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaOrInclusiveParameters FSchema::GetOrInclusive(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrInclusive);
|
|
const FOrInclusiveData& ElementData = OrInclusiveData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaOrInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.Elements = TArrayView<const FSchemaElement>(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num);
|
|
Parameters.AttentionEncodingSize = ElementData.AttentionEncodingSize;
|
|
Parameters.AttentionHeadNum = ElementData.AttentionHeadNum;
|
|
Parameters.ValueEncodingSize = ElementData.ValueEncodingSize;
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaArrayParameters FSchema::GetArray(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Array);
|
|
const FArrayData& ElementData = ArrayData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaArrayParameters Parameters;
|
|
Parameters.Num = ElementData.Num;
|
|
Parameters.Element = SubElementObjects[ElementData.ElementIndex];
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaSetParameters FSchema::GetSet(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Set);
|
|
const FSetData& ElementData = SetData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaSetParameters Parameters;
|
|
Parameters.MaxNum = ElementData.MaxNum;
|
|
Parameters.Element = SubElementObjects[ElementData.ElementIndex];
|
|
Parameters.AttentionEncodingSize = ElementData.AttentionEncodingSize;
|
|
Parameters.AttentionHeadNum = ElementData.AttentionHeadNum;
|
|
Parameters.ValueEncodingSize = ElementData.ValueEncodingSize;
|
|
return Parameters;
|
|
}
|
|
|
|
FSchemaEncodingParameters FSchema::GetEncoding(const FSchemaElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Encoding);
|
|
const FEncodingData& ElementData = EncodingData[TypeDataIndices[Element.Index]];
|
|
|
|
FSchemaEncodingParameters Parameters;
|
|
Parameters.Element = SubElementObjects[ElementData.ElementIndex];
|
|
Parameters.EncodingSize = ElementData.EncodingSize;
|
|
Parameters.LayerNum = ElementData.LayerNum;
|
|
Parameters.ActivationFunction = ElementData.ActivationFunction;
|
|
return Parameters;
|
|
}
|
|
|
|
uint32 FSchema::GetGeneration() const
|
|
{
|
|
return Generation;
|
|
}
|
|
|
|
void FSchema::Empty()
|
|
{
|
|
Types.Empty();
|
|
Tags.Empty();
|
|
ObservationVectorSizes.Empty();
|
|
EncodedVectorSizes.Empty();
|
|
TypeDataIndices.Empty();
|
|
|
|
ContinuousData.Empty();
|
|
DiscreteExclusiveData.Empty();
|
|
DiscreteInclusiveData.Empty();
|
|
NamedDiscreteExclusiveData.Empty();
|
|
NamedDiscreteInclusiveData.Empty();
|
|
AndData.Empty();
|
|
OrExclusiveData.Empty();
|
|
OrInclusiveData.Empty();
|
|
ArrayData.Empty();
|
|
SetData.Empty();
|
|
EncodingData.Empty();
|
|
|
|
SubElementNames.Empty();
|
|
SubElementObjects.Empty();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
bool FSchema::IsEmpty() const
|
|
{
|
|
return Types.IsEmpty();
|
|
}
|
|
|
|
void FSchema::Reset()
|
|
{
|
|
Types.Reset();
|
|
Tags.Reset();
|
|
ObservationVectorSizes.Reset();
|
|
EncodedVectorSizes.Reset();
|
|
TypeDataIndices.Reset();
|
|
|
|
ContinuousData.Reset();
|
|
DiscreteExclusiveData.Reset();
|
|
DiscreteInclusiveData.Reset();
|
|
NamedDiscreteExclusiveData.Reset();
|
|
NamedDiscreteInclusiveData.Reset();
|
|
AndData.Reset();
|
|
OrExclusiveData.Reset();
|
|
OrInclusiveData.Reset();
|
|
ArrayData.Reset();
|
|
SetData.Reset();
|
|
EncodingData.Reset();
|
|
|
|
SubElementNames.Reset();
|
|
SubElementObjects.Reset();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
FObjectElement FObject::CreateNull(const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Null);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(0);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateContinuous(const FObjectContinuousParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::Continuous);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(Parameters.Values.Num());
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(0);
|
|
|
|
ContinuousValues.Append(Parameters.Values);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateDiscreteExclusive(const FObjectDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::DiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(1);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(0);
|
|
|
|
DiscreteValues.Add(Parameters.DiscreteIndex);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateDiscreteInclusive(const FObjectDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::DiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(Parameters.DiscreteIndices.Num());
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(0);
|
|
|
|
DiscreteValues.Append(Parameters.DiscreteIndices);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateNamedDiscreteExclusive(const FObjectNamedDiscreteExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
const int32 Index = Types.Add(EType::NamedDiscreteExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(1);
|
|
|
|
SubElementNames.Add(Parameters.ElementName);
|
|
SubElementObjects.Add(FObjectElement());
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateNamedDiscreteInclusive(const FObjectNamedDiscreteInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
|
|
const int32 Index = Types.Add(EType::NamedDiscreteInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(Parameters.ElementNames.Num());
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
for (int32 Idx = 0; Idx < Parameters.ElementNames.Num(); Idx++) { SubElementObjects.Add(FObjectElement()); }
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateAnd(const FObjectAndParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::And);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateOrExclusive(const FObjectOrExclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
const int32 Index = Types.Add(EType::OrExclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(1);
|
|
|
|
SubElementNames.Add(Parameters.ElementName);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateOrInclusive(const FObjectOrInclusiveParameters Parameters, const FName Tag)
|
|
{
|
|
check(Parameters.Elements.Num() == Parameters.ElementNames.Num());
|
|
check(!Private::ContainsDuplicates(Parameters.ElementNames));
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::OrInclusive);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
SubElementNames.Append(Parameters.ElementNames);
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateArray(const FObjectArrayParameters Parameters, const FName Tag)
|
|
{
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::Array);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Elements.Num(); ElementIdx++)
|
|
{
|
|
SubElementNames.Add(NAME_Name);
|
|
}
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateSet(const FObjectSetParameters Parameters, const FName Tag)
|
|
{
|
|
check(Private::CheckAllValid(*this, Parameters.Elements));
|
|
|
|
const int32 Index = Types.Add(EType::Set);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(Parameters.Elements.Num());
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Elements.Num(); ElementIdx++)
|
|
{
|
|
SubElementNames.Add(NAME_Name);
|
|
}
|
|
SubElementObjects.Append(Parameters.Elements);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
FObjectElement FObject::CreateEncoding(const FObjectEncodingParameters Parameters, const FName Tag)
|
|
{
|
|
check(IsValid(Parameters.Element));
|
|
|
|
const int32 Index = Types.Add(EType::Encoding);
|
|
Tags.Add(Tag);
|
|
|
|
ContinuousDataOffsets.Add(ContinuousValues.Num());
|
|
ContinuousDataNums.Add(0);
|
|
|
|
DiscreteDataOffsets.Add(DiscreteValues.Num());
|
|
DiscreteDataNums.Add(0);
|
|
|
|
SubElementDataOffsets.Add(SubElementObjects.Num());
|
|
SubElementDataNums.Add(1);
|
|
|
|
SubElementNames.Add(NAME_Name);
|
|
SubElementObjects.Add(Parameters.Element);
|
|
|
|
return { Index, Generation };
|
|
}
|
|
|
|
bool FObject::IsValid(const FObjectElement Element) const
|
|
{
|
|
return Element.Generation == Generation && Element.Index != INDEX_NONE;
|
|
}
|
|
|
|
EType FObject::GetType(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Types[Element.Index];
|
|
}
|
|
|
|
FName FObject::GetTag(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element));
|
|
return Tags[Element.Index];
|
|
}
|
|
|
|
FObjectContinuousParameters FObject::GetContinuous(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Continuous);
|
|
|
|
FObjectContinuousParameters Parameters;
|
|
Parameters.Values = TArrayView<const float>(ContinuousValues.GetData() + ContinuousDataOffsets[Element.Index], ContinuousDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectDiscreteExclusiveParameters FObject::GetDiscreteExclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteExclusive);
|
|
|
|
FObjectDiscreteExclusiveParameters Parameters;
|
|
Parameters.DiscreteIndex = DiscreteValues[DiscreteDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectDiscreteInclusiveParameters FObject::GetDiscreteInclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::DiscreteInclusive);
|
|
|
|
FObjectDiscreteInclusiveParameters Parameters;
|
|
Parameters.DiscreteIndices = TArrayView<const int32>(DiscreteValues.GetData() + DiscreteDataOffsets[Element.Index], DiscreteDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectNamedDiscreteExclusiveParameters FObject::GetNamedDiscreteExclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteExclusive);
|
|
|
|
FObjectNamedDiscreteExclusiveParameters Parameters;
|
|
Parameters.ElementName = SubElementNames[SubElementDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectNamedDiscreteInclusiveParameters FObject::GetNamedDiscreteInclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteInclusive);
|
|
|
|
FObjectNamedDiscreteInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectAndParameters FObject::GetAnd(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::And);
|
|
|
|
FObjectAndParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]);
|
|
Parameters.Elements = TArrayView<const FObjectElement>(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectOrExclusiveParameters FObject::GetOrExclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrExclusive);
|
|
|
|
FObjectOrExclusiveParameters Parameters;
|
|
Parameters.ElementName = SubElementNames[SubElementDataOffsets[Element.Index]];
|
|
Parameters.Element = SubElementObjects[SubElementDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectOrInclusiveParameters FObject::GetOrInclusive(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::OrInclusive);
|
|
|
|
FObjectOrInclusiveParameters Parameters;
|
|
Parameters.ElementNames = TArrayView<const FName>(SubElementNames.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]);
|
|
Parameters.Elements = TArrayView<const FObjectElement>(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectArrayParameters FObject::GetArray(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Array);
|
|
|
|
FObjectArrayParameters Parameters;
|
|
Parameters.Elements = TArrayView<const FObjectElement>(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectSetParameters FObject::GetSet(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Set);
|
|
|
|
FObjectSetParameters Parameters;
|
|
Parameters.Elements = TArrayView<const FObjectElement>(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]);
|
|
return Parameters;
|
|
}
|
|
|
|
FObjectEncodingParameters FObject::GetEncoding(const FObjectElement Element) const
|
|
{
|
|
check(IsValid(Element) && GetType(Element) == EType::Encoding);
|
|
|
|
FObjectEncodingParameters Parameters;
|
|
Parameters.Element = SubElementObjects[SubElementDataOffsets[Element.Index]];
|
|
return Parameters;
|
|
}
|
|
|
|
uint32 FObject::GetGeneration() const
|
|
{
|
|
return Generation;
|
|
}
|
|
|
|
void FObject::Empty()
|
|
{
|
|
Types.Empty();
|
|
Tags.Empty();
|
|
|
|
ContinuousDataOffsets.Empty();
|
|
ContinuousDataNums.Empty();
|
|
|
|
DiscreteDataOffsets.Reset();
|
|
DiscreteDataNums.Reset();
|
|
|
|
SubElementDataOffsets.Empty();
|
|
SubElementDataNums.Empty();
|
|
|
|
ContinuousValues.Empty();
|
|
DiscreteValues.Empty();
|
|
SubElementNames.Empty();
|
|
SubElementObjects.Empty();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
bool FObject::IsEmpty() const
|
|
{
|
|
return Types.IsEmpty();
|
|
}
|
|
|
|
void FObject::Reset()
|
|
{
|
|
Types.Reset();
|
|
Tags.Reset();
|
|
|
|
ContinuousDataOffsets.Reset();
|
|
ContinuousDataNums.Reset();
|
|
|
|
DiscreteDataOffsets.Reset();
|
|
DiscreteDataNums.Reset();
|
|
|
|
SubElementDataOffsets.Reset();
|
|
SubElementDataNums.Reset();
|
|
|
|
ContinuousValues.Reset();
|
|
DiscreteValues.Reset();
|
|
SubElementNames.Reset();
|
|
SubElementObjects.Reset();
|
|
|
|
Generation++;
|
|
}
|
|
|
|
namespace Private
|
|
{
|
|
static inline NNE::RuntimeBasic::FModelBuilder::EActivationFunction GetNNEActivationFunction(const EEncodingActivationFunction ActivationFunction)
|
|
{
|
|
switch (ActivationFunction)
|
|
{
|
|
case EEncodingActivationFunction::ReLU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ReLU;
|
|
case EEncodingActivationFunction::ELU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ELU;
|
|
case EEncodingActivationFunction::TanH: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::TanH;
|
|
case EEncodingActivationFunction::GELU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::GELU;
|
|
default: checkNoEntry(); return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ReLU;
|
|
}
|
|
}
|
|
|
|
static inline int32 HashFNameStable(const FName Name)
|
|
{
|
|
const FString NameString = Name.ToString().ToLower();
|
|
return (int32)CityHash32(
|
|
(const char*)NameString.GetCharArray().GetData(),
|
|
NameString.GetCharArray().GetTypeSize() *
|
|
NameString.GetCharArray().Num());
|
|
}
|
|
|
|
static inline int32 HashInt(const int32 Int)
|
|
{
|
|
return (int32)CityHash32((const char*)&Int, sizeof(int32));
|
|
}
|
|
|
|
static inline int32 HashCombine(const TArrayView<const int32> Hashes)
|
|
{
|
|
return (int32)CityHash32((const char*)Hashes.GetData(), Hashes.Num() * Hashes.GetTypeSize());
|
|
}
|
|
|
|
static inline int32 HashElements(
|
|
const FSchema& Schema,
|
|
const TArrayView<const FName> SchemaElementNames,
|
|
const int32 Salt)
|
|
{
|
|
// Note: Here we xor all entries together.
|
|
// This makes the hash in invariant to the ordering of names which is actually what we want
|
|
// since this array is representing a set-like structure and it is fine to pass elements in a different order.
|
|
|
|
int32 Hash = 0x5592716a;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaElementNames.Num(); SchemaElementIdx++)
|
|
{
|
|
Hash ^= HashFNameStable(SchemaElementNames[SchemaElementIdx]);
|
|
}
|
|
|
|
return Hash;
|
|
}
|
|
|
|
static inline int32 HashElements(
|
|
const FSchema& Schema,
|
|
const TArrayView<const FName> SchemaElementNames,
|
|
const TArrayView<const FSchemaElement> SchemaElements,
|
|
const int32 Salt)
|
|
{
|
|
// Note: Here we xor all entries together.
|
|
// This makes the hash in invariant to the ordering of pairs of names and elements
|
|
// which is actually what we want since these two arrays are representing a map-like
|
|
// structure and it is fine to pass keys and values in a different order.
|
|
|
|
int32 Hash = 0x5b3bbe4d;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaElements.Num(); SchemaElementIdx++)
|
|
{
|
|
Hash ^= HashCombine({ HashFNameStable(SchemaElementNames[SchemaElementIdx]), GetSchemaObjectsCompatibilityHash(Schema, SchemaElements[SchemaElementIdx], Salt) });
|
|
}
|
|
|
|
return Hash;
|
|
}
|
|
}
|
|
|
|
int32 GetSchemaObjectsCompatibilityHash(
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const int32 Salt)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
|
|
const int32 Hash = Private::HashCombine({ Salt, Private::HashInt((int32)SchemaElementType) });
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
case EType::Null: return Hash;
|
|
|
|
case EType::Continuous: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetContinuous(SchemaElement).Num) });
|
|
|
|
case EType::DiscreteExclusive: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetDiscreteExclusive(SchemaElement).Num) });
|
|
case EType::DiscreteInclusive: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetDiscreteInclusive(SchemaElement).Num) });
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const FSchemaNamedDiscreteExclusiveParameters Parameters = Schema.GetNamedDiscreteExclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Salt) });
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const FSchemaNamedDiscreteInclusiveParameters Parameters = Schema.GetNamedDiscreteInclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Salt) });
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) });
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) });
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) });
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashInt(Parameters.Num), GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt) });
|
|
}
|
|
|
|
case EType::Set:
|
|
{
|
|
const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement);
|
|
return Private::HashCombine({ Hash, Private::HashInt(Parameters.MaxNum), GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt) });
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
return GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt);
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return 0;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
bool AreSchemaObjectsCompatible(
|
|
const FSchema& SchemaA,
|
|
const FSchemaElement SchemaElementA,
|
|
const FSchema& SchemaB,
|
|
const FSchemaElement SchemaElementB)
|
|
{
|
|
check(SchemaA.IsValid(SchemaElementA));
|
|
check(SchemaB.IsValid(SchemaElementB));
|
|
|
|
const EType SchemaElementTypeA = SchemaA.GetType(SchemaElementA);
|
|
const EType SchemaElementTypeB = SchemaB.GetType(SchemaElementB);
|
|
|
|
// If any element is an encoding element we forward the comparison to the sub-element since encoding elements don't affect compatibility
|
|
if (SchemaElementTypeA == EType::Encoding) { return AreSchemaObjectsCompatible(SchemaA, SchemaA.GetEncoding(SchemaElementA).Element, SchemaB, SchemaElementB); }
|
|
if (SchemaElementTypeB == EType::Encoding) { return AreSchemaObjectsCompatible(SchemaA, SchemaElementA, SchemaB, SchemaB.GetEncoding(SchemaElementB).Element); }
|
|
|
|
// Otherwise if types don't match we immediately know elements are incompatible
|
|
if (SchemaElementTypeA != SchemaElementTypeB) { return false; }
|
|
|
|
// This is an early-out since if the input sizes are different we are definitely incompatible
|
|
if (SchemaA.GetObservationVectorSize(SchemaElementA) != SchemaB.GetObservationVectorSize(SchemaElementB)) { return false; }
|
|
|
|
switch (SchemaElementTypeA)
|
|
{
|
|
case EType::Null: return true;
|
|
|
|
case EType::Continuous: return SchemaA.GetContinuous(SchemaElementA).Num == SchemaB.GetContinuous(SchemaElementB).Num;
|
|
|
|
case EType::DiscreteExclusive: return SchemaA.GetDiscreteExclusive(SchemaElementA).Num == SchemaB.GetDiscreteExclusive(SchemaElementB).Num;
|
|
case EType::DiscreteInclusive: return SchemaA.GetDiscreteInclusive(SchemaElementA).Num == SchemaB.GetDiscreteInclusive(SchemaElementB).Num;
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const FSchemaNamedDiscreteExclusiveParameters ParametersA = SchemaA.GetNamedDiscreteExclusive(SchemaElementA);
|
|
const FSchemaNamedDiscreteExclusiveParameters ParametersB = SchemaB.GetNamedDiscreteExclusive(SchemaElementB);
|
|
|
|
if (ParametersA.ElementNames.Num() != ParametersB.ElementNames.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.ElementNames.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const FSchemaNamedDiscreteInclusiveParameters ParametersA = SchemaA.GetNamedDiscreteInclusive(SchemaElementA);
|
|
const FSchemaNamedDiscreteInclusiveParameters ParametersB = SchemaB.GetNamedDiscreteInclusive(SchemaElementB);
|
|
|
|
if (ParametersA.ElementNames.Num() != ParametersB.ElementNames.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.ElementNames.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters ParametersA = SchemaA.GetAnd(SchemaElementA);
|
|
const FSchemaAndParameters ParametersB = SchemaB.GetAnd(SchemaElementB);
|
|
|
|
if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters ParametersA = SchemaA.GetOrExclusive(SchemaElementA);
|
|
const FSchemaOrExclusiveParameters ParametersB = SchemaB.GetOrExclusive(SchemaElementB);
|
|
|
|
if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters ParametersA = SchemaA.GetOrInclusive(SchemaElementA);
|
|
const FSchemaOrInclusiveParameters ParametersB = SchemaB.GetOrInclusive(SchemaElementB);
|
|
|
|
if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; }
|
|
|
|
for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++)
|
|
{
|
|
const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]);
|
|
if (SchemaElementBIdx == INDEX_NONE) { return false; }
|
|
if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters ParametersA = SchemaA.GetArray(SchemaElementA);
|
|
const FSchemaArrayParameters ParametersB = SchemaB.GetArray(SchemaElementB);
|
|
|
|
return (ParametersA.Num == ParametersB.Num) && AreSchemaObjectsCompatible(SchemaA, ParametersA.Element, SchemaB, ParametersB.Element);
|
|
}
|
|
|
|
case EType::Set:
|
|
{
|
|
const FSchemaSetParameters ParametersA = SchemaA.GetSet(SchemaElementA);
|
|
const FSchemaSetParameters ParametersB = SchemaB.GetSet(SchemaElementB);
|
|
|
|
return (ParametersA.MaxNum == ParametersB.MaxNum) && AreSchemaObjectsCompatible(SchemaA, ParametersA.Element, SchemaB, ParametersB.Element);
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
checkf(false, TEXT("Encoding elements should always be forwarded..."));
|
|
return false;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return false;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
void MakeEncoderNetworkModelBuilderElementFromSchema(
|
|
NNE::RuntimeBasic::FModelBuilderElement& OutElement,
|
|
NNE::RuntimeBasic::FModelBuilder& Builder,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const FNetworkSettings& NetworkSettings)
|
|
{
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
case EType::Null:
|
|
{
|
|
OutElement = Builder.MakeCopy(0);
|
|
break;
|
|
}
|
|
|
|
case EType::Continuous:
|
|
{
|
|
const int32 ValueNum = Schema.GetContinuous(SchemaElement).Num;
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
ValueNum,
|
|
Builder.MakeValuesZero(ValueNum),
|
|
Builder.MakeValuesOne(ValueNum));
|
|
break;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetDiscreteExclusive(SchemaElement).Num;
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
ValueNum,
|
|
Builder.MakeValuesZero(ValueNum),
|
|
Builder.MakeValuesOne(ValueNum));
|
|
break;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetDiscreteInclusive(SchemaElement).Num;
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
ValueNum,
|
|
Builder.MakeValuesZero(ValueNum),
|
|
Builder.MakeValuesOne(ValueNum));
|
|
break;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Num();
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
ValueNum,
|
|
Builder.MakeValuesZero(ValueNum),
|
|
Builder.MakeValuesOne(ValueNum));
|
|
break;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const int32 ValueNum = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Num();
|
|
|
|
OutElement = Builder.MakeDenormalize(
|
|
ValueNum,
|
|
Builder.MakeValuesZero(ValueNum),
|
|
Builder.MakeValuesOne(ValueNum));
|
|
break;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderLayers;
|
|
BuilderLayers.Reserve(Parameters.Elements.Num());
|
|
for (const FSchemaElement SubElement : Parameters.Elements)
|
|
{
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings);
|
|
BuilderLayers.Emplace(BuilderSubElement);
|
|
}
|
|
|
|
OutElement = Builder.MakeConcat(BuilderLayers);
|
|
break;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderSubLayers;
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderEncoders;
|
|
BuilderSubLayers.Reserve(Parameters.Elements.Num());
|
|
BuilderEncoders.Reserve(Parameters.Elements.Num());
|
|
for (const FSchemaElement SubElement : Parameters.Elements)
|
|
{
|
|
const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(SubElement);
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings);
|
|
BuilderSubLayers.Emplace(BuilderSubElement);
|
|
|
|
NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings;
|
|
LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ?
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed :
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal;
|
|
|
|
switch (NetworkSettings.WeightInitialization)
|
|
{
|
|
case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break;
|
|
case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break;
|
|
default: checkNoEntry();
|
|
}
|
|
|
|
BuilderEncoders.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.EncodingSize, LinearLayerSettings));
|
|
}
|
|
|
|
OutElement = Builder.MakeAggregateOrExclusive(Parameters.EncodingSize, BuilderSubLayers, BuilderEncoders);
|
|
break;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderSubLayers;
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderQueryLayers;
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderKeyLayers;
|
|
TArray<NNE::RuntimeBasic::FModelBuilderElement, TInlineAllocator<8>> BuilderValueLayers;
|
|
BuilderSubLayers.Reserve(Parameters.Elements.Num());
|
|
BuilderQueryLayers.Reserve(Parameters.Elements.Num());
|
|
BuilderValueLayers.Reserve(Parameters.Elements.Num());
|
|
for (const FSchemaElement SubElement : Parameters.Elements)
|
|
{
|
|
const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(SubElement);
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings);
|
|
BuilderSubLayers.Emplace(BuilderSubElement);
|
|
|
|
NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings;
|
|
LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ?
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed :
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal;
|
|
|
|
switch (NetworkSettings.WeightInitialization)
|
|
{
|
|
case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break;
|
|
case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break;
|
|
default: checkNoEntry();
|
|
}
|
|
|
|
BuilderQueryLayers.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings));
|
|
BuilderKeyLayers.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings));
|
|
BuilderValueLayers.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.ValueEncodingSize, LinearLayerSettings));
|
|
}
|
|
|
|
OutElement = Builder.MakeAggregateOrInclusive(
|
|
Parameters.ValueEncodingSize,
|
|
Parameters.AttentionEncodingSize,
|
|
Parameters.AttentionHeadNum,
|
|
BuilderSubLayers,
|
|
BuilderQueryLayers,
|
|
BuilderKeyLayers,
|
|
BuilderValueLayers);
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings);
|
|
OutElement = Builder.MakeArray(Parameters.Num, BuilderSubElement);
|
|
break;
|
|
}
|
|
|
|
case EType::Set:
|
|
{
|
|
const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement);
|
|
|
|
const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(Parameters.Element);
|
|
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings);
|
|
|
|
NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings;
|
|
LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ?
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed :
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal;
|
|
|
|
switch (NetworkSettings.WeightInitialization)
|
|
{
|
|
case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break;
|
|
case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break;
|
|
default: checkNoEntry();
|
|
}
|
|
|
|
OutElement = Builder.MakeAggregateSet(
|
|
Parameters.MaxNum,
|
|
Parameters.ValueEncodingSize,
|
|
Parameters.AttentionEncodingSize,
|
|
Parameters.AttentionHeadNum,
|
|
BuilderSubElement,
|
|
Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings),
|
|
Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings),
|
|
Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.ValueEncodingSize, LinearLayerSettings));
|
|
|
|
break;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
|
|
const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(Parameters.Element);
|
|
|
|
NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement;
|
|
MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings);
|
|
|
|
NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings;
|
|
LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ?
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed :
|
|
NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal;
|
|
|
|
switch (NetworkSettings.WeightInitialization)
|
|
{
|
|
case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break;
|
|
case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type =
|
|
NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break;
|
|
default: checkNoEntry();
|
|
}
|
|
|
|
OutElement = Builder.MakeSequence({
|
|
BuilderSubElement,
|
|
Builder.MakeMLP(
|
|
SubElementEncodedSize,
|
|
Parameters.EncodingSize,
|
|
Parameters.EncodingSize,
|
|
Parameters.LayerNum + 1, // Add 1 to account for input layer
|
|
Private::GetNNEActivationFunction(Parameters.ActivationFunction),
|
|
true,
|
|
LinearLayerSettings)
|
|
});
|
|
|
|
break;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
}
|
|
}
|
|
|
|
checkf(OutElement.GetInputSize() == Schema.GetObservationVectorSize(SchemaElement),
|
|
TEXT("Encoder Network Input unexpected size for %s. Got %i, expected %i according to Schema."),
|
|
*Schema.GetTag(SchemaElement).ToString(), OutElement.GetInputSize(), Schema.GetObservationVectorSize(SchemaElement));
|
|
|
|
checkf(OutElement.GetOutputSize() == Schema.GetEncodedVectorSize(SchemaElement),
|
|
TEXT("Encoder Network Output unexpected size for %s. Got %i, expected %i according to Schema."),
|
|
*Schema.GetTag(SchemaElement).ToString(), OutElement.GetOutputSize(), Schema.GetEncodedVectorSize(SchemaElement));
|
|
}
|
|
|
|
void GenerateEncoderNetworkFileDataFromSchema(
|
|
TArray<uint8>& OutFileData,
|
|
uint32& OutInputSize,
|
|
uint32& OutOutputSize,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const FNetworkSettings& NetworkSettings,
|
|
const uint32 Seed)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
|
|
NNE::RuntimeBasic::FModelBuilder Builder(Seed);
|
|
NNE::RuntimeBasic::FModelBuilderElement Element;
|
|
MakeEncoderNetworkModelBuilderElementFromSchema(Element, Builder, Schema, SchemaElement, NetworkSettings);
|
|
Builder.WriteFileDataAndReset(OutFileData, OutInputSize, OutOutputSize, Element);
|
|
}
|
|
|
|
void SetVectorFromObject(
|
|
TLearningArrayView<1, float> OutObservationVector,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const FObject& Object,
|
|
const FObjectElement ObjectElement)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
check(Object.IsValid(ObjectElement));
|
|
check(OutObservationVector.Num() == Schema.GetObservationVectorSize(SchemaElement));
|
|
|
|
// Check that the types match
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
const EType ObjectElementType = Object.GetType(ObjectElement);
|
|
check(ObjectElementType == SchemaElementType);
|
|
|
|
// Zero Observation Vector
|
|
|
|
Array::Zero(OutObservationVector);
|
|
|
|
// Logic for each specific element type
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null: return;
|
|
|
|
case EType::Continuous:
|
|
{
|
|
// Check the input sizes match
|
|
|
|
const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement);
|
|
|
|
const TArrayView<const float> ObservationValues = Object.GetContinuous(ObjectElement).Values;
|
|
check(Schema.GetObservationVectorSize(SchemaElement) == ObservationValues.Num());
|
|
check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num());
|
|
check(Schema.GetObservationVectorSize(SchemaElement) == SchemaParameters.Num);
|
|
|
|
// Copy in and scale the values from the observation object
|
|
|
|
const int32 ValueNum = SchemaParameters.Num;
|
|
const float ValueScale = FMath::Max(SchemaParameters.Scale, UE_SMALL_NUMBER);
|
|
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
OutObservationVector[ValueIdx] = ObservationValues[ValueIdx] / ValueScale;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
const int32 ObservationValue = Object.GetDiscreteExclusive(ObjectElement).DiscreteIndex;
|
|
check(Schema.GetObservationVectorSize(SchemaElement) > ObservationValue && ObservationValue >= 0);
|
|
check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num());
|
|
|
|
// Set the single value in the observation vector
|
|
|
|
OutObservationVector[ObservationValue] = 1.0f;
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
const TArrayView<const int32> ObservationValues = Object.GetDiscreteInclusive(ObjectElement).DiscreteIndices;
|
|
check(Schema.GetObservationVectorSize(SchemaElement) >= ObservationValues.Num());
|
|
check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num());
|
|
|
|
// Set values in the observation vector
|
|
|
|
for (int32 ObservationValueIdx = 0; ObservationValueIdx < ObservationValues.Num(); ObservationValueIdx++)
|
|
{
|
|
check(Schema.GetObservationVectorSize(SchemaElement) > ObservationValues[ObservationValueIdx] && ObservationValues[ObservationValueIdx] >= 0);
|
|
OutObservationVector[ObservationValues[ObservationValueIdx]] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames;
|
|
const FName ObservationValue = Object.GetNamedDiscreteExclusive(ObjectElement).ElementName;
|
|
check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num());
|
|
|
|
// Set the single value in the observation vector
|
|
const int32 ObservationIndex = SchemaNames.Find(ObservationValue);
|
|
check(ObservationIndex != INDEX_NONE);
|
|
OutObservationVector[ObservationIndex] = 1.0f;
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames;
|
|
const TArrayView<const FName> ObservationValues = Object.GetNamedDiscreteInclusive(ObjectElement).ElementNames;
|
|
check(Schema.GetObservationVectorSize(SchemaElement) >= ObservationValues.Num());
|
|
check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num());
|
|
|
|
// Set values in the observation vector
|
|
for (int32 ObservationValueIdx = 0; ObservationValueIdx < ObservationValues.Num(); ObservationValueIdx++)
|
|
{
|
|
const int32 ObservationIndex = SchemaNames.Find(ObservationValues[ObservationValueIdx]);
|
|
check(ObservationIndex != INDEX_NONE);
|
|
OutObservationVector[ObservationIndex] = 1.0f;
|
|
}
|
|
return;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
// Check the number of sub-elements match
|
|
|
|
const FSchemaAndParameters SchemaParameters = Schema.GetAnd(SchemaElement);
|
|
const FObjectAndParameters ObjectParameters = Object.GetAnd(ObjectElement);
|
|
check(SchemaParameters.Elements.Num() == ObjectParameters.Elements.Num());
|
|
|
|
// Update Sub-elements
|
|
|
|
int32 SubElementOffset = 0;
|
|
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 ObjectElementIndex = ObjectParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]);
|
|
check(ObjectElementIndex != INDEX_NONE);
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Elements[SchemaElementIdx]);
|
|
|
|
SetVectorFromObject(
|
|
OutObservationVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIdx],
|
|
Object,
|
|
ObjectParameters.Elements[ObjectElementIndex]);
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
check(SubElementOffset == OutObservationVector.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
// Check only one sub-element is given and index is valid
|
|
|
|
const FSchemaOrExclusiveParameters SchemaParameters = Schema.GetOrExclusive(SchemaElement);
|
|
const FObjectOrExclusiveParameters ObjectParameters = Object.GetOrExclusive(ObjectElement);
|
|
|
|
const int32 SchemaElementIndex = SchemaParameters.ElementNames.Find(ObjectParameters.ElementName);
|
|
check(SchemaElementIndex != INDEX_NONE);
|
|
|
|
// Update sub-element
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Elements[SchemaElementIndex]);
|
|
|
|
SetVectorFromObject(
|
|
OutObservationVector.Slice(0, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIndex],
|
|
Object,
|
|
ObjectParameters.Element);
|
|
|
|
// Set Mask
|
|
|
|
const int32 MaxSubElementSize = Private::GetMaxObservationVectorSize(Schema, SchemaParameters.Elements);
|
|
|
|
OutObservationVector[MaxSubElementSize + SchemaElementIndex] = 1.0f;
|
|
|
|
check(OutObservationVector.Num() == MaxSubElementSize + SchemaParameters.Elements.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
// Check all indices are in range
|
|
|
|
const FSchemaOrInclusiveParameters SchemaParameters = Schema.GetOrInclusive(SchemaElement);
|
|
const FObjectOrInclusiveParameters ObjectParameters = Object.GetOrInclusive(ObjectElement);
|
|
check(ObjectParameters.Elements.Num() <= SchemaParameters.Elements.Num());
|
|
|
|
// Update sub-elements
|
|
|
|
int32 SubElementOffset = 0;
|
|
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Elements[SchemaElementIdx]);
|
|
const int32 ObjectElementIdx = ObjectParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]);
|
|
|
|
if (ObjectElementIdx != INDEX_NONE)
|
|
{
|
|
SetVectorFromObject(
|
|
OutObservationVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Elements[SchemaElementIdx],
|
|
Object,
|
|
ObjectParameters.Elements[ObjectElementIdx]);
|
|
}
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
// Set Mask
|
|
|
|
check(SubElementOffset + SchemaParameters.Elements.Num() == OutObservationVector.Num());
|
|
|
|
for (int32 ObjectElementIdx = 0; ObjectElementIdx < ObjectParameters.Elements.Num(); ObjectElementIdx++)
|
|
{
|
|
const int32 SchemaElementIdx = SchemaParameters.ElementNames.Find(ObjectParameters.ElementNames[ObjectElementIdx]);
|
|
check(SchemaElementIdx != INDEX_NONE);
|
|
|
|
OutObservationVector[SubElementOffset + SchemaElementIdx] = 1.0f;
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
// Check number of array elements is correct
|
|
|
|
const FSchemaArrayParameters SchemaParameters = Schema.GetArray(SchemaElement);
|
|
const FObjectArrayParameters ObjectParameters = Object.GetArray(ObjectElement);
|
|
check(SchemaParameters.Num == ObjectParameters.Elements.Num());
|
|
|
|
// Update sub-elements
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < SchemaParameters.Num; ElementIdx++)
|
|
{
|
|
SetVectorFromObject(
|
|
OutObservationVector.Slice(ElementIdx * SubElementSize, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Element,
|
|
Object,
|
|
ObjectParameters.Elements[ElementIdx]);
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Set:
|
|
{
|
|
// Check number of set elements is correct
|
|
|
|
const FSchemaSetParameters SchemaParameters = Schema.GetSet(SchemaElement);
|
|
const FObjectSetParameters ObjectParameters = Object.GetSet(ObjectElement);
|
|
check(SchemaParameters.MaxNum >= ObjectParameters.Elements.Num());
|
|
|
|
// Update sub-elements
|
|
|
|
int32 SubElementOffset = 0;
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < ObjectParameters.Elements.Num(); ElementIdx++)
|
|
{
|
|
SetVectorFromObject(
|
|
OutObservationVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
SchemaParameters.Element,
|
|
Object,
|
|
ObjectParameters.Elements[ElementIdx]);
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
|
|
SubElementOffset = SubElementSize * SchemaParameters.MaxNum;
|
|
|
|
// Set Mask
|
|
|
|
Array::Set(OutObservationVector.Slice(SubElementOffset, ObjectParameters.Elements.Num()), 1.0f);
|
|
|
|
check(SubElementOffset + SchemaParameters.MaxNum == OutObservationVector.Num());
|
|
return;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters SchemaParameters = Schema.GetEncoding(SchemaElement);
|
|
const FObjectEncodingParameters ObjectParameters = Object.GetEncoding(ObjectElement);
|
|
|
|
SetVectorFromObject(
|
|
OutObservationVector,
|
|
Schema,
|
|
SchemaParameters.Element,
|
|
Object,
|
|
ObjectParameters.Element);
|
|
|
|
return;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return;
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
void GetObjectFromVector(
|
|
FObject& OutObject,
|
|
FObjectElement& OutObjectElement,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const TLearningArrayView<1, const float> ObservationVector)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
|
|
// Check that the types match
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
const FName SchemaElementTag = Schema.GetTag(SchemaElement);
|
|
|
|
// Get Observation Vector Size
|
|
|
|
const int32 ObservationVectorSize = ObservationVector.Num();
|
|
check(ObservationVectorSize == Schema.GetObservationVectorSize(SchemaElement));
|
|
|
|
// Logic for each specific element type
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null:
|
|
{
|
|
OutObjectElement = OutObject.CreateNull(SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Continuous:
|
|
{
|
|
const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement);
|
|
check(ObservationVectorSize == SchemaParameters.Num);
|
|
|
|
const int32 ValueNum = SchemaParameters.Num;
|
|
const float ValueScale = FMath::Max(SchemaParameters.Scale, UE_SMALL_NUMBER);
|
|
|
|
TLearningArray<1, float, TInlineAllocator<32>> ObservationValues;
|
|
ObservationValues.SetNumUninitialized({ ValueNum });
|
|
for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++)
|
|
{
|
|
ObservationValues[ValueIdx] = ValueScale * ObservationVector[ValueIdx];
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateContinuous({ MakeArrayView(ObservationValues.GetData(), ObservationValues.Num()) }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
check(ObservationVectorSize == Schema.GetDiscreteExclusive(SchemaElement).Num);
|
|
|
|
// Find Index
|
|
int32 ExclusiveIndex = INDEX_NONE;
|
|
for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++)
|
|
{
|
|
check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f);
|
|
if (ObservationVector[Idx])
|
|
{
|
|
ExclusiveIndex = Idx;
|
|
break;
|
|
}
|
|
}
|
|
check(ExclusiveIndex != INDEX_NONE);
|
|
|
|
OutObjectElement = OutObject.CreateDiscreteExclusive({ ExclusiveIndex }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
check(ObservationVectorSize == Schema.GetDiscreteInclusive(SchemaElement).Num);
|
|
|
|
// Find Indices
|
|
TArray<int32, TInlineAllocator<8>> InclusiveIndices;
|
|
InclusiveIndices.Reserve(ObservationVectorSize);
|
|
for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++)
|
|
{
|
|
check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f);
|
|
if (ObservationVector[Idx])
|
|
{
|
|
InclusiveIndices.Add(Idx);
|
|
}
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateDiscreteInclusive({ InclusiveIndices }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames;
|
|
check(ObservationVectorSize == Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Num());
|
|
|
|
// Find Name
|
|
FName ExclusiveName = NAME_None;
|
|
for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++)
|
|
{
|
|
check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f);
|
|
if (ObservationVector[Idx])
|
|
{
|
|
ExclusiveName = SchemaNames[Idx];
|
|
break;
|
|
}
|
|
}
|
|
check(ExclusiveName != NAME_None);
|
|
|
|
OutObjectElement = OutObject.CreateNamedDiscreteExclusive({ ExclusiveName }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
const TArrayView<const FName> SchemaNames = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames;
|
|
check(ObservationVectorSize == Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Num());
|
|
|
|
// Find Names
|
|
TArray<FName, TInlineAllocator<8>> InclusiveNames;
|
|
InclusiveNames.Reserve(ObservationVectorSize);
|
|
for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++)
|
|
{
|
|
check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f);
|
|
if (ObservationVector[Idx])
|
|
{
|
|
InclusiveNames.Add(SchemaNames[Idx]);
|
|
}
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateNamedDiscreteInclusive({ InclusiveNames }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
|
|
// Create Sub-elements
|
|
|
|
TArray<FObjectElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(Parameters.Elements.Num());
|
|
|
|
int32 SubElementOffset = 0;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < Parameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIdx]);
|
|
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElements[SchemaElementIdx],
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIdx],
|
|
ObservationVector.Slice(SubElementOffset, SubElementSize));
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset == ObservationVectorSize);
|
|
|
|
OutObjectElement = OutObject.CreateAnd({ Parameters.ElementNames, SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
|
|
// Find active element
|
|
|
|
const int32 MaxSubElementSize = Private::GetMaxObservationVectorSize(Schema, Parameters.Elements);
|
|
|
|
int32 SchemaElementIndex = INDEX_NONE;
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
|
|
{
|
|
check(ObservationVector[MaxSubElementSize + SubElementIdx] == 0.0f || ObservationVector[MaxSubElementSize + SubElementIdx] == 1.0f);
|
|
if (ObservationVector[MaxSubElementSize + SubElementIdx])
|
|
{
|
|
SchemaElementIndex = SubElementIdx;
|
|
break;
|
|
}
|
|
}
|
|
check(SchemaElementIndex != INDEX_NONE);
|
|
|
|
// Create sub-element
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIndex]);
|
|
|
|
FObjectElement SubElement;
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIndex],
|
|
ObservationVector.Slice(0, SubElementSize));
|
|
|
|
OutObjectElement = OutObject.CreateOrExclusive({ Parameters.ElementNames[SchemaElementIndex], SubElement }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
|
|
// Find total sub-element size
|
|
|
|
const int32 TotalSubElementSize = Private::GetTotalObservationVectorSize(Schema, Parameters.Elements);
|
|
|
|
// Create sub-elements
|
|
|
|
TArray<FName, TInlineAllocator<8>> SubElementNames;
|
|
TArray<FObjectElement, TInlineAllocator<8>> SubElements;
|
|
SubElementNames.Reserve(Parameters.Elements.Num());
|
|
SubElements.Reserve(Parameters.Elements.Num());
|
|
|
|
int32 SubElementOffset = 0;
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SubElementIdx]);
|
|
|
|
check(
|
|
ObservationVector[TotalSubElementSize + SubElementIdx] == 0.0f ||
|
|
ObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f);
|
|
|
|
if (ObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f)
|
|
{
|
|
FObjectElement SubElement;
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Elements[SubElementIdx],
|
|
ObservationVector.Slice(SubElementOffset, SubElementSize));
|
|
|
|
SubElementNames.Add(Parameters.ElementNames[SubElementIdx]);
|
|
SubElements.Add(SubElement);
|
|
}
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset + Parameters.Elements.Num() == ObservationVectorSize);
|
|
|
|
OutObjectElement = OutObject.CreateOrInclusive({ SubElementNames, SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
|
|
TArray<FObjectElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.SetNumUninitialized(Parameters.Num);
|
|
|
|
// Create sub-elements
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Num; ElementIdx++)
|
|
{
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElements[ElementIdx],
|
|
Schema,
|
|
Parameters.Element,
|
|
ObservationVector.Slice(ElementIdx * SubElementSize, SubElementSize));
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateArray({ SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Set:
|
|
{
|
|
const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement);
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element);
|
|
|
|
// Create sub-elements
|
|
|
|
TArray<FObjectElement, TInlineAllocator<8>> SubElements;
|
|
SubElements.Reserve(Parameters.MaxNum);
|
|
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.MaxNum; SubElementIdx++)
|
|
{
|
|
check(
|
|
ObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f ||
|
|
ObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 1.0f);
|
|
|
|
if (ObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f)
|
|
{
|
|
break;
|
|
}
|
|
|
|
FObjectElement SubElement;
|
|
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Element,
|
|
ObservationVector.Slice(SubElementIdx * SubElementSize, SubElementSize));
|
|
|
|
SubElements.Add(SubElement);
|
|
}
|
|
|
|
OutObjectElement = OutObject.CreateSet({ SubElements }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement);
|
|
|
|
FObjectElement SubElement;
|
|
GetObjectFromVector(
|
|
OutObject,
|
|
SubElement,
|
|
Schema,
|
|
Parameters.Element,
|
|
ObservationVector);
|
|
|
|
OutObjectElement = OutObject.CreateEncoding({ SubElement }, SchemaElementTag);
|
|
return;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
OutObjectElement = FObjectElement();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
void AddGaussianNoiseToVector(
|
|
uint32& InOutRandomState,
|
|
TLearningArrayView<1, float> InOutObservationVector,
|
|
const FSchema& Schema,
|
|
const FSchemaElement SchemaElement,
|
|
const float NoiseScale)
|
|
{
|
|
check(Schema.IsValid(SchemaElement));
|
|
|
|
const EType SchemaElementType = Schema.GetType(SchemaElement);
|
|
|
|
const int32 ObservationVectorSize = InOutObservationVector.Num();
|
|
check(ObservationVectorSize == Schema.GetObservationVectorSize(SchemaElement));
|
|
|
|
switch (SchemaElementType)
|
|
{
|
|
|
|
case EType::Null:
|
|
{
|
|
return;
|
|
}
|
|
|
|
case EType::Continuous:
|
|
{
|
|
const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement);
|
|
check(ObservationVectorSize == SchemaParameters.Num);
|
|
|
|
TLearningArray<1, float, TInlineAllocator<32>> NoiseValues;
|
|
NoiseValues.SetNumUninitialized({ SchemaParameters.Num });
|
|
Random::SampleGaussianArray(NoiseValues, InOutRandomState, 0.0f, NoiseScale);
|
|
for (int32 ValueIdx = 0; ValueIdx < SchemaParameters.Num; ValueIdx++)
|
|
{
|
|
InOutObservationVector[ValueIdx] += NoiseValues[ValueIdx];
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteExclusive:
|
|
{
|
|
return;
|
|
}
|
|
|
|
case EType::DiscreteInclusive:
|
|
{
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteExclusive:
|
|
{
|
|
return;
|
|
}
|
|
|
|
case EType::NamedDiscreteInclusive:
|
|
{
|
|
return;
|
|
}
|
|
|
|
case EType::And:
|
|
{
|
|
const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement);
|
|
|
|
int32 SubElementOffset = 0;
|
|
for (int32 SchemaElementIdx = 0; SchemaElementIdx < Parameters.Elements.Num(); SchemaElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIdx]);
|
|
|
|
AddGaussianNoiseToVector(
|
|
InOutRandomState,
|
|
InOutObservationVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIdx],
|
|
NoiseScale);
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset == ObservationVectorSize);
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::OrExclusive:
|
|
{
|
|
const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement);
|
|
const int32 MaxSubElementSize = Private::GetMaxObservationVectorSize(Schema, Parameters.Elements);
|
|
|
|
int32 SchemaElementIndex = INDEX_NONE;
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
|
|
{
|
|
check(InOutObservationVector[MaxSubElementSize + SubElementIdx] == 0.0f || InOutObservationVector[MaxSubElementSize + SubElementIdx] == 1.0f);
|
|
if (InOutObservationVector[MaxSubElementSize + SubElementIdx])
|
|
{
|
|
SchemaElementIndex = SubElementIdx;
|
|
break;
|
|
}
|
|
}
|
|
check(SchemaElementIndex != INDEX_NONE);
|
|
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIndex]);
|
|
|
|
AddGaussianNoiseToVector(
|
|
InOutRandomState,
|
|
InOutObservationVector.Slice(0, SubElementSize),
|
|
Schema,
|
|
Parameters.Elements[SchemaElementIndex],
|
|
NoiseScale);
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::OrInclusive:
|
|
{
|
|
const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement);
|
|
const int32 TotalSubElementSize = Private::GetTotalObservationVectorSize(Schema, Parameters.Elements);
|
|
|
|
int32 SubElementOffset = 0;
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++)
|
|
{
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SubElementIdx]);
|
|
|
|
check(
|
|
InOutObservationVector[TotalSubElementSize + SubElementIdx] == 0.0f ||
|
|
InOutObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f);
|
|
|
|
if (InOutObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f)
|
|
{
|
|
AddGaussianNoiseToVector(
|
|
InOutRandomState,
|
|
InOutObservationVector.Slice(SubElementOffset, SubElementSize),
|
|
Schema,
|
|
Parameters.Elements[SubElementIdx],
|
|
NoiseScale);
|
|
}
|
|
|
|
SubElementOffset += SubElementSize;
|
|
}
|
|
check(SubElementOffset + Parameters.Elements.Num() == ObservationVectorSize);
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Array:
|
|
{
|
|
const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement);
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element);
|
|
|
|
for (int32 ElementIdx = 0; ElementIdx < Parameters.Num; ElementIdx++)
|
|
{
|
|
AddGaussianNoiseToVector(
|
|
InOutRandomState,
|
|
InOutObservationVector.Slice(ElementIdx* SubElementSize, SubElementSize),
|
|
Schema,
|
|
Parameters.Element,
|
|
NoiseScale);
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Set:
|
|
{
|
|
const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement);
|
|
const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element);
|
|
|
|
for (int32 SubElementIdx = 0; SubElementIdx < Parameters.MaxNum; SubElementIdx++)
|
|
{
|
|
check(
|
|
InOutObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f ||
|
|
InOutObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 1.0f);
|
|
|
|
if (InOutObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f)
|
|
{
|
|
break;
|
|
}
|
|
|
|
AddGaussianNoiseToVector(
|
|
InOutRandomState,
|
|
InOutObservationVector.Slice(SubElementIdx * SubElementSize, SubElementSize),
|
|
Schema,
|
|
Parameters.Element,
|
|
NoiseScale);
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
case EType::Encoding:
|
|
{
|
|
AddGaussianNoiseToVector(
|
|
InOutRandomState,
|
|
InOutObservationVector,
|
|
Schema,
|
|
Schema.GetEncoding(SchemaElement).Element,
|
|
NoiseScale);
|
|
|
|
return;
|
|
}
|
|
|
|
default:
|
|
{
|
|
checkNoEntry();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
} |