diff --git a/Source/AssetRipper.IO.Files.SourceGenerator/Program.cs b/Source/AssetRipper.IO.Files.SourceGenerator/Program.cs index ba95cae3b..eb0c024b9 100644 --- a/Source/AssetRipper.IO.Files.SourceGenerator/Program.cs +++ b/Source/AssetRipper.IO.Files.SourceGenerator/Program.cs @@ -46,7 +46,7 @@ internal static class Program { string virtualKeyword = api.Type is FileSystemApiType.Sealed ? "" : "virtual "; string parametersWithTypes = string.Join(", ", api.Parameters.Select(parameter => $"{parameter.Item1} {parameter.Item2}")); - writer.WriteLine($"public {virtualKeyword}{api.ReturnType} {api.Name}({parametersWithTypes})"); + writer.WriteLine($"public {virtualKeyword}{api.BaseReturnType} {api.Name}({parametersWithTypes})"); using (new CurlyBrackets(writer)) { if (api.Type is FileSystemApiType.Throw) @@ -97,7 +97,7 @@ internal static class Program } string parametersWithTypes = string.Join(", ", api.Parameters.Select(parameter => $"{parameter.Item1} {parameter.Item2}")); - writer.WriteLine($"public override {api.ReturnType} {api.Name}({parametersWithTypes})"); + writer.WriteLine($"public override {api.DerivedReturnType} {api.Name}({parametersWithTypes})"); using (new CurlyBrackets(writer)) { string returnKeyword = api.VoidReturn ? "" : "return "; @@ -157,6 +157,7 @@ internal static class Program { [nameof(File)] = new() { + new((Func)File.Create), new(File.Delete), new(File.Exists), new(File.OpenRead), @@ -206,7 +207,15 @@ internal static class Program public MethodInfo MethodInfo => Delegate.Method; public FileSystemApiType Type { get; init; } = FileSystemApiType.Throw; public string DeclaringType => MethodInfo.DeclaringType!.GetGlobalQualifiedName(); - public string ReturnType => MethodInfo.ReturnType.GetGlobalQualifiedName(); + public string BaseReturnType + { + get + { + Type returnType = MethodInfo.ReturnType == typeof(FileStream) ? typeof(Stream) : MethodInfo.ReturnType; + return returnType.GetGlobalQualifiedName(); + } + } + public string DerivedReturnType => MethodInfo.ReturnType.GetGlobalQualifiedName(); public bool VoidReturn => MethodInfo.ReturnType == typeof(void); public string Name => MethodInfo.Name; public string FullName => $"{DeclaringType}.{Name}"; diff --git a/Source/AssetRipper.IO.Files.Tests/VirtualFileSystemTests.cs b/Source/AssetRipper.IO.Files.Tests/VirtualFileSystemTests.cs index f3f8b1d65..5ef867f79 100644 --- a/Source/AssetRipper.IO.Files.Tests/VirtualFileSystemTests.cs +++ b/Source/AssetRipper.IO.Files.Tests/VirtualFileSystemTests.cs @@ -109,7 +109,7 @@ public class VirtualFileSystemTests { Assert.That(fs.Directory.Exists("/test"), Is.True); Assert.That(fs.Count, Is.EqualTo(2));// root and test - }); + }); } [Test] @@ -142,9 +142,9 @@ public class VirtualFileSystemTests } [Test] - public void CreatingFileTwiceThrows() + public void CreatingFileTwiceSucceeds() { - Assert.Throws(() => + Assert.DoesNotThrow(() => { VirtualFileSystem fs = new(); fs.Directory.Create("/test"); @@ -170,4 +170,26 @@ public class VirtualFileSystemTests Stream stream = fs.File.Create("/test"); Assert.That(stream.Length, Is.Zero); } + + [Test] + public void ReadWriteTextParity() + { + VirtualFileSystem fs = new(); + string path = "/test"; + string contents = "Hello, world!"; + fs.File.WriteAllText(path, contents); + string readContents = fs.File.ReadAllText(path); + Assert.That(readContents, Is.EqualTo(contents)); + } + + [Test] + public void ReadWriteBytesParity() + { + VirtualFileSystem fs = new(); + string path = "/test"; + byte[] bytes = [0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x20, 0x77, 0x6F, 0x72, 0x6C, 0x64, 0x21]; + fs.File.WriteAllBytes(path, bytes); + byte[] readBytes = fs.File.ReadAllBytes(path); + Assert.That(readBytes, Is.EqualTo(bytes)); + } } diff --git a/Source/AssetRipper.IO.Files/FileSystem.cs b/Source/AssetRipper.IO.Files/FileSystem.cs index 25935c1e5..c2f833532 100644 --- a/Source/AssetRipper.IO.Files/FileSystem.cs +++ b/Source/AssetRipper.IO.Files/FileSystem.cs @@ -13,7 +13,6 @@ public partial class FileSystem public partial class FileImplementation { - public virtual Stream Create(string path) => throw new NotSupportedException(); } public partial class DirectoryImplementation diff --git a/Source/AssetRipper.IO.Files/FileSystem.g.cs b/Source/AssetRipper.IO.Files/FileSystem.g.cs index 73b5e24a5..49d3f1f68 100644 --- a/Source/AssetRipper.IO.Files/FileSystem.g.cs +++ b/Source/AssetRipper.IO.Files/FileSystem.g.cs @@ -10,6 +10,11 @@ public abstract partial class FileSystem { protected DirectoryImplementation Directory => fileSystem.Directory; protected PathImplementation Path => fileSystem.Path; + public virtual global::System.IO.Stream Create(global::System.String path) + { + throw new global::System.NotSupportedException(); + } + public virtual void Delete(global::System.String path) { throw new global::System.NotSupportedException(); @@ -20,12 +25,12 @@ public abstract partial class FileSystem throw new global::System.NotSupportedException(); } - public virtual global::System.IO.FileStream OpenRead(global::System.String path) + public virtual global::System.IO.Stream OpenRead(global::System.String path) { throw new global::System.NotSupportedException(); } - public virtual global::System.IO.FileStream OpenWrite(global::System.String path) + public virtual global::System.IO.Stream OpenWrite(global::System.String path) { throw new global::System.NotSupportedException(); } diff --git a/Source/AssetRipper.IO.Files/LocalFileSystem.cs b/Source/AssetRipper.IO.Files/LocalFileSystem.cs index 65e1ecc88..7963b3cb6 100644 --- a/Source/AssetRipper.IO.Files/LocalFileSystem.cs +++ b/Source/AssetRipper.IO.Files/LocalFileSystem.cs @@ -6,7 +6,6 @@ public partial class LocalFileSystem : FileSystem public partial class LocalFileImplementation { - public override Stream Create(string path) => System.IO.File.Create(path); } public partial class LocalDirectoryImplementation diff --git a/Source/AssetRipper.IO.Files/LocalFileSystem.g.cs b/Source/AssetRipper.IO.Files/LocalFileSystem.g.cs index cd9b6e858..758c34466 100644 --- a/Source/AssetRipper.IO.Files/LocalFileSystem.g.cs +++ b/Source/AssetRipper.IO.Files/LocalFileSystem.g.cs @@ -8,6 +8,11 @@ public sealed partial class LocalFileSystem : FileSystem public sealed partial class LocalFileImplementation(LocalFileSystem fileSystem) : FileImplementation(fileSystem) { + public override global::System.IO.FileStream Create(global::System.String path) + { + return global::System.IO.File.Create(path); + } + public override void Delete(global::System.String path) { global::System.IO.File.Delete(path); diff --git a/Source/AssetRipper.IO.Files/VirtualFileSystem.cs b/Source/AssetRipper.IO.Files/VirtualFileSystem.cs index 70f7a6b28..6b163e2f7 100644 --- a/Source/AssetRipper.IO.Files/VirtualFileSystem.cs +++ b/Source/AssetRipper.IO.Files/VirtualFileSystem.cs @@ -1,5 +1,6 @@ - -using AssetRipper.IO.Files.Streams.Smart; +using AssetRipper.IO.Files.Streams.Smart; +using System.Buffers; +using System.Diagnostics; using System.Text; namespace AssetRipper.IO.Files; @@ -31,24 +32,83 @@ public partial class VirtualFileSystem : FileSystem public partial class VirtualFileImplementation { - public override Stream Create(string path) + public override SmartStream Create(string path) { string directory = fileSystem.GetFullDirectoryName(path); + string fullPath = Path.GetFullPath(path); if (!fileSystem.directories.Contains(directory)) { throw new DirectoryNotFoundException($"Directory '{directory}' not found."); } - if (fileSystem.files.ContainsKey(path)) + if (!fileSystem.files.TryGetValue(fullPath, out SmartStream? stream)) { - throw new IOException($"File '{path}' already exists."); + stream = SmartStream.CreateMemory(); + fileSystem.files.Add(fullPath, stream); + } + else + { + stream.SetLength(0); } - SmartStream stream = SmartStream.CreateMemory(); - fileSystem.files.Add(path, stream); return stream.CreateReference(); } + public SmartStream Open(string path) + { + string directory = fileSystem.GetFullDirectoryName(path); + string fullPath = Path.GetFullPath(path); + if (!fileSystem.directories.Contains(directory)) + { + throw new DirectoryNotFoundException($"Directory '{directory}' not found."); + } + if (!fileSystem.files.TryGetValue(fullPath, out SmartStream? stream)) + { + throw new FileNotFoundException($"File '{path}' not found."); + } + return stream.CreateReference(); + } + public override SmartStream OpenRead(string path) => Open(path); + public override SmartStream OpenWrite(string path) => Open(path); + public override void Delete(string path) + { + string directory = fileSystem.GetFullDirectoryName(path); + string fullPath = Path.GetFullPath(path); + if (!fileSystem.directories.Contains(directory)) + { + throw new DirectoryNotFoundException($"Directory '{directory}' not found."); + } + if (fileSystem.files.Remove(fullPath, out SmartStream? stream)) + { + stream.Dispose(); + } + } public override bool Exists(string path) => fileSystem.files.ContainsKey(path); public override string ReadAllText(string path) => ReadAllText(path, Encoding.UTF8); + public override string ReadAllText(string path, Encoding encoding) => encoding.GetString(ReadAllBytes(path)); public override void WriteAllText(string path, ReadOnlySpan contents) => WriteAllText(path, contents, Encoding.UTF8); + public override void WriteAllText(string path, ReadOnlySpan contents, Encoding encoding) + { + int byteCount = encoding.GetByteCount(contents); + byte[] array = ArrayPool.Shared.Rent(byteCount); + Span span = array.AsSpan(0, byteCount); + int bytesWritten = encoding.GetBytes(contents, span); + Debug.Assert(bytesWritten == byteCount); + WriteAllBytes(path, span); + ArrayPool.Shared.Return(array); + } + public override byte[] ReadAllBytes(string path) + { + using SmartStream stream = Open(path); + byte[] buffer = new byte[stream.Length]; + stream.Position = 0; + stream.ReadExactly(buffer); + return buffer; + } + public override void WriteAllBytes(string path, ReadOnlySpan bytes) + { + using SmartStream stream = Create(path); + stream.SetLength(bytes.Length); + stream.Position = 0; + stream.Write(bytes); + } } public partial class VirtualDirectoryImplementation @@ -70,6 +130,16 @@ public partial class VirtualFileSystem : FileSystem } } + public override IEnumerable EnumerateDirectories(string path, string searchPattern, SearchOption searchOption) + { + throw new NotImplementedException(); + } + + public override IEnumerable EnumerateFiles(string path, string searchPattern, SearchOption searchOption) + { + throw new NotImplementedException(); + } + public override bool Exists(string? path) => fileSystem.directories.Contains(GetFullPath(path)); private string GetFullPath(string? path) => Path.GetFullPath(path);