Add proxy permission errors, refactor proxy matching, change member name length error handling. Closes #25.
This commit is contained in:
parent
5d3cb7b6bf
commit
c36a054519
@ -100,6 +100,9 @@ class CommandContext:
|
|||||||
async def reply_ok(self, content=None, embed=None):
|
async def reply_ok(self, content=None, embed=None):
|
||||||
return await self.reply(content="\u2705 {}".format(content or ""), embed=embed)
|
return await self.reply(content="\u2705 {}".format(content or ""), embed=embed)
|
||||||
|
|
||||||
|
async def reply_warn(self, content=None, embed=None):
|
||||||
|
return await self.reply(content="\u26a0 {}".format(content or ""), embed=embed)
|
||||||
|
|
||||||
async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
|
async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
|
||||||
message = await self.reply(message)
|
message = await self.reply(message)
|
||||||
await message.add_reaction("\u2705") # Checkmark
|
await message.add_reaction("\u2705") # Checkmark
|
||||||
|
@ -39,9 +39,18 @@ async def member_name(ctx: CommandContext):
|
|||||||
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member))
|
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member))
|
||||||
new_name = ctx.pop_str(CommandError("You must pass a new member name.", help=help.edit_member))
|
new_name = ctx.pop_str(CommandError("You must pass a new member name.", help=help.edit_member))
|
||||||
|
|
||||||
await member.set_name(ctx.conn, system, new_name)
|
await member.set_name(ctx.conn, new_name)
|
||||||
await ctx.reply_ok("Member name updated.")
|
await ctx.reply_ok("Member name updated.")
|
||||||
|
|
||||||
|
if len(new_name) < 2 and not system.tag:
|
||||||
|
await ctx.reply_warn("This member's new name is under 2 characters, and thus cannot be proxied. To prevent this, use a longer member name, or add a system tag.")
|
||||||
|
elif len(new_name) > 32:
|
||||||
|
exceeds_by = len(new_name) - 32
|
||||||
|
await ctx.reply_warn("This member's new name is longer than 32 characters, and thus cannot be proxied. To prevent this, shorten the member name by {} characters.".format(exceeds_by))
|
||||||
|
elif len(new_name) > system.get_member_name_limit():
|
||||||
|
exceeds_by = len(new_name) - system.get_member_name_limit()
|
||||||
|
await ctx.reply_warn("This member's new name, when combined with the system tag `{}`, is longer than 32 characters, and thus cannot be proxied. To prevent this, shorten the name or system tag by at least {} characters.".format(system.tag, exceeds_by))
|
||||||
|
|
||||||
|
|
||||||
async def member_description(ctx: CommandContext):
|
async def member_description(ctx: CommandContext):
|
||||||
await ctx.ensure_system()
|
await ctx.ensure_system()
|
||||||
|
@ -32,7 +32,8 @@ async def new_system(ctx: CommandContext):
|
|||||||
|
|
||||||
|
|
||||||
async def system_set(ctx: CommandContext):
|
async def system_set(ctx: CommandContext):
|
||||||
raise CommandError("`pk;system set` has been retired. Please use the new member modifying commands: `pk;system [name|description|avatar|tag]`.")
|
raise CommandError(
|
||||||
|
"`pk;system set` has been retired. Please use the new member modifying commands: `pk;system [name|description|avatar|tag]`.")
|
||||||
|
|
||||||
|
|
||||||
async def system_name(ctx: CommandContext):
|
async def system_name(ctx: CommandContext):
|
||||||
@ -58,6 +59,28 @@ async def system_tag(ctx: CommandContext):
|
|||||||
await system.set_tag(ctx.conn, new_tag)
|
await system.set_tag(ctx.conn, new_tag)
|
||||||
await ctx.reply_ok("System tag {}.".format("updated" if new_tag else "cleared"))
|
await ctx.reply_ok("System tag {}.".format("updated" if new_tag else "cleared"))
|
||||||
|
|
||||||
|
# System class is immutable, update the tag so get_member_name_limit works
|
||||||
|
system = system._replace(tag=new_tag)
|
||||||
|
members = await system.get_members(ctx.conn)
|
||||||
|
|
||||||
|
# Certain members might not be able to be proxied with this new tag, show a warning for those
|
||||||
|
members_exceeding = [member for member in members if
|
||||||
|
len(member.name) > system.get_member_name_limit()]
|
||||||
|
if members_exceeding:
|
||||||
|
member_names = ", ".join([member.name for member in members_exceeding])
|
||||||
|
await ctx.reply_warn(
|
||||||
|
"Due to the length of this tag, the following members will not be able to be proxied: {}. Please use a shorter tag to prevent this.".format(
|
||||||
|
member_names))
|
||||||
|
|
||||||
|
# Edge case: members with name length 1 and no new tag
|
||||||
|
if not new_tag:
|
||||||
|
one_length_members = [member for member in members if len(member.name) == 1]
|
||||||
|
if one_length_members:
|
||||||
|
member_names = ", ".join([member.name for member in one_length_members])
|
||||||
|
await ctx.reply_warn(
|
||||||
|
"Without a system tag, you will not be able to proxy members with a one-character name: {}. To prevent this, please add a system tag or lengthen their name.".format(
|
||||||
|
member_names))
|
||||||
|
|
||||||
|
|
||||||
async def system_avatar(ctx: CommandContext):
|
async def system_avatar(ctx: CommandContext):
|
||||||
system = await ctx.ensure_system()
|
system = await ctx.ensure_system()
|
||||||
@ -89,7 +112,9 @@ async def system_link(ctx: CommandContext):
|
|||||||
if account_system:
|
if account_system:
|
||||||
raise CommandError(AccountAlreadyLinkedError(account_system).message)
|
raise CommandError(AccountAlreadyLinkedError(account_system).message)
|
||||||
|
|
||||||
if not await ctx.confirm_react(linkee, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention)):
|
if not await ctx.confirm_react(linkee,
|
||||||
|
"{}, please confirm the link by clicking the ✅ reaction on this message.".format(
|
||||||
|
linkee.mention)):
|
||||||
raise CommandError("Account link cancelled.")
|
raise CommandError("Account link cancelled.")
|
||||||
|
|
||||||
await system.link_account(ctx.conn, linkee.id)
|
await system.link_account(ctx.conn, linkee.id)
|
||||||
@ -241,5 +266,6 @@ async def system_frontpercent(ctx: CommandContext):
|
|||||||
embed.add_field(name=member.name if member else "(no fronter)",
|
embed.add_field(name=member.name if member else "(no fronter)",
|
||||||
value="{}% ({})".format(percent, humanize.naturaldelta(front_time)))
|
value="{}% ({})".format(percent, humanize.naturaldelta(front_time)))
|
||||||
|
|
||||||
embed.set_footer(text="Since {} ({})".format(span_start.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(pluralkit.utils.fix_time(span_start))))
|
embed.set_footer(text="Since {} ({})".format(span_start.isoformat(sep=" ", timespec="seconds"),
|
||||||
|
humanize.naturaltime(pluralkit.utils.fix_time(span_start))))
|
||||||
await ctx.reply(embed=embed)
|
await ctx.reply(embed=embed)
|
||||||
|
@ -1,25 +1,27 @@
|
|||||||
|
import discord
|
||||||
import humanize
|
import humanize
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import discord
|
|
||||||
|
|
||||||
import pluralkit
|
import pluralkit
|
||||||
from pluralkit import db
|
|
||||||
from pluralkit.bot.utils import escape
|
from pluralkit.bot.utils import escape
|
||||||
from pluralkit.member import Member
|
from pluralkit.member import Member
|
||||||
from pluralkit.switch import Switch
|
from pluralkit.switch import Switch
|
||||||
from pluralkit.system import System
|
from pluralkit.system import System
|
||||||
from pluralkit.utils import get_fronters
|
from pluralkit.utils import get_fronters
|
||||||
|
|
||||||
|
|
||||||
def truncate_field_name(s: str) -> str:
|
def truncate_field_name(s: str) -> str:
|
||||||
return s[:256]
|
return s[:256]
|
||||||
|
|
||||||
|
|
||||||
def truncate_field_body(s: str) -> str:
|
def truncate_field_body(s: str) -> str:
|
||||||
return s[:1024]
|
return s[:1024]
|
||||||
|
|
||||||
|
|
||||||
def truncate_description(s: str) -> str:
|
def truncate_description(s: str) -> str:
|
||||||
return s[:2048]
|
return s[:2048]
|
||||||
|
|
||||||
|
|
||||||
def truncate_title(s: str) -> str:
|
def truncate_title(s: str) -> str:
|
||||||
return s[:256]
|
return s[:256]
|
||||||
|
|
||||||
@ -33,7 +35,7 @@ def success(text: str) -> discord.Embed:
|
|||||||
|
|
||||||
def error(text: str, help: Tuple[str, str] = None) -> discord.Embed:
|
def error(text: str, help: Tuple[str, str] = None) -> discord.Embed:
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.description = truncate_description(s)
|
embed.description = truncate_description(text)
|
||||||
embed.colour = discord.Colour.dark_red()
|
embed.colour = discord.Colour.dark_red()
|
||||||
|
|
||||||
if help:
|
if help:
|
||||||
|
@ -1,76 +1,26 @@
|
|||||||
from io import BytesIO
|
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import logging
|
import logging
|
||||||
import re
|
from io import BytesIO
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.bot import utils, channel_logger
|
from pluralkit.bot import utils, channel_logger
|
||||||
from pluralkit.bot.channel_logger import ChannelLogger
|
from pluralkit.bot.channel_logger import ChannelLogger
|
||||||
|
from pluralkit.member import Member
|
||||||
|
from pluralkit.system import System
|
||||||
|
|
||||||
logger = logging.getLogger("pluralkit.bot.proxy")
|
logger = logging.getLogger("pluralkit.bot.proxy")
|
||||||
|
|
||||||
|
class ProxyError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def fix_webhook(webhook: discord.Webhook) -> discord.Webhook:
|
def fix_webhook(webhook: discord.Webhook) -> discord.Webhook:
|
||||||
# Workaround for https://github.com/Rapptz/discord.py/issues/1242 and similar issues (#1150)
|
# Workaround for https://github.com/Rapptz/discord.py/issues/1242 and similar issues (#1150)
|
||||||
webhook._adapter.store_user = webhook._adapter._store_user
|
webhook._adapter.store_user = webhook._adapter._store_user
|
||||||
webhook._adapter.http = None
|
webhook._adapter.http = None
|
||||||
return webhook
|
return webhook
|
||||||
|
|
||||||
def extract_leading_mentions(message_text):
|
|
||||||
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
|
|
||||||
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
|
|
||||||
if not match:
|
|
||||||
return message_text, ""
|
|
||||||
|
|
||||||
# Return the text after the mentions, and the mentions themselves
|
|
||||||
return message_text[match.span(0)[1]:].strip(), match.group(0)
|
|
||||||
|
|
||||||
|
|
||||||
def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
|
|
||||||
# Skip members with no defined proxy tags
|
|
||||||
if not member.prefix and not member.suffix:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# DB defines empty prefix/suffixes as None, replace with empty strings to prevent errors
|
|
||||||
prefix = member.prefix or ""
|
|
||||||
suffix = member.suffix or ""
|
|
||||||
|
|
||||||
# Ignore mentions at the very start of the message, and match proxy tags after those
|
|
||||||
message_text, leading_mentions = extract_leading_mentions(message_text)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions,
|
|
||||||
prefix, suffix))
|
|
||||||
|
|
||||||
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
|
|
||||||
prefix_length = len(prefix)
|
|
||||||
suffix_length = len(suffix)
|
|
||||||
|
|
||||||
# If suffix_length is 0, the last bit of the slice will be "-0", and the slice will fail
|
|
||||||
if suffix_length > 0:
|
|
||||||
inner_string = message_text[prefix_length:-suffix_length]
|
|
||||||
else:
|
|
||||||
inner_string = message_text[prefix_length:]
|
|
||||||
|
|
||||||
# Add the mentions we stripped back
|
|
||||||
inner_string = leading_mentions + inner_string
|
|
||||||
return inner_string
|
|
||||||
|
|
||||||
|
|
||||||
def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
|
|
||||||
# Sort by specificity (members with both prefix and suffix go higher)
|
|
||||||
# This will make sure more "precise" proxy tags get tried first
|
|
||||||
members: List[db.ProxyMember] = sorted(members, key=lambda x: int(
|
|
||||||
bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
|
|
||||||
|
|
||||||
for member in members:
|
|
||||||
match = match_member_proxy_tags(member, message_text)
|
|
||||||
if match is not None: # Using "is not None" because an empty string is OK here too
|
|
||||||
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
|
|
||||||
return member, match
|
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_webhook_for_channel(conn, bot_user: discord.User, channel: discord.TextChannel):
|
async def get_or_create_webhook_for_channel(conn, bot_user: discord.User, channel: discord.TextChannel):
|
||||||
# First, check if we have one saved in the DB
|
# First, check if we have one saved in the DB
|
||||||
webhook_from_db = await db.get_webhook(conn, channel.id)
|
webhook_from_db = await db.get_webhook(conn, channel.id)
|
||||||
@ -83,17 +33,20 @@ async def get_or_create_webhook_for_channel(conn, bot_user: discord.User, channe
|
|||||||
hook._adapter.store_user = hook._adapter._store_user
|
hook._adapter.store_user = hook._adapter._store_user
|
||||||
return fix_webhook(hook)
|
return fix_webhook(hook)
|
||||||
|
|
||||||
# If not, we check to see if there already exists one we've missed
|
try:
|
||||||
for existing_hook in await channel.webhooks():
|
# If not, we check to see if there already exists one we've missed
|
||||||
existing_hook_creator = existing_hook.user.id if existing_hook.user else None
|
for existing_hook in await channel.webhooks():
|
||||||
is_mine = existing_hook.name == "PluralKit Proxy Webhook" and existing_hook_creator == bot_user.id
|
existing_hook_creator = existing_hook.user.id if existing_hook.user else None
|
||||||
if is_mine:
|
is_mine = existing_hook.name == "PluralKit Proxy Webhook" and existing_hook_creator == bot_user.id
|
||||||
# We found one we made, let's add that to the DB just to be sure
|
if is_mine:
|
||||||
await db.add_webhook(conn, channel.id, existing_hook.id, existing_hook.token)
|
# We found one we made, let's add that to the DB just to be sure
|
||||||
return fix_webhook(existing_hook)
|
await db.add_webhook(conn, channel.id, existing_hook.id, existing_hook.token)
|
||||||
|
return fix_webhook(existing_hook)
|
||||||
|
|
||||||
# If not, we create one and save it
|
# If not, we create one and save it
|
||||||
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
|
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
|
||||||
|
except discord.Forbidden:
|
||||||
|
raise ProxyError("PluralKit does not have the \"Manage Webhooks\" permission, and thus cannot proxy your message. Please contact a server administrator.")
|
||||||
|
|
||||||
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
|
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
|
||||||
return fix_webhook(created_webhook)
|
return fix_webhook(created_webhook)
|
||||||
@ -113,28 +66,37 @@ async def make_attachment_file(message: discord.Message):
|
|||||||
return discord.File(bio, first_attachment.filename)
|
return discord.File(bio, first_attachment.filename)
|
||||||
|
|
||||||
|
|
||||||
async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember,
|
async def send_proxy_message(conn, original_message: discord.Message, system: System, member: Member,
|
||||||
inner_text: str, logger: ChannelLogger, bot_user: discord.User):
|
inner_text: str, logger: ChannelLogger, bot_user: discord.User):
|
||||||
# Send the message through the webhook
|
# Send the message through the webhook
|
||||||
webhook = await get_or_create_webhook_for_channel(conn, bot_user, original_message.channel)
|
webhook = await get_or_create_webhook_for_channel(conn, bot_user, original_message.channel)
|
||||||
|
|
||||||
|
# Bounds check the combined name to avoid silent erroring
|
||||||
|
full_username = "{} {}".format(member.name, system.tag or "").strip()
|
||||||
|
if len(full_username) < 2:
|
||||||
|
raise ProxyError("The webhook's name, `{}`, is shorter than two characters, and thus cannot be proxied. Please change the member name or use a longer system tag.".format(full_username))
|
||||||
|
if len(full_username) > 32:
|
||||||
|
raise ProxyError("The webhook's name, `{}`, is longer than 32 characters, and thus cannot be proxied. Please change the member name or use a shorter system tag.".format(full_username))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sent_message = await webhook.send(
|
sent_message = await webhook.send(
|
||||||
content=inner_text,
|
content=inner_text,
|
||||||
username="{} {}".format(proxy_member.name, proxy_member.tag or "").strip(),
|
username=full_username,
|
||||||
avatar_url=proxy_member.avatar_url,
|
avatar_url=member.avatar_url,
|
||||||
file=await make_attachment_file(original_message),
|
file=await make_attachment_file(original_message),
|
||||||
wait=True
|
wait=True
|
||||||
)
|
)
|
||||||
except discord.NotFound:
|
except discord.NotFound:
|
||||||
# The webhook we got from the DB doesn't actually exist
|
# The webhook we got from the DB doesn't actually exist
|
||||||
|
# This can happen if someone manually deletes it from the server
|
||||||
# If we delete it from the DB then call the function again, it'll re-create one for us
|
# If we delete it from the DB then call the function again, it'll re-create one for us
|
||||||
|
# (lol, lazy)
|
||||||
await db.delete_webhook(conn, original_message.channel.id)
|
await db.delete_webhook(conn, original_message.channel.id)
|
||||||
await do_proxy_message(conn, original_message, proxy_member, inner_text, logger)
|
await send_proxy_message(conn, original_message, system, member, inner_text, logger, bot_user)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Save the proxied message in the database
|
# Save the proxied message in the database
|
||||||
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id,
|
await db.add_message(conn, sent_message.id, original_message.channel.id, member.id,
|
||||||
original_message.author.id)
|
original_message.author.id)
|
||||||
|
|
||||||
await logger.log_message_proxied(
|
await logger.log_message_proxied(
|
||||||
@ -145,11 +107,11 @@ async def do_proxy_message(conn, original_message: discord.Message, proxy_member
|
|||||||
original_message.author.name,
|
original_message.author.name,
|
||||||
original_message.author.discriminator,
|
original_message.author.discriminator,
|
||||||
original_message.author.id,
|
original_message.author.id,
|
||||||
proxy_member.name,
|
member.name,
|
||||||
proxy_member.hid,
|
member.hid,
|
||||||
proxy_member.avatar_url,
|
member.avatar_url,
|
||||||
proxy_member.system_name,
|
system.name,
|
||||||
proxy_member.system_hid,
|
system.hid,
|
||||||
inner_text,
|
inner_text,
|
||||||
sent_message.attachments[0].url if sent_message.attachments else None,
|
sent_message.attachments[0].url if sent_message.attachments else None,
|
||||||
sent_message.created_at,
|
sent_message.created_at,
|
||||||
@ -157,34 +119,44 @@ async def do_proxy_message(conn, original_message: discord.Message, proxy_member
|
|||||||
)
|
)
|
||||||
|
|
||||||
# And finally, gotta delete the original.
|
# And finally, gotta delete the original.
|
||||||
await original_message.delete()
|
try:
|
||||||
|
await original_message.delete()
|
||||||
|
except discord.Forbidden:
|
||||||
|
raise ProxyError("PluralKit does not have permission to delete user messages. Please contact a server administrator.")
|
||||||
|
|
||||||
|
|
||||||
async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger, bot_user: discord.User) -> bool:
|
async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger, bot_user: discord.User) -> bool:
|
||||||
# Don't bother proxying in DMs with the bot
|
# Don't bother proxying in DMs
|
||||||
if isinstance(message.channel, discord.abc.PrivateChannel):
|
if isinstance(message.channel, discord.abc.PrivateChannel):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Return every member associated with the account
|
# Get the system associated with the account, if possible
|
||||||
members = await db.get_members_by_account(conn, message.author.id)
|
system = await System.get_by_account(conn, message.author.id)
|
||||||
proxy_match = match_proxy_tags(members, message.content)
|
if not system:
|
||||||
if not proxy_match:
|
|
||||||
# No proxy tags match here, we done
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
member, inner_text = proxy_match
|
# Match on the members' proxy tags
|
||||||
|
proxy_match = await system.match_proxy(conn, message.content)
|
||||||
|
if not proxy_match:
|
||||||
|
return False
|
||||||
|
|
||||||
# Sanitize inner text for @everyones and such
|
member, inner_message = proxy_match
|
||||||
inner_text = utils.sanitize(inner_text)
|
|
||||||
|
# Make sure no @everyones slip through, etc
|
||||||
|
inner_message = utils.sanitize(inner_message)
|
||||||
|
|
||||||
# If we don't have an inner text OR an attachment, we cancel because the hook can't send that
|
# If we don't have an inner text OR an attachment, we cancel because the hook can't send that
|
||||||
# Strip so it counts a string of solely spaces as blank too
|
# Strip so it counts a string of solely spaces as blank too
|
||||||
if not inner_text.strip() and not message.attachments:
|
if not inner_message.strip() and not message.attachments:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# So, we now have enough information to successfully proxy a message
|
# So, we now have enough information to successfully proxy a message
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
await do_proxy_message(conn, message, member, inner_text, logger, bot_user)
|
try:
|
||||||
|
await send_proxy_message(conn, message, system, member, inner_message, logger, bot_user)
|
||||||
|
except ProxyError as e:
|
||||||
|
await message.channel.send("\u274c {}".format(str(e)))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,24 +54,17 @@ class Member(namedtuple("Member",
|
|||||||
|
|
||||||
return member
|
return member
|
||||||
|
|
||||||
async def set_name(self, conn, system: "System", new_name: str):
|
async def set_name(self, conn, new_name: str):
|
||||||
"""
|
"""
|
||||||
Set the name of a member. Requires the system to be passed in order to bounds check with the system tag.
|
Set the name of a member.
|
||||||
:raises: MemberNameTooLongError, CustomEmojiError
|
:raises: CustomEmojiError
|
||||||
"""
|
"""
|
||||||
# Custom emojis can't go in the member name
|
# Custom emojis can't go in the member name
|
||||||
# Technically they *could* but they wouldn't render properly
|
# Technically they *could*, but they wouldn't render properly
|
||||||
# so I'd rather explicitly ban them to in order to avoid confusion
|
# so I'd rather explicitly ban them to in order to avoid confusion
|
||||||
|
|
||||||
# The textual form is longer than the length limit in most cases
|
|
||||||
# so we check this *before* the length check for better errors
|
|
||||||
if contains_custom_emoji(new_name):
|
if contains_custom_emoji(new_name):
|
||||||
raise errors.CustomEmojiError()
|
raise errors.CustomEmojiError()
|
||||||
|
|
||||||
# Explicit name length checking
|
|
||||||
if len(new_name) > system.get_member_name_limit():
|
|
||||||
raise errors.MemberNameTooLongError(tag_present=bool(system.tag))
|
|
||||||
|
|
||||||
await db.update_member_field(conn, self.id, "name", new_name)
|
await db.update_member_field(conn, self.id, "name", new_name)
|
||||||
|
|
||||||
async def set_description(self, conn, new_description: Optional[str]):
|
async def set_description(self, conn, new_description: Optional[str]):
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import random
|
import random
|
||||||
|
import re
|
||||||
import string
|
import string
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from collections.__init__ import namedtuple
|
from collections.__init__ import namedtuple
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from pluralkit import db, errors
|
from pluralkit import db, errors
|
||||||
from pluralkit.member import Member
|
from pluralkit.member import Member
|
||||||
@ -63,11 +64,6 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
if contains_custom_emoji(new_tag):
|
if contains_custom_emoji(new_tag):
|
||||||
raise errors.CustomEmojiError()
|
raise errors.CustomEmojiError()
|
||||||
|
|
||||||
# Check name+tag length for all members
|
|
||||||
members_exceeding = await db.get_members_exceeding(conn, system_id=self.id, length=32 - len(new_tag) - 1)
|
|
||||||
if len(members_exceeding) > 0:
|
|
||||||
raise errors.TagTooLongWithMembersError([member.name for member in members_exceeding])
|
|
||||||
|
|
||||||
await db.update_system_field(conn, self.id, "tag", new_tag)
|
await db.update_system_field(conn, self.id, "tag", new_tag)
|
||||||
|
|
||||||
async def set_avatar(self, conn, new_avatar_url: Optional[str]):
|
async def set_avatar(self, conn, new_avatar_url: Optional[str]):
|
||||||
@ -133,12 +129,43 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_member_name_limit(self) -> int:
|
def get_member_name_limit(self) -> int:
|
||||||
"""Returns the maximum length a member's name or nickname is allowed to be. Depends on the system tag."""
|
"""Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
|
||||||
if self.tag:
|
if self.tag:
|
||||||
return 32 - len(self.tag) - 1
|
return 32 - len(self.tag) - 1
|
||||||
else:
|
else:
|
||||||
return 32
|
return 32
|
||||||
|
|
||||||
|
async def match_proxy(self, conn, message: str) -> Optional[Tuple[Member, str]]:
|
||||||
|
"""Tries to find a member with proxy tags matching the given message. Returns the member and the inner contents."""
|
||||||
|
members = await db.get_all_members(conn, self.id)
|
||||||
|
|
||||||
|
# Sort by specificity (members with both prefix and suffix defined go higher)
|
||||||
|
# This will make sure more "precise" proxy tags get tried first and match properly
|
||||||
|
members = sorted(members, key=lambda x: int(bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
|
||||||
|
|
||||||
|
for member in members:
|
||||||
|
proxy_prefix = member.prefix or ""
|
||||||
|
proxy_suffix = member.suffix or ""
|
||||||
|
|
||||||
|
# Check if the message matches these tags
|
||||||
|
if message.startswith(proxy_prefix) and message.endswith(proxy_suffix):
|
||||||
|
# If the message starts with a mention, "separate" that and match the bit after
|
||||||
|
mention_match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message)
|
||||||
|
leading_mentions = ""
|
||||||
|
if mention_match:
|
||||||
|
message = message[mention_match.span(0)[1]:].strip()
|
||||||
|
leading_mentions = mention_match.group(0)
|
||||||
|
|
||||||
|
# Extract the inner message (special case because -0 is invalid as an end slice)
|
||||||
|
if len(proxy_suffix) == 0:
|
||||||
|
inner_message = message[len(proxy_prefix):]
|
||||||
|
else:
|
||||||
|
inner_message = message[len(proxy_prefix):-len(proxy_suffix)]
|
||||||
|
|
||||||
|
# Add the stripped mentions back if there are any
|
||||||
|
inner_message = leading_mentions + inner_message
|
||||||
|
return member, inner_message
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
return {
|
return {
|
||||||
"id": self.hid,
|
"id": self.hid,
|
||||||
|
Loading…
Reference in New Issue
Block a user