Move hid generation to the database. Closes #157.

This commit is contained in:
Ske 2020-06-13 00:43:48 +02:00
parent c39c51426f
commit 8ac2f1e4b8
5 changed files with 36 additions and 43 deletions

View File

@ -82,3 +82,30 @@ as $$
left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id
where accounts.uid = account_id where accounts.uid = account_id
$$ language sql stable rows 10; $$ language sql stable rows 10;
create function generate_hid() returns text as $$
select string_agg(substr('abcdefghijklmnopqrstuvwxyz', ceil(random() * 26)::integer, 1), '') from generate_series(1, 5)
$$ language sql volatile;
create function find_free_system_hid() returns text as $$
declare new_hid text;
begin
loop
new_hid := generate_hid();
if not exists (select 1 from systems where hid = new_hid) then return new_hid; end if;
end loop;
end
$$ language plpgsql volatile;
create function find_free_member_hid() returns text as $$
declare new_hid text;
begin
loop
new_hid := generate_hid();
if not exists (select 1 from members where hid = new_hid) then return new_hid; end if;
end loop;
end
$$ language plpgsql volatile;

View File

@ -3,3 +3,6 @@ drop view if exists member_list;
drop function if exists message_context; drop function if exists message_context;
drop function if exists proxy_members; drop function if exists proxy_members;
drop function if exists generate_hid;
drop function if exists find_free_system_hid;
drop function if exists find_free_member_hid;

View File

@ -58,15 +58,9 @@ namespace PluralKit.Core {
} }
public async Task<PKSystem> CreateSystem(string systemName = null) { public async Task<PKSystem> CreateSystem(string systemName = null) {
string hid;
do
{
hid = StringUtils.GenerateHid();
} while (await GetSystemByHid(hid) != null);
PKSystem system; PKSystem system;
using (var conn = await _conn.Obtain()) using (var conn = await _conn.Obtain())
system = await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (@Hid, @Name) returning *", new { Hid = hid, Name = systemName }); system = await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *", new { Name = systemName });
_logger.Information("Created system {System}", system.Id); _logger.Information("Created system {System}", system.Id);
// New system has no accounts, therefore nothing gets cached, therefore no need to invalidate caches right here // New system has no accounts, therefore nothing gets cached, therefore no need to invalidate caches right here
@ -147,16 +141,9 @@ namespace PluralKit.Core {
} }
public async Task<PKMember> CreateMember(PKSystem system, string name) { public async Task<PKMember> CreateMember(PKSystem system, string name) {
string hid;
do
{
hid = StringUtils.GenerateHid();
} while (await GetMemberByHid(hid) != null);
PKMember member; PKMember member;
using (var conn = await _conn.Obtain()) using (var conn = await _conn.Obtain())
member = await conn.QuerySingleAsync<PKMember>("insert into members (hid, system, name) values (@Hid, @SystemId, @Name) returning *", new { member = await conn.QuerySingleAsync<PKMember>("insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", new {
Hid = hid,
SystemID = system.Id, SystemID = system.Id,
Name = name Name = name
}); });

View File

@ -2,6 +2,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Data;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -69,7 +70,7 @@ namespace PluralKit.Core
// if not, roll a new hid and we'll insert one with that // if not, roll a new hid and we'll insert one with that
// (we can't trust the hid given in the member, it might let us overwrite another system's members) // (we can't trust the hid given in the member, it might let us overwrite another system's members)
var existingMember = FindExistingMemberInSystem(member.Hid, member.Name); var existingMember = FindExistingMemberInSystem(member.Hid, member.Name);
string newHid = existingMember?.Hid ?? await FindFreeHid(); string newHid = existingMember?.Hid ?? await _conn.QuerySingleAsync<string>("find_free_member_hid", commandType: CommandType.StoredProcedure);
// Upsert member data and return the ID // Upsert member data and return the ID
QueryBuilder qb = QueryBuilder.Upsert("members", "hid") QueryBuilder qb = QueryBuilder.Upsert("members", "hid")
@ -112,19 +113,6 @@ namespace PluralKit.Core
return null; return null;
} }
private async Task<string> FindFreeHid()
{
string hid;
do
{
hid = await _conn.QuerySingleOrDefaultAsync<string>(
"select @Hid where not exists (select id from members where hid = @Hid)",
new {Hid = StringUtils.GenerateHid()});
} while (hid == null);
return hid;
}
/// <summary> /// <summary>
/// Register switches in bulk. /// Register switches in bulk.
/// </summary> /// </summary>

View File

@ -6,18 +6,6 @@ namespace PluralKit.Core
{ {
public static class StringUtils public static class StringUtils
{ {
public static string GenerateHid()
{
var rnd = new Random();
var charset = "abcdefghijklmnopqrstuvwxyz";
string hid = "";
for (int i = 0; i < 5; i++)
{
hid += charset[rnd.Next(charset.Length)];
}
return hid;
}
public static string GenerateToken() public static string GenerateToken()
{ {
var buf = new byte[48]; // Results in a 64-byte Base64 string (no padding) var buf = new byte[48]; // Results in a 64-byte Base64 string (no padding)