diff --git a/PluralKit.Bot/Init.cs b/PluralKit.Bot/Init.cs index 7e72b3fe..8de838e3 100644 --- a/PluralKit.Bot/Init.cs +++ b/PluralKit.Bot/Init.cs @@ -35,7 +35,7 @@ namespace PluralKit.Bot // "Connect to the database" (ie. set off database migrations and ensure state) logger.Information("Connecting to database"); - await services.Resolve().InitializeDatabase(); + await services.Resolve().ApplyMigrations(); // Init the bot instance itself, register handlers and such to the client before beginning to connect logger.Information("Initializing bot"); diff --git a/PluralKit.Core/Database/Database.cs b/PluralKit.Core/Database/Database.cs index ae669f9d..d0f1975f 100644 --- a/PluralKit.Core/Database/Database.cs +++ b/PluralKit.Core/Database/Database.cs @@ -1,5 +1,6 @@ using System; using System.Data; +using System.IO; using System.Threading.Tasks; using App.Metrics; @@ -16,6 +17,9 @@ namespace PluralKit.Core { internal class Database: IDatabase { + private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files + private const int TargetSchemaVersion = 7; + private readonly CoreConfig _config; private readonly ILogger _logger; private readonly IMetrics _metrics; @@ -52,7 +56,7 @@ namespace PluralKit.Core NpgsqlConnection.GlobalTypeMapper.MapComposite("proxy_tag"); NpgsqlConnection.GlobalTypeMapper.MapEnum("privacy_level"); } - + public async Task Obtain() { // Mark the request (for a handle, I guess) in the metrics @@ -64,6 +68,69 @@ namespace PluralKit.Core await conn.OpenAsync(); return conn; } + + public async Task ApplyMigrations() + { + // Run everything in a transaction + await using var conn = await Obtain(); + await using var tx = await conn.BeginTransactionAsync(); + + // Before applying migrations, clean out views/functions to prevent type errors + await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx); + + // Apply all migrations between the current database version and the target version + await ApplyMigrations(conn, tx); + + // Now, reapply views/functions (we deleted them above, no need to worry about conflicts) + await ExecuteSqlFile($"{RootPath}.views.sql", conn, tx); + await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx); + + // Finally, commit tx + await tx.CommitAsync(); + } + + private async Task ApplyMigrations(IPKConnection conn, IDbTransaction tx) + { + var currentVersion = await GetCurrentDatabaseVersion(conn); + _logger.Information("Current schema version: {CurrentVersion}", currentVersion); + for (var migration = currentVersion + 1; migration <= TargetSchemaVersion; migration++) + { + _logger.Information("Applying schema migration {MigrationId}", migration); + await ExecuteSqlFile($"{RootPath}.Migrations.{migration}.sql", conn, tx); + } + } + + private async Task ExecuteSqlFile(string resourceName, IPKConnection conn, IDbTransaction tx = null) + { + await using var stream = typeof(Database).Assembly.GetManifestResourceStream(resourceName); + if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'"); + + using var reader = new StreamReader(stream); + var query = await reader.ReadToEndAsync(); + + await conn.ExecuteAsync(query, transaction: tx); + + // If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches + // This will propagate to every other connection as well, since it marks the global type mapper collection dirty. + ((PKConnection) conn).ReloadTypes(); + } + + private async Task GetCurrentDatabaseVersion(IPKConnection conn) + { + // First, check if the "info" table exists (it may not, if this is a *really* old database) + var hasInfoTable = + await conn.QuerySingleOrDefaultAsync( + "select count(*) from information_schema.tables where table_name = 'info'") == 1; + + // If we have the table, read the schema version + if (hasInfoTable) + return await conn.QuerySingleOrDefaultAsync("select schema_version from info"); + + // If not, we return version "-1" + // This means migration 0 will get executed, getting us into a consistent state + // Then, migration 1 gets executed, which creates the info table and sets version to 1 + return -1; + } private class PassthroughTypeHandler: SqlMapper.TypeHandler { diff --git a/PluralKit.Core/Database/IDatabase.cs b/PluralKit.Core/Database/IDatabase.cs index 3cb44af0..2ac84d5e 100644 --- a/PluralKit.Core/Database/IDatabase.cs +++ b/PluralKit.Core/Database/IDatabase.cs @@ -4,6 +4,7 @@ namespace PluralKit.Core { public interface IDatabase { + Task ApplyMigrations(); Task Obtain(); } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Schemas.cs b/PluralKit.Core/Database/Schemas.cs deleted file mode 100644 index ef11ccd3..00000000 --- a/PluralKit.Core/Database/Schemas.cs +++ /dev/null @@ -1,92 +0,0 @@ -using System; -using System.Data; -using System.IO; -using System.Threading.Tasks; - -using Dapper; - -using Npgsql; - -using Serilog; - -namespace PluralKit.Core -{ - public class Schemas - { - private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files - private const int TargetSchemaVersion = 7; - - private IDatabase _conn; - private ILogger _logger; - - public Schemas(IDatabase conn, ILogger logger) - { - _conn = conn; - _logger = logger.ForContext(); - } - - public async Task InitializeDatabase() - { - // Run everything in a transaction - await using var conn = await _conn.Obtain(); - await using var tx = await conn.BeginTransactionAsync(); - - // Before applying migrations, clean out views/functions to prevent type errors - await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx); - - // Apply all migrations between the current database version and the target version - await ApplyMigrations(conn, tx); - - // Now, reapply views/functions (we deleted them above, no need to worry about conflicts) - await ExecuteSqlFile($"{RootPath}.views.sql", conn, tx); - await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx); - - // Finally, commit tx - await tx.CommitAsync(); - } - - private async Task ApplyMigrations(IPKConnection conn, IDbTransaction tx) - { - var currentVersion = await GetCurrentDatabaseVersion(conn); - _logger.Information("Current schema version: {CurrentVersion}", currentVersion); - for (var migration = currentVersion + 1; migration <= TargetSchemaVersion; migration++) - { - _logger.Information("Applying schema migration {MigrationId}", migration); - await ExecuteSqlFile($"{RootPath}.Migrations.{migration}.sql", conn, tx); - } - } - - private async Task ExecuteSqlFile(string resourceName, IPKConnection conn, IDbTransaction tx = null) - { - await using var stream = typeof(Schemas).Assembly.GetManifestResourceStream(resourceName); - if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'"); - - using var reader = new StreamReader(stream); - var query = await reader.ReadToEndAsync(); - - await conn.ExecuteAsync(query, transaction: tx); - - // If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches - // This will propagate to every other connection as well, since it marks the global type mapper collection dirty. - // TODO: find a way to get around the cast to our internal tracker wrapper... this could break if that ever changes - conn.ReloadTypes(); - } - - private async Task GetCurrentDatabaseVersion(IPKConnection conn) - { - // First, check if the "info" table exists (it may not, if this is a *really* old database) - var hasInfoTable = - await conn.QuerySingleOrDefaultAsync( - "select count(*) from information_schema.tables where table_name = 'info'") == 1; - - // If we have the table, read the schema version - if (hasInfoTable) - return await conn.QuerySingleOrDefaultAsync("select schema_version from info"); - - // If not, we return version "-1" - // This means migration 0 will get executed, getting us into a consistent state - // Then, migration 1 gets executed, which creates the info table and sets version to 1 - return -1; - } - } -} \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/IPKConnection.cs b/PluralKit.Core/Database/Wrappers/IPKConnection.cs index f82971d2..87cc22a5 100644 --- a/PluralKit.Core/Database/Wrappers/IPKConnection.cs +++ b/PluralKit.Core/Database/Wrappers/IPKConnection.cs @@ -22,9 +22,7 @@ namespace PluralKit.Core public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand); public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand); - - public void ReloadTypes(); - + [Obsolete] new void Open(); [Obsolete] new void Close(); diff --git a/PluralKit.Core/Modules.cs b/PluralKit.Core/Modules.cs index 6229286c..a76bfa14 100644 --- a/PluralKit.Core/Modules.cs +++ b/PluralKit.Core/Modules.cs @@ -23,9 +23,8 @@ namespace PluralKit.Core protected override void Load(ContainerBuilder builder) { builder.RegisterType().SingleInstance(); - builder.RegisterType().AsSelf().SingleInstance(); + builder.RegisterType().As().SingleInstance(); builder.RegisterType().AsSelf().As(); - builder.RegisterType().AsSelf(); builder.Populate(new ServiceCollection().AddMemoryCache()); }