PluralKit/src/pluralkit/system.py

190 lines
7.4 KiB
Python

import random
import re
import string
from datetime import datetime
from collections.__init__ import namedtuple
from typing import Optional, List, Tuple
from pluralkit import db, errors
from pluralkit.member import Member
from pluralkit.switch import Switch
from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "token", "created"])):
id: int
hid: str
name: str
description: str
tag: str
avatar_url: str
token: str
created: datetime
@staticmethod
async def get_by_account(conn, account_id: int) -> Optional["System"]:
return await db.get_system_by_account(conn, account_id)
@staticmethod
async def get_by_token(conn, token: str) -> Optional["System"]:
return await db.get_system_by_token(conn, token)
@staticmethod
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
async with conn.transaction():
existing_system = await System.get_by_account(conn, account_id)
if existing_system:
raise errors.ExistingSystemError()
new_hid = generate_hid()
async with conn.transaction():
new_system = await db.create_system(conn, system_name, new_hid)
await db.link_account(conn, new_system.id, account_id)
return new_system
async def set_name(self, conn, new_name: Optional[str]):
await db.update_system_field(conn, self.id, "name", new_name)
async def set_description(self, conn, new_description: Optional[str]):
# Explicit length error
if new_description and len(new_description) > 1024:
raise errors.DescriptionTooLongError()
await db.update_system_field(conn, self.id, "description", new_description)
async def set_tag(self, conn, new_tag: Optional[str]):
if new_tag:
# Explicit length error
if len(new_tag) > 32:
raise errors.TagTooLongError()
if contains_custom_emoji(new_tag):
raise errors.CustomEmojiError()
await db.update_system_field(conn, self.id, "tag", new_tag)
async def set_avatar(self, conn, new_avatar_url: Optional[str]):
if new_avatar_url:
validate_avatar_url_or_raise(new_avatar_url)
await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url)
async def link_account(self, conn, new_account_id: int):
async with conn.transaction():
existing_system = await System.get_by_account(conn, new_account_id)
if existing_system:
if existing_system.id == self.id:
raise errors.AccountInOwnSystemError()
raise errors.AccountAlreadyLinkedError(existing_system)
await db.link_account(conn, self.id, new_account_id)
async def unlink_account(self, conn, account_id: int):
async with conn.transaction():
linked_accounts = await db.get_linked_accounts(conn, self.id)
if len(linked_accounts) == 1:
raise errors.UnlinkingLastAccountError()
await db.unlink_account(conn, self.id, account_id)
async def get_linked_account_ids(self, conn) -> List[int]:
return await db.get_linked_accounts(conn, self.id)
async def delete(self, conn):
await db.remove_system(conn, self.id)
async def refresh_token(self, conn) -> str:
new_token = "".join(random.choices(string.ascii_letters + string.digits, k=64))
await db.update_system_field(conn, self.id, "token", new_token)
return new_token
async def create_member(self, conn, member_name: str) -> Member:
# TODO: figure out what to do if this errors out on collision on generate_hid
new_hid = generate_hid()
if len(member_name) > self.get_member_name_limit():
raise errors.MemberNameTooLongError(tag_present=bool(self.tag))
member = await db.create_member(conn, self.id, member_name, new_hid)
return member
async def get_members(self, conn) -> List[Member]:
return await db.get_all_members(conn, self.id)
async def get_switches(self, conn, count) -> List[Switch]:
"""Returns the latest `count` switches logged for this system, ordered latest to earliest."""
return [Switch(**s) for s in await db.front_history(conn, self.id, count)]
async def get_latest_switch(self, conn) -> Optional[Switch]:
"""Returns the latest switch logged for this system, or None if no switches have been logged"""
switches = await self.get_switches(conn, 1)
if switches:
return switches[0]
else:
return None
async def add_switch(self, conn, members: List[Member]):
async with conn.transaction():
switch_id = await db.add_switch(conn, self.id)
# TODO: batch query here
for member in members:
await db.add_switch_member(conn, switch_id, member.id)
def get_member_name_limit(self) -> int:
"""Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
if self.tag:
return 32 - len(self.tag) - 1
else:
return 32
async def match_proxy(self, conn, message: str) -> Optional[Tuple[Member, str]]:
"""Tries to find a member with proxy tags matching the given message. Returns the member and the inner contents."""
members = await db.get_all_members(conn, self.id)
# Sort by specificity (members with both prefix and suffix defined go higher)
# This will make sure more "precise" proxy tags get tried first and match properly
members = sorted(members, key=lambda x: int(bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
for member in members:
proxy_prefix = member.prefix or ""
proxy_suffix = member.suffix or ""
if not proxy_prefix and not proxy_suffix:
# If the member has neither a prefix or a suffix, cancel early
# Otherwise it'd match any message no matter what
continue
# Check if the message matches these tags
if message.startswith(proxy_prefix) and message.endswith(proxy_suffix):
# If the message starts with a mention, "separate" that and match the bit after
mention_match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message)
leading_mentions = ""
if mention_match:
message = message[mention_match.span(0)[1]:].strip()
leading_mentions = mention_match.group(0)
# Extract the inner message (special case because -0 is invalid as an end slice)
if len(proxy_suffix) == 0:
inner_message = message[len(proxy_prefix):]
else:
inner_message = message[len(proxy_prefix):-len(proxy_suffix)]
# Add the stripped mentions back if there are any
inner_message = leading_mentions + inner_message
return member, inner_message
def to_json(self):
return {
"id": self.hid,
"name": self.name,
"description": self.description,
"tag": self.tag,
"avatar_url": self.avatar_url
}