2025-09-17 16:40:02 -07:00

550 lines
16 KiB
C#

using AssetRipper.AssemblyDumper.Utils;
using RangeClassList = System.Collections.Generic.List<System.Collections.Generic.KeyValuePair<AssetRipper.Numerics.Range<AssetRipper.Primitives.UnityVersion>, AssetRipper.AssemblyDumper.UniversalClass>>;
namespace AssetRipper.AssemblyDumper.Passes;
internal static class Pass005_SplitAbstractClasses
{
private static readonly HashSet<int> processedClasses = new();
private const int MaxRunCount = 1000;
private const float MinimumProportion = 0.7f;
public static void DoPass()
{
//ListAbstractClassIds();
AssignInheritance();
DoOtherStuff();
}
private static void ListAbstractClassIds()
{
HashSet<int> abstractIds = GetAbstractClassIds();
foreach (int abstractId in abstractIds.OrderBy(i => i))
{
VersionedList<UniversalClass> list = SharedState.Instance.ClassInformation[abstractId];
if (list.All(c => c.Value?.IsAbstract ?? true))
{
Console.WriteLine($"\t{abstractId} abstract");
}
else
{
Console.WriteLine($"\t{abstractId} abstract sometimes");
}
}
}
private static void DoOtherStuff()
{
HashSet<int> abstractIds = GetAbstractClassIds();
int runCount = 0;
while (abstractIds.Count > 0)
{
foreach (int abstractId in abstractIds.ToList())
{
VersionedList<UniversalClass> abstractClassList = SharedState.Instance.ClassInformation[abstractId];
if (abstractClassList.AnyDerivedAbstractAndUnprocessed())
{
continue;
}
else
{
RangeClassList rangeClassList = abstractClassList.MakeRangeClassList();
List<Section> sections = rangeClassList.MakeSectionList();
UnifySectionFields(sections, rangeClassList);
sections = GetMergedSections(sections);
sections.ReplaceClassesWithClones();
sections.FixInheritance();
sections.ApplyApprovedFields();
abstractClassList.UpdateWithSectionData(sections);
abstractIds.Remove(abstractId);
processedClasses.Add(abstractId);
}
}
runCount++;
if (runCount >= MaxRunCount)
{
throw new Exception("Hit max run count");
}
}
}
private static void ReplaceClassesWithClones(this List<Section> sectionList)
{
foreach (Section section in sectionList)
{
UniversalClass? baseClass = section.Class.BaseClass;
if (baseClass is not null && baseClass.DerivedClasses.Contains(section.Class))
{
baseClass.DerivedClasses.Remove(section.Class);
}
section.Class = section.Class.DeepClone();
if (baseClass is not null)
{
baseClass.DerivedClasses.Add(section.Class);
}
}
}
private static void FixInheritance(this List<Section> sectionList)
{
foreach (Section section in sectionList)
{
section.Class.DerivedClasses.Clear();
foreach (UniversalClass derivedClass in section.DerivedClasses)
{
derivedClass.BaseClass = null;
}
}
foreach (Section section in sectionList)
{
foreach (UniversalClass derivedClass in section.DerivedClasses)
{
if (derivedClass.BaseClass is null)
{
derivedClass.BaseClass = section.Class;
section.Class.DerivedClasses.Add(derivedClass);
}
}
}
}
private static void ApplyApprovedFields(this List<Section> sectionList)
{
foreach (Section section in sectionList)
{
section.Class.InitializeRootNodes();
foreach ((_, (UniversalNode? releaseNode, UniversalNode? editorNode)) in section.ApprovedFields)
{
if (releaseNode is not null)
{
section.Class.ReleaseRootNode!.SubNodes.Add(releaseNode);
}
if (editorNode is not null)
{
section.Class.EditorRootNode!.SubNodes.Add(editorNode);
}
}
}
}
private static void InitializeRootNodes(this UniversalClass abstractClass)
{
abstractClass.EditorRootNode ??= new()
{
Name = "Base",
OriginalName = "Base",
TypeName = abstractClass.Name,
SubNodes = new(),
Version = 1,
};
abstractClass.ReleaseRootNode ??= new()
{
Name = "Base",
OriginalName = "Base",
TypeName = abstractClass.Name,
SubNodes = new(),
Version = 1,
};
}
private static void UpdateWithSectionData(this VersionedList<UniversalClass> versionedList, List<Section> sectionList)
{
List<KeyValuePair<UnityVersion, UniversalClass?>> originalList = versionedList.ToList();
versionedList.Clear();
int i = 0, j = 0;
while (i < originalList.Count || j < sectionList.Count)
{
(UnityVersion originalStart, UniversalClass? originalClass) = originalList[i];
if (originalClass is null || !originalClass.IsAbstract)
{
versionedList.Add(originalStart, originalClass);
i++;
}
else
{
UnityVersion originalEnd = i == originalList.Count - 1 ? UnityVersion.MaxVersion : originalList[i + 1].Key;
Section currentSection = sectionList[j];
if (originalEnd <= currentSection.Range.Start)
{
i++;
}
else if (originalStart <= currentSection.Range.Start && currentSection.Range.End <= originalEnd)
{
versionedList.Add(currentSection.Range.Start, currentSection.Class);
j++;
if (originalEnd == currentSection.Range.End)
{
i++;
}
}
else
{
throw new InvalidOperationException();
}
}
}
}
private static List<Section> GetMergedSections(List<Section> sectionList)
{
List<Section> mergedSections = new();
Section currentSection = sectionList[0].Clone();
for (int i = 1; i < sectionList.Count; i++)
{
Section nextSection = sectionList[i];
if (CanBeMerged(currentSection, nextSection))
{
currentSection.Range = currentSection.Range.MakeUnion(nextSection.Range);
currentSection.DerivedClasses.AddRange(nextSection.DerivedClasses);
}
else
{
mergedSections.Add(currentSection);
currentSection = nextSection.Clone();
}
}
mergedSections.Add(currentSection);
return mergedSections;
}
private static void UnifySectionFields(List<Section> sections, RangeClassList rangeClassList)
{
foreach (string fieldName in rangeClassList.GetAllFieldNames())
{
int count = 0;
double proportionSum = 0;
foreach (Section section in sections)
{
if (section.HasField(fieldName, out float proportion))
{
count++;
proportionSum += proportion;
}
}
float averageProportion = (float)(proportionSum / count);
if (averageProportion <= MinimumProportion)
{
continue;
}
bool useField = true;
Dictionary<Section, (UniversalNode?, UniversalNode?)> nodeDictionary = new();
foreach (Section section in sections)
{
if (!useField)
{
break;
}
UniversalNode? releaseNode = section.DerivedClasses
.SelectMany(c => c.ReleaseRootNode?.SubNodes ?? Enumerable.Empty<UniversalNode>())
.FirstOrDefault(n => n.Name == fieldName);
if (releaseNode is not null)
{
foreach (UniversalClass derivedClass in section.DerivedClasses)
{
if (derivedClass.ReleaseRootNode is not null
&& derivedClass.ReleaseRootNode.TryGetSubNodeByName(fieldName, out UniversalNode? derivedNode)
&& !UniversalNodeComparer.Equals(releaseNode, derivedNode, false))
{
useField = false;
break;
}
}
}
UniversalNode? editorNode = section.DerivedClasses
.SelectMany(c => c.EditorRootNode?.SubNodes ?? Enumerable.Empty<UniversalNode>())
.FirstOrDefault(n => n.Name == fieldName);
if (editorNode is not null)
{
foreach (UniversalClass derivedClass in section.DerivedClasses)
{
if (derivedClass.EditorRootNode is not null
&& derivedClass.EditorRootNode.TryGetSubNodeByName(fieldName, out UniversalNode? derivedNode)
&& !UniversalNodeComparer.Equals(editorNode, derivedNode, false))
{
useField = false;
break;
}
}
}
nodeDictionary.Add(section, (releaseNode, editorNode));
}
if (useField)
{
foreach ((Section section, (UniversalNode?, UniversalNode?) pair) in nodeDictionary)
{
section.ApprovedFields.Add(fieldName, pair);
}
}
}
}
private static List<Section> MakeSectionList(this RangeClassList rangeClassList)
{
List<UnityVersion> versions = rangeClassList.GetAllUnityVersions();
Dictionary<UniversalClass, UnityVersionRange> derivedRangeDictionary = rangeClassList
.SelectMany(pair => pair.Value.DerivedClasses)
.Distinct()
.ToDictionary(derived => derived, derived => GetRangeForClass(derived));
List<Section> sections = new();
int i = 0, j = 0;
while (i < rangeClassList.Count && j < versions.Count)
{
UnityVersionRange originalRange = rangeClassList[i].Key;
UnityVersion currentStart = versions[j];
UnityVersion currentEnd = j == versions.Count - 1 ? UnityVersion.MaxVersion : versions[j + 1];
if (currentEnd <= originalRange.Start)
{
j++;
}
else if (originalRange.End <= currentStart)
{
i++;
}
else if (originalRange.Start <= currentStart && currentEnd <= originalRange.End)
{
sections.Add(Section.Create(rangeClassList[i].Value, new UnityVersionRange(currentStart, currentEnd), derivedRangeDictionary));
j++;
}
else
{
throw new InvalidOperationException();
}
}
return sections;
}
private static HashSet<string> GetAllFieldNames(this RangeClassList rangeClassList)
{
HashSet<string> fieldNames = new();
foreach ((_, UniversalClass universalClass) in rangeClassList)
{
foreach (UniversalClass derivedClass in universalClass.DerivedClasses)
{
if (derivedClass.ReleaseRootNode is not null)
{
foreach (UniversalNode subnode in derivedClass.ReleaseRootNode.SubNodes)
{
if (universalClass.ReleaseRootNode is null || !universalClass.ReleaseRootNode.TryGetSubNodeByName(subnode.Name, out _))
{
fieldNames.Add(subnode.Name);
}
}
}
if (derivedClass.EditorRootNode is not null)
{
foreach (UniversalNode subnode in derivedClass.EditorRootNode.SubNodes)
{
if (universalClass.EditorRootNode is null || !universalClass.EditorRootNode.TryGetSubNodeByName(subnode.Name, out _))
{
fieldNames.Add(subnode.Name);
}
}
}
}
}
return fieldNames;
}
private static List<UnityVersion> GetAllUnityVersions(this RangeClassList rangeClassList)
{
UnityVersion minimumVersion = rangeClassList[0].Key.Start;
UnityVersion maximumVersion = rangeClassList[rangeClassList.Count - 1].Key.End;
HashSet<UnityVersion> versionHashSet = new();
foreach ((UnityVersionRange range, UniversalClass universalClass) in rangeClassList)
{
versionHashSet.Add(range.Start);
versionHashSet.Add(range.End);
foreach (UniversalClass derivedClass in universalClass.DerivedClasses)
{
UnityVersionRange derivedRange = GetRangeForClass(derivedClass);
if (minimumVersion < derivedRange.Start)
{
versionHashSet.Add(derivedRange.Start);
}
if (derivedRange.End < maximumVersion)
{
versionHashSet.Add(derivedRange.End);
}
}
}
List<UnityVersion> versionList = new List<UnityVersion>(versionHashSet.Count);
versionList.AddRange(versionHashSet);
versionList.Sort();
return versionList;
}
private static UnityVersionRange GetRangeForClass(UniversalClass universalClass)
{
return SharedState.Instance.ClassInformation[universalClass.TypeID].GetRangeForItem(universalClass);
}
private static RangeClassList MakeRangeClassList(this VersionedList<UniversalClass> abstractClassList)
{
RangeClassList result = new();
for (int i = 0; i < abstractClassList.Count; i++)
{
UniversalClass? universalClass = abstractClassList[i].Value;
if (universalClass is not null && universalClass.IsAbstract)
{
result.Add(new KeyValuePair<UnityVersionRange, UniversalClass>(abstractClassList.GetRange(i), universalClass));
}
}
return result;
}
private static bool AnyDerivedAbstractAndUnprocessed(this VersionedList<UniversalClass> abstractClassList)
{
return abstractClassList
.Select(pair => pair.Value)
.Where(universalClass => universalClass is not null && universalClass.IsAbstract)
.SelectMany(universalClass => universalClass!.DerivedClasses)
.Any(derivedClass => derivedClass.IsAbstract && !processedClasses.Contains(derivedClass.TypeID));
}
private static HashSet<int> GetAbstractClassIds()
{
return SharedState.Instance.ClassInformation
.Where(dictPair => dictPair.Value.Any(listPair => listPair.Value?.IsAbstract ?? false))
.Select(pair => pair.Key)
.ToHashSet();
}
private static void AssignInheritance()
{
foreach ((_, VersionedList<UniversalClass> list) in SharedState.Instance.ClassInformation)
{
foreach ((UnityVersion startVersion, UniversalClass? universalClass) in list)
{
if (!string.IsNullOrEmpty(universalClass?.BaseString))
{
UniversalClass baseClass = GetClass(universalClass.BaseString, startVersion);
universalClass.BaseClass = baseClass;
baseClass.DerivedClasses.Add(universalClass);
}
}
}
}
private static UniversalClass GetClass(string name, UnityVersion version)
{
return SharedState.Instance.NameToTypeID[name]
.Select(id => SharedState.Instance.ClassInformation[id].TryFindMatch(name, version))
.Where(c => c is not null)
.Single()!;
}
private static UniversalClass? TryFindMatch(this VersionedList<UniversalClass> list, string name, UnityVersion version)
{
UniversalClass? result = list.GetItemForVersion(version);
return result is not null && result.Name == name ? result : null;
}
private static bool CanBeMerged(Section section1, Section section2)
{
if (!section1.Class.Equals(section2.Class))
{
return false;
}
else if (!section1.Range.CanUnion(section2.Range))
{
return false;
}
else if (section1.ApprovedFields.Count != section2.ApprovedFields.Count)
{
return false;
}
foreach ((string fieldName, (UniversalNode? releaseNode1, UniversalNode? editorNode1)) in section1.ApprovedFields)
{
if (section2.ApprovedFields.TryGetValue(fieldName, out (UniversalNode?, UniversalNode?) pair2))
{
if (!UniversalNodeComparer.Equals(releaseNode1, pair2.Item1, false))
{
return false;
}
if (!UniversalNodeComparer.Equals(editorNode1, pair2.Item2, false))
{
return false;
}
}
}
return true;
}
private static void AddRange<T>(this HashSet<T> hashset, IEnumerable<T> enumerable)
{
foreach (T item in enumerable)
{
hashset.Add(item);
}
}
private class Section
{
public Section(UniversalClass @class, UnityVersionRange range, HashSet<UniversalClass> derivedClasses)
{
Class = @class;
Range = range;
DerivedClasses = derivedClasses;
}
public static Section Create(UniversalClass @class, UnityVersionRange range, Dictionary<UniversalClass, UnityVersionRange> dictionary)
{
HashSet<UniversalClass> derivedClasses = @class.DerivedClasses.Where(derived => dictionary[derived].Contains(range)).ToHashSet();
return new Section(@class, range, derivedClasses);
}
public UniversalClass Class { get; set; }
public UnityVersionRange Range { get; set; }
public HashSet<UniversalClass> DerivedClasses { get; }
/// <summary>
/// FieldName : (ReleaseNode?, EditorNode?)
/// </summary>
public Dictionary<string, (UniversalNode?, UniversalNode?)> ApprovedFields { get; } = new();
public Section Clone()
{
HashSet<UniversalClass> newDerivedClasses = new(DerivedClasses.Count);
newDerivedClasses.AddRange(DerivedClasses);
Section newSection = new Section(Class, Range, newDerivedClasses);
foreach ((string fieldName, (UniversalNode?, UniversalNode?) pair) in ApprovedFields)
{
newSection.ApprovedFields.Add(fieldName, pair);
}
return newSection;
}
public bool HasField(string fieldName, out float proportion)
{
int sum = 0;
foreach (UniversalClass derivedClass in DerivedClasses)
{
if (derivedClass.ReleaseRootNode?.TryGetSubNodeByName(fieldName) is not null
|| derivedClass.EditorRootNode?.TryGetSubNodeByName(fieldName) is not null)
{
sum++;
}
}
proportion = (float)sum / DerivedClasses.Count;
return sum > 0;
}
public override string ToString()
{
return $"{Class.Name} {Range}";
}
}
}