diff --git a/PluralKit.API/Controllers/SystemController.cs b/PluralKit.API/Controllers/SystemController.cs index 9024509c..4108e352 100644 --- a/PluralKit.API/Controllers/SystemController.cs +++ b/PluralKit.API/Controllers/SystemController.cs @@ -36,10 +36,10 @@ namespace PluralKit.API.Controllers private SystemStore _systems; private MemberStore _members; private SwitchStore _switches; - private IDbConnection _conn; + private DbConnectionFactory _conn; private TokenAuthService _auth; - public SystemController(SystemStore systems, MemberStore members, SwitchStore switches, IDbConnection conn, TokenAuthService auth) + public SystemController(SystemStore systems, MemberStore members, SwitchStore switches, DbConnectionFactory conn, TokenAuthService auth) { _systems = systems; _members = members; @@ -74,15 +74,18 @@ namespace PluralKit.API.Controllers var system = await _systems.GetByHid(hid); if (system == null) return NotFound("System not found."); - var res = await _conn.QueryAsync( - @"select *, array( + using (var conn = _conn.Obtain()) + { + var res = await conn.QueryAsync( + @"select *, array( select members.hid from switch_members, members where switch_members.switch = switches.id and members.id = switch_members.member ) as members from switches where switches.system = @System and switches.timestamp < @Before order by switches.timestamp desc - limit 100;", new { System = system.Id, Before = before }); - return Ok(res); + limit 100;", new {System = system.Id, Before = before}); + return Ok(res); + } } [HttpGet("{hid}/fronters")] @@ -142,7 +145,10 @@ namespace PluralKit.API.Controllers return BadRequest("New members identical to existing fronters."); // Resolve member objects for all given IDs - var membersList = (await _conn.QueryAsync("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList(); + IEnumerable membersList; + using (var conn = _conn.Obtain()) + membersList = (await conn.QueryAsync("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList(); + foreach (var member in membersList) if (member.System != _auth.CurrentSystem.Id) return BadRequest($"Cannot switch to member '{member.Hid}' not in system."); diff --git a/PluralKit.API/Startup.cs b/PluralKit.API/Startup.cs index 077bde80..ff452f02 100644 --- a/PluralKit.API/Startup.cs +++ b/PluralKit.API/Startup.cs @@ -33,22 +33,17 @@ namespace PluralKit.API services.AddMvc(opts => { }) .SetCompatibilityVersion(CompatibilityVersion.Version_2_2) .AddJsonOptions(opts => { opts.SerializerSettings.BuildSerializerSettings(); }); - + services .AddTransient() .AddTransient() .AddTransient() .AddTransient() - + .AddScoped() - + .AddTransient(_ => Configuration.GetSection("PluralKit").Get() ?? new CoreConfig()) - .AddScoped(svc => - { - var conn = new NpgsqlConnection(svc.GetRequiredService().Database); - conn.Open(); - return conn; - }); + .AddSingleton(svc => new DbConnectionFactory(svc.GetRequiredService().Database)); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index 71c04a39..e64c5373 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -1,5 +1,6 @@ using System; using System.Data; +using System.Data.Common; using System.Diagnostics; using System.IO; using System.Linq; @@ -32,8 +33,8 @@ namespace PluralKit.Bot using (var services = BuildServiceProvider()) { Console.WriteLine("- Connecting to database..."); - var connection = services.GetRequiredService() as NpgsqlConnection; - await Schema.CreateTables(connection); + using (var conn = services.GetRequiredService().Obtain()) + await Schema.CreateTables(conn); Console.WriteLine("- Connecting to Discord..."); var client = services.GetRequiredService() as DiscordSocketClient; @@ -51,13 +52,7 @@ namespace PluralKit.Bot .AddTransient(_ => _config.GetSection("PluralKit").Get() ?? new CoreConfig()) .AddTransient(_ => _config.GetSection("PluralKit").GetSection("Bot").Get() ?? new BotConfig()) - .AddTransient(svc => - { - - var conn = new NpgsqlConnection(svc.GetRequiredService().Database); - conn.Open(); - return conn; - }) + .AddTransient(svc => new DbConnectionFactory(svc.GetRequiredService().Database)) .AddSingleton() .AddSingleton() @@ -170,9 +165,10 @@ namespace PluralKit.Bot // If it does, fetch the sender's system (because most commands need that) into the context, // and start command execution // Note system may be null if user has no system, hence `OrDefault` - var connection = serviceScope.ServiceProvider.GetService(); - var system = await connection.QueryFirstOrDefaultAsync("select systems.* from systems, accounts where accounts.uid = @Id and systems.id = accounts.system", new { Id = arg.Author.Id }); - await _commands.ExecuteAsync(new PKCommandContext(_client, arg, connection, system), argPos, serviceScope.ServiceProvider); + PKSystem system; + using (var conn = serviceScope.ServiceProvider.GetService().Obtain()) + system = await conn.QueryFirstOrDefaultAsync("select systems.* from systems, accounts where accounts.uid = @Id and systems.id = accounts.system", new { Id = arg.Author.Id }); + await _commands.ExecuteAsync(new PKCommandContext(_client, arg, system), argPos, serviceScope.ServiceProvider); } else { diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index 93ccb638..6915c526 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -11,13 +11,13 @@ namespace PluralKit.Bot { public class LogChannelService { private IDiscordClient _client; - private IDbConnection _connection; + private DbConnectionFactory _conn; private EmbedService _embed; - public LogChannelService(IDiscordClient client, IDbConnection connection, EmbedService embed) + public LogChannelService(IDiscordClient client, DbConnectionFactory conn, EmbedService embed) { this._client = client; - this._connection = connection; + this._conn = conn; this._embed = embed; } @@ -30,9 +30,14 @@ namespace PluralKit.Bot { } public async Task GetLogChannel(IGuild guild) { - var server = await _connection.QueryFirstOrDefaultAsync("select * from servers where id = @Id", new { Id = guild.Id }); - if (server?.LogChannel == null) return null; - return await _client.GetChannelAsync(server.LogChannel.Value) as ITextChannel; + using (var conn = _conn.Obtain()) + { + var server = + await conn.QueryFirstOrDefaultAsync("select * from servers where id = @Id", + new {Id = guild.Id}); + if (server?.LogChannel == null) return null; + return await _client.GetChannelAsync(server.LogChannel.Value) as ITextChannel; + } } public async Task SetLogChannel(IGuild guild, ITextChannel newLogChannel) { @@ -41,7 +46,12 @@ namespace PluralKit.Bot { LogChannel = newLogChannel?.Id }; - await _connection.QueryAsync("insert into servers (id, log_channel) values (@Id, @LogChannel) on conflict (id) do update set log_channel = @LogChannel", def); + using (var conn = _conn.Obtain()) + { + await conn.QueryAsync( + "insert into servers (id, log_channel) values (@Id, @LogChannel) on conflict (id) do update set log_channel = @LogChannel", + def); + } } } } \ No newline at end of file diff --git a/PluralKit.Bot/Services/ProxyService.cs b/PluralKit.Bot/Services/ProxyService.cs index 083461a4..17a61c1d 100644 --- a/PluralKit.Bot/Services/ProxyService.cs +++ b/PluralKit.Bot/Services/ProxyService.cs @@ -28,17 +28,17 @@ namespace PluralKit.Bot class ProxyService { private IDiscordClient _client; - private IDbConnection _connection; + private DbConnectionFactory _conn; private LogChannelService _logger; private WebhookCacheService _webhookCache; private MessageStore _messageStorage; private EmbedService _embeds; - public ProxyService(IDiscordClient client, WebhookCacheService webhookCache, IDbConnection connection, LogChannelService logger, MessageStore messageStorage, EmbedService embeds) + public ProxyService(IDiscordClient client, WebhookCacheService webhookCache, DbConnectionFactory conn, LogChannelService logger, MessageStore messageStorage, EmbedService embeds) { _client = client; _webhookCache = webhookCache; - _connection = connection; + _conn = conn; _logger = logger; _messageStorage = messageStorage; _embeds = embeds; @@ -76,11 +76,16 @@ namespace PluralKit.Bot return null; } - public async Task HandleMessageAsync(IMessage message) { - var results = await _connection.QueryAsync( - "select members.*, systems.* from members, systems, accounts where members.system = systems.id and accounts.system = systems.id and accounts.uid = @Uid", - (member, system) => - new ProxyDatabaseResult { Member = member, System = system }, new { Uid = message.Author.Id }); + public async Task HandleMessageAsync(IMessage message) + { + IEnumerable results; + using (var conn = _conn.Obtain()) + { + results = await conn.QueryAsync( + "select members.*, systems.* from members, systems, accounts where members.system = systems.id and accounts.system = systems.id and accounts.uid = @Uid", + (member, system) => + new ProxyDatabaseResult {Member = member, System = system}, new {Uid = message.Author.Id}); + } // Find a member with proxy tags matching the message var match = GetProxyTagMatch(message.Content, results); diff --git a/PluralKit.Bot/Utils.cs b/PluralKit.Bot/Utils.cs index 341541b6..c2ef2def 100644 --- a/PluralKit.Bot/Utils.cs +++ b/PluralKit.Bot/Utils.cs @@ -152,14 +152,12 @@ namespace PluralKit.Bot /// Subclass of ICommandContext with PK-specific additional fields and functionality public class PKCommandContext : SocketCommandContext { - public IDbConnection Connection { get; } public PKSystem SenderSystem { get; } private object _entity; - public PKCommandContext(DiscordSocketClient client, SocketUserMessage msg, IDbConnection connection, PKSystem system) : base(client, msg) + public PKCommandContext(DiscordSocketClient client, SocketUserMessage msg, PKSystem system) : base(client, msg) { - Connection = connection; SenderSystem = system; } diff --git a/PluralKit.Core/Stores.cs b/PluralKit.Core/Stores.cs index 012794fa..084f1b6b 100644 --- a/PluralKit.Core/Stores.cs +++ b/PluralKit.Core/Stores.cs @@ -10,81 +10,96 @@ using NodaTime; namespace PluralKit { public class SystemStore { - private IDbConnection conn; + private DbConnectionFactory _conn; - public SystemStore(IDbConnection conn) { - this.conn = conn; + public SystemStore(DbConnectionFactory conn) { + this._conn = conn; } public async Task Create(string systemName = null) { // TODO: handle HID collision case var hid = Utils.GenerateHid(); - return await conn.QuerySingleAsync("insert into systems (hid, name) values (@Hid, @Name) returning *", new { Hid = hid, Name = systemName }); + + using (var conn = _conn.Obtain()) + return await conn.QuerySingleAsync("insert into systems (hid, name) values (@Hid, @Name) returning *", new { Hid = hid, Name = systemName }); } public async Task Link(PKSystem system, ulong accountId) { - await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId)", new { Id = accountId, SystemId = system.Id }); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId)", new { Id = accountId, SystemId = system.Id }); } public async Task Unlink(PKSystem system, ulong accountId) { - await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id }); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id }); } public async Task GetByAccount(ulong accountId) { - return await conn.QuerySingleOrDefaultAsync("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId }); + using (var conn = _conn.Obtain()) + return await conn.QuerySingleOrDefaultAsync("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId }); } public async Task GetByHid(string hid) { - return await conn.QuerySingleOrDefaultAsync("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() }); + using (var conn = _conn.Obtain()) + return await conn.QuerySingleOrDefaultAsync("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() }); } public async Task GetByToken(string token) { - return await conn.QuerySingleOrDefaultAsync("select * from systems where token = @Token", new { Token = token }); + using (var conn = _conn.Obtain()) + return await conn.QuerySingleOrDefaultAsync("select * from systems where token = @Token", new { Token = token }); } public async Task GetById(int id) { - return await conn.QuerySingleOrDefaultAsync("select * from systems where id = @Id", new { Id = id }); + using (var conn = _conn.Obtain()) + return await conn.QuerySingleOrDefaultAsync("select * from systems where id = @Id", new { Id = id }); } public async Task Save(PKSystem system) { - await conn.ExecuteAsync("update systems set name = @Name, description = @Description, tag = @Tag, avatar_url = @AvatarUrl, token = @Token, ui_tz = @UiTz where id = @Id", system); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("update systems set name = @Name, description = @Description, tag = @Tag, avatar_url = @AvatarUrl, token = @Token, ui_tz = @UiTz where id = @Id", system); } public async Task Delete(PKSystem system) { - await conn.ExecuteAsync("delete from systems where id = @Id", system); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("delete from systems where id = @Id", system); } public async Task> GetLinkedAccountIds(PKSystem system) { - return await conn.QueryAsync("select uid from accounts where system = @Id", new { Id = system.Id }); + using (var conn = _conn.Obtain()) + return await conn.QueryAsync("select uid from accounts where system = @Id", new { Id = system.Id }); } } public class MemberStore { - private IDbConnection conn; + private DbConnectionFactory _conn; - public MemberStore(IDbConnection conn) { - this.conn = conn; + public MemberStore(DbConnectionFactory conn) { + this._conn = conn; } public async Task Create(PKSystem system, string name) { // TODO: handle collision var hid = Utils.GenerateHid(); - return await conn.QuerySingleAsync("insert into members (hid, system, name) values (@Hid, @SystemId, @Name) returning *", new { - Hid = hid, - SystemID = system.Id, - Name = name - }); + + using (var conn = _conn.Obtain()) + return await conn.QuerySingleAsync("insert into members (hid, system, name) values (@Hid, @SystemId, @Name) returning *", new { + Hid = hid, + SystemID = system.Id, + Name = name + }); } public async Task GetByHid(string hid) { - return await conn.QuerySingleOrDefaultAsync("select * from members where hid = @Hid", new { Hid = hid.ToLower() }); + using (var conn = _conn.Obtain()) + return await conn.QuerySingleOrDefaultAsync("select * from members where hid = @Hid", new { Hid = hid.ToLower() }); } public async Task GetByName(PKSystem system, string name) { // QueryFirst, since members can (in rare cases) share names - return await conn.QueryFirstOrDefaultAsync("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id }); + using (var conn = _conn.Obtain()) + return await conn.QueryFirstOrDefaultAsync("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id }); } public async Task> GetUnproxyableMembers(PKSystem system) { @@ -96,20 +111,24 @@ namespace PluralKit { } public async Task> GetBySystem(PKSystem system) { - return await conn.QueryAsync("select * from members where system = @SystemID", new { SystemID = system.Id }); + using (var conn = _conn.Obtain()) + return await conn.QueryAsync("select * from members where system = @SystemID", new { SystemID = system.Id }); } public async Task Save(PKMember member) { - await conn.ExecuteAsync("update members set name = @Name, description = @Description, color = @Color, avatar_url = @AvatarUrl, birthday = @Birthday, pronouns = @Pronouns, prefix = @Prefix, suffix = @Suffix where id = @Id", member); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("update members set name = @Name, description = @Description, color = @Color, avatar_url = @AvatarUrl, birthday = @Birthday, pronouns = @Pronouns, prefix = @Prefix, suffix = @Suffix where id = @Id", member); } public async Task Delete(PKMember member) { - await conn.ExecuteAsync("delete from members where id = @Id", member); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("delete from members where id = @Id", member); } public async Task MessageCount(PKMember member) { - return await conn.QuerySingleAsync("select count(*) from messages where member = @Id", member); + using (var conn = _conn.Obtain()) + return await conn.QuerySingleAsync("select count(*) from messages where member = @Id", member); } } @@ -127,59 +146,63 @@ namespace PluralKit { public PKSystem System; } - private IDbConnection _connection; + private DbConnectionFactory _conn; - public MessageStore(IDbConnection connection) { - this._connection = connection; + public MessageStore(DbConnectionFactory conn) { + this._conn = conn; } public async Task Store(ulong senderId, ulong messageId, ulong channelId, PKMember member) { - await _connection.ExecuteAsync("insert into messages(mid, channel, member, sender) values(@MessageId, @ChannelId, @MemberId, @SenderId)", new { - MessageId = messageId, - ChannelId = channelId, - MemberId = member.Id, - SenderId = senderId - }); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("insert into messages(mid, channel, member, sender) values(@MessageId, @ChannelId, @MemberId, @SenderId)", new { + MessageId = messageId, + ChannelId = channelId, + MemberId = member.Id, + SenderId = senderId + }); } public async Task Get(ulong id) { - return (await _connection.QueryAsync("select messages.*, members.*, systems.* from messages, members, systems where mid = @Id and messages.member = members.id and systems.id = members.system", (msg, member, system) => new StoredMessage - { - Message = msg, - System = system, - Member = member - }, new { Id = id })).FirstOrDefault(); + using (var conn = _conn.Obtain()) + return (await conn.QueryAsync("select messages.*, members.*, systems.* from messages, members, systems where mid = @Id and messages.member = members.id and systems.id = members.system", (msg, member, system) => new StoredMessage + { + Message = msg, + System = system, + Member = member + }, new { Id = id })).FirstOrDefault(); } public async Task Delete(ulong id) { - await _connection.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }); } } public class SwitchStore { - private IDbConnection _connection; + private DbConnectionFactory _conn; - public SwitchStore(IDbConnection connection) + public SwitchStore(DbConnectionFactory conn) { - _connection = connection; + _conn = conn; } public async Task RegisterSwitch(PKSystem system, IEnumerable members) { // Use a transaction here since we're doing multiple executed commands in one - using (var tx = _connection.BeginTransaction()) + using (var conn = _conn.Obtain()) + using (var tx = conn.BeginTransaction()) { // First, we insert the switch itself - var sw = await _connection.QuerySingleAsync("insert into switches(system) values (@System) returning *", + var sw = await conn.QuerySingleAsync("insert into switches(system) values (@System) returning *", new {System = system.Id}); // Then we insert each member in the switch in the switch_members table // TODO: can we parallelize this or send it in bulk somehow? foreach (var member in members) { - await _connection.ExecuteAsync( + await conn.ExecuteAsync( "insert into switch_members(switch, member) values(@Switch, @Member)", new {Switch = sw.Id, Member = member.Id}); } @@ -193,33 +216,38 @@ namespace PluralKit { { // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list // (maybe when we get caching in?) - return await _connection.QueryAsync("select * from switches where system = @System order by timestamp desc limit @Count", new {System = system.Id, Count = count}); + using (var conn = _conn.Obtain()) + return await conn.QueryAsync("select * from switches where system = @System order by timestamp desc limit @Count", new {System = system.Id, Count = count}); } public async Task> GetSwitchMemberIds(PKSwitch sw) { - return await _connection.QueryAsync("select member from switch_members where switch = @Switch", - new {Switch = sw.Id}); + using (var conn = _conn.Obtain()) + return await conn.QueryAsync("select member from switch_members where switch = @Switch", + new {Switch = sw.Id}); } public async Task> GetSwitchMembers(PKSwitch sw) { - return await _connection.QueryAsync( - "select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch", - new {Switch = sw.Id}); + using (var conn = _conn.Obtain()) + return await conn.QueryAsync( + "select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch", + new {Switch = sw.Id}); } public async Task GetLatestSwitch(PKSystem system) => (await GetSwitches(system, 1)).FirstOrDefault(); public async Task MoveSwitch(PKSwitch sw, Instant time) { - await _connection.ExecuteAsync("update switches set timestamp = @Time where id = @Id", - new {Time = time, Id = sw.Id}); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id", + new {Time = time, Id = sw.Id}); } public async Task DeleteSwitch(PKSwitch sw) { - await _connection.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id}); + using (var conn = _conn.Obtain()) + await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id}); } public struct SwitchListEntry @@ -242,11 +270,15 @@ namespace PluralKit { // query DB for all members involved in any of the switches above and collect into a dictionary for future use // this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary // key used in GetPerMemberSwitchDuration below - var memberObjects = (await _connection.QueryAsync( - "select distinct members.* from members, switch_members where switch_members.switch = any(@Switches) and switch_members.member = members.id", // lol postgres specific `= any()` syntax - new {Switches = switchesInRange.Select(sw => sw.Id).ToList()})) - .ToDictionary(m => m.Id); - + Dictionary memberObjects; + using (var conn = _conn.Obtain()) + { + memberObjects = (await conn.QueryAsync( + "select distinct members.* from members, switch_members where switch_members.switch = any(@Switches) and switch_members.member = members.id", // lol postgres specific `= any()` syntax + new {Switches = switchesInRange.Select(sw => sw.Id).ToList()})) + .ToDictionary(m => m.Id); + } + // we create the entry objects var outList = new List(); diff --git a/PluralKit.Core/Utils.cs b/PluralKit.Core/Utils.cs index 2e49620a..f50cf248 100644 --- a/PluralKit.Core/Utils.cs +++ b/PluralKit.Core/Utils.cs @@ -333,4 +333,19 @@ namespace PluralKit return (T) value; } } + + public class DbConnectionFactory + { + private string _connectionString; + + public DbConnectionFactory(string connectionString) + { + _connectionString = connectionString; + } + + public IDbConnection Obtain() + { + return new NpgsqlConnection(_connectionString); + } + } }