diff --git a/Source/AssetRipper.IO.Files/AssetRipper.IO.Files.csproj b/Source/AssetRipper.IO.Files/AssetRipper.IO.Files.csproj index 0b06221df..f1769715a 100644 --- a/Source/AssetRipper.IO.Files/AssetRipper.IO.Files.csproj +++ b/Source/AssetRipper.IO.Files/AssetRipper.IO.Files.csproj @@ -1,9 +1,10 @@ - + true ..\0Bins\Other\AssetRipper.IO.Files\$(Configuration)\ ..\0Bins\obj\AssetRipper.IO.Files\$(Configuration)\ + true @@ -14,4 +15,8 @@ + + + + diff --git a/Source/AssetRipper.IO.Files/BuildTarget.cs b/Source/AssetRipper.IO.Files/BuildTarget.cs index d0f6984e8..2a04571c1 100644 --- a/Source/AssetRipper.IO.Files/BuildTarget.cs +++ b/Source/AssetRipper.IO.Files/BuildTarget.cs @@ -1,9 +1,14 @@ -namespace AssetRipper.IO.Files +using AssetRipper.SmartEnums; + +namespace AssetRipper.IO.Files; + +[SmartEnum] +public readonly partial record struct BuildTarget { /// /// /// - public enum BuildTarget : uint + private enum Internal : uint { ValidPlayer = 1, /// @@ -146,29 +151,30 @@ namespace AssetRipper.IO.Files AnyPlayer = 0xFFFFFFFF, } - public static class PlatformExtensions + public bool IsStandalone { - public static bool IsCompatible(this BuildTarget _this, BuildTarget comp) + get { - return _this == comp || (_this.IsStandalone() && comp.IsStandalone()); - } - - public static bool IsStandalone(this BuildTarget _this) - { - switch (_this) + switch (this) { - case BuildTarget.StandaloneWinPlayer: - case BuildTarget.StandaloneWin64Player: - case BuildTarget.StandaloneLinux: - case BuildTarget.StandaloneLinux64: - case BuildTarget.StandaloneLinuxUniversal: - case BuildTarget.StandaloneOSXIntel: - case BuildTarget.StandaloneOSXIntel64: - case BuildTarget.StandaloneOSXPPC: - case BuildTarget.StandaloneOSXUniversal: + case StandaloneWinPlayer: + case StandaloneWin64Player: + case StandaloneLinux: + case StandaloneLinux64: + case StandaloneLinuxUniversal: + case StandaloneOSXIntel: + case StandaloneOSXIntel64: + case StandaloneOSXPPC: + case StandaloneOSXUniversal: return true; + default: + return false; } - return false; } } + + public bool IsCompatible(BuildTarget comp) + { + return this == comp || (IsStandalone && comp.IsStandalone); + } } diff --git a/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlock.cs b/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlock.cs index fb8bdc58b..c91e7457c 100644 --- a/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlock.cs +++ b/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlock.cs @@ -32,11 +32,11 @@ namespace AssetRipper.IO.Files.BundleFiles.FileStream { get { - return Flags.GetCompression(); + return Flags.CompressionType; } private set { - Flags = (Flags & ~StorageBlockFlags.CompressionTypeMask) | (StorageBlockFlags.CompressionTypeMask & (StorageBlockFlags)value); + Flags = Flags.WithCompressionType(value); } } } diff --git a/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlockFlags.cs b/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlockFlags.cs index 34e7f2766..14858191b 100644 --- a/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlockFlags.cs +++ b/Source/AssetRipper.IO.Files/BundleFiles/FileStream/StorageBlockFlags.cs @@ -1,23 +1,45 @@ -namespace AssetRipper.IO.Files.BundleFiles.FileStream +using AssetRipper.SmartEnums; + +namespace AssetRipper.IO.Files.BundleFiles.FileStream; + +[SmartEnum] +public readonly partial record struct StorageBlockFlags { - [Flags] - public enum StorageBlockFlags + private enum Internal { CompressionTypeMask = 0x3F, Streamed = 0x40, } - public static class StorageBlockFlagsExtensions + public CompressionType CompressionType { - public static CompressionType GetCompression(this StorageBlockFlags _this) + get { - return (CompressionType)(_this & StorageBlockFlags.CompressionTypeMask); - } - - public static bool IsStreamed(this StorageBlockFlags _this) - { - return (_this & StorageBlockFlags.Streamed) != 0; + return (CompressionType)(this & CompressionTypeMask); } } + + public bool IsStreamed + { + get + { + return (this & Streamed) != 0; + } + } + + public StorageBlockFlags WithCompressionType(CompressionType compressionType) + { + return (this & ~CompressionTypeMask) | (StorageBlockFlags)compressionType; + } + + public static explicit operator StorageBlockFlags(CompressionType compressionType) + { + return (StorageBlockFlags)(int)compressionType; + } + + public static explicit operator CompressionType(StorageBlockFlags flags) + { + return (CompressionType)(int)(flags); + } } diff --git a/Source/AssetRipper.SmartEnums/AssetRipper.SmartEnums.csproj b/Source/AssetRipper.SmartEnums/AssetRipper.SmartEnums.csproj new file mode 100644 index 000000000..d833fda3f --- /dev/null +++ b/Source/AssetRipper.SmartEnums/AssetRipper.SmartEnums.csproj @@ -0,0 +1,29 @@ + + + + netstandard2.0 + false + true + true + System.Runtime.CompilerServices.ModuleInitializerAttribute + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + diff --git a/Source/AssetRipper.SmartEnums/README.md b/Source/AssetRipper.SmartEnums/README.md new file mode 100644 index 000000000..621cc8fcf --- /dev/null +++ b/Source/AssetRipper.SmartEnums/README.md @@ -0,0 +1,108 @@ +# AssetRipper.SmartEnums + +User code: + +```cs +[SmartEnum] +public readonly partial record struct MyEnum +{ + private enum __Internal : uint + { + Value0, + Value1 + } + + public void CustomMethod() {} +} +``` + +Generated code: + +```cs +readonly partial record struct MyEnum : + global::System.IParsable, + global::System.Numerics.IBitwiseOperators, + global::System.Numerics.IComparisonOperators, + global::System.Numerics.IEqualityOperators, + global::System.Numerics.IShiftOperators +{ + private readonly uint __value; + + public MyEnum(uint value) => __value = value; + + /// + public const uint Value0 = (uint)__Internal.Value0; + + /// + public const uint Value1 = (uint)__Internal.Value1; + + public static implicit operator uint(MyEnum value) => value.__value; + public static implicit operator MyEnum(uint value) => new(value); + + public static MyEnum operator &(MyEnum left, MyEnum right) => new(left.__value & right.__value); + public static MyEnum operator |(MyEnum left, MyEnum right) => new(left.__value | right.__value); + public static MyEnum operator ^(MyEnum left, MyEnum right) => new(left.__value ^ right.__value); + public static MyEnum operator ~(MyEnum value) => new(~value.__value); + + public static bool operator <(MyEnum left, MyEnum right) => left.__value < right.__value; + public static bool operator >(MyEnum left, MyEnum right) => left.__value > right.__value; + public static bool operator <=(MyEnum left, MyEnum right) => left.__value <= right.__value; + public static bool operator >=(MyEnum left, MyEnum right) => left.__value >= right.__value; + + public static MyEnum operator <<(MyEnum value, int count) => new(value.__value << count); + public static MyEnum operator >>(MyEnum value, int count) => new(value.__value >> count); + public static MyEnum operator >>>(MyEnum value, int count) => new(value.__value >>> count); + + public override string ToString() => __value switch + { + Value0 => nameof(Value0), + Value1 => nameof(Value1), + _ => __value.ToString(), + }; + + public static MyEnum Parse(string s) => Parse(s, null); + public static MyEnum Parse(string s, IFormatProvider? provider) => s switch + { + nameof(MyEnum.Value0) => MyEnum.Value0, + nameof(MyEnum.Value1) => MyEnum.Value1, + _ => uint.Parse(s, provider), + }; + public static bool TryParse(string? s, out MyEnum result) => TryParse(s, null, out result); + public static bool TryParse(string? s, IFormatProvider? provider, out MyEnum result) + { + switch (s) + { + case nameof(MyEnum.Value0): + result = MyEnum.Value0; + return true; + case nameof(MyEnum.Value1): + result = MyEnum.Value1; + return true; + default: + if (uint.TryParse(s, provider, out uint value)) + { + result = new(value); + return true; + } + result = default; + return false; + } + } + + public static global::System.ReadOnlySpan GetValues() => __ValueCache.Values; + public static global::System.ReadOnlySpan GetUnderlyingValues() => + [ + Value0, + Value1, + ]; + public uint GetUnderlyingValue() => __value; +} +file static class __ValueCache +{ + public static MyEnum[] Values = new MyEnum[] + { + MyEnum.Value0, + MyEnum.Value1, + }; +} +``` \ No newline at end of file diff --git a/Source/AssetRipper.SmartEnums/RoslynExtensions.cs b/Source/AssetRipper.SmartEnums/RoslynExtensions.cs new file mode 100644 index 000000000..ba50b30c2 --- /dev/null +++ b/Source/AssetRipper.SmartEnums/RoslynExtensions.cs @@ -0,0 +1,18 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace AssetRipper.SmartEnums; + +internal static class RoslynExtensions +{ + public static bool IsPartial(this MemberDeclarationSyntax c) + { + return c.Modifiers.Any(SyntaxKind.PartialKeyword); + } + + public static bool IsReadOnly(this MemberDeclarationSyntax c) + { + return c.Modifiers.Any(SyntaxKind.ReadOnlyKeyword); + } +} diff --git a/Source/AssetRipper.SmartEnums/Sequence.cs b/Source/AssetRipper.SmartEnums/Sequence.cs new file mode 100644 index 000000000..1b0d0ac5c --- /dev/null +++ b/Source/AssetRipper.SmartEnums/Sequence.cs @@ -0,0 +1,24 @@ +using System.Collections; + +namespace AssetRipper.SmartEnums; + +internal readonly struct Sequence(T[] values) : IEquatable>, IReadOnlyList +{ + public T this[int index] => Values[index]; + + public T[] Values { get; } = values; + + public int Count => Values.Length; + + public bool Equals(Sequence other) => Values.SequenceEqual(other.Values); + + public override bool Equals(object obj) => obj is Sequence other && Equals(other); + + public override int GetHashCode() => Values.Length; // Doesn't need to be good + + public static implicit operator Sequence(T[] values) => new(values); + + public IEnumerator GetEnumerator() => ((IEnumerable)Values).GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => Values.GetEnumerator(); +} diff --git a/Source/AssetRipper.SmartEnums/SmartEnumGenerator.cs b/Source/AssetRipper.SmartEnums/SmartEnumGenerator.cs new file mode 100644 index 000000000..87f5d97b8 --- /dev/null +++ b/Source/AssetRipper.SmartEnums/SmartEnumGenerator.cs @@ -0,0 +1,230 @@ +using AssetRipper.Text.SourceGeneration; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using SGF; +using System.CodeDom.Compiler; + +namespace AssetRipper.SmartEnums; + +[IncrementalGenerator] +public sealed class SmartEnumGenerator() : IncrementalGenerator(nameof(SmartEnumGenerator)) +{ + private readonly record struct SmartEnumData(string Namespace, string StructName, string EnumName, string? EnumType, Sequence EnumValues) + { + public IEnumerable GetAllNames() + { + yield return StructName; + yield return EnumName; + foreach (string enumValue in EnumValues) + { + yield return enumValue; + } + } + } + + public override void OnInitialize(SgfInitializationContext context) + { + context.RegisterPostInitializationOutput(InjectAttribute); + + IncrementalValuesProvider valueProvider = context.SyntaxProvider.ForAttributeWithMetadataName("AssetRipper.SmartEnums.SmartEnumAttribute", + static (node, ct) => + { + return node is RecordDeclarationSyntax s + && s.IsPartial() + && s.IsReadOnly() + && s.Parent is BaseNamespaceDeclarationSyntax + && s.ChildNodes().OfType().Count() == 1; + }, + static (context, ct) => + { + RecordDeclarationSyntax structDeclaration = (RecordDeclarationSyntax)context.TargetNode; + BaseNamespaceDeclarationSyntax namespaceDeclaration = (BaseNamespaceDeclarationSyntax)structDeclaration.Parent!; + EnumDeclarationSyntax enumDeclaration = structDeclaration.ChildNodes().OfType().Single(); + + string structName = structDeclaration.Identifier.Text; + string @namespace = namespaceDeclaration.Name.ToString(); + string enumName = enumDeclaration.Identifier.Text; + string? enumType = enumDeclaration.BaseList?.Types.Single().ToString(); + string[] enumValues = enumDeclaration.Members.OfType().Select(e => e.Identifier.Text).ToArray(); + + return new SmartEnumData(@namespace, structName, enumName, enumType, enumValues); + }); + + context.RegisterSourceOutput(valueProvider, GenerateEnum); + } + + private static void GenerateEnum(SgfSourceProductionContext context, SmartEnumData enumData) + { + StringWriter stringWriter = new() { NewLine = "\n" }; + using IndentedTextWriter writer = IndentedTextWriterFactory.Create(stringWriter); + + string structName = enumData.StructName; + string enumName = enumData.EnumName; + string enumType = enumData.EnumType ?? "int"; + string[] enumValues = enumData.EnumValues.Values; + + string valueName = GetNonConflictingName("__value", enumData.GetAllNames()); + + writer.WriteGeneratedCodeWarning(); + writer.WriteLineNoTabs(); + writer.WriteFileScopedNamespace(enumData.Namespace); + writer.WriteLineNoTabs(); + writer.WriteLine($"readonly partial record struct {structName} :"); + using (new Indented(writer)) + { + writer.WriteLine($"global::System.IParsable<{structName}>,"); + writer.WriteLine($"global::System.Numerics.IBitwiseOperators<{structName}, {structName}, {structName}>,"); + writer.WriteLine($"global::System.Numerics.IComparisonOperators<{structName}, {structName}, bool>,"); + writer.WriteLine($"global::System.Numerics.IEqualityOperators<{structName}, {structName}, bool>,"); + writer.WriteLine($"global::System.Numerics.IShiftOperators<{structName}, int, {structName}>"); + } + using (new CurlyBrackets(writer)) + { + writer.WriteLine($"private readonly {enumType} {valueName};"); + writer.WriteLineNoTabs(); + + // Constructor + writer.WriteLine($"public {structName}({enumType} value) => {valueName} = value;"); + writer.WriteLineNoTabs(); + + foreach (string enumValue in enumValues) + { + writer.WriteInheritDocumentation($"{enumName}.{enumValue}"); + writer.WriteLine($"public const {enumType} {enumValue} = ({enumType}){enumName}.{enumValue};"); + writer.WriteLineNoTabs(); + } + + // Implicit operators + writer.WriteLine($"public static implicit operator {enumType}({structName} value) => value.{valueName};"); + writer.WriteLine($"public static implicit operator {structName}({enumType} value) => new(value);"); + writer.WriteLineNoTabs(); + + // Bitwise operators + writer.WriteLine($"public static {structName} operator &({structName} left, {structName} right) => new(left.{valueName} & right.{valueName});"); + writer.WriteLine($"public static {structName} operator |({structName} left, {structName} right) => new(left.{valueName} | right.{valueName});"); + writer.WriteLine($"public static {structName} operator ^({structName} left, {structName} right) => new(left.{valueName} ^ right.{valueName});"); + writer.WriteLine($"public static {structName} operator ~({structName} value) => new(~value.{valueName});"); + writer.WriteLineNoTabs(); + + // Comparison operators + writer.WriteLine($"public static bool operator <({structName} left, {structName} right) => left.{valueName} < right.{valueName};"); + writer.WriteLine($"public static bool operator >({structName} left, {structName} right) => left.{valueName} > right.{valueName};"); + writer.WriteLine($"public static bool operator <=({structName} left, {structName} right) => left.{valueName} <= right.{valueName};"); + writer.WriteLine($"public static bool operator >=({structName} left, {structName} right) => left.{valueName} >= right.{valueName};"); + writer.WriteLineNoTabs(); + + // Shift operators + writer.WriteLine($"public static {structName} operator <<({structName} value, int count) => new(value.{valueName} << count);"); + writer.WriteLine($"public static {structName} operator >>({structName} value, int count) => new(value.{valueName} >> count);"); + writer.WriteLine($"public static {structName} operator >>>({structName} value, int count) => new(value.{valueName} >>> count);"); + writer.WriteLineNoTabs(); + + // ToString + writer.WriteLine($"public override string ToString() => {valueName} switch"); + using (new CurlyBracketsWithSemicolon(writer)) + { + foreach (string enumValue in enumValues) + { + writer.WriteLine($"{enumValue} => nameof({enumValue}),"); + } + writer.WriteLine($"_ => {valueName}.ToString(),"); + } + writer.WriteLineNoTabs(); + + // IParsable + writer.WriteLine($"public static {structName} Parse(string s) => Parse(s, null);"); + writer.WriteLine($"public static {structName} Parse(string s, IFormatProvider? provider) => s switch"); + using (new CurlyBracketsWithSemicolon(writer)) + { + foreach (string enumValue in enumValues) + { + writer.WriteLine($"nameof({structName}.{enumValue}) => {structName}.{enumValue},"); + } + writer.WriteLine($"_ => {enumType}.Parse(s, provider),"); + } + writer.WriteLine($"public static bool TryParse(string? s, out {structName} result) => TryParse(s, null, out result);"); + writer.WriteLine($"public static bool TryParse(string? s, IFormatProvider? provider, out {structName} result)"); + using (new CurlyBrackets(writer)) + { + writer.WriteLine("switch (s)"); + using (new CurlyBrackets(writer)) + { + foreach (string enumValue in enumValues) + { + writer.WriteLine($"case nameof({structName}.{enumValue}):"); + using (new Indented(writer)) + { + writer.WriteLine($"result = {structName}.{enumValue};"); + writer.WriteLine("return true;"); + } + } + writer.WriteLine("default:"); + using (new Indented(writer)) + { + writer.WriteLine($"if ({enumType}.TryParse(s, provider, out {enumType} value))"); + using (new CurlyBrackets(writer)) + { + writer.WriteLine("result = new(value);"); + writer.WriteLine("return true;"); + } + writer.WriteLine("result = default;"); + writer.WriteLine("return false;"); + } + } + } + writer.WriteLineNoTabs(); + + writer.WriteLine($"public static global::System.ReadOnlySpan<{structName}> GetValues() => __ValueCache.Values;"); + writer.WriteLine($"public static global::System.ReadOnlySpan<{enumType}> GetUnderlyingValues() =>"); + writer.WriteLine('['); + using (new Indented(writer)) + { + foreach (string enumValue in enumValues) + { + writer.WriteLine($"{enumValue},"); + } + } + writer.WriteLine("];"); + writer.WriteLine($"public {enumType} GetUnderlyingValue() => {valueName};"); + } + writer.WriteLine("file static class __ValueCache"); + using (new CurlyBrackets(writer)) + { + writer.WriteLine($"public static {structName}[] Values = new {structName}[]"); + using (new CurlyBracketsWithSemicolon(writer)) + { + foreach (string enumValue in enumValues) + { + writer.WriteLine($"{structName}.{enumValue},"); + } + } + } + + context.AddSource($"{structName}.cs", stringWriter.ToString()); + } + + private static void InjectAttribute(IncrementalGeneratorPostInitializationContext context) + { + context.AddSource("SmartEnumAttribute.cs", """ + namespace AssetRipper.SmartEnums; + + //[global::Microsoft.CodeAnalysis.Embedded] + [global::System.AttributeUsage(global::System.AttributeTargets.Struct)] + internal sealed class SmartEnumAttribute : global::System.Attribute + { + } + """); + + //context.AddEmbeddedAttributeDefinition(); //To do: this api isn't available yet + //https://github.com/dotnet/roslyn/blob/main/docs/features/incremental-generators.cookbook.md#put-microsoftcodeanalysisembeddedattribute-on-generated-marker-types + } + + private static string GetNonConflictingName(string name, IEnumerable existingNames) + { + while (existingNames.Contains(name)) + { + name = "_" + name; + } + return name; + } +}