2018-07-11 22:47:44 +00:00
import time
import asyncpg
import asyncpg . exceptions
2018-07-16 18:53:41 +00:00
from pluralkit import stats
2018-07-11 22:47:44 +00:00
from pluralkit . bot import logger
async def connect ( ) :
while True :
try :
return await asyncpg . create_pool ( user = " postgres " , password = " postgres " , database = " postgres " , host = " db " )
except ( ConnectionError , asyncpg . exceptions . CannotConnectNowError ) :
pass
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
def db_wrap ( func ) :
async def inner ( * args , * * kwargs ) :
before = time . perf_counter ( )
2018-07-16 18:53:41 +00:00
try :
res = await func ( * args , * * kwargs )
after = time . perf_counter ( )
2018-07-11 22:49:02 +00:00
2018-07-16 18:53:41 +00:00
logger . debug ( " - DB call {} took {:.2f} ms " . format ( func . __name__ , ( after - before ) * 1000 ) )
await stats . report_db_query ( func . __name__ , after - before , True )
return res
except asyncpg . exceptions . PostgresError :
await stats . report_db_query ( func . __name__ , time . perf_counter ( ) - before , False )
logger . exception ( " Error from database query {} " . format ( func . __name__ ) )
2018-07-11 22:47:44 +00:00
return inner
@db_wrap
async def create_system ( conn , system_name : str , system_hid : str ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Creating system (name= {} , hid= {} ) " . format (
system_name , system_hid ) )
2018-07-11 22:47:44 +00:00
return await conn . fetchrow ( " insert into systems (name, hid) values ($1, $2) returning * " , system_name , system_hid )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def remove_system ( conn , system_id : int ) :
logger . debug ( " Deleting system (id= {} ) " . format ( system_id ) )
await conn . execute ( " delete from systems where id = $1 " , system_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def create_member ( conn , system_id : int , member_name : str , member_hid : str ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Creating member (system= {} , name= {} , hid= {} ) " . format (
system_id , member_name , member_hid ) )
return await conn . fetchrow ( " insert into members (name, system, hid) values ($1, $2, $3) returning * " , member_name , system_id , member_hid )
2018-07-11 22:47:44 +00:00
@db_wrap
async def delete_member ( conn , member_id : int ) :
logger . debug ( " Deleting member (id= {} ) " . format ( member_id ) )
await conn . execute ( " delete from members where id = $1 " , member_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def link_account ( conn , system_id : int , account_id : str ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Linking account (account_id= {} , system_id= {} ) " . format (
account_id , system_id ) )
2018-07-11 22:47:44 +00:00
await conn . execute ( " insert into accounts (uid, system) values ($1, $2) " , int ( account_id ) , system_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def unlink_account ( conn , system_id : int , account_id : str ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Unlinking account (account_id= {} , system_id= {} ) " . format (
account_id , system_id ) )
2018-07-11 22:47:44 +00:00
await conn . execute ( " delete from accounts where uid = $1 and system = $2 " , int ( account_id ) , system_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_linked_accounts ( conn , system_id : int ) :
return [ row [ " uid " ] for row in await conn . fetch ( " select uid from accounts where system = $1 " , system_id ) ]
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_system_by_account ( conn , account_id : str ) :
return await conn . fetchrow ( " select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id " , int ( account_id ) )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_system_by_hid ( conn , system_hid : str ) :
return await conn . fetchrow ( " select * from systems where hid = $1 " , system_hid )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_system ( conn , system_id : int ) :
return await conn . fetchrow ( " select * from systems where id = $1 " , system_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_member_by_name ( conn , system_id : int , member_name : str ) :
2018-07-14 20:43:23 +00:00
return await conn . fetchrow ( " select * from members where system = $1 and lower(name) = lower($2) " , system_id , member_name )
2018-07-11 22:47:44 +00:00
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_member_by_hid_in_system ( conn , system_id : int , member_hid : str ) :
return await conn . fetchrow ( " select * from members where system = $1 and hid = $2 " , system_id , member_hid )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_member_by_hid ( conn , member_hid : str ) :
return await conn . fetchrow ( " select * from members where hid = $1 " , member_hid )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_member ( conn , member_id : int ) :
return await conn . fetchrow ( " select * from members where id = $1 " , member_id )
2018-07-14 00:28:15 +00:00
@db_wrap
async def get_members ( conn , members : list ) :
return await conn . fetch ( " select * from members where id = any($1) " , members )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_message ( conn , message_id : str ) :
return await conn . fetchrow ( " select * from messages where mid = $1 " , message_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def update_system_field ( conn , system_id : int , field : str , value ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Updating system field (id= {} , {} = {} ) " . format (
system_id , field , value ) )
2018-07-11 22:47:44 +00:00
await conn . execute ( " update systems set {} = $1 where id = $2 " . format ( field ) , value , system_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def update_member_field ( conn , member_id : int , field : str , value ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Updating member field (id= {} , {} = {} ) " . format (
member_id , field , value ) )
2018-07-11 22:47:44 +00:00
await conn . execute ( " update members set {} = $1 where id = $2 " . format ( field ) , value , member_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_all_members ( conn , system_id : int ) :
return await conn . fetch ( " select * from members where system = $1 " , system_id )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_members_exceeding ( conn , system_id : int , length : int ) :
2018-07-13 22:05:37 +00:00
return await conn . fetch ( " select * from members where system = $1 and length(name) > $2 " , system_id , length )
2018-07-11 22:47:44 +00:00
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_webhook ( conn , channel_id : str ) :
2018-07-12 13:03:34 +00:00
return await conn . fetchrow ( " select webhook, token from webhooks where channel = $1 " , int ( channel_id ) )
2018-07-11 22:47:44 +00:00
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def add_webhook ( conn , channel_id : str , webhook_id : str , webhook_token : str ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Adding new webhook (channel= {} , webhook= {} , token= {} ) " . format (
channel_id , webhook_id , webhook_token ) )
2018-07-11 22:47:44 +00:00
await conn . execute ( " insert into webhooks (channel, webhook, token) values ($1, $2, $3) " , int ( channel_id ) , int ( webhook_id ) , webhook_token )
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
2018-07-14 19:16:39 +00:00
async def add_message ( conn , message_id : str , channel_id : str , member_id : int , sender_id : str , content : str ) :
2018-07-11 22:49:02 +00:00
logger . debug ( " Adding new message (id= {} , channel= {} , member= {} , sender= {} ) " . format (
message_id , channel_id , member_id , sender_id ) )
2018-07-14 19:16:39 +00:00
await conn . execute ( " insert into messages (mid, channel, member, sender, content) values ($1, $2, $3, $4, $5) " , int ( message_id ) , int ( channel_id ) , member_id , int ( sender_id ) , content )
2018-07-11 22:47:44 +00:00
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_members_by_account ( conn , account_id : str ) :
# Returns a "chimera" object
2018-07-13 20:49:27 +00:00
return await conn . fetch ( """ select
members . id , members . hid , members . prefix , members . suffix , members . color , members . name , members . avatar_url ,
systems . tag , systems . name as system_name , systems . hid as system_hid
from
systems , members , accounts
where
accounts . uid = $ 1
and systems . id = accounts . system
and members . system = systems . id """ , int(account_id))
2018-07-11 22:47:44 +00:00
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def get_message_by_sender_and_id ( conn , message_id : str , sender_id : str ) :
2018-07-13 20:49:27 +00:00
return await conn . fetchrow ( """ select
messages . * ,
members . name , members . hid , members . avatar_url ,
systems . name as system_name , systems . hid as system_hid
from
messages , members , systems
where
messages . member = members . id
and members . system = systems . id
and mid = $ 1 and sender = $ 2 """ , int(message_id), int(sender_id))
2018-07-11 22:47:44 +00:00
2018-07-11 22:49:02 +00:00
2018-07-11 22:47:44 +00:00
@db_wrap
async def delete_message ( conn , message_id : str ) :
logger . debug ( " Deleting message (id= {} ) " . format ( message_id ) )
await conn . execute ( " delete from messages where mid = $1 " , int ( message_id ) )
2018-07-12 00:14:32 +00:00
@db_wrap
2018-07-14 00:28:15 +00:00
async def front_history ( conn , system_id : int , count : int ) :
return await conn . fetch ( """ select
switches . * ,
array (
select member from switch_members
where switch_members . switch = switches . id
2018-07-14 00:55:23 +00:00
order by switch_members . id asc
2018-07-14 00:28:15 +00:00
) as members
2018-07-13 20:49:27 +00:00
from switches
where switches . system = $ 1
2018-07-14 00:28:15 +00:00
order by switches . timestamp desc
limit $ 2 """ , system_id, count)
2018-07-12 00:14:32 +00:00
@db_wrap
2018-07-14 00:28:15 +00:00
async def add_switch ( conn , system_id : int ) :
logger . debug ( " Adding switch (system= {} ) " . format ( system_id ) )
res = await conn . fetchrow ( " insert into switches (system) values ($1) returning * " , system_id )
return res [ " id " ]
2018-07-12 00:14:32 +00:00
@db_wrap
2018-07-14 00:28:15 +00:00
async def add_switch_member ( conn , switch_id : int , member_id : int ) :
logger . debug ( " Adding switch member (switch= {} , member= {} ) " . format ( switch_id , member_id ) )
await conn . execute ( " insert into switch_members (switch, member) values ($1, $2) " , switch_id , member_id )
2018-07-11 22:49:02 +00:00
2018-07-12 13:03:34 +00:00
@db_wrap
async def get_server_info ( conn , server_id : str ) :
return await conn . fetchrow ( " select * from servers where id = $1 " , int ( server_id ) )
@db_wrap
async def update_server ( conn , server_id : str , logging_channel_id : str ) :
logging_channel_id = int ( logging_channel_id ) if logging_channel_id else None
logger . debug ( " Updating server settings (id= {} , log_channel= {} ) " . format ( server_id , logging_channel_id ) )
await conn . execute ( " insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2 " , int ( server_id ) , logging_channel_id )
2018-07-16 18:53:41 +00:00
@db_wrap
async def member_count ( conn ) :
return await conn . fetchval ( " select count(*) from members " )
@db_wrap
async def system_count ( conn ) :
return await conn . fetchval ( " select count(*) from systems " )
2018-07-11 22:47:44 +00:00
async def create_tables ( conn ) :
await conn . execute ( """ create table if not exists systems (
id serial primary key ,
hid char ( 5 ) unique not null ,
name text ,
description text ,
tag text ,
2018-07-15 14:41:21 +00:00
avatar_url text ,
2018-07-11 22:47:44 +00:00
created timestamp not null default current_timestamp
) """ )
await conn . execute ( """ create table if not exists members (
id serial primary key ,
hid char ( 5 ) unique not null ,
system serial not null references systems ( id ) on delete cascade ,
color char ( 6 ) ,
avatar_url text ,
name text not null ,
birthday date ,
pronouns text ,
description text ,
prefix text ,
suffix text ,
created timestamp not null default current_timestamp
) """ )
await conn . execute ( """ create table if not exists accounts (
uid bigint primary key ,
system serial not null references systems ( id ) on delete cascade
) """ )
await conn . execute ( """ create table if not exists messages (
mid bigint primary key ,
channel bigint not null ,
member serial not null references members ( id ) on delete cascade ,
2018-07-14 19:16:39 +00:00
content text not null ,
2018-07-11 22:47:44 +00:00
sender bigint not null references accounts ( uid )
) """ )
await conn . execute ( """ create table if not exists switches (
id serial primary key ,
system serial not null references systems ( id ) on delete cascade ,
2018-07-14 00:28:15 +00:00
timestamp timestamp not null default current_timestamp
) """ )
await conn . execute ( """ create table if not exists switch_members (
id serial primary key ,
switch serial not null references switches ( id ) on delete cascade ,
member serial not null references members ( id ) on delete cascade
2018-07-11 22:47:44 +00:00
) """ )
await conn . execute ( """ create table if not exists webhooks (
channel bigint primary key ,
webhook bigint not null ,
token text not null
) """ )
await conn . execute ( """ create table if not exists servers (
id bigint primary key ,
2018-07-12 13:03:34 +00:00
log_channel bigint
2018-07-11 22:49:02 +00:00
) """ )