AssetEqualityComparer

This commit is contained in:
ds5678 2024-07-02 22:00:51 -07:00
parent 3c9275f31d
commit c25b083e6c
6 changed files with 181 additions and 3 deletions

View File

@ -7,7 +7,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="AssetRipper.SourceGenerated" Version="1.0.17" />
<PackageReference Include="AssetRipper.SourceGenerated" Version="1.0.18" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
<PackageReference Include="NUnit" Version="4.1.0" />
<PackageReference Include="NUnit3TestAdapter" Version="4.5.0" />

View File

@ -0,0 +1,161 @@
using AssetRipper.Assets.Metadata;
namespace AssetRipper.Assets.Cloning
{
public sealed class AssetEqualityComparer : IEqualityComparer<IUnityObjectBase>
{
private static readonly Dictionary<UnorderedPair, bool> compareCache = new();
private static readonly Dictionary<UnorderedPair, List<UnorderedPair>> dependentEqualityPairs = new();
public IUnityObjectBase CallingObject { get; private set; } = default!;
public IUnityObjectBase OtherObject { get; private set; } = default!;
/// <summary>
/// Used for source generation.
/// </summary>
/// <param name="pptrFromCallingObject"></param>
/// <param name="pptrFromOtherObject"></param>
/// <returns>True if they're equal, false if they're inequal, or null if it was added to the list of dependent pairs.</returns>
public bool? MaybeAddDependentComparison(IPPtr pptrFromCallingObject, IPPtr pptrFromOtherObject)
{
IUnityObjectBase? x = CallingObject.Collection.TryGetAsset(pptrFromCallingObject.FileID, pptrFromCallingObject.PathID);
IUnityObjectBase? y = OtherObject.Collection.TryGetAsset(pptrFromOtherObject.FileID, pptrFromOtherObject.PathID);
if (ReferenceEquals(x, y)) //Both null or both same instance
{
return true;
}
else if (x is null || y is null || x.GetType() != y.GetType())
{
return false;
}
else if (compareCache.TryGetValue((x, y), out bool value))
{
return value;
}
else
{
if (dependentEqualityPairs.TryGetValue((CallingObject, OtherObject), out List<UnorderedPair>? list))
{
list.Add((x, y));
}
else
{
list = [(x, y)];
dependentEqualityPairs.Add((CallingObject, OtherObject), list);
}
return null;
}
}
public bool Equals(IUnityObjectBase? x, IUnityObjectBase? y)
{
if (ReferenceEquals(x, y)) //Both null or both same instance
{
return true;
}
else if (x is null || y is null || x.GetType() != y.GetType())
{
return false;
}
DoComparison(x, y);
EvaluateDependentEqualityComparisons();
return compareCache[(x, y)];
}
private void EvaluateDependentEqualityComparisons()
{
if (dependentEqualityPairs.Count == 0)
{
return;
}
List<UnorderedPair> pairsToCompare = new();
bool hasChanged;
do
{
hasChanged = false;
pairsToCompare.Clear();
foreach ((UnorderedPair keyPair, List<UnorderedPair> list) in dependentEqualityPairs.ToArray())
{
for (int i = list.Count - 1; i >= 0; i--)
{
UnorderedPair valuePair = list[i];
if (compareCache.TryGetValue(valuePair, out bool value))
{
compareCache[keyPair] = value;
dependentEqualityPairs.Remove(keyPair);
hasChanged = true;
break;
}
else if (!dependentEqualityPairs.ContainsKey(valuePair))
{
pairsToCompare.Add(valuePair);
}
}
}
if (pairsToCompare.Count > 0)
{
hasChanged = true;
foreach (UnorderedPair pair in pairsToCompare)
{
DoComparison(pair.First, pair.Second);
}
}
} while (hasChanged);
}
private void DoComparison(IUnityObjectBase x, IUnityObjectBase y)
{
CallingObject = x;
OtherObject = y;
bool? result = x.AddToEqualityComparer(y, this);
if (result is { } value)
{
dependentEqualityPairs.Remove((x, y));
compareCache[(x, y)] = value;
}
CallingObject = default!;
OtherObject = default!;
}
public int GetHashCode(IUnityObjectBase? obj)
{
if (obj == null)
{
return 0;
}
return HashCode.Combine(obj.GetType(), obj.GetBestName());
}
private readonly record struct UnorderedPair(IUnityObjectBase First, IUnityObjectBase Second)
{
public bool Equals(UnorderedPair other)
{
return (First == other.First && Second == other.Second) || (First == other.Second && Second == other.First);
}
public override int GetHashCode()
{
return First.GetHashCode() ^ Second.GetHashCode();
}
public static implicit operator UnorderedPair((IUnityObjectBase First, IUnityObjectBase Second) pair)
{
return new(pair.First, pair.Second);
}
public static implicit operator (IUnityObjectBase, IUnityObjectBase)(UnorderedPair pair)
{
return (pair.First, pair.Second);
}
}
}
}

View File

@ -1,6 +1,6 @@
namespace AssetRipper.Assets.Generics
{
public sealed class AssetPair<TKey, TValue> : AccessPairBase<TKey, TValue>
public sealed class AssetPair<TKey, TValue> : AccessPairBase<TKey, TValue>, IEquatable<AssetPair<TKey, TValue>>
where TKey : notnull, new()
where TValue : notnull, new()
{
@ -49,6 +49,8 @@
}
}
bool IEquatable<AssetPair<TKey, TValue>>.Equals(AssetPair<TKey, TValue>? other) => Equals(other);
public static implicit operator KeyValuePair<TKey, TValue>(AssetPair<TKey, TValue> pair)
{
return new KeyValuePair<TKey, TValue>(pair.Key, pair.Value);

View File

@ -38,6 +38,16 @@ public interface IUnityAssetBase : IEndianSpanReadable, IAssetWritable
/// <param name="walker">A walker for traversal.</param>
void WalkStandard(AssetWalker walker);
IEnumerable<(string, PPtr)> FetchDependencies();
/// <summary>
/// Compares this object to another object for deep value equality.
/// </summary>
/// <remarks>
/// <paramref name="other"/> is expected to be not null and of the same type as this object.
/// </remarks>
/// <param name="other">The other object.</param>
/// <param name="comparer">The <see cref="AssetEqualityComparer"/> to which any dependent comparisons are added.</param>
/// <returns>Null if it could not be immediately determined</returns>
bool? AddToEqualityComparer(IUnityAssetBase other, AssetEqualityComparer comparer);
}
public static class UnityAssetBaseExtensions
{

View File

@ -55,6 +55,11 @@ public abstract class UnityAssetBase : IUnityAssetBase
}
}
public virtual bool? AddToEqualityComparer(IUnityAssetBase other, AssetEqualityComparer comparer)
{
throw MethodNotSupported();
}
private NotSupportedException MethodNotSupported([CallerMemberName] string? methodName = null)
{
return new NotSupportedException($"{methodName} is not supported for {GetType().FullName}");

View File

@ -13,7 +13,7 @@
<ItemGroup>
<PackageReference Include="AssetRipper.Checksum" Version="1.0.0" />
<PackageReference Include="AssetRipper.SourceGenerated" Version="1.0.17" />
<PackageReference Include="AssetRipper.SourceGenerated" Version="1.0.18" />
</ItemGroup>
</Project>