// Copyright Epic Games, Inc. All Rights Reserved. #include "LearningObservation.h" #include "LearningRandom.h" #include "NNERuntimeBasicCpuBuilder.h" namespace UE::Learning::Observation { namespace Private { static inline bool ContainsDuplicates(const TArrayView ElementNames) { TSet, TInlineSetAllocator<32>> ElementNameSet; ElementNameSet.Append(ElementNames); return ElementNames.Num() != ElementNameSet.Num(); } static inline bool CheckAllValid(const FSchema& Schema, const TArrayView Elements) { for (const FSchemaElement SubElement : Elements) { if (!Schema.IsValid(SubElement)) { return false; } } return true; } static inline int32 GetMaxObservationVectorSize(const FSchema& Schema, const TArrayView Elements) { int32 Size = 0; for (const FSchemaElement SubElement : Elements) { Size = FMath::Max(Size, Schema.GetObservationVectorSize(SubElement)); } return Size; } static inline int32 GetTotalObservationVectorSize(const FSchema& Schema, const TArrayView Elements) { int32 Size = 0; for (const FSchemaElement SubElement : Elements) { Size += Schema.GetObservationVectorSize(SubElement); } return Size; } static inline int32 GetTotalEncodedObservationVectorSize(const FSchema& Schema, const TArrayView Elements) { int32 Size = 0; for (const FSchemaElement SubElement : Elements) { Size += Schema.GetEncodedVectorSize(SubElement); } return Size; } static inline bool CheckAllValid(const FObject& Object, const TArrayView Elements) { for (const FObjectElement SubElement : Elements) { if (!Object.IsValid(SubElement)) { return false; } } return true; } } FSchemaElement FSchema::CreateNull(const FName Tag) { const int32 Index = Types.Add(EType::Null); Tags.Add(Tag); ObservationVectorSizes.Add(0); EncodedVectorSizes.Add(0); TypeDataIndices.Add(INDEX_NONE); return { Index, Generation }; } FSchemaElement FSchema::CreateContinuous(const FSchemaContinuousParameters Parameters, const FName Tag) { check(Parameters.Num >= 0); check(Parameters.Scale >= 0.0f); FContinuousData ElementData; ElementData.Num = Parameters.Num; ElementData.Scale = Parameters.Scale; const int32 Index = Types.Add(EType::Continuous); Tags.Add(Tag); ObservationVectorSizes.Add(Parameters.Num); EncodedVectorSizes.Add(Parameters.Num); TypeDataIndices.Add(ContinuousData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateDiscreteExclusive(const FSchemaDiscreteExclusiveParameters Parameters, const FName Tag) { FDiscreteExclusiveData ElementData; ElementData.Num = Parameters.Num; const int32 Index = Types.Add(EType::DiscreteExclusive); Tags.Add(Tag); ObservationVectorSizes.Add(Parameters.Num); EncodedVectorSizes.Add(Parameters.Num); TypeDataIndices.Add(DiscreteExclusiveData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateDiscreteInclusive(const FSchemaDiscreteInclusiveParameters Parameters, const FName Tag) { FDiscreteInclusiveData ElementData; ElementData.Num = Parameters.Num; const int32 Index = Types.Add(EType::DiscreteInclusive); Tags.Add(Tag); ObservationVectorSizes.Add(Parameters.Num); EncodedVectorSizes.Add(Parameters.Num); TypeDataIndices.Add(DiscreteInclusiveData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateNamedDiscreteExclusive(const FSchemaNamedDiscreteExclusiveParameters Parameters, const FName Tag) { check(!Private::ContainsDuplicates(Parameters.ElementNames)); FNamedDiscreteExclusiveData ElementData; ElementData.Num = Parameters.ElementNames.Num(); ElementData.ElementsOffset = SubElementObjects.Num(); SubElementNames.Append(Parameters.ElementNames); for (int32 Idx = 0; Idx < ElementData.Num; Idx++) { SubElementObjects.Add(FSchemaElement()); } const int32 Index = Types.Add(EType::NamedDiscreteExclusive); Tags.Add(Tag); ObservationVectorSizes.Add(Parameters.ElementNames.Num()); EncodedVectorSizes.Add(Parameters.ElementNames.Num()); TypeDataIndices.Add(NamedDiscreteExclusiveData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateNamedDiscreteInclusive(const FSchemaNamedDiscreteInclusiveParameters Parameters, const FName Tag) { check(!Private::ContainsDuplicates(Parameters.ElementNames)); FNamedDiscreteInclusiveData ElementData; ElementData.Num = Parameters.ElementNames.Num(); ElementData.ElementsOffset = SubElementObjects.Num(); SubElementNames.Append(Parameters.ElementNames); for (int32 Idx = 0; Idx < ElementData.Num; Idx++) { SubElementObjects.Add(FSchemaElement()); } const int32 Index = Types.Add(EType::NamedDiscreteInclusive); Tags.Add(Tag); ObservationVectorSizes.Add(Parameters.ElementNames.Num()); EncodedVectorSizes.Add(Parameters.ElementNames.Num()); TypeDataIndices.Add(NamedDiscreteInclusiveData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateAnd(const FSchemaAndParameters Parameters, const FName Tag) { check(Parameters.Elements.Num() == Parameters.ElementNames.Num()); check(!Private::ContainsDuplicates(Parameters.ElementNames)); check(Private::CheckAllValid(*this, Parameters.Elements)); FAndData ElementData; ElementData.Num = Parameters.Elements.Num(); ElementData.ElementsOffset = SubElementObjects.Num(); SubElementNames.Append(Parameters.ElementNames); SubElementObjects.Append(Parameters.Elements); const int32 Index = Types.Add(EType::And); Tags.Add(Tag); ObservationVectorSizes.Add(Private::GetTotalObservationVectorSize(*this, Parameters.Elements)); EncodedVectorSizes.Add(Private::GetTotalEncodedObservationVectorSize(*this, Parameters.Elements)); TypeDataIndices.Add(AndData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateOrExclusive(const FSchemaOrExclusiveParameters Parameters, const FName Tag) { check(Parameters.Elements.Num() == Parameters.ElementNames.Num()); check(!Private::ContainsDuplicates(Parameters.ElementNames)); check(Private::CheckAllValid(*this, Parameters.Elements)); FOrExclusiveData ElementData; ElementData.Num = Parameters.Elements.Num(); ElementData.ElementsOffset = SubElementObjects.Num(); ElementData.EncodingSize = Parameters.EncodingSize; SubElementNames.Append(Parameters.ElementNames); SubElementObjects.Append(Parameters.Elements); const int32 Index = Types.Add(EType::OrExclusive); Tags.Add(Tag); ObservationVectorSizes.Add(Private::GetMaxObservationVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num()); EncodedVectorSizes.Add(Parameters.EncodingSize + Parameters.Elements.Num()); TypeDataIndices.Add(OrExclusiveData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateOrInclusive(const FSchemaOrInclusiveParameters Parameters, const FName Tag) { check(Parameters.Elements.Num() == Parameters.ElementNames.Num()); check(!Private::ContainsDuplicates(Parameters.ElementNames)); check(Private::CheckAllValid(*this, Parameters.Elements)); FOrInclusiveData ElementData; ElementData.Num = Parameters.Elements.Num(); ElementData.ElementsOffset = SubElementObjects.Num(); ElementData.AttentionEncodingSize = Parameters.AttentionEncodingSize; ElementData.AttentionHeadNum = Parameters.AttentionHeadNum; ElementData.ValueEncodingSize = Parameters.ValueEncodingSize; SubElementNames.Append(Parameters.ElementNames); SubElementObjects.Append(Parameters.Elements); const int32 Index = Types.Add(EType::OrInclusive); Tags.Add(Tag); ObservationVectorSizes.Add(Private::GetTotalObservationVectorSize(*this, Parameters.Elements) + Parameters.Elements.Num()); EncodedVectorSizes.Add(Parameters.AttentionHeadNum * Parameters.ValueEncodingSize + Parameters.Elements.Num()); TypeDataIndices.Add(OrInclusiveData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateArray(const FSchemaArrayParameters Parameters, const FName Tag) { check(Parameters.Num >= 0); check(IsValid(Parameters.Element)); FArrayData ElementData; ElementData.Num = Parameters.Num; ElementData.ElementIndex = SubElementObjects.Num(); SubElementNames.Add(NAME_None); SubElementObjects.Add(Parameters.Element); const int32 Index = Types.Add(EType::Array); Tags.Add(Tag); ObservationVectorSizes.Add(GetObservationVectorSize(Parameters.Element) * Parameters.Num); EncodedVectorSizes.Add(GetEncodedVectorSize(Parameters.Element) * Parameters.Num); TypeDataIndices.Add(ArrayData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateSet(const FSchemaSetParameters Parameters, const FName Tag) { check(IsValid(Parameters.Element)); FSetData ElementData; ElementData.MaxNum = Parameters.MaxNum; ElementData.ElementIndex = SubElementObjects.Num(); ElementData.AttentionEncodingSize = Parameters.AttentionEncodingSize; ElementData.AttentionHeadNum = Parameters.AttentionHeadNum; ElementData.ValueEncodingSize = Parameters.ValueEncodingSize; SubElementNames.Add(NAME_None); SubElementObjects.Add(Parameters.Element); const int32 Index = Types.Add(EType::Set); Tags.Add(Tag); ObservationVectorSizes.Add(GetObservationVectorSize(Parameters.Element) * Parameters.MaxNum + Parameters.MaxNum); EncodedVectorSizes.Add(Parameters.ValueEncodingSize * Parameters.AttentionHeadNum + 1); TypeDataIndices.Add(SetData.Add(ElementData)); return { Index, Generation }; } FSchemaElement FSchema::CreateEncoding(const FSchemaEncodingParameters Parameters, const FName Tag) { check(IsValid(Parameters.Element)); FEncodingData ElementData; ElementData.ElementIndex = SubElementObjects.Num(); ElementData.EncodingSize = Parameters.EncodingSize; ElementData.LayerNum = Parameters.LayerNum; ElementData.ActivationFunction = Parameters.ActivationFunction; SubElementNames.Add(NAME_None); SubElementObjects.Add(Parameters.Element); const int32 Index = Types.Add(EType::Encoding); Tags.Add(Tag); ObservationVectorSizes.Add(GetObservationVectorSize(Parameters.Element)); EncodedVectorSizes.Add(Parameters.EncodingSize); TypeDataIndices.Add(EncodingData.Add(ElementData)); return { Index, Generation }; } bool FSchema::IsValid(const FSchemaElement Element) const { return Element.Generation == Generation && Element.Index != INDEX_NONE; } EType FSchema::GetType(const FSchemaElement Element) const { check(IsValid(Element)); return Types[Element.Index]; } FName FSchema::GetTag(const FSchemaElement Element) const { check(IsValid(Element)); return Tags[Element.Index]; } int32 FSchema::GetObservationVectorSize(const FSchemaElement Element) const { check(IsValid(Element)); return ObservationVectorSizes[Element.Index]; } int32 FSchema::GetEncodedVectorSize(const FSchemaElement Element) const { check(IsValid(Element)); return EncodedVectorSizes[Element.Index]; } FSchemaContinuousParameters FSchema::GetContinuous(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Continuous); const FContinuousData& ElementData = ContinuousData[TypeDataIndices[Element.Index]]; FSchemaContinuousParameters Parameters; Parameters.Num = ElementData.Num; Parameters.Scale = ElementData.Scale; return Parameters; } FSchemaDiscreteExclusiveParameters FSchema::GetDiscreteExclusive(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::DiscreteExclusive); const FDiscreteExclusiveData& ElementData = DiscreteExclusiveData[TypeDataIndices[Element.Index]]; FSchemaDiscreteExclusiveParameters Parameters; Parameters.Num = ElementData.Num; return Parameters; } FSchemaDiscreteInclusiveParameters FSchema::GetDiscreteInclusive(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::DiscreteInclusive); const FDiscreteInclusiveData& ElementData = DiscreteInclusiveData[TypeDataIndices[Element.Index]]; FSchemaDiscreteInclusiveParameters Parameters; Parameters.Num = ElementData.Num; return Parameters; } FSchemaNamedDiscreteExclusiveParameters FSchema::GetNamedDiscreteExclusive(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteExclusive); const FNamedDiscreteExclusiveData& ElementData = NamedDiscreteExclusiveData[TypeDataIndices[Element.Index]]; FSchemaNamedDiscreteExclusiveParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num); return Parameters; } FSchemaNamedDiscreteInclusiveParameters FSchema::GetNamedDiscreteInclusive(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteInclusive); const FNamedDiscreteInclusiveData& ElementData = NamedDiscreteInclusiveData[TypeDataIndices[Element.Index]]; FSchemaNamedDiscreteInclusiveParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num); return Parameters; } FSchemaAndParameters FSchema::GetAnd(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::And); const FAndData& ElementData = AndData[TypeDataIndices[Element.Index]]; FSchemaAndParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num); Parameters.Elements = TArrayView(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num); return Parameters; } FSchemaOrExclusiveParameters FSchema::GetOrExclusive(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::OrExclusive); const FOrExclusiveData& ElementData = OrExclusiveData[TypeDataIndices[Element.Index]]; FSchemaOrExclusiveParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num); Parameters.Elements = TArrayView(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num); Parameters.EncodingSize = ElementData.EncodingSize; return Parameters; } FSchemaOrInclusiveParameters FSchema::GetOrInclusive(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::OrInclusive); const FOrInclusiveData& ElementData = OrInclusiveData[TypeDataIndices[Element.Index]]; FSchemaOrInclusiveParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + ElementData.ElementsOffset, ElementData.Num); Parameters.Elements = TArrayView(SubElementObjects.GetData() + ElementData.ElementsOffset, ElementData.Num); Parameters.AttentionEncodingSize = ElementData.AttentionEncodingSize; Parameters.AttentionHeadNum = ElementData.AttentionHeadNum; Parameters.ValueEncodingSize = ElementData.ValueEncodingSize; return Parameters; } FSchemaArrayParameters FSchema::GetArray(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Array); const FArrayData& ElementData = ArrayData[TypeDataIndices[Element.Index]]; FSchemaArrayParameters Parameters; Parameters.Num = ElementData.Num; Parameters.Element = SubElementObjects[ElementData.ElementIndex]; return Parameters; } FSchemaSetParameters FSchema::GetSet(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Set); const FSetData& ElementData = SetData[TypeDataIndices[Element.Index]]; FSchemaSetParameters Parameters; Parameters.MaxNum = ElementData.MaxNum; Parameters.Element = SubElementObjects[ElementData.ElementIndex]; Parameters.AttentionEncodingSize = ElementData.AttentionEncodingSize; Parameters.AttentionHeadNum = ElementData.AttentionHeadNum; Parameters.ValueEncodingSize = ElementData.ValueEncodingSize; return Parameters; } FSchemaEncodingParameters FSchema::GetEncoding(const FSchemaElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Encoding); const FEncodingData& ElementData = EncodingData[TypeDataIndices[Element.Index]]; FSchemaEncodingParameters Parameters; Parameters.Element = SubElementObjects[ElementData.ElementIndex]; Parameters.EncodingSize = ElementData.EncodingSize; Parameters.LayerNum = ElementData.LayerNum; Parameters.ActivationFunction = ElementData.ActivationFunction; return Parameters; } uint32 FSchema::GetGeneration() const { return Generation; } void FSchema::Empty() { Types.Empty(); Tags.Empty(); ObservationVectorSizes.Empty(); EncodedVectorSizes.Empty(); TypeDataIndices.Empty(); ContinuousData.Empty(); DiscreteExclusiveData.Empty(); DiscreteInclusiveData.Empty(); NamedDiscreteExclusiveData.Empty(); NamedDiscreteInclusiveData.Empty(); AndData.Empty(); OrExclusiveData.Empty(); OrInclusiveData.Empty(); ArrayData.Empty(); SetData.Empty(); EncodingData.Empty(); SubElementNames.Empty(); SubElementObjects.Empty(); Generation++; } bool FSchema::IsEmpty() const { return Types.IsEmpty(); } void FSchema::Reset() { Types.Reset(); Tags.Reset(); ObservationVectorSizes.Reset(); EncodedVectorSizes.Reset(); TypeDataIndices.Reset(); ContinuousData.Reset(); DiscreteExclusiveData.Reset(); DiscreteInclusiveData.Reset(); NamedDiscreteExclusiveData.Reset(); NamedDiscreteInclusiveData.Reset(); AndData.Reset(); OrExclusiveData.Reset(); OrInclusiveData.Reset(); ArrayData.Reset(); SetData.Reset(); EncodingData.Reset(); SubElementNames.Reset(); SubElementObjects.Reset(); Generation++; } FObjectElement FObject::CreateNull(const FName Tag) { const int32 Index = Types.Add(EType::Null); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(0); return { Index, Generation }; } FObjectElement FObject::CreateContinuous(const FObjectContinuousParameters Parameters, const FName Tag) { const int32 Index = Types.Add(EType::Continuous); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(Parameters.Values.Num()); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(0); ContinuousValues.Append(Parameters.Values); return { Index, Generation }; } FObjectElement FObject::CreateDiscreteExclusive(const FObjectDiscreteExclusiveParameters Parameters, const FName Tag) { const int32 Index = Types.Add(EType::DiscreteExclusive); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(1); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(0); DiscreteValues.Add(Parameters.DiscreteIndex); return { Index, Generation }; } FObjectElement FObject::CreateDiscreteInclusive(const FObjectDiscreteInclusiveParameters Parameters, const FName Tag) { const int32 Index = Types.Add(EType::DiscreteInclusive); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(Parameters.DiscreteIndices.Num()); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(0); DiscreteValues.Append(Parameters.DiscreteIndices); return { Index, Generation }; } FObjectElement FObject::CreateNamedDiscreteExclusive(const FObjectNamedDiscreteExclusiveParameters Parameters, const FName Tag) { const int32 Index = Types.Add(EType::NamedDiscreteExclusive); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(1); SubElementNames.Add(Parameters.ElementName); SubElementObjects.Add(FObjectElement()); return { Index, Generation }; } FObjectElement FObject::CreateNamedDiscreteInclusive(const FObjectNamedDiscreteInclusiveParameters Parameters, const FName Tag) { check(!Private::ContainsDuplicates(Parameters.ElementNames)); const int32 Index = Types.Add(EType::NamedDiscreteInclusive); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(Parameters.ElementNames.Num()); SubElementNames.Append(Parameters.ElementNames); for (int32 Idx = 0; Idx < Parameters.ElementNames.Num(); Idx++) { SubElementObjects.Add(FObjectElement()); } return { Index, Generation }; } FObjectElement FObject::CreateAnd(const FObjectAndParameters Parameters, const FName Tag) { check(Parameters.Elements.Num() == Parameters.ElementNames.Num()); check(!Private::ContainsDuplicates(Parameters.ElementNames)); check(Private::CheckAllValid(*this, Parameters.Elements)); const int32 Index = Types.Add(EType::And); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(Parameters.Elements.Num()); SubElementNames.Append(Parameters.ElementNames); SubElementObjects.Append(Parameters.Elements); return { Index, Generation }; } FObjectElement FObject::CreateOrExclusive(const FObjectOrExclusiveParameters Parameters, const FName Tag) { check(IsValid(Parameters.Element)); const int32 Index = Types.Add(EType::OrExclusive); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(1); SubElementNames.Add(Parameters.ElementName); SubElementObjects.Add(Parameters.Element); return { Index, Generation }; } FObjectElement FObject::CreateOrInclusive(const FObjectOrInclusiveParameters Parameters, const FName Tag) { check(Parameters.Elements.Num() == Parameters.ElementNames.Num()); check(!Private::ContainsDuplicates(Parameters.ElementNames)); check(Private::CheckAllValid(*this, Parameters.Elements)); const int32 Index = Types.Add(EType::OrInclusive); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(Parameters.Elements.Num()); SubElementNames.Append(Parameters.ElementNames); SubElementObjects.Append(Parameters.Elements); return { Index, Generation }; } FObjectElement FObject::CreateArray(const FObjectArrayParameters Parameters, const FName Tag) { check(Private::CheckAllValid(*this, Parameters.Elements)); const int32 Index = Types.Add(EType::Array); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(Parameters.Elements.Num()); for (int32 ElementIdx = 0; ElementIdx < Parameters.Elements.Num(); ElementIdx++) { SubElementNames.Add(NAME_Name); } SubElementObjects.Append(Parameters.Elements); return { Index, Generation }; } FObjectElement FObject::CreateSet(const FObjectSetParameters Parameters, const FName Tag) { check(Private::CheckAllValid(*this, Parameters.Elements)); const int32 Index = Types.Add(EType::Set); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(Parameters.Elements.Num()); for (int32 ElementIdx = 0; ElementIdx < Parameters.Elements.Num(); ElementIdx++) { SubElementNames.Add(NAME_Name); } SubElementObjects.Append(Parameters.Elements); return { Index, Generation }; } FObjectElement FObject::CreateEncoding(const FObjectEncodingParameters Parameters, const FName Tag) { check(IsValid(Parameters.Element)); const int32 Index = Types.Add(EType::Encoding); Tags.Add(Tag); ContinuousDataOffsets.Add(ContinuousValues.Num()); ContinuousDataNums.Add(0); DiscreteDataOffsets.Add(DiscreteValues.Num()); DiscreteDataNums.Add(0); SubElementDataOffsets.Add(SubElementObjects.Num()); SubElementDataNums.Add(1); SubElementNames.Add(NAME_Name); SubElementObjects.Add(Parameters.Element); return { Index, Generation }; } bool FObject::IsValid(const FObjectElement Element) const { return Element.Generation == Generation && Element.Index != INDEX_NONE; } EType FObject::GetType(const FObjectElement Element) const { check(IsValid(Element)); return Types[Element.Index]; } FName FObject::GetTag(const FObjectElement Element) const { check(IsValid(Element)); return Tags[Element.Index]; } FObjectContinuousParameters FObject::GetContinuous(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Continuous); FObjectContinuousParameters Parameters; Parameters.Values = TArrayView(ContinuousValues.GetData() + ContinuousDataOffsets[Element.Index], ContinuousDataNums[Element.Index]); return Parameters; } FObjectDiscreteExclusiveParameters FObject::GetDiscreteExclusive(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::DiscreteExclusive); FObjectDiscreteExclusiveParameters Parameters; Parameters.DiscreteIndex = DiscreteValues[DiscreteDataOffsets[Element.Index]]; return Parameters; } FObjectDiscreteInclusiveParameters FObject::GetDiscreteInclusive(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::DiscreteInclusive); FObjectDiscreteInclusiveParameters Parameters; Parameters.DiscreteIndices = TArrayView(DiscreteValues.GetData() + DiscreteDataOffsets[Element.Index], DiscreteDataNums[Element.Index]); return Parameters; } FObjectNamedDiscreteExclusiveParameters FObject::GetNamedDiscreteExclusive(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteExclusive); FObjectNamedDiscreteExclusiveParameters Parameters; Parameters.ElementName = SubElementNames[SubElementDataOffsets[Element.Index]]; return Parameters; } FObjectNamedDiscreteInclusiveParameters FObject::GetNamedDiscreteInclusive(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::NamedDiscreteInclusive); FObjectNamedDiscreteInclusiveParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]); return Parameters; } FObjectAndParameters FObject::GetAnd(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::And); FObjectAndParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]); Parameters.Elements = TArrayView(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]); return Parameters; } FObjectOrExclusiveParameters FObject::GetOrExclusive(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::OrExclusive); FObjectOrExclusiveParameters Parameters; Parameters.ElementName = SubElementNames[SubElementDataOffsets[Element.Index]]; Parameters.Element = SubElementObjects[SubElementDataOffsets[Element.Index]]; return Parameters; } FObjectOrInclusiveParameters FObject::GetOrInclusive(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::OrInclusive); FObjectOrInclusiveParameters Parameters; Parameters.ElementNames = TArrayView(SubElementNames.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]); Parameters.Elements = TArrayView(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]); return Parameters; } FObjectArrayParameters FObject::GetArray(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Array); FObjectArrayParameters Parameters; Parameters.Elements = TArrayView(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]); return Parameters; } FObjectSetParameters FObject::GetSet(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Set); FObjectSetParameters Parameters; Parameters.Elements = TArrayView(SubElementObjects.GetData() + SubElementDataOffsets[Element.Index], SubElementDataNums[Element.Index]); return Parameters; } FObjectEncodingParameters FObject::GetEncoding(const FObjectElement Element) const { check(IsValid(Element) && GetType(Element) == EType::Encoding); FObjectEncodingParameters Parameters; Parameters.Element = SubElementObjects[SubElementDataOffsets[Element.Index]]; return Parameters; } uint32 FObject::GetGeneration() const { return Generation; } void FObject::Empty() { Types.Empty(); Tags.Empty(); ContinuousDataOffsets.Empty(); ContinuousDataNums.Empty(); DiscreteDataOffsets.Reset(); DiscreteDataNums.Reset(); SubElementDataOffsets.Empty(); SubElementDataNums.Empty(); ContinuousValues.Empty(); DiscreteValues.Empty(); SubElementNames.Empty(); SubElementObjects.Empty(); Generation++; } bool FObject::IsEmpty() const { return Types.IsEmpty(); } void FObject::Reset() { Types.Reset(); Tags.Reset(); ContinuousDataOffsets.Reset(); ContinuousDataNums.Reset(); DiscreteDataOffsets.Reset(); DiscreteDataNums.Reset(); SubElementDataOffsets.Reset(); SubElementDataNums.Reset(); ContinuousValues.Reset(); DiscreteValues.Reset(); SubElementNames.Reset(); SubElementObjects.Reset(); Generation++; } namespace Private { static inline NNE::RuntimeBasic::FModelBuilder::EActivationFunction GetNNEActivationFunction(const EEncodingActivationFunction ActivationFunction) { switch (ActivationFunction) { case EEncodingActivationFunction::ReLU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ReLU; case EEncodingActivationFunction::ELU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ELU; case EEncodingActivationFunction::TanH: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::TanH; case EEncodingActivationFunction::GELU: return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::GELU; default: checkNoEntry(); return NNE::RuntimeBasic::FModelBuilder::EActivationFunction::ReLU; } } static inline int32 HashFNameStable(const FName Name) { const FString NameString = Name.ToString().ToLower(); return (int32)CityHash32( (const char*)NameString.GetCharArray().GetData(), NameString.GetCharArray().GetTypeSize() * NameString.GetCharArray().Num()); } static inline int32 HashInt(const int32 Int) { return (int32)CityHash32((const char*)&Int, sizeof(int32)); } static inline int32 HashCombine(const TArrayView Hashes) { return (int32)CityHash32((const char*)Hashes.GetData(), Hashes.Num() * Hashes.GetTypeSize()); } static inline int32 HashElements( const FSchema& Schema, const TArrayView SchemaElementNames, const int32 Salt) { // Note: Here we xor all entries together. // This makes the hash in invariant to the ordering of names which is actually what we want // since this array is representing a set-like structure and it is fine to pass elements in a different order. int32 Hash = 0x5592716a; for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaElementNames.Num(); SchemaElementIdx++) { Hash ^= HashFNameStable(SchemaElementNames[SchemaElementIdx]); } return Hash; } static inline int32 HashElements( const FSchema& Schema, const TArrayView SchemaElementNames, const TArrayView SchemaElements, const int32 Salt) { // Note: Here we xor all entries together. // This makes the hash in invariant to the ordering of pairs of names and elements // which is actually what we want since these two arrays are representing a map-like // structure and it is fine to pass keys and values in a different order. int32 Hash = 0x5b3bbe4d; for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaElements.Num(); SchemaElementIdx++) { Hash ^= HashCombine({ HashFNameStable(SchemaElementNames[SchemaElementIdx]), GetSchemaObjectsCompatibilityHash(Schema, SchemaElements[SchemaElementIdx], Salt) }); } return Hash; } } int32 GetSchemaObjectsCompatibilityHash( const FSchema& Schema, const FSchemaElement SchemaElement, const int32 Salt) { check(Schema.IsValid(SchemaElement)); const EType SchemaElementType = Schema.GetType(SchemaElement); const int32 Hash = Private::HashCombine({ Salt, Private::HashInt((int32)SchemaElementType) }); switch (SchemaElementType) { case EType::Null: return Hash; case EType::Continuous: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetContinuous(SchemaElement).Num) }); case EType::DiscreteExclusive: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetDiscreteExclusive(SchemaElement).Num) }); case EType::DiscreteInclusive: return Private::HashCombine({ Hash, Private::HashInt(Schema.GetDiscreteInclusive(SchemaElement).Num) }); case EType::NamedDiscreteExclusive: { const FSchemaNamedDiscreteExclusiveParameters Parameters = Schema.GetNamedDiscreteExclusive(SchemaElement); return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Salt) }); } case EType::NamedDiscreteInclusive: { const FSchemaNamedDiscreteInclusiveParameters Parameters = Schema.GetNamedDiscreteInclusive(SchemaElement); return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Salt) }); } case EType::And: { const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement); return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) }); } case EType::OrExclusive: { const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement); return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) }); } case EType::OrInclusive: { const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement); return Private::HashCombine({ Hash, Private::HashElements(Schema, Parameters.ElementNames, Parameters.Elements, Salt) }); } case EType::Array: { const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement); return Private::HashCombine({ Hash, Private::HashInt(Parameters.Num), GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt) }); } case EType::Set: { const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement); return Private::HashCombine({ Hash, Private::HashInt(Parameters.MaxNum), GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt) }); } case EType::Encoding: { const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement); return GetSchemaObjectsCompatibilityHash(Schema, Parameters.Element, Salt); } default: { checkNoEntry(); return 0; } } } bool AreSchemaObjectsCompatible( const FSchema& SchemaA, const FSchemaElement SchemaElementA, const FSchema& SchemaB, const FSchemaElement SchemaElementB) { check(SchemaA.IsValid(SchemaElementA)); check(SchemaB.IsValid(SchemaElementB)); const EType SchemaElementTypeA = SchemaA.GetType(SchemaElementA); const EType SchemaElementTypeB = SchemaB.GetType(SchemaElementB); // If any element is an encoding element we forward the comparison to the sub-element since encoding elements don't affect compatibility if (SchemaElementTypeA == EType::Encoding) { return AreSchemaObjectsCompatible(SchemaA, SchemaA.GetEncoding(SchemaElementA).Element, SchemaB, SchemaElementB); } if (SchemaElementTypeB == EType::Encoding) { return AreSchemaObjectsCompatible(SchemaA, SchemaElementA, SchemaB, SchemaB.GetEncoding(SchemaElementB).Element); } // Otherwise if types don't match we immediately know elements are incompatible if (SchemaElementTypeA != SchemaElementTypeB) { return false; } // This is an early-out since if the input sizes are different we are definitely incompatible if (SchemaA.GetObservationVectorSize(SchemaElementA) != SchemaB.GetObservationVectorSize(SchemaElementB)) { return false; } switch (SchemaElementTypeA) { case EType::Null: return true; case EType::Continuous: return SchemaA.GetContinuous(SchemaElementA).Num == SchemaB.GetContinuous(SchemaElementB).Num; case EType::DiscreteExclusive: return SchemaA.GetDiscreteExclusive(SchemaElementA).Num == SchemaB.GetDiscreteExclusive(SchemaElementB).Num; case EType::DiscreteInclusive: return SchemaA.GetDiscreteInclusive(SchemaElementA).Num == SchemaB.GetDiscreteInclusive(SchemaElementB).Num; case EType::NamedDiscreteExclusive: { const FSchemaNamedDiscreteExclusiveParameters ParametersA = SchemaA.GetNamedDiscreteExclusive(SchemaElementA); const FSchemaNamedDiscreteExclusiveParameters ParametersB = SchemaB.GetNamedDiscreteExclusive(SchemaElementB); if (ParametersA.ElementNames.Num() != ParametersB.ElementNames.Num()) { return false; } for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.ElementNames.Num(); SchemaElementAIdx++) { const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]); if (SchemaElementBIdx == INDEX_NONE) { return false; } } return true; } case EType::NamedDiscreteInclusive: { const FSchemaNamedDiscreteInclusiveParameters ParametersA = SchemaA.GetNamedDiscreteInclusive(SchemaElementA); const FSchemaNamedDiscreteInclusiveParameters ParametersB = SchemaB.GetNamedDiscreteInclusive(SchemaElementB); if (ParametersA.ElementNames.Num() != ParametersB.ElementNames.Num()) { return false; } for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.ElementNames.Num(); SchemaElementAIdx++) { const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]); if (SchemaElementBIdx == INDEX_NONE) { return false; } } return true; } case EType::And: { const FSchemaAndParameters ParametersA = SchemaA.GetAnd(SchemaElementA); const FSchemaAndParameters ParametersB = SchemaB.GetAnd(SchemaElementB); if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; } for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++) { const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]); if (SchemaElementBIdx == INDEX_NONE) { return false; } if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; } } return true; } case EType::OrExclusive: { const FSchemaOrExclusiveParameters ParametersA = SchemaA.GetOrExclusive(SchemaElementA); const FSchemaOrExclusiveParameters ParametersB = SchemaB.GetOrExclusive(SchemaElementB); if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; } for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++) { const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]); if (SchemaElementBIdx == INDEX_NONE) { return false; } if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; } } return true; } case EType::OrInclusive: { const FSchemaOrInclusiveParameters ParametersA = SchemaA.GetOrInclusive(SchemaElementA); const FSchemaOrInclusiveParameters ParametersB = SchemaB.GetOrInclusive(SchemaElementB); if (ParametersA.Elements.Num() != ParametersB.Elements.Num()) { return false; } for (int32 SchemaElementAIdx = 0; SchemaElementAIdx < ParametersA.Elements.Num(); SchemaElementAIdx++) { const int32 SchemaElementBIdx = ParametersB.ElementNames.Find(ParametersA.ElementNames[SchemaElementAIdx]); if (SchemaElementBIdx == INDEX_NONE) { return false; } if (!AreSchemaObjectsCompatible(SchemaA, ParametersA.Elements[SchemaElementAIdx], SchemaB, ParametersB.Elements[SchemaElementBIdx])) { return false; } } return true; } case EType::Array: { const FSchemaArrayParameters ParametersA = SchemaA.GetArray(SchemaElementA); const FSchemaArrayParameters ParametersB = SchemaB.GetArray(SchemaElementB); return (ParametersA.Num == ParametersB.Num) && AreSchemaObjectsCompatible(SchemaA, ParametersA.Element, SchemaB, ParametersB.Element); } case EType::Set: { const FSchemaSetParameters ParametersA = SchemaA.GetSet(SchemaElementA); const FSchemaSetParameters ParametersB = SchemaB.GetSet(SchemaElementB); return (ParametersA.MaxNum == ParametersB.MaxNum) && AreSchemaObjectsCompatible(SchemaA, ParametersA.Element, SchemaB, ParametersB.Element); } case EType::Encoding: { checkf(false, TEXT("Encoding elements should always be forwarded...")); return false; } default: { checkNoEntry(); return false; } } } void MakeEncoderNetworkModelBuilderElementFromSchema( NNE::RuntimeBasic::FModelBuilderElement& OutElement, NNE::RuntimeBasic::FModelBuilder& Builder, const FSchema& Schema, const FSchemaElement SchemaElement, const FNetworkSettings& NetworkSettings) { const EType SchemaElementType = Schema.GetType(SchemaElement); switch (SchemaElementType) { case EType::Null: { OutElement = Builder.MakeCopy(0); break; } case EType::Continuous: { const int32 ValueNum = Schema.GetContinuous(SchemaElement).Num; OutElement = Builder.MakeDenormalize( ValueNum, Builder.MakeValuesZero(ValueNum), Builder.MakeValuesOne(ValueNum)); break; } case EType::DiscreteExclusive: { const int32 ValueNum = Schema.GetDiscreteExclusive(SchemaElement).Num; OutElement = Builder.MakeDenormalize( ValueNum, Builder.MakeValuesZero(ValueNum), Builder.MakeValuesOne(ValueNum)); break; } case EType::DiscreteInclusive: { const int32 ValueNum = Schema.GetDiscreteInclusive(SchemaElement).Num; OutElement = Builder.MakeDenormalize( ValueNum, Builder.MakeValuesZero(ValueNum), Builder.MakeValuesOne(ValueNum)); break; } case EType::NamedDiscreteExclusive: { const int32 ValueNum = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Num(); OutElement = Builder.MakeDenormalize( ValueNum, Builder.MakeValuesZero(ValueNum), Builder.MakeValuesOne(ValueNum)); break; } case EType::NamedDiscreteInclusive: { const int32 ValueNum = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Num(); OutElement = Builder.MakeDenormalize( ValueNum, Builder.MakeValuesZero(ValueNum), Builder.MakeValuesOne(ValueNum)); break; } case EType::And: { const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement); TArray> BuilderLayers; BuilderLayers.Reserve(Parameters.Elements.Num()); for (const FSchemaElement SubElement : Parameters.Elements) { NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement; MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings); BuilderLayers.Emplace(BuilderSubElement); } OutElement = Builder.MakeConcat(BuilderLayers); break; } case EType::OrExclusive: { const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement); TArray> BuilderSubLayers; TArray> BuilderEncoders; BuilderSubLayers.Reserve(Parameters.Elements.Num()); BuilderEncoders.Reserve(Parameters.Elements.Num()); for (const FSchemaElement SubElement : Parameters.Elements) { const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(SubElement); NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement; MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings); BuilderSubLayers.Emplace(BuilderSubElement); NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings; LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ? NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed : NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal; switch (NetworkSettings.WeightInitialization) { case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break; case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break; default: checkNoEntry(); } BuilderEncoders.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.EncodingSize, LinearLayerSettings)); } OutElement = Builder.MakeAggregateOrExclusive(Parameters.EncodingSize, BuilderSubLayers, BuilderEncoders); break; } case EType::OrInclusive: { const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement); TArray> BuilderSubLayers; TArray> BuilderQueryLayers; TArray> BuilderKeyLayers; TArray> BuilderValueLayers; BuilderSubLayers.Reserve(Parameters.Elements.Num()); BuilderQueryLayers.Reserve(Parameters.Elements.Num()); BuilderValueLayers.Reserve(Parameters.Elements.Num()); for (const FSchemaElement SubElement : Parameters.Elements) { const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(SubElement); NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement; MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, SubElement, NetworkSettings); BuilderSubLayers.Emplace(BuilderSubElement); NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings; LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ? NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed : NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal; switch (NetworkSettings.WeightInitialization) { case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break; case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break; default: checkNoEntry(); } BuilderQueryLayers.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings)); BuilderKeyLayers.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings)); BuilderValueLayers.Emplace(Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.ValueEncodingSize, LinearLayerSettings)); } OutElement = Builder.MakeAggregateOrInclusive( Parameters.ValueEncodingSize, Parameters.AttentionEncodingSize, Parameters.AttentionHeadNum, BuilderSubLayers, BuilderQueryLayers, BuilderKeyLayers, BuilderValueLayers); break; } case EType::Array: { const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement); NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement; MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings); OutElement = Builder.MakeArray(Parameters.Num, BuilderSubElement); break; } case EType::Set: { const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement); const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(Parameters.Element); NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement; MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings); NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings; LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ? NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed : NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal; switch (NetworkSettings.WeightInitialization) { case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break; case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break; default: checkNoEntry(); } OutElement = Builder.MakeAggregateSet( Parameters.MaxNum, Parameters.ValueEncodingSize, Parameters.AttentionEncodingSize, Parameters.AttentionHeadNum, BuilderSubElement, Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings), Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.AttentionEncodingSize, LinearLayerSettings), Builder.MakeLinearLayer(SubElementEncodedSize, Parameters.AttentionHeadNum * Parameters.ValueEncodingSize, LinearLayerSettings)); break; } case EType::Encoding: { const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement); const int32 SubElementEncodedSize = Schema.GetEncodedVectorSize(Parameters.Element); NNE::RuntimeBasic::FModelBuilderElement BuilderSubElement; MakeEncoderNetworkModelBuilderElementFromSchema(BuilderSubElement, Builder, Schema, Parameters.Element, NetworkSettings); NNE::RuntimeBasic::FModelBuilder::FLinearLayerSettings LinearLayerSettings; LinearLayerSettings.Type = NetworkSettings.bUseCompressedLinearLayers ? NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Compressed : NNE::RuntimeBasic::FModelBuilder::ELinearLayerType::Normal; switch (NetworkSettings.WeightInitialization) { case EWeightInitialization::KaimingGaussian: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingGaussian; break; case EWeightInitialization::KaimingUniform: LinearLayerSettings.WeightInitializationSettings.Type = NNE::RuntimeBasic::FModelBuilder::EWeightInitializationType::KaimingUniform; break; default: checkNoEntry(); } OutElement = Builder.MakeSequence({ BuilderSubElement, Builder.MakeMLP( SubElementEncodedSize, Parameters.EncodingSize, Parameters.EncodingSize, Parameters.LayerNum + 1, // Add 1 to account for input layer Private::GetNNEActivationFunction(Parameters.ActivationFunction), true, LinearLayerSettings) }); break; } default: { checkNoEntry(); } } checkf(OutElement.GetInputSize() == Schema.GetObservationVectorSize(SchemaElement), TEXT("Encoder Network Input unexpected size for %s. Got %i, expected %i according to Schema."), *Schema.GetTag(SchemaElement).ToString(), OutElement.GetInputSize(), Schema.GetObservationVectorSize(SchemaElement)); checkf(OutElement.GetOutputSize() == Schema.GetEncodedVectorSize(SchemaElement), TEXT("Encoder Network Output unexpected size for %s. Got %i, expected %i according to Schema."), *Schema.GetTag(SchemaElement).ToString(), OutElement.GetOutputSize(), Schema.GetEncodedVectorSize(SchemaElement)); } void GenerateEncoderNetworkFileDataFromSchema( TArray& OutFileData, uint32& OutInputSize, uint32& OutOutputSize, const FSchema& Schema, const FSchemaElement SchemaElement, const FNetworkSettings& NetworkSettings, const uint32 Seed) { check(Schema.IsValid(SchemaElement)); NNE::RuntimeBasic::FModelBuilder Builder(Seed); NNE::RuntimeBasic::FModelBuilderElement Element; MakeEncoderNetworkModelBuilderElementFromSchema(Element, Builder, Schema, SchemaElement, NetworkSettings); Builder.WriteFileDataAndReset(OutFileData, OutInputSize, OutOutputSize, Element); } void SetVectorFromObject( TLearningArrayView<1, float> OutObservationVector, const FSchema& Schema, const FSchemaElement SchemaElement, const FObject& Object, const FObjectElement ObjectElement) { check(Schema.IsValid(SchemaElement)); check(Object.IsValid(ObjectElement)); check(OutObservationVector.Num() == Schema.GetObservationVectorSize(SchemaElement)); // Check that the types match const EType SchemaElementType = Schema.GetType(SchemaElement); const EType ObjectElementType = Object.GetType(ObjectElement); check(ObjectElementType == SchemaElementType); // Zero Observation Vector Array::Zero(OutObservationVector); // Logic for each specific element type switch (SchemaElementType) { case EType::Null: return; case EType::Continuous: { // Check the input sizes match const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement); const TArrayView ObservationValues = Object.GetContinuous(ObjectElement).Values; check(Schema.GetObservationVectorSize(SchemaElement) == ObservationValues.Num()); check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num()); check(Schema.GetObservationVectorSize(SchemaElement) == SchemaParameters.Num); // Copy in and scale the values from the observation object const int32 ValueNum = SchemaParameters.Num; const float ValueScale = FMath::Max(SchemaParameters.Scale, UE_SMALL_NUMBER); for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++) { OutObservationVector[ValueIdx] = ObservationValues[ValueIdx] / ValueScale; } return; } case EType::DiscreteExclusive: { const int32 ObservationValue = Object.GetDiscreteExclusive(ObjectElement).DiscreteIndex; check(Schema.GetObservationVectorSize(SchemaElement) > ObservationValue && ObservationValue >= 0); check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num()); // Set the single value in the observation vector OutObservationVector[ObservationValue] = 1.0f; return; } case EType::DiscreteInclusive: { const TArrayView ObservationValues = Object.GetDiscreteInclusive(ObjectElement).DiscreteIndices; check(Schema.GetObservationVectorSize(SchemaElement) >= ObservationValues.Num()); check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num()); // Set values in the observation vector for (int32 ObservationValueIdx = 0; ObservationValueIdx < ObservationValues.Num(); ObservationValueIdx++) { check(Schema.GetObservationVectorSize(SchemaElement) > ObservationValues[ObservationValueIdx] && ObservationValues[ObservationValueIdx] >= 0); OutObservationVector[ObservationValues[ObservationValueIdx]] = 1.0f; } return; } case EType::NamedDiscreteExclusive: { const TArrayView SchemaNames = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames; const FName ObservationValue = Object.GetNamedDiscreteExclusive(ObjectElement).ElementName; check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num()); // Set the single value in the observation vector const int32 ObservationIndex = SchemaNames.Find(ObservationValue); check(ObservationIndex != INDEX_NONE); OutObservationVector[ObservationIndex] = 1.0f; return; } case EType::NamedDiscreteInclusive: { const TArrayView SchemaNames = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames; const TArrayView ObservationValues = Object.GetNamedDiscreteInclusive(ObjectElement).ElementNames; check(Schema.GetObservationVectorSize(SchemaElement) >= ObservationValues.Num()); check(Schema.GetObservationVectorSize(SchemaElement) == OutObservationVector.Num()); // Set values in the observation vector for (int32 ObservationValueIdx = 0; ObservationValueIdx < ObservationValues.Num(); ObservationValueIdx++) { const int32 ObservationIndex = SchemaNames.Find(ObservationValues[ObservationValueIdx]); check(ObservationIndex != INDEX_NONE); OutObservationVector[ObservationIndex] = 1.0f; } return; } case EType::And: { // Check the number of sub-elements match const FSchemaAndParameters SchemaParameters = Schema.GetAnd(SchemaElement); const FObjectAndParameters ObjectParameters = Object.GetAnd(ObjectElement); check(SchemaParameters.Elements.Num() == ObjectParameters.Elements.Num()); // Update Sub-elements int32 SubElementOffset = 0; for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++) { const int32 ObjectElementIndex = ObjectParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]); check(ObjectElementIndex != INDEX_NONE); const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Elements[SchemaElementIdx]); SetVectorFromObject( OutObservationVector.Slice(SubElementOffset, SubElementSize), Schema, SchemaParameters.Elements[SchemaElementIdx], Object, ObjectParameters.Elements[ObjectElementIndex]); SubElementOffset += SubElementSize; } check(SubElementOffset == OutObservationVector.Num()); return; } case EType::OrExclusive: { // Check only one sub-element is given and index is valid const FSchemaOrExclusiveParameters SchemaParameters = Schema.GetOrExclusive(SchemaElement); const FObjectOrExclusiveParameters ObjectParameters = Object.GetOrExclusive(ObjectElement); const int32 SchemaElementIndex = SchemaParameters.ElementNames.Find(ObjectParameters.ElementName); check(SchemaElementIndex != INDEX_NONE); // Update sub-element const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Elements[SchemaElementIndex]); SetVectorFromObject( OutObservationVector.Slice(0, SubElementSize), Schema, SchemaParameters.Elements[SchemaElementIndex], Object, ObjectParameters.Element); // Set Mask const int32 MaxSubElementSize = Private::GetMaxObservationVectorSize(Schema, SchemaParameters.Elements); OutObservationVector[MaxSubElementSize + SchemaElementIndex] = 1.0f; check(OutObservationVector.Num() == MaxSubElementSize + SchemaParameters.Elements.Num()); return; } case EType::OrInclusive: { // Check all indices are in range const FSchemaOrInclusiveParameters SchemaParameters = Schema.GetOrInclusive(SchemaElement); const FObjectOrInclusiveParameters ObjectParameters = Object.GetOrInclusive(ObjectElement); check(ObjectParameters.Elements.Num() <= SchemaParameters.Elements.Num()); // Update sub-elements int32 SubElementOffset = 0; for (int32 SchemaElementIdx = 0; SchemaElementIdx < SchemaParameters.Elements.Num(); SchemaElementIdx++) { const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Elements[SchemaElementIdx]); const int32 ObjectElementIdx = ObjectParameters.ElementNames.Find(SchemaParameters.ElementNames[SchemaElementIdx]); if (ObjectElementIdx != INDEX_NONE) { SetVectorFromObject( OutObservationVector.Slice(SubElementOffset, SubElementSize), Schema, SchemaParameters.Elements[SchemaElementIdx], Object, ObjectParameters.Elements[ObjectElementIdx]); } SubElementOffset += SubElementSize; } // Set Mask check(SubElementOffset + SchemaParameters.Elements.Num() == OutObservationVector.Num()); for (int32 ObjectElementIdx = 0; ObjectElementIdx < ObjectParameters.Elements.Num(); ObjectElementIdx++) { const int32 SchemaElementIdx = SchemaParameters.ElementNames.Find(ObjectParameters.ElementNames[ObjectElementIdx]); check(SchemaElementIdx != INDEX_NONE); OutObservationVector[SubElementOffset + SchemaElementIdx] = 1.0f; } return; } case EType::Array: { // Check number of array elements is correct const FSchemaArrayParameters SchemaParameters = Schema.GetArray(SchemaElement); const FObjectArrayParameters ObjectParameters = Object.GetArray(ObjectElement); check(SchemaParameters.Num == ObjectParameters.Elements.Num()); // Update sub-elements const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Element); for (int32 ElementIdx = 0; ElementIdx < SchemaParameters.Num; ElementIdx++) { SetVectorFromObject( OutObservationVector.Slice(ElementIdx * SubElementSize, SubElementSize), Schema, SchemaParameters.Element, Object, ObjectParameters.Elements[ElementIdx]); } return; } case EType::Set: { // Check number of set elements is correct const FSchemaSetParameters SchemaParameters = Schema.GetSet(SchemaElement); const FObjectSetParameters ObjectParameters = Object.GetSet(ObjectElement); check(SchemaParameters.MaxNum >= ObjectParameters.Elements.Num()); // Update sub-elements int32 SubElementOffset = 0; const int32 SubElementSize = Schema.GetObservationVectorSize(SchemaParameters.Element); for (int32 ElementIdx = 0; ElementIdx < ObjectParameters.Elements.Num(); ElementIdx++) { SetVectorFromObject( OutObservationVector.Slice(SubElementOffset, SubElementSize), Schema, SchemaParameters.Element, Object, ObjectParameters.Elements[ElementIdx]); SubElementOffset += SubElementSize; } SubElementOffset = SubElementSize * SchemaParameters.MaxNum; // Set Mask Array::Set(OutObservationVector.Slice(SubElementOffset, ObjectParameters.Elements.Num()), 1.0f); check(SubElementOffset + SchemaParameters.MaxNum == OutObservationVector.Num()); return; } case EType::Encoding: { const FSchemaEncodingParameters SchemaParameters = Schema.GetEncoding(SchemaElement); const FObjectEncodingParameters ObjectParameters = Object.GetEncoding(ObjectElement); SetVectorFromObject( OutObservationVector, Schema, SchemaParameters.Element, Object, ObjectParameters.Element); return; } default: { checkNoEntry(); return; } } } void GetObjectFromVector( FObject& OutObject, FObjectElement& OutObjectElement, const FSchema& Schema, const FSchemaElement SchemaElement, const TLearningArrayView<1, const float> ObservationVector) { check(Schema.IsValid(SchemaElement)); // Check that the types match const EType SchemaElementType = Schema.GetType(SchemaElement); const FName SchemaElementTag = Schema.GetTag(SchemaElement); // Get Observation Vector Size const int32 ObservationVectorSize = ObservationVector.Num(); check(ObservationVectorSize == Schema.GetObservationVectorSize(SchemaElement)); // Logic for each specific element type switch (SchemaElementType) { case EType::Null: { OutObjectElement = OutObject.CreateNull(SchemaElementTag); return; } case EType::Continuous: { const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement); check(ObservationVectorSize == SchemaParameters.Num); const int32 ValueNum = SchemaParameters.Num; const float ValueScale = FMath::Max(SchemaParameters.Scale, UE_SMALL_NUMBER); TLearningArray<1, float, TInlineAllocator<32>> ObservationValues; ObservationValues.SetNumUninitialized({ ValueNum }); for (int32 ValueIdx = 0; ValueIdx < ValueNum; ValueIdx++) { ObservationValues[ValueIdx] = ValueScale * ObservationVector[ValueIdx]; } OutObjectElement = OutObject.CreateContinuous({ MakeArrayView(ObservationValues.GetData(), ObservationValues.Num()) }, SchemaElementTag); return; } case EType::DiscreteExclusive: { check(ObservationVectorSize == Schema.GetDiscreteExclusive(SchemaElement).Num); // Find Index int32 ExclusiveIndex = INDEX_NONE; for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++) { check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f); if (ObservationVector[Idx]) { ExclusiveIndex = Idx; break; } } check(ExclusiveIndex != INDEX_NONE); OutObjectElement = OutObject.CreateDiscreteExclusive({ ExclusiveIndex }, SchemaElementTag); return; } case EType::DiscreteInclusive: { check(ObservationVectorSize == Schema.GetDiscreteInclusive(SchemaElement).Num); // Find Indices TArray> InclusiveIndices; InclusiveIndices.Reserve(ObservationVectorSize); for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++) { check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f); if (ObservationVector[Idx]) { InclusiveIndices.Add(Idx); } } OutObjectElement = OutObject.CreateDiscreteInclusive({ InclusiveIndices }, SchemaElementTag); return; } case EType::NamedDiscreteExclusive: { const TArrayView SchemaNames = Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames; check(ObservationVectorSize == Schema.GetNamedDiscreteExclusive(SchemaElement).ElementNames.Num()); // Find Name FName ExclusiveName = NAME_None; for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++) { check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f); if (ObservationVector[Idx]) { ExclusiveName = SchemaNames[Idx]; break; } } check(ExclusiveName != NAME_None); OutObjectElement = OutObject.CreateNamedDiscreteExclusive({ ExclusiveName }, SchemaElementTag); return; } case EType::NamedDiscreteInclusive: { const TArrayView SchemaNames = Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames; check(ObservationVectorSize == Schema.GetNamedDiscreteInclusive(SchemaElement).ElementNames.Num()); // Find Names TArray> InclusiveNames; InclusiveNames.Reserve(ObservationVectorSize); for (int32 Idx = 0; Idx < ObservationVectorSize; Idx++) { check(ObservationVector[Idx] == 0.0f || ObservationVector[Idx] == 1.0f); if (ObservationVector[Idx]) { InclusiveNames.Add(SchemaNames[Idx]); } } OutObjectElement = OutObject.CreateNamedDiscreteInclusive({ InclusiveNames }, SchemaElementTag); return; } case EType::And: { const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement); // Create Sub-elements TArray> SubElements; SubElements.SetNumUninitialized(Parameters.Elements.Num()); int32 SubElementOffset = 0; for (int32 SchemaElementIdx = 0; SchemaElementIdx < Parameters.Elements.Num(); SchemaElementIdx++) { const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIdx]); GetObjectFromVector( OutObject, SubElements[SchemaElementIdx], Schema, Parameters.Elements[SchemaElementIdx], ObservationVector.Slice(SubElementOffset, SubElementSize)); SubElementOffset += SubElementSize; } check(SubElementOffset == ObservationVectorSize); OutObjectElement = OutObject.CreateAnd({ Parameters.ElementNames, SubElements }, SchemaElementTag); return; } case EType::OrExclusive: { const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement); // Find active element const int32 MaxSubElementSize = Private::GetMaxObservationVectorSize(Schema, Parameters.Elements); int32 SchemaElementIndex = INDEX_NONE; for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { check(ObservationVector[MaxSubElementSize + SubElementIdx] == 0.0f || ObservationVector[MaxSubElementSize + SubElementIdx] == 1.0f); if (ObservationVector[MaxSubElementSize + SubElementIdx]) { SchemaElementIndex = SubElementIdx; break; } } check(SchemaElementIndex != INDEX_NONE); // Create sub-element const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIndex]); FObjectElement SubElement; GetObjectFromVector( OutObject, SubElement, Schema, Parameters.Elements[SchemaElementIndex], ObservationVector.Slice(0, SubElementSize)); OutObjectElement = OutObject.CreateOrExclusive({ Parameters.ElementNames[SchemaElementIndex], SubElement }, SchemaElementTag); return; } case EType::OrInclusive: { const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement); // Find total sub-element size const int32 TotalSubElementSize = Private::GetTotalObservationVectorSize(Schema, Parameters.Elements); // Create sub-elements TArray> SubElementNames; TArray> SubElements; SubElementNames.Reserve(Parameters.Elements.Num()); SubElements.Reserve(Parameters.Elements.Num()); int32 SubElementOffset = 0; for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SubElementIdx]); check( ObservationVector[TotalSubElementSize + SubElementIdx] == 0.0f || ObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f); if (ObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f) { FObjectElement SubElement; GetObjectFromVector( OutObject, SubElement, Schema, Parameters.Elements[SubElementIdx], ObservationVector.Slice(SubElementOffset, SubElementSize)); SubElementNames.Add(Parameters.ElementNames[SubElementIdx]); SubElements.Add(SubElement); } SubElementOffset += SubElementSize; } check(SubElementOffset + Parameters.Elements.Num() == ObservationVectorSize); OutObjectElement = OutObject.CreateOrInclusive({ SubElementNames, SubElements }, SchemaElementTag); return; } case EType::Array: { const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement); TArray> SubElements; SubElements.SetNumUninitialized(Parameters.Num); // Create sub-elements const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element); for (int32 ElementIdx = 0; ElementIdx < Parameters.Num; ElementIdx++) { GetObjectFromVector( OutObject, SubElements[ElementIdx], Schema, Parameters.Element, ObservationVector.Slice(ElementIdx * SubElementSize, SubElementSize)); } OutObjectElement = OutObject.CreateArray({ SubElements }, SchemaElementTag); return; } case EType::Set: { const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement); const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element); // Create sub-elements TArray> SubElements; SubElements.Reserve(Parameters.MaxNum); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.MaxNum; SubElementIdx++) { check( ObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f || ObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 1.0f); if (ObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f) { break; } FObjectElement SubElement; GetObjectFromVector( OutObject, SubElement, Schema, Parameters.Element, ObservationVector.Slice(SubElementIdx * SubElementSize, SubElementSize)); SubElements.Add(SubElement); } OutObjectElement = OutObject.CreateSet({ SubElements }, SchemaElementTag); return; } case EType::Encoding: { const FSchemaEncodingParameters Parameters = Schema.GetEncoding(SchemaElement); FObjectElement SubElement; GetObjectFromVector( OutObject, SubElement, Schema, Parameters.Element, ObservationVector); OutObjectElement = OutObject.CreateEncoding({ SubElement }, SchemaElementTag); return; } default: { checkNoEntry(); OutObjectElement = FObjectElement(); return; } } } void AddGaussianNoiseToVector( uint32& InOutRandomState, TLearningArrayView<1, float> InOutObservationVector, const FSchema& Schema, const FSchemaElement SchemaElement, const float NoiseScale) { check(Schema.IsValid(SchemaElement)); const EType SchemaElementType = Schema.GetType(SchemaElement); const int32 ObservationVectorSize = InOutObservationVector.Num(); check(ObservationVectorSize == Schema.GetObservationVectorSize(SchemaElement)); switch (SchemaElementType) { case EType::Null: { return; } case EType::Continuous: { const FSchemaContinuousParameters SchemaParameters = Schema.GetContinuous(SchemaElement); check(ObservationVectorSize == SchemaParameters.Num); TLearningArray<1, float, TInlineAllocator<32>> NoiseValues; NoiseValues.SetNumUninitialized({ SchemaParameters.Num }); Random::SampleGaussianArray(NoiseValues, InOutRandomState, 0.0f, NoiseScale); for (int32 ValueIdx = 0; ValueIdx < SchemaParameters.Num; ValueIdx++) { InOutObservationVector[ValueIdx] += NoiseValues[ValueIdx]; } return; } case EType::DiscreteExclusive: { return; } case EType::DiscreteInclusive: { return; } case EType::NamedDiscreteExclusive: { return; } case EType::NamedDiscreteInclusive: { return; } case EType::And: { const FSchemaAndParameters Parameters = Schema.GetAnd(SchemaElement); int32 SubElementOffset = 0; for (int32 SchemaElementIdx = 0; SchemaElementIdx < Parameters.Elements.Num(); SchemaElementIdx++) { const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIdx]); AddGaussianNoiseToVector( InOutRandomState, InOutObservationVector.Slice(SubElementOffset, SubElementSize), Schema, Parameters.Elements[SchemaElementIdx], NoiseScale); SubElementOffset += SubElementSize; } check(SubElementOffset == ObservationVectorSize); return; } case EType::OrExclusive: { const FSchemaOrExclusiveParameters Parameters = Schema.GetOrExclusive(SchemaElement); const int32 MaxSubElementSize = Private::GetMaxObservationVectorSize(Schema, Parameters.Elements); int32 SchemaElementIndex = INDEX_NONE; for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { check(InOutObservationVector[MaxSubElementSize + SubElementIdx] == 0.0f || InOutObservationVector[MaxSubElementSize + SubElementIdx] == 1.0f); if (InOutObservationVector[MaxSubElementSize + SubElementIdx]) { SchemaElementIndex = SubElementIdx; break; } } check(SchemaElementIndex != INDEX_NONE); const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SchemaElementIndex]); AddGaussianNoiseToVector( InOutRandomState, InOutObservationVector.Slice(0, SubElementSize), Schema, Parameters.Elements[SchemaElementIndex], NoiseScale); return; } case EType::OrInclusive: { const FSchemaOrInclusiveParameters Parameters = Schema.GetOrInclusive(SchemaElement); const int32 TotalSubElementSize = Private::GetTotalObservationVectorSize(Schema, Parameters.Elements); int32 SubElementOffset = 0; for (int32 SubElementIdx = 0; SubElementIdx < Parameters.Elements.Num(); SubElementIdx++) { const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Elements[SubElementIdx]); check( InOutObservationVector[TotalSubElementSize + SubElementIdx] == 0.0f || InOutObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f); if (InOutObservationVector[TotalSubElementSize + SubElementIdx] == 1.0f) { AddGaussianNoiseToVector( InOutRandomState, InOutObservationVector.Slice(SubElementOffset, SubElementSize), Schema, Parameters.Elements[SubElementIdx], NoiseScale); } SubElementOffset += SubElementSize; } check(SubElementOffset + Parameters.Elements.Num() == ObservationVectorSize); return; } case EType::Array: { const FSchemaArrayParameters Parameters = Schema.GetArray(SchemaElement); const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element); for (int32 ElementIdx = 0; ElementIdx < Parameters.Num; ElementIdx++) { AddGaussianNoiseToVector( InOutRandomState, InOutObservationVector.Slice(ElementIdx* SubElementSize, SubElementSize), Schema, Parameters.Element, NoiseScale); } return; } case EType::Set: { const FSchemaSetParameters Parameters = Schema.GetSet(SchemaElement); const int32 SubElementSize = Schema.GetObservationVectorSize(Parameters.Element); for (int32 SubElementIdx = 0; SubElementIdx < Parameters.MaxNum; SubElementIdx++) { check( InOutObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f || InOutObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 1.0f); if (InOutObservationVector[SubElementSize * Parameters.MaxNum + SubElementIdx] == 0.0f) { break; } AddGaussianNoiseToVector( InOutRandomState, InOutObservationVector.Slice(SubElementIdx * SubElementSize, SubElementSize), Schema, Parameters.Element, NoiseScale); } return; } case EType::Encoding: { AddGaussianNoiseToVector( InOutRandomState, InOutObservationVector, Schema, Schema.GetEncoding(SchemaElement).Element, NoiseScale); return; } default: { checkNoEntry(); return; } } } }