323 lines
13 KiB
Python
323 lines
13 KiB
Python
import random
|
|
import re
|
|
import string
|
|
from collections.__init__ import namedtuple
|
|
from datetime import datetime
|
|
from typing import Optional, List, Tuple
|
|
|
|
import pytz
|
|
|
|
from pluralkit import db, errors
|
|
from pluralkit.member import Member
|
|
from pluralkit.switch import Switch
|
|
from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise
|
|
|
|
class TupperboxImportResult(namedtuple("TupperboxImportResult", ["updated", "created", "tags"])):
|
|
pass
|
|
|
|
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "token", "created", "ui_tz"])):
|
|
id: int
|
|
hid: str
|
|
name: str
|
|
description: str
|
|
tag: str
|
|
avatar_url: str
|
|
token: str
|
|
created: datetime
|
|
# pytz-compatible time zone name, usually Olson-style (eg. Europe/Amsterdam)
|
|
ui_tz: str
|
|
|
|
@staticmethod
|
|
async def get_by_id(conn, system_id: int) -> Optional["System"]:
|
|
return await db.get_system(conn, system_id)
|
|
|
|
@staticmethod
|
|
async def get_by_account(conn, account_id: int) -> Optional["System"]:
|
|
return await db.get_system_by_account(conn, account_id)
|
|
|
|
@staticmethod
|
|
async def get_by_token(conn, token: str) -> Optional["System"]:
|
|
return await db.get_system_by_token(conn, token)
|
|
|
|
@staticmethod
|
|
async def get_by_hid(conn, hid: str) -> Optional["System"]:
|
|
return await db.get_system_by_hid(conn, hid)
|
|
|
|
@staticmethod
|
|
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
|
|
async with conn.transaction():
|
|
existing_system = await System.get_by_account(conn, account_id)
|
|
if existing_system:
|
|
raise errors.ExistingSystemError()
|
|
|
|
new_hid = generate_hid()
|
|
|
|
async with conn.transaction():
|
|
new_system = await db.create_system(conn, system_name, new_hid)
|
|
await db.link_account(conn, new_system.id, account_id)
|
|
|
|
return new_system
|
|
|
|
async def set_name(self, conn, new_name: Optional[str]):
|
|
await db.update_system_field(conn, self.id, "name", new_name)
|
|
|
|
async def set_description(self, conn, new_description: Optional[str]):
|
|
# Explicit length error
|
|
if new_description and len(new_description) > 1024:
|
|
raise errors.DescriptionTooLongError()
|
|
|
|
await db.update_system_field(conn, self.id, "description", new_description)
|
|
|
|
async def set_tag(self, conn, new_tag: Optional[str]):
|
|
if new_tag:
|
|
# Explicit length error
|
|
if len(new_tag) > 32:
|
|
raise errors.TagTooLongError()
|
|
|
|
if contains_custom_emoji(new_tag):
|
|
raise errors.CustomEmojiError()
|
|
|
|
await db.update_system_field(conn, self.id, "tag", new_tag)
|
|
|
|
async def set_avatar(self, conn, new_avatar_url: Optional[str]):
|
|
if new_avatar_url:
|
|
validate_avatar_url_or_raise(new_avatar_url)
|
|
|
|
await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url)
|
|
|
|
async def link_account(self, conn, new_account_id: int):
|
|
async with conn.transaction():
|
|
existing_system = await System.get_by_account(conn, new_account_id)
|
|
|
|
if existing_system:
|
|
if existing_system.id == self.id:
|
|
raise errors.AccountInOwnSystemError()
|
|
|
|
raise errors.AccountAlreadyLinkedError(existing_system)
|
|
|
|
await db.link_account(conn, self.id, new_account_id)
|
|
|
|
async def unlink_account(self, conn, account_id: int):
|
|
async with conn.transaction():
|
|
linked_accounts = await db.get_linked_accounts(conn, self.id)
|
|
if len(linked_accounts) == 1:
|
|
raise errors.UnlinkingLastAccountError()
|
|
|
|
await db.unlink_account(conn, self.id, account_id)
|
|
|
|
async def get_linked_account_ids(self, conn) -> List[int]:
|
|
return await db.get_linked_accounts(conn, self.id)
|
|
|
|
async def delete(self, conn):
|
|
await db.remove_system(conn, self.id)
|
|
|
|
async def refresh_token(self, conn) -> str:
|
|
new_token = "".join(random.choices(string.ascii_letters + string.digits, k=64))
|
|
await db.update_system_field(conn, self.id, "token", new_token)
|
|
return new_token
|
|
|
|
async def get_token(self, conn) -> str:
|
|
if self.token:
|
|
return self.token
|
|
return await self.refresh_token(conn)
|
|
|
|
async def create_member(self, conn, member_name: str) -> Member:
|
|
# TODO: figure out what to do if this errors out on collision on generate_hid
|
|
new_hid = generate_hid()
|
|
|
|
if len(member_name) > self.get_member_name_limit():
|
|
raise errors.MemberNameTooLongError(tag_present=bool(self.tag))
|
|
|
|
member = await db.create_member(conn, self.id, member_name, new_hid)
|
|
return member
|
|
|
|
async def get_members(self, conn) -> List[Member]:
|
|
return await db.get_all_members(conn, self.id)
|
|
|
|
async def get_switches(self, conn, count) -> List[Switch]:
|
|
"""Returns the latest `count` switches logged for this system, ordered latest to earliest."""
|
|
return [Switch(**s) for s in await db.front_history(conn, self.id, count)]
|
|
|
|
async def get_latest_switch(self, conn) -> Optional[Switch]:
|
|
"""Returns the latest switch logged for this system, or None if no switches have been logged"""
|
|
switches = await self.get_switches(conn, 1)
|
|
if switches:
|
|
return switches[0]
|
|
else:
|
|
return None
|
|
|
|
async def add_switch(self, conn, members: List[Member]) -> Switch:
|
|
"""
|
|
Logs a new switch for a system.
|
|
|
|
:raises: MembersAlreadyFrontingError, DuplicateSwitchMembersError
|
|
"""
|
|
new_ids = [member.id for member in members]
|
|
|
|
last_switch = await self.get_latest_switch(conn)
|
|
|
|
# If we have a switch logged before, make sure this isn't a dupe switch
|
|
if last_switch:
|
|
last_switch_members = await last_switch.fetch_members(conn)
|
|
last_ids = [member.id for member in last_switch_members]
|
|
|
|
# We don't compare by set() here because swapping multiple is a valid operation
|
|
if last_ids == new_ids:
|
|
raise errors.MembersAlreadyFrontingError(members)
|
|
|
|
# Check for dupes
|
|
if len(set(new_ids)) != len(new_ids):
|
|
raise errors.DuplicateSwitchMembersError()
|
|
|
|
async with conn.transaction():
|
|
switch_id = await db.add_switch(conn, self.id)
|
|
|
|
# TODO: batch query here
|
|
for member in members:
|
|
await db.add_switch_member(conn, switch_id, member.id)
|
|
|
|
return await self.get_latest_switch(conn)
|
|
|
|
def get_member_name_limit(self) -> int:
|
|
"""Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
|
|
if self.tag:
|
|
return 32 - len(self.tag) - 1
|
|
else:
|
|
return 32
|
|
|
|
async def match_proxy(self, conn, message: str) -> Optional[Tuple[Member, str]]:
|
|
"""Tries to find a member with proxy tags matching the given message. Returns the member and the inner contents."""
|
|
members = await db.get_all_members(conn, self.id)
|
|
|
|
# Sort by specificity (members with both prefix and suffix defined go higher)
|
|
# This will make sure more "precise" proxy tags get tried first and match properly
|
|
members = sorted(members, key=lambda x: int(bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
|
|
|
|
for member in members:
|
|
proxy_prefix = member.prefix or ""
|
|
proxy_suffix = member.suffix or ""
|
|
|
|
if not proxy_prefix and not proxy_suffix:
|
|
# If the member has neither a prefix or a suffix, cancel early
|
|
# Otherwise it'd match any message no matter what
|
|
continue
|
|
|
|
# Check if the message matches these tags
|
|
if message.startswith(proxy_prefix) and message.endswith(proxy_suffix):
|
|
# If the message starts with a mention, "separate" that and match the bit after
|
|
mention_match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message)
|
|
leading_mentions = ""
|
|
if mention_match:
|
|
message = message[mention_match.span(0)[1]:].strip()
|
|
leading_mentions = mention_match.group(0)
|
|
|
|
# Extract the inner message (special case because -0 is invalid as an end slice)
|
|
if len(proxy_suffix) == 0:
|
|
inner_message = message[len(proxy_prefix):]
|
|
else:
|
|
inner_message = message[len(proxy_prefix):-len(proxy_suffix)]
|
|
|
|
# Add the stripped mentions back if there are any
|
|
inner_message = leading_mentions + inner_message
|
|
return member, inner_message
|
|
|
|
def format_time(self, dt: datetime) -> str:
|
|
"""
|
|
Localizes the given `datetime` to a string based on the system's preferred time zone.
|
|
|
|
Assumes `dt` is a naïve `datetime` instance set to UTC, which is consistent with the rest of PluralKit.
|
|
"""
|
|
tz = pytz.timezone(self.ui_tz)
|
|
|
|
# Set to aware (UTC), convert to tz, set to naive (tz), then format and append name
|
|
return tz.normalize(pytz.utc.localize(dt)).replace(tzinfo=None).isoformat(sep=" ", timespec="seconds") + " " + tz.tzname(dt)
|
|
|
|
async def set_time_zone(self, conn, tz_name: str) -> pytz.tzinfo:
|
|
"""
|
|
Sets the system time zone to the time zone represented by the given string.
|
|
|
|
If `tz_name` is None or an empty string, will default to UTC.
|
|
If `tz_name` does not represent a valid time zone string, will raise InvalidTimeZoneError.
|
|
|
|
:raises: InvalidTimeZoneError
|
|
:returns: The `pytz.tzinfo` instance of the newly set time zone.
|
|
"""
|
|
|
|
try:
|
|
tz = pytz.timezone(tz_name or "UTC")
|
|
except pytz.UnknownTimeZoneError:
|
|
raise errors.InvalidTimeZoneError(tz_name)
|
|
|
|
await db.update_system_field(conn, self.id, "ui_tz", tz.zone)
|
|
return tz
|
|
|
|
async def import_from_tupperbox(self, conn, data: dict):
|
|
"""
|
|
Imports from a Tupperbox JSON data file.
|
|
:raises: TupperboxImportError
|
|
"""
|
|
if not "tuppers" in data:
|
|
raise errors.TupperboxImportError()
|
|
if not isinstance(data["tuppers"], list):
|
|
raise errors.TupperboxImportError()
|
|
|
|
all_tags = set()
|
|
created_members = set()
|
|
updated_members = set()
|
|
for tupper in data["tuppers"]:
|
|
# Sanity check tupper fields
|
|
for field in ["name", "avatar_url", "brackets", "birthday", "description", "tag"]:
|
|
if field not in tupper:
|
|
raise errors.TupperboxImportError()
|
|
|
|
# Find member by name, create if not exists
|
|
member_name = str(tupper["name"])
|
|
member = await Member.get_member_by_name(conn, self.id, member_name)
|
|
if not member:
|
|
# And keep track of created members
|
|
created_members.add(member_name)
|
|
member = await self.create_member(conn, member_name)
|
|
else:
|
|
# Keep track of updated members
|
|
updated_members.add(member_name)
|
|
|
|
# Set avatar
|
|
await member.set_avatar(conn, str(tupper["avatar_url"]))
|
|
|
|
# Set proxy tags
|
|
if not (isinstance(tupper["brackets"], list) and len(tupper["brackets"]) >= 2):
|
|
raise errors.TupperboxImportError()
|
|
await member.set_proxy_tags(conn, str(tupper["brackets"][0]), str(tupper["brackets"][1]))
|
|
|
|
# Set birthdate (input is in ISO-8601, first 10 characters is the date)
|
|
if tupper["birthday"]:
|
|
try:
|
|
await member.set_birthdate(conn, str(tupper["birthday"][:10]))
|
|
except errors.InvalidDateStringError:
|
|
pass
|
|
|
|
# Set description
|
|
await member.set_description(conn, tupper["description"])
|
|
|
|
# Keep track of tag
|
|
all_tags.add(tupper["tag"])
|
|
|
|
# Since Tupperbox does tags on a per-member basis, we only apply a system tag if
|
|
# every member has the same tag (surprisingly common)
|
|
# If not, we just do nothing. (This will be reported in the caller function through the returned result)
|
|
if len(all_tags) == 1:
|
|
tag = list(all_tags)[0]
|
|
await self.set_tag(conn, tag)
|
|
|
|
return TupperboxImportResult(updated=updated_members, created=created_members, tags=all_tags)
|
|
|
|
def to_json(self):
|
|
return {
|
|
"id": self.hid,
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"tag": self.tag,
|
|
"avatar_url": self.avatar_url,
|
|
"tz": self.ui_tz
|
|
}
|