Upgrade various store methods to IAsyncEnumerable

This commit is contained in:
Ske 2020-01-18 00:02:17 +01:00
parent 9a3355eb4b
commit 8a689ac0f2
7 changed files with 62 additions and 38 deletions

View File

@ -113,11 +113,11 @@ namespace PluralKit.API.Controllers
var sw = await _data.GetLatestSwitch(system); var sw = await _data.GetLatestSwitch(system);
if (sw == null) return NotFound("System has no registered switches."); if (sw == null) return NotFound("System has no registered switches.");
var members = await _data.GetSwitchMembers(sw); var members = _data.GetSwitchMembers(sw);
return Ok(new FrontersReturn return Ok(new FrontersReturn
{ {
Timestamp = sw.Timestamp, Timestamp = sw.Timestamp,
Members = members.Select(m => m.ToJson(_auth.ContextFor(system))) Members = await members.Select(m => m.ToJson(_auth.ContextFor(system))).ToListAsync()
}); });
} }
@ -151,10 +151,10 @@ namespace PluralKit.API.Controllers
var latestSwitch = await _data.GetLatestSwitch(_auth.CurrentSystem); var latestSwitch = await _data.GetLatestSwitch(_auth.CurrentSystem);
if (latestSwitch != null) if (latestSwitch != null)
{ {
var latestSwitchMembers = await _data.GetSwitchMembers(latestSwitch); var latestSwitchMembers = _data.GetSwitchMembers(latestSwitch);
// Bail if this switch is identical to the latest one // Bail if this switch is identical to the latest one
if (latestSwitchMembers.Select(m => m.Hid).SequenceEqual(param.Members)) if (await latestSwitchMembers.Select(m => m.Hid).SequenceEqualAsync(param.Members.ToAsyncEnumerable()))
return BadRequest("New members identical to existing fronters."); return BadRequest("New members identical to existing fronters.");
} }

View File

@ -58,9 +58,9 @@ namespace PluralKit.Bot.Commands
var lastSwitch = await _data.GetLatestSwitch(ctx.System); var lastSwitch = await _data.GetLatestSwitch(ctx.System);
if (lastSwitch != null) if (lastSwitch != null)
{ {
var lastSwitchMembers = await _data.GetSwitchMembers(lastSwitch); var lastSwitchMembers = _data.GetSwitchMembers(lastSwitch);
// Make sure the requested switch isn't identical to the last one // Make sure the requested switch isn't identical to the last one
if (lastSwitchMembers.Select(m => m.Id).SequenceEqual(members.Select(m => m.Id))) if (await lastSwitchMembers.Select(m => m.Id).SequenceEqualAsync(members.Select(m => m.Id).ToAsyncEnumerable()))
throw Errors.SameSwitch(members); throw Errors.SameSwitch(members);
} }
@ -86,13 +86,13 @@ namespace PluralKit.Bot.Commands
if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture; if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture;
// Fetch the last two switches for the system to do bounds checking on // Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = (await _data.GetSwitches(ctx.System, 2)).ToArray(); var lastTwoSwitches = await _data.GetSwitches(ctx.System).Take(2).ToListAsync();
// If we don't have a switch to move, don't bother // If we don't have a switch to move, don't bother
if (lastTwoSwitches.Length == 0) throw Errors.NoRegisteredSwitches; if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
// If there's a switch *behind* the one we move, we check to make srue we're not moving the time further back than that // If there's a switch *behind* the one we move, we check to make srue we're not moving the time further back than that
if (lastTwoSwitches.Length == 2) if (lastTwoSwitches.Count == 2)
{ {
if (lastTwoSwitches[1].Timestamp > time.ToInstant()) if (lastTwoSwitches[1].Timestamp > time.ToInstant())
throw Errors.SwitchMoveBeforeSecondLast(lastTwoSwitches[1].Timestamp.InZone(tz)); throw Errors.SwitchMoveBeforeSecondLast(lastTwoSwitches[1].Timestamp.InZone(tz));
@ -100,8 +100,8 @@ namespace PluralKit.Bot.Commands
// Now we can actually do the move, yay! // Now we can actually do the move, yay!
// But, we do a prompt to confirm. // But, we do a prompt to confirm.
var lastSwitchMembers = await _data.GetSwitchMembers(lastTwoSwitches[0]); var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]);
var lastSwitchMemberStr = string.Join(", ", lastSwitchMembers.Select(m => m.Name)); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.Name).ToListAsync());
var lastSwitchTimeStr = Formats.ZonedDateTimeFormat.Format(lastTwoSwitches[0].Timestamp.InZone(ctx.System.Zone)); var lastSwitchTimeStr = Formats.ZonedDateTimeFormat.Format(lastTwoSwitches[0].Timestamp.InZone(ctx.System.Zone));
var lastSwitchDeltaStr = Formats.DurationFormat.Format(SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp); var lastSwitchDeltaStr = Formats.DurationFormat.Format(SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp);
var newSwitchTimeStr = Formats.ZonedDateTimeFormat.Format(time); var newSwitchTimeStr = Formats.ZonedDateTimeFormat.Format(time);
@ -132,23 +132,23 @@ namespace PluralKit.Bot.Commands
} }
// Fetch the last two switches for the system to do bounds checking on // Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = (await _data.GetSwitches(ctx.System, 2)).ToArray(); var lastTwoSwitches = await _data.GetSwitches(ctx.System).Take(2).ToListAsync();
if (lastTwoSwitches.Length == 0) throw Errors.NoRegisteredSwitches; if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
var lastSwitchMembers = await _data.GetSwitchMembers(lastTwoSwitches[0]); var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]);
var lastSwitchMemberStr = string.Join(", ", lastSwitchMembers.Select(m => m.Name)); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.Name).ToListAsync());
var lastSwitchDeltaStr = Formats.DurationFormat.Format(SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp); var lastSwitchDeltaStr = Formats.DurationFormat.Format(SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp);
IUserMessage msg; IUserMessage msg;
if (lastTwoSwitches.Length == 1) if (lastTwoSwitches.Count == 1)
{ {
msg = await ctx.Reply( msg = await ctx.Reply(
$"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr.SanitizeMentions()}, {lastSwitchDeltaStr} ago). You have no other switches logged. Is this okay?"); $"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr.SanitizeMentions()}, {lastSwitchDeltaStr} ago). You have no other switches logged. Is this okay?");
} }
else else
{ {
var secondSwitchMembers = await _data.GetSwitchMembers(lastTwoSwitches[1]); var secondSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[1]);
var secondSwitchMemberStr = string.Join(", ", secondSwitchMembers.Select(m => m.Name)); var secondSwitchMemberStr = string.Join(", ", await secondSwitchMembers.Select(m => m.Name).ToListAsync());
var secondSwitchDeltaStr = Formats.DurationFormat.Format(SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[1].Timestamp); var secondSwitchDeltaStr = Formats.DurationFormat.Format(SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[1].Timestamp);
msg = await ctx.Reply( msg = await ctx.Reply(
$"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr.SanitizeMentions()}, {lastSwitchDeltaStr} ago). The next latest switch is {secondSwitchMemberStr.SanitizeMentions()} ({secondSwitchDeltaStr} ago). Is this okay?"); $"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr.SanitizeMentions()}, {lastSwitchDeltaStr} ago). The next latest switch is {secondSwitchMemberStr.SanitizeMentions()} ({secondSwitchDeltaStr} ago). Is this okay?");

View File

@ -239,10 +239,13 @@ namespace PluralKit.Bot.Commands
if (system == null) throw Errors.NoSystemError; if (system == null) throw Errors.NoSystemError;
ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy); ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy);
var sws = (await _data.GetSwitches(system, 10)).ToList(); var sws = _data.GetSwitches(system).Take(10);
if (sws.Count == 0) throw Errors.NoRegisteredSwitches; var embed = await _embeds.CreateFrontHistoryEmbed(sws, system.Zone);
await ctx.Reply(embed: await _embeds.CreateFrontHistoryEmbed(sws, system.Zone)); // Moving the count check to the CreateFrontHistoryEmbed function to avoid a double-iteration
// If embed == null, then there's no switches, so error
if (embed == null) throw Errors.NoRegisteredSwitches;
await ctx.Reply(embed: embed);
} }
public async Task SystemFrontPercent(Context ctx, PKSystem system) public async Task SystemFrontPercent(Context ctx, PKSystem system)

View File

@ -36,7 +36,7 @@ namespace PluralKit.Bot {
var latestSwitch = await _data.GetLatestSwitch(system); var latestSwitch = await _data.GetLatestSwitch(system);
if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx)) if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx))
{ {
var switchMembers = (await _data.GetSwitchMembers(latestSwitch)).ToList(); var switchMembers = await _data.GetSwitchMembers(latestSwitch).ToListAsync();
if (switchMembers.Count > 0) if (switchMembers.Count > 0)
eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None), eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None),
string.Join(", ", switchMembers.Select(m => m.Name))); string.Join(", ", switchMembers.Select(m => m.Name)));
@ -115,7 +115,7 @@ namespace PluralKit.Bot {
public async Task<Embed> CreateFronterEmbed(PKSwitch sw, DateTimeZone zone) public async Task<Embed> CreateFronterEmbed(PKSwitch sw, DateTimeZone zone)
{ {
var members = (await _data.GetSwitchMembers(sw)).ToList(); var members = await _data.GetSwitchMembers(sw).ToListAsync();
var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp; var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp;
return new EmbedBuilder() return new EmbedBuilder()
.WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? Color.Blue) .WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? Color.Blue)
@ -124,15 +124,15 @@ namespace PluralKit.Bot {
.Build(); .Build();
} }
public async Task<Embed> CreateFrontHistoryEmbed(IEnumerable<PKSwitch> sws, DateTimeZone zone) public async Task<Embed> CreateFrontHistoryEmbed(IAsyncEnumerable<PKSwitch> sws, DateTimeZone zone)
{ {
var outputStr = ""; var outputStr = "";
PKSwitch lastSw = null; PKSwitch lastSw = null;
foreach (var sw in sws) await foreach (var sw in sws)
{ {
// Fetch member list and format // Fetch member list and format
var members = (await _data.GetSwitchMembers(sw)).ToList(); var members = await _data.GetSwitchMembers(sw).ToListAsync();
var membersStr = members.Any() ? string.Join(", ", members.Select(m => m.Name)) : "no fronter"; var membersStr = members.Any() ? string.Join(", ", members.Select(m => m.Name)) : "no fronter";
var switchSince = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp; var switchSince = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp;
@ -156,6 +156,9 @@ namespace PluralKit.Bot {
lastSw = sw; lastSw = sw;
} }
if (lastSw == null)
return null;
return new EmbedBuilder() return new EmbedBuilder()
.WithTitle("Past switches") .WithTitle("Past switches")
.WithDescription(outputStr) .WithDescription(outputStr)

View File

@ -25,6 +25,7 @@
<PackageReference Include="Serilog.Sinks.Async" Version="1.4.1-dev-00071" /> <PackageReference Include="Serilog.Sinks.Async" Version="1.4.1-dev-00071" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.0.0-dev-00834" /> <PackageReference Include="Serilog.Sinks.Console" Version="4.0.0-dev-00834" />
<PackageReference Include="Serilog.Sinks.File" Version="4.1.0" /> <PackageReference Include="Serilog.Sinks.File" Version="4.1.0" />
<PackageReference Include="System.Interactive.Async" Version="4.0.0" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@ -1,4 +1,5 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Data.Common;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -289,7 +290,7 @@ namespace PluralKit {
/// Gets switches from a system. /// Gets switches from a system.
/// </summary> /// </summary>
/// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns> /// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns>
Task<IEnumerable<PKSwitch>> GetSwitches(PKSystem system, int count); IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system);
/// <summary> /// <summary>
/// Gets the latest (temporally; closest to now) switch of a given system. /// Gets the latest (temporally; closest to now) switch of a given system.
@ -299,7 +300,7 @@ namespace PluralKit {
/// <summary> /// <summary>
/// Gets the members a given switch consists of. /// Gets the members a given switch consists of.
/// </summary> /// </summary>
Task<IEnumerable<PKMember>> GetSwitchMembers(PKSwitch sw); IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw);
/// <summary> /// <summary>
/// Gets a list of fronters over a given period of time. /// Gets a list of fronters over a given period of time.
@ -787,7 +788,9 @@ namespace PluralKit {
public async Task AddSwitchesBulk(PKSystem system, IEnumerable<ImportedSwitch> switches) public async Task AddSwitchesBulk(PKSystem system, IEnumerable<ImportedSwitch> switches)
{ {
// Read existing switches to enforce unique timestamps // Read existing switches to enforce unique timestamps
var priorSwitches = await GetSwitches(system); var priorSwitches = new List<PKSwitch>();
await foreach (var sw in GetSwitches(system)) priorSwitches.Add(sw);
var lastSwitchId = priorSwitches.Any() var lastSwitchId = priorSwitches.Any()
? priorSwitches.Max(x => x.Id) ? priorSwitches.Max(x => x.Id)
: 0; : 0;
@ -855,12 +858,13 @@ namespace PluralKit {
_logger.Information("Completed bulk import of switches for system {0}", system.Hid); _logger.Information("Completed bulk import of switches for system {0}", system.Hid);
} }
public async Task<IEnumerable<PKSwitch>> GetSwitches(PKSystem system, int count = 9999999) public IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system)
{ {
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
// (maybe when we get caching in?) // (maybe when we get caching in?)
using (var conn = await _conn.Obtain()) return _conn.QueryStreamAsync<PKSwitch>(
return await conn.QueryAsync<PKSwitch>("select * from switches where system = @System order by timestamp desc limit @Count", new {System = system.Id, Count = count}); "select * from switches where system = @System order by timestamp desc",
new {System = system.Id});
} }
public async Task<IEnumerable<SwitchMembersListEntry>> GetSwitchMembersList(PKSystem system, Instant start, Instant end) public async Task<IEnumerable<SwitchMembersListEntry>> GetSwitchMembersList(PKSystem system, Instant start, Instant end)
@ -899,15 +903,15 @@ namespace PluralKit {
} }
} }
public async Task<IEnumerable<PKMember>> GetSwitchMembers(PKSwitch sw) public IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw)
{ {
using (var conn = await _conn.Obtain()) return _conn.QueryStreamAsync<PKMember>(
return await conn.QueryAsync<PKMember>( "select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id", new {Switch = sw.Id});
new {Switch = sw.Id});
} }
public async Task<PKSwitch> GetLatestSwitch(PKSystem system) => (await GetSwitches(system, 1)).FirstOrDefault(); public async Task<PKSwitch> GetLatestSwitch(PKSystem system) =>
await GetSwitches(system).FirstOrDefaultAsync();
public async Task MoveSwitch(PKSwitch sw, Instant time) public async Task MoveSwitch(PKSwitch sw, Instant time)
{ {

View File

@ -669,4 +669,17 @@ namespace PluralKit
EventId = Guid.NewGuid(); EventId = Guid.NewGuid();
} }
} }
public static class ConnectionUtils
{
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this DbConnectionFactory connFactory, string sql, object param)
{
using var conn = await connFactory.Obtain();
var reader = await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>();
while (reader.Read())
yield return parser(reader);
}
}
} }