2018-07-24 22:47:57 +02:00
from collections import namedtuple
from datetime import datetime
import logging
from typing import List
2018-07-12 00:47:44 +02:00
import time
import asyncpg
import asyncpg . exceptions
2018-08-02 00:36:50 +02:00
from discord . utils import snowflake_time
2018-07-12 00:47:44 +02:00
2018-07-24 22:47:57 +02:00
from pluralkit import System , Member , stats
2018-07-12 00:47:44 +02:00
2018-07-24 22:47:57 +02:00
logger = logging . getLogger ( " pluralkit.db " )
2018-08-22 19:50:32 +02:00
async def connect ( username , password , database , host , port ) :
2018-07-12 00:47:44 +02:00
while True :
try :
2018-08-22 19:50:32 +02:00
return await asyncpg . create_pool ( user = username , password = password , database = database , host = host , port = port )
2018-07-12 00:47:44 +02:00
except ( ConnectionError , asyncpg . exceptions . CannotConnectNowError ) :
pass
def db_wrap ( func ) :
async def inner ( * args , * * kwargs ) :
before = time . perf_counter ( )
2018-07-16 20:53:41 +02:00
try :
res = await func ( * args , * * kwargs )
after = time . perf_counter ( )
2018-07-12 00:49:02 +02:00
2018-07-16 20:53:41 +02:00
logger . debug ( " - DB call {} took {:.2f} ms " . format ( func . __name__ , ( after - before ) * 1000 ) )
2018-09-01 19:12:33 +02:00
# TODO: find some way to give this func access to the bot's stats object
#await stats.report_db_query(func.__name__, after - before, True)
2018-07-16 20:53:41 +02:00
return res
except asyncpg . exceptions . PostgresError :
2018-09-01 19:12:33 +02:00
#await stats.report_db_query(func.__name__, time.perf_counter() - before, False)
2018-07-16 20:53:41 +02:00
logger . exception ( " Error from database query {} " . format ( func . __name__ ) )
2018-07-12 00:47:44 +02:00
return inner
@db_wrap
2018-07-24 22:47:57 +02:00
async def create_system ( conn , system_name : str , system_hid : str ) - > System :
2018-07-12 00:49:02 +02:00
logger . debug ( " Creating system (name= {} , hid= {} ) " . format (
system_name , system_hid ) )
2018-07-24 22:47:57 +02:00
row = await conn . fetchrow ( " insert into systems (name, hid) values ($1, $2) returning * " , system_name , system_hid )
return System ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def remove_system ( conn , system_id : int ) :
logger . debug ( " Deleting system (id= {} ) " . format ( system_id ) )
await conn . execute ( " delete from systems where id = $1 " , system_id )
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def create_member ( conn , system_id : int , member_name : str , member_hid : str ) - > Member :
2018-07-12 00:49:02 +02:00
logger . debug ( " Creating member (system= {} , name= {} , hid= {} ) " . format (
system_id , member_name , member_hid ) )
2018-07-24 22:47:57 +02:00
row = await conn . fetchrow ( " insert into members (name, system, hid) values ($1, $2, $3) returning * " , member_name , system_id , member_hid )
return Member ( * * row ) if row else None
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def delete_member ( conn , member_id : int ) :
logger . debug ( " Deleting member (id= {} ) " . format ( member_id ) )
await conn . execute ( " delete from members where id = $1 " , member_id )
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def link_account ( conn , system_id : int , account_id : str ) :
2018-07-12 00:49:02 +02:00
logger . debug ( " Linking account (account_id= {} , system_id= {} ) " . format (
account_id , system_id ) )
2018-07-12 00:47:44 +02:00
await conn . execute ( " insert into accounts (uid, system) values ($1, $2) " , int ( account_id ) , system_id )
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def unlink_account ( conn , system_id : int , account_id : str ) :
2018-07-12 00:49:02 +02:00
logger . debug ( " Unlinking account (account_id= {} , system_id= {} ) " . format (
account_id , system_id ) )
2018-07-12 00:47:44 +02:00
await conn . execute ( " delete from accounts where uid = $1 and system = $2 " , int ( account_id ) , system_id )
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_linked_accounts ( conn , system_id : int ) - > List [ int ] :
2018-07-12 00:47:44 +02:00
return [ row [ " uid " ] for row in await conn . fetch ( " select uid from accounts where system = $1 " , system_id ) ]
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_system_by_account ( conn , account_id : str ) - > System :
row = await conn . fetchrow ( " select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id " , int ( account_id ) )
return System ( * * row ) if row else None
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_system_by_hid ( conn , system_hid : str ) - > System :
row = await conn . fetchrow ( " select * from systems where hid = $1 " , system_hid )
return System ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_system ( conn , system_id : int ) - > System :
row = await conn . fetchrow ( " select * from systems where id = $1 " , system_id )
return System ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_member_by_name ( conn , system_id : int , member_name : str ) - > Member :
row = await conn . fetchrow ( " select * from members where system = $1 and lower(name) = lower($2) " , system_id , member_name )
return Member ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_member_by_hid_in_system ( conn , system_id : int , member_hid : str ) - > Member :
row = await conn . fetchrow ( " select * from members where system = $1 and hid = $2 " , system_id , member_hid )
return Member ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_member_by_hid ( conn , member_hid : str ) - > Member :
row = await conn . fetchrow ( " select * from members where hid = $1 " , member_hid )
return Member ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_member ( conn , member_id : int ) - > Member :
row = await conn . fetchrow ( " select * from members where id = $1 " , member_id )
return Member ( * * row ) if row else None
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_members ( conn , members : list ) - > List [ Member ] :
rows = await conn . fetch ( " select * from members where id = any($1) " , members )
return [ Member ( * * row ) for row in rows ]
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def update_system_field ( conn , system_id : int , field : str , value ) :
2018-07-12 00:49:02 +02:00
logger . debug ( " Updating system field (id= {} , {} = {} ) " . format (
system_id , field , value ) )
2018-07-12 00:47:44 +02:00
await conn . execute ( " update systems set {} = $1 where id = $2 " . format ( field ) , value , system_id )
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def update_member_field ( conn , member_id : int , field : str , value ) :
2018-07-12 00:49:02 +02:00
logger . debug ( " Updating member field (id= {} , {} = {} ) " . format (
member_id , field , value ) )
2018-07-12 00:47:44 +02:00
await conn . execute ( " update members set {} = $1 where id = $2 " . format ( field ) , value , member_id )
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_all_members ( conn , system_id : int ) - > List [ Member ] :
rows = await conn . fetch ( " select * from members where system = $1 " , system_id )
return [ Member ( * * row ) for row in rows ]
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_members_exceeding ( conn , system_id : int , length : int ) - > List [ Member ] :
rows = await conn . fetch ( " select * from members where system = $1 and length(name) > $2 " , system_id , length )
return [ Member ( * * row ) for row in rows ]
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_webhook ( conn , channel_id : str ) - > ( str , str ) :
row = await conn . fetchrow ( " select webhook, token from webhooks where channel = $1 " , int ( channel_id ) )
return ( str ( row [ " webhook " ] ) , row [ " token " ] ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def add_webhook ( conn , channel_id : str , webhook_id : str , webhook_token : str ) :
2018-07-12 00:49:02 +02:00
logger . debug ( " Adding new webhook (channel= {} , webhook= {} , token= {} ) " . format (
channel_id , webhook_id , webhook_token ) )
2018-07-12 00:47:44 +02:00
await conn . execute ( " insert into webhooks (channel, webhook, token) values ($1, $2, $3) " , int ( channel_id ) , int ( webhook_id ) , webhook_token )
2018-07-24 22:47:57 +02:00
@db_wrap
async def delete_webhook ( conn , channel_id : str ) :
await conn . execute ( " delete from webhooks where channel = $1 " , int ( channel_id ) )
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-14 21:16:39 +02:00
async def add_message ( conn , message_id : str , channel_id : str , member_id : int , sender_id : str , content : str ) :
2018-07-12 00:49:02 +02:00
logger . debug ( " Adding new message (id= {} , channel= {} , member= {} , sender= {} ) " . format (
message_id , channel_id , member_id , sender_id ) )
2018-07-14 21:16:39 +02:00
await conn . execute ( " insert into messages (mid, channel, member, sender, content) values ($1, $2, $3, $4, $5) " , int ( message_id ) , int ( channel_id ) , member_id , int ( sender_id ) , content )
2018-07-12 00:47:44 +02:00
2018-07-24 22:47:57 +02:00
class ProxyMember ( namedtuple ( " ProxyMember " , [ " id " , " hid " , " prefix " , " suffix " , " color " , " name " , " avatar_url " , " tag " , " system_name " , " system_hid " ] ) ) :
id : int
hid : str
prefix : str
suffix : str
color : str
name : str
avatar_url : str
tag : str
system_name : str
system_hid : str
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_members_by_account ( conn , account_id : str ) - > List [ ProxyMember ] :
2018-07-12 00:47:44 +02:00
# Returns a "chimera" object
2018-07-24 22:47:57 +02:00
rows = await conn . fetch ( """ select
2018-07-13 22:49:27 +02:00
members . id , members . hid , members . prefix , members . suffix , members . color , members . name , members . avatar_url ,
systems . tag , systems . name as system_name , systems . hid as system_hid
from
systems , members , accounts
where
accounts . uid = $ 1
and systems . id = accounts . system
and members . system = systems . id """ , int(account_id))
2018-07-24 22:47:57 +02:00
return [ ProxyMember ( * * row ) for row in rows ]
class MessageInfo ( namedtuple ( " MemberInfo " , [ " mid " , " channel " , " member " , " content " , " sender " , " name " , " hid " , " avatar_url " , " system_name " , " system_hid " ] ) ) :
mid : int
channel : int
member : int
content : str
sender : int
name : str
hid : str
avatar_url : str
system_name : str
system_hid : str
2018-08-02 00:36:50 +02:00
def to_json ( self ) :
return {
" id " : str ( self . mid ) ,
" channel " : str ( self . channel ) ,
" member " : self . hid ,
" system " : self . system_hid ,
" message_sender " : str ( self . sender ) ,
" content " : self . content ,
" timestamp " : snowflake_time ( self . mid ) . isoformat ( )
}
2018-07-24 22:47:57 +02:00
@db_wrap
async def get_message_by_sender_and_id ( conn , message_id : str , sender_id : str ) - > MessageInfo :
row = await conn . fetchrow ( """ select
messages . * ,
members . name , members . hid , members . avatar_url ,
systems . name as system_name , systems . hid as system_hid
from
messages , members , systems
where
messages . member = members . id
and members . system = systems . id
and mid = $ 1
and sender = $ 2 """ , int(message_id), int(sender_id))
return MessageInfo ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def get_message ( conn , message_id : str ) - > MessageInfo :
row = await conn . fetchrow ( """ select
2018-07-13 22:49:27 +02:00
messages . * ,
members . name , members . hid , members . avatar_url ,
systems . name as system_name , systems . hid as system_hid
from
messages , members , systems
where
messages . member = members . id
and members . system = systems . id
2018-07-24 22:47:57 +02:00
and mid = $ 1 """ , int(message_id))
return MessageInfo ( * * row ) if row else None
2018-07-12 00:47:44 +02:00
2018-07-12 00:49:02 +02:00
2018-07-12 00:47:44 +02:00
@db_wrap
async def delete_message ( conn , message_id : str ) :
logger . debug ( " Deleting message (id= {} ) " . format ( message_id ) )
await conn . execute ( " delete from messages where mid = $1 " , int ( message_id ) )
2018-07-12 02:14:32 +02:00
@db_wrap
2018-07-14 02:28:15 +02:00
async def front_history ( conn , system_id : int , count : int ) :
return await conn . fetch ( """ select
switches . * ,
array (
select member from switch_members
where switch_members . switch = switches . id
2018-07-14 02:55:23 +02:00
order by switch_members . id asc
2018-07-14 02:28:15 +02:00
) as members
2018-07-13 22:49:27 +02:00
from switches
where switches . system = $ 1
2018-07-14 02:28:15 +02:00
order by switches . timestamp desc
limit $ 2 """ , system_id, count)
2018-07-12 02:14:32 +02:00
@db_wrap
2018-07-14 02:28:15 +02:00
async def add_switch ( conn , system_id : int ) :
logger . debug ( " Adding switch (system= {} ) " . format ( system_id ) )
res = await conn . fetchrow ( " insert into switches (system) values ($1) returning * " , system_id )
return res [ " id " ]
2018-07-12 02:14:32 +02:00
2018-07-20 22:56:32 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def move_last_switch ( conn , system_id : int , switch_id : int , new_time : datetime ) :
2018-07-20 22:56:32 +02:00
logger . debug ( " Moving latest switch (system= {} , id= {} , new_time= {} ) " . format ( system_id , switch_id , new_time ) )
await conn . execute ( " update switches set timestamp = $1 where system = $2 and id = $3 " , new_time , system_id , switch_id )
2018-07-12 02:14:32 +02:00
@db_wrap
2018-07-14 02:28:15 +02:00
async def add_switch_member ( conn , switch_id : int , member_id : int ) :
logger . debug ( " Adding switch member (switch= {} , member= {} ) " . format ( switch_id , member_id ) )
await conn . execute ( " insert into switch_members (switch, member) values ($1, $2) " , switch_id , member_id )
2018-07-12 00:49:02 +02:00
2018-07-12 15:03:34 +02:00
@db_wrap
async def get_server_info ( conn , server_id : str ) :
return await conn . fetchrow ( " select * from servers where id = $1 " , int ( server_id ) )
@db_wrap
async def update_server ( conn , server_id : str , logging_channel_id : str ) :
logging_channel_id = int ( logging_channel_id ) if logging_channel_id else None
logger . debug ( " Updating server settings (id= {} , log_channel= {} ) " . format ( server_id , logging_channel_id ) )
await conn . execute ( " insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2 " , int ( server_id ) , logging_channel_id )
2018-07-16 20:53:41 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def member_count ( conn ) - > int :
2018-07-16 20:53:41 +02:00
return await conn . fetchval ( " select count(*) from members " )
@db_wrap
2018-07-24 22:47:57 +02:00
async def system_count ( conn ) - > int :
2018-07-16 20:53:41 +02:00
return await conn . fetchval ( " select count(*) from systems " )
2018-07-21 01:22:07 +02:00
@db_wrap
2018-07-24 22:47:57 +02:00
async def message_count ( conn ) - > int :
2018-07-21 01:22:07 +02:00
return await conn . fetchval ( " select count(*) from messages " )
@db_wrap
2018-07-24 22:47:57 +02:00
async def account_count ( conn ) - > int :
2018-07-21 01:22:07 +02:00
return await conn . fetchval ( " select count(*) from accounts " )
2018-07-12 00:47:44 +02:00
async def create_tables ( conn ) :
await conn . execute ( """ create table if not exists systems (
id serial primary key ,
hid char ( 5 ) unique not null ,
name text ,
description text ,
tag text ,
2018-07-15 16:41:21 +02:00
avatar_url text ,
2018-07-12 00:47:44 +02:00
created timestamp not null default current_timestamp
) """ )
await conn . execute ( """ create table if not exists members (
id serial primary key ,
hid char ( 5 ) unique not null ,
system serial not null references systems ( id ) on delete cascade ,
color char ( 6 ) ,
avatar_url text ,
name text not null ,
birthday date ,
pronouns text ,
description text ,
prefix text ,
suffix text ,
created timestamp not null default current_timestamp
) """ )
await conn . execute ( """ create table if not exists accounts (
uid bigint primary key ,
system serial not null references systems ( id ) on delete cascade
) """ )
await conn . execute ( """ create table if not exists messages (
mid bigint primary key ,
channel bigint not null ,
member serial not null references members ( id ) on delete cascade ,
2018-07-14 21:16:39 +02:00
content text not null ,
2018-07-24 22:47:57 +02:00
sender bigint not null
2018-07-12 00:47:44 +02:00
) """ )
await conn . execute ( """ create table if not exists switches (
id serial primary key ,
system serial not null references systems ( id ) on delete cascade ,
2018-07-14 02:28:15 +02:00
timestamp timestamp not null default current_timestamp
) """ )
await conn . execute ( """ create table if not exists switch_members (
id serial primary key ,
switch serial not null references switches ( id ) on delete cascade ,
member serial not null references members ( id ) on delete cascade
2018-07-12 00:47:44 +02:00
) """ )
await conn . execute ( """ create table if not exists webhooks (
channel bigint primary key ,
webhook bigint not null ,
token text not null
) """ )
await conn . execute ( """ create table if not exists servers (
id bigint primary key ,
2018-07-12 15:03:34 +02:00
log_channel bigint
2018-07-12 00:49:02 +02:00
) """ )