Smart enum generator

This commit is contained in:
ds5678 2025-03-30 22:42:03 -07:00
parent fab3242a75
commit cfd2d822c7
9 changed files with 476 additions and 34 deletions

View File

@ -1,9 +1,10 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<IsTrimmable>true</IsTrimmable>
<OutputPath>..\0Bins\Other\AssetRipper.IO.Files\$(Configuration)\</OutputPath>
<IntermediateOutputPath>..\0Bins\obj\AssetRipper.IO.Files\$(Configuration)\</IntermediateOutputPath>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
</PropertyGroup>
<ItemGroup>
@ -14,4 +15,8 @@
<PackageReference Include="ZstdSharp.Port" Version="0.8.5" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AssetRipper.SmartEnums\AssetRipper.SmartEnums.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
</Project>

View File

@ -1,9 +1,14 @@
namespace AssetRipper.IO.Files
using AssetRipper.SmartEnums;
namespace AssetRipper.IO.Files;
[SmartEnum]
public readonly partial record struct BuildTarget
{
/// <summary>
/// <see href="https://github.com/Unity-Technologies/UnityCsReference/blob/master/Editor/Mono/BuildTarget.cs"/>
/// </summary>
public enum BuildTarget : uint
private enum Internal : uint
{
ValidPlayer = 1,
/// <summary>
@ -146,29 +151,30 @@ namespace AssetRipper.IO.Files
AnyPlayer = 0xFFFFFFFF,
}
public static class PlatformExtensions
public bool IsStandalone
{
public static bool IsCompatible(this BuildTarget _this, BuildTarget comp)
get
{
return _this == comp || (_this.IsStandalone() && comp.IsStandalone());
}
public static bool IsStandalone(this BuildTarget _this)
{
switch (_this)
switch (this)
{
case BuildTarget.StandaloneWinPlayer:
case BuildTarget.StandaloneWin64Player:
case BuildTarget.StandaloneLinux:
case BuildTarget.StandaloneLinux64:
case BuildTarget.StandaloneLinuxUniversal:
case BuildTarget.StandaloneOSXIntel:
case BuildTarget.StandaloneOSXIntel64:
case BuildTarget.StandaloneOSXPPC:
case BuildTarget.StandaloneOSXUniversal:
case StandaloneWinPlayer:
case StandaloneWin64Player:
case StandaloneLinux:
case StandaloneLinux64:
case StandaloneLinuxUniversal:
case StandaloneOSXIntel:
case StandaloneOSXIntel64:
case StandaloneOSXPPC:
case StandaloneOSXUniversal:
return true;
default:
return false;
}
return false;
}
}
public bool IsCompatible(BuildTarget comp)
{
return this == comp || (IsStandalone && comp.IsStandalone);
}
}

View File

@ -32,11 +32,11 @@ namespace AssetRipper.IO.Files.BundleFiles.FileStream
{
get
{
return Flags.GetCompression();
return Flags.CompressionType;
}
private set
{
Flags = (Flags & ~StorageBlockFlags.CompressionTypeMask) | (StorageBlockFlags.CompressionTypeMask & (StorageBlockFlags)value);
Flags = Flags.WithCompressionType(value);
}
}
}

View File

@ -1,23 +1,45 @@
namespace AssetRipper.IO.Files.BundleFiles.FileStream
using AssetRipper.SmartEnums;
namespace AssetRipper.IO.Files.BundleFiles.FileStream;
[SmartEnum]
public readonly partial record struct StorageBlockFlags
{
[Flags]
public enum StorageBlockFlags
private enum Internal
{
CompressionTypeMask = 0x3F,
Streamed = 0x40,
}
public static class StorageBlockFlagsExtensions
public CompressionType CompressionType
{
public static CompressionType GetCompression(this StorageBlockFlags _this)
get
{
return (CompressionType)(_this & StorageBlockFlags.CompressionTypeMask);
}
public static bool IsStreamed(this StorageBlockFlags _this)
{
return (_this & StorageBlockFlags.Streamed) != 0;
return (CompressionType)(this & CompressionTypeMask);
}
}
public bool IsStreamed
{
get
{
return (this & Streamed) != 0;
}
}
public StorageBlockFlags WithCompressionType(CompressionType compressionType)
{
return (this & ~CompressionTypeMask) | (StorageBlockFlags)compressionType;
}
public static explicit operator StorageBlockFlags(CompressionType compressionType)
{
return (StorageBlockFlags)(int)compressionType;
}
public static explicit operator CompressionType(StorageBlockFlags flags)
{
return (CompressionType)(int)(flags);
}
}

View File

@ -0,0 +1,29 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IsAotCompatible>false</IsAotCompatible>
<AppendTargetFrameworkToOutputPath>true</AppendTargetFrameworkToOutputPath>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
<PolySharpExcludeGeneratedTypes>System.Runtime.CompilerServices.ModuleInitializerAttribute</PolySharpExcludeGeneratedTypes>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.11.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.13.0" />
<PackageReference Include="PolySharp" Version="1.15.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="SourceGenerator.Foundations" Version="2.0.13" />
</ItemGroup>
<ItemGroup>
<!-- Generator dependencies -->
<PackageReference Include="AssetRipper.Text.SourceGeneration" Version="1.2.2" PrivateAssets="all" GeneratePathProperty="true" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,108 @@
# AssetRipper.SmartEnums
User code:
```cs
[SmartEnum]
public readonly partial record struct MyEnum
{
private enum __Internal : uint
{
Value0,
Value1
}
public void CustomMethod() {}
}
```
Generated code:
```cs
readonly partial record struct MyEnum :
global::System.IParsable<MyEnum>,
global::System.Numerics.IBitwiseOperators<MyEnum, MyEnum, MyEnum>,
global::System.Numerics.IComparisonOperators<MyEnum, MyEnum, bool>,
global::System.Numerics.IEqualityOperators<MyEnum, MyEnum, bool>,
global::System.Numerics.IShiftOperators<MyEnum, int, MyEnum>
{
private readonly uint __value;
public MyEnum(uint value) => __value = value;
/// <inheritdoc cref="__Internal.Value0"/>
public const uint Value0 = (uint)__Internal.Value0;
/// <inheritdoc cref="__Internal.Value1"/>
public const uint Value1 = (uint)__Internal.Value1;
public static implicit operator uint(MyEnum value) => value.__value;
public static implicit operator MyEnum(uint value) => new(value);
public static MyEnum operator &(MyEnum left, MyEnum right) => new(left.__value & right.__value);
public static MyEnum operator |(MyEnum left, MyEnum right) => new(left.__value | right.__value);
public static MyEnum operator ^(MyEnum left, MyEnum right) => new(left.__value ^ right.__value);
public static MyEnum operator ~(MyEnum value) => new(~value.__value);
public static bool operator <(MyEnum left, MyEnum right) => left.__value < right.__value;
public static bool operator >(MyEnum left, MyEnum right) => left.__value > right.__value;
public static bool operator <=(MyEnum left, MyEnum right) => left.__value <= right.__value;
public static bool operator >=(MyEnum left, MyEnum right) => left.__value >= right.__value;
public static MyEnum operator <<(MyEnum value, int count) => new(value.__value << count);
public static MyEnum operator >>(MyEnum value, int count) => new(value.__value >> count);
public static MyEnum operator >>>(MyEnum value, int count) => new(value.__value >>> count);
public override string ToString() => __value switch
{
Value0 => nameof(Value0),
Value1 => nameof(Value1),
_ => __value.ToString(),
};
public static MyEnum Parse(string s) => Parse(s, null);
public static MyEnum Parse(string s, IFormatProvider? provider) => s switch
{
nameof(MyEnum.Value0) => MyEnum.Value0,
nameof(MyEnum.Value1) => MyEnum.Value1,
_ => uint.Parse(s, provider),
};
public static bool TryParse(string? s, out MyEnum result) => TryParse(s, null, out result);
public static bool TryParse(string? s, IFormatProvider? provider, out MyEnum result)
{
switch (s)
{
case nameof(MyEnum.Value0):
result = MyEnum.Value0;
return true;
case nameof(MyEnum.Value1):
result = MyEnum.Value1;
return true;
default:
if (uint.TryParse(s, provider, out uint value))
{
result = new(value);
return true;
}
result = default;
return false;
}
}
public static global::System.ReadOnlySpan<MyEnum> GetValues() => __ValueCache.Values;
public static global::System.ReadOnlySpan<uint> GetUnderlyingValues() =>
[
Value0,
Value1,
];
public uint GetUnderlyingValue() => __value;
}
file static class __ValueCache
{
public static MyEnum[] Values = new MyEnum[]
{
MyEnum.Value0,
MyEnum.Value1,
};
}
```

View File

@ -0,0 +1,18 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace AssetRipper.SmartEnums;
internal static class RoslynExtensions
{
public static bool IsPartial(this MemberDeclarationSyntax c)
{
return c.Modifiers.Any(SyntaxKind.PartialKeyword);
}
public static bool IsReadOnly(this MemberDeclarationSyntax c)
{
return c.Modifiers.Any(SyntaxKind.ReadOnlyKeyword);
}
}

View File

@ -0,0 +1,24 @@
using System.Collections;
namespace AssetRipper.SmartEnums;
internal readonly struct Sequence<T>(T[] values) : IEquatable<Sequence<T>>, IReadOnlyList<T>
{
public T this[int index] => Values[index];
public T[] Values { get; } = values;
public int Count => Values.Length;
public bool Equals(Sequence<T> other) => Values.SequenceEqual(other.Values);
public override bool Equals(object obj) => obj is Sequence<T> other && Equals(other);
public override int GetHashCode() => Values.Length; // Doesn't need to be good
public static implicit operator Sequence<T>(T[] values) => new(values);
public IEnumerator<T> GetEnumerator() => ((IEnumerable<T>)Values).GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => Values.GetEnumerator();
}

View File

@ -0,0 +1,230 @@
using AssetRipper.Text.SourceGeneration;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using SGF;
using System.CodeDom.Compiler;
namespace AssetRipper.SmartEnums;
[IncrementalGenerator]
public sealed class SmartEnumGenerator() : IncrementalGenerator(nameof(SmartEnumGenerator))
{
private readonly record struct SmartEnumData(string Namespace, string StructName, string EnumName, string? EnumType, Sequence<string> EnumValues)
{
public IEnumerable<string> GetAllNames()
{
yield return StructName;
yield return EnumName;
foreach (string enumValue in EnumValues)
{
yield return enumValue;
}
}
}
public override void OnInitialize(SgfInitializationContext context)
{
context.RegisterPostInitializationOutput(InjectAttribute);
IncrementalValuesProvider<SmartEnumData> valueProvider = context.SyntaxProvider.ForAttributeWithMetadataName("AssetRipper.SmartEnums.SmartEnumAttribute",
static (node, ct) =>
{
return node is RecordDeclarationSyntax s
&& s.IsPartial()
&& s.IsReadOnly()
&& s.Parent is BaseNamespaceDeclarationSyntax
&& s.ChildNodes().OfType<EnumDeclarationSyntax>().Count() == 1;
},
static (context, ct) =>
{
RecordDeclarationSyntax structDeclaration = (RecordDeclarationSyntax)context.TargetNode;
BaseNamespaceDeclarationSyntax namespaceDeclaration = (BaseNamespaceDeclarationSyntax)structDeclaration.Parent!;
EnumDeclarationSyntax enumDeclaration = structDeclaration.ChildNodes().OfType<EnumDeclarationSyntax>().Single();
string structName = structDeclaration.Identifier.Text;
string @namespace = namespaceDeclaration.Name.ToString();
string enumName = enumDeclaration.Identifier.Text;
string? enumType = enumDeclaration.BaseList?.Types.Single().ToString();
string[] enumValues = enumDeclaration.Members.OfType<EnumMemberDeclarationSyntax>().Select(e => e.Identifier.Text).ToArray();
return new SmartEnumData(@namespace, structName, enumName, enumType, enumValues);
});
context.RegisterSourceOutput(valueProvider, GenerateEnum);
}
private static void GenerateEnum(SgfSourceProductionContext context, SmartEnumData enumData)
{
StringWriter stringWriter = new() { NewLine = "\n" };
using IndentedTextWriter writer = IndentedTextWriterFactory.Create(stringWriter);
string structName = enumData.StructName;
string enumName = enumData.EnumName;
string enumType = enumData.EnumType ?? "int";
string[] enumValues = enumData.EnumValues.Values;
string valueName = GetNonConflictingName("__value", enumData.GetAllNames());
writer.WriteGeneratedCodeWarning();
writer.WriteLineNoTabs();
writer.WriteFileScopedNamespace(enumData.Namespace);
writer.WriteLineNoTabs();
writer.WriteLine($"readonly partial record struct {structName} :");
using (new Indented(writer))
{
writer.WriteLine($"global::System.IParsable<{structName}>,");
writer.WriteLine($"global::System.Numerics.IBitwiseOperators<{structName}, {structName}, {structName}>,");
writer.WriteLine($"global::System.Numerics.IComparisonOperators<{structName}, {structName}, bool>,");
writer.WriteLine($"global::System.Numerics.IEqualityOperators<{structName}, {structName}, bool>,");
writer.WriteLine($"global::System.Numerics.IShiftOperators<{structName}, int, {structName}>");
}
using (new CurlyBrackets(writer))
{
writer.WriteLine($"private readonly {enumType} {valueName};");
writer.WriteLineNoTabs();
// Constructor
writer.WriteLine($"public {structName}({enumType} value) => {valueName} = value;");
writer.WriteLineNoTabs();
foreach (string enumValue in enumValues)
{
writer.WriteInheritDocumentation($"{enumName}.{enumValue}");
writer.WriteLine($"public const {enumType} {enumValue} = ({enumType}){enumName}.{enumValue};");
writer.WriteLineNoTabs();
}
// Implicit operators
writer.WriteLine($"public static implicit operator {enumType}({structName} value) => value.{valueName};");
writer.WriteLine($"public static implicit operator {structName}({enumType} value) => new(value);");
writer.WriteLineNoTabs();
// Bitwise operators
writer.WriteLine($"public static {structName} operator &({structName} left, {structName} right) => new(left.{valueName} & right.{valueName});");
writer.WriteLine($"public static {structName} operator |({structName} left, {structName} right) => new(left.{valueName} | right.{valueName});");
writer.WriteLine($"public static {structName} operator ^({structName} left, {structName} right) => new(left.{valueName} ^ right.{valueName});");
writer.WriteLine($"public static {structName} operator ~({structName} value) => new(~value.{valueName});");
writer.WriteLineNoTabs();
// Comparison operators
writer.WriteLine($"public static bool operator <({structName} left, {structName} right) => left.{valueName} < right.{valueName};");
writer.WriteLine($"public static bool operator >({structName} left, {structName} right) => left.{valueName} > right.{valueName};");
writer.WriteLine($"public static bool operator <=({structName} left, {structName} right) => left.{valueName} <= right.{valueName};");
writer.WriteLine($"public static bool operator >=({structName} left, {structName} right) => left.{valueName} >= right.{valueName};");
writer.WriteLineNoTabs();
// Shift operators
writer.WriteLine($"public static {structName} operator <<({structName} value, int count) => new(value.{valueName} << count);");
writer.WriteLine($"public static {structName} operator >>({structName} value, int count) => new(value.{valueName} >> count);");
writer.WriteLine($"public static {structName} operator >>>({structName} value, int count) => new(value.{valueName} >>> count);");
writer.WriteLineNoTabs();
// ToString
writer.WriteLine($"public override string ToString() => {valueName} switch");
using (new CurlyBracketsWithSemicolon(writer))
{
foreach (string enumValue in enumValues)
{
writer.WriteLine($"{enumValue} => nameof({enumValue}),");
}
writer.WriteLine($"_ => {valueName}.ToString(),");
}
writer.WriteLineNoTabs();
// IParsable
writer.WriteLine($"public static {structName} Parse(string s) => Parse(s, null);");
writer.WriteLine($"public static {structName} Parse(string s, IFormatProvider? provider) => s switch");
using (new CurlyBracketsWithSemicolon(writer))
{
foreach (string enumValue in enumValues)
{
writer.WriteLine($"nameof({structName}.{enumValue}) => {structName}.{enumValue},");
}
writer.WriteLine($"_ => {enumType}.Parse(s, provider),");
}
writer.WriteLine($"public static bool TryParse(string? s, out {structName} result) => TryParse(s, null, out result);");
writer.WriteLine($"public static bool TryParse(string? s, IFormatProvider? provider, out {structName} result)");
using (new CurlyBrackets(writer))
{
writer.WriteLine("switch (s)");
using (new CurlyBrackets(writer))
{
foreach (string enumValue in enumValues)
{
writer.WriteLine($"case nameof({structName}.{enumValue}):");
using (new Indented(writer))
{
writer.WriteLine($"result = {structName}.{enumValue};");
writer.WriteLine("return true;");
}
}
writer.WriteLine("default:");
using (new Indented(writer))
{
writer.WriteLine($"if ({enumType}.TryParse(s, provider, out {enumType} value))");
using (new CurlyBrackets(writer))
{
writer.WriteLine("result = new(value);");
writer.WriteLine("return true;");
}
writer.WriteLine("result = default;");
writer.WriteLine("return false;");
}
}
}
writer.WriteLineNoTabs();
writer.WriteLine($"public static global::System.ReadOnlySpan<{structName}> GetValues() => __ValueCache.Values;");
writer.WriteLine($"public static global::System.ReadOnlySpan<{enumType}> GetUnderlyingValues() =>");
writer.WriteLine('[');
using (new Indented(writer))
{
foreach (string enumValue in enumValues)
{
writer.WriteLine($"{enumValue},");
}
}
writer.WriteLine("];");
writer.WriteLine($"public {enumType} GetUnderlyingValue() => {valueName};");
}
writer.WriteLine("file static class __ValueCache");
using (new CurlyBrackets(writer))
{
writer.WriteLine($"public static {structName}[] Values = new {structName}[]");
using (new CurlyBracketsWithSemicolon(writer))
{
foreach (string enumValue in enumValues)
{
writer.WriteLine($"{structName}.{enumValue},");
}
}
}
context.AddSource($"{structName}.cs", stringWriter.ToString());
}
private static void InjectAttribute(IncrementalGeneratorPostInitializationContext context)
{
context.AddSource("SmartEnumAttribute.cs", """
namespace AssetRipper.SmartEnums;
//[global::Microsoft.CodeAnalysis.Embedded]
[global::System.AttributeUsage(global::System.AttributeTargets.Struct)]
internal sealed class SmartEnumAttribute : global::System.Attribute
{
}
""");
//context.AddEmbeddedAttributeDefinition(); //To do: this api isn't available yet
//https://github.com/dotnet/roslyn/blob/main/docs/features/incremental-generators.cookbook.md#put-microsoftcodeanalysisembeddedattribute-on-generated-marker-types
}
private static string GetNonConflictingName(string name, IEnumerable<string> existingNames)
{
while (existingNames.Contains(name))
{
name = "_" + name;
}
return name;
}
}