Add proxy permission errors, refactor proxy matching, change member name length error handling. Closes #25.

This commit is contained in:
Ske 2018-11-30 21:42:01 +01:00
parent 5d3cb7b6bf
commit c36a054519
7 changed files with 149 additions and 117 deletions

View File

@ -100,6 +100,9 @@ class CommandContext:
async def reply_ok(self, content=None, embed=None): async def reply_ok(self, content=None, embed=None):
return await self.reply(content="\u2705 {}".format(content or ""), embed=embed) return await self.reply(content="\u2705 {}".format(content or ""), embed=embed)
async def reply_warn(self, content=None, embed=None):
return await self.reply(content="\u26a0 {}".format(content or ""), embed=embed)
async def confirm_react(self, user: Union[discord.Member, discord.User], message: str): async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
message = await self.reply(message) message = await self.reply(message)
await message.add_reaction("\u2705") # Checkmark await message.add_reaction("\u2705") # Checkmark

View File

@ -39,9 +39,18 @@ async def member_name(ctx: CommandContext):
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member)) member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member))
new_name = ctx.pop_str(CommandError("You must pass a new member name.", help=help.edit_member)) new_name = ctx.pop_str(CommandError("You must pass a new member name.", help=help.edit_member))
await member.set_name(ctx.conn, system, new_name) await member.set_name(ctx.conn, new_name)
await ctx.reply_ok("Member name updated.") await ctx.reply_ok("Member name updated.")
if len(new_name) < 2 and not system.tag:
await ctx.reply_warn("This member's new name is under 2 characters, and thus cannot be proxied. To prevent this, use a longer member name, or add a system tag.")
elif len(new_name) > 32:
exceeds_by = len(new_name) - 32
await ctx.reply_warn("This member's new name is longer than 32 characters, and thus cannot be proxied. To prevent this, shorten the member name by {} characters.".format(exceeds_by))
elif len(new_name) > system.get_member_name_limit():
exceeds_by = len(new_name) - system.get_member_name_limit()
await ctx.reply_warn("This member's new name, when combined with the system tag `{}`, is longer than 32 characters, and thus cannot be proxied. To prevent this, shorten the name or system tag by at least {} characters.".format(system.tag, exceeds_by))
async def member_description(ctx: CommandContext): async def member_description(ctx: CommandContext):
await ctx.ensure_system() await ctx.ensure_system()

View File

@ -32,7 +32,8 @@ async def new_system(ctx: CommandContext):
async def system_set(ctx: CommandContext): async def system_set(ctx: CommandContext):
raise CommandError("`pk;system set` has been retired. Please use the new member modifying commands: `pk;system [name|description|avatar|tag]`.") raise CommandError(
"`pk;system set` has been retired. Please use the new member modifying commands: `pk;system [name|description|avatar|tag]`.")
async def system_name(ctx: CommandContext): async def system_name(ctx: CommandContext):
@ -58,6 +59,28 @@ async def system_tag(ctx: CommandContext):
await system.set_tag(ctx.conn, new_tag) await system.set_tag(ctx.conn, new_tag)
await ctx.reply_ok("System tag {}.".format("updated" if new_tag else "cleared")) await ctx.reply_ok("System tag {}.".format("updated" if new_tag else "cleared"))
# System class is immutable, update the tag so get_member_name_limit works
system = system._replace(tag=new_tag)
members = await system.get_members(ctx.conn)
# Certain members might not be able to be proxied with this new tag, show a warning for those
members_exceeding = [member for member in members if
len(member.name) > system.get_member_name_limit()]
if members_exceeding:
member_names = ", ".join([member.name for member in members_exceeding])
await ctx.reply_warn(
"Due to the length of this tag, the following members will not be able to be proxied: {}. Please use a shorter tag to prevent this.".format(
member_names))
# Edge case: members with name length 1 and no new tag
if not new_tag:
one_length_members = [member for member in members if len(member.name) == 1]
if one_length_members:
member_names = ", ".join([member.name for member in one_length_members])
await ctx.reply_warn(
"Without a system tag, you will not be able to proxy members with a one-character name: {}. To prevent this, please add a system tag or lengthen their name.".format(
member_names))
async def system_avatar(ctx: CommandContext): async def system_avatar(ctx: CommandContext):
system = await ctx.ensure_system() system = await ctx.ensure_system()
@ -89,7 +112,9 @@ async def system_link(ctx: CommandContext):
if account_system: if account_system:
raise CommandError(AccountAlreadyLinkedError(account_system).message) raise CommandError(AccountAlreadyLinkedError(account_system).message)
if not await ctx.confirm_react(linkee, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention)): if not await ctx.confirm_react(linkee,
"{}, please confirm the link by clicking the ✅ reaction on this message.".format(
linkee.mention)):
raise CommandError("Account link cancelled.") raise CommandError("Account link cancelled.")
await system.link_account(ctx.conn, linkee.id) await system.link_account(ctx.conn, linkee.id)
@ -241,5 +266,6 @@ async def system_frontpercent(ctx: CommandContext):
embed.add_field(name=member.name if member else "(no fronter)", embed.add_field(name=member.name if member else "(no fronter)",
value="{}% ({})".format(percent, humanize.naturaldelta(front_time))) value="{}% ({})".format(percent, humanize.naturaldelta(front_time)))
embed.set_footer(text="Since {} ({})".format(span_start.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(pluralkit.utils.fix_time(span_start)))) embed.set_footer(text="Since {} ({})".format(span_start.isoformat(sep=" ", timespec="seconds"),
humanize.naturaltime(pluralkit.utils.fix_time(span_start))))
await ctx.reply(embed=embed) await ctx.reply(embed=embed)

View File

@ -1,25 +1,27 @@
import discord
import humanize import humanize
from typing import Tuple from typing import Tuple
import discord
import pluralkit import pluralkit
from pluralkit import db
from pluralkit.bot.utils import escape from pluralkit.bot.utils import escape
from pluralkit.member import Member from pluralkit.member import Member
from pluralkit.switch import Switch from pluralkit.switch import Switch
from pluralkit.system import System from pluralkit.system import System
from pluralkit.utils import get_fronters from pluralkit.utils import get_fronters
def truncate_field_name(s: str) -> str: def truncate_field_name(s: str) -> str:
return s[:256] return s[:256]
def truncate_field_body(s: str) -> str: def truncate_field_body(s: str) -> str:
return s[:1024] return s[:1024]
def truncate_description(s: str) -> str: def truncate_description(s: str) -> str:
return s[:2048] return s[:2048]
def truncate_title(s: str) -> str: def truncate_title(s: str) -> str:
return s[:256] return s[:256]
@ -33,7 +35,7 @@ def success(text: str) -> discord.Embed:
def error(text: str, help: Tuple[str, str] = None) -> discord.Embed: def error(text: str, help: Tuple[str, str] = None) -> discord.Embed:
embed = discord.Embed() embed = discord.Embed()
embed.description = truncate_description(s) embed.description = truncate_description(text)
embed.colour = discord.Colour.dark_red() embed.colour = discord.Colour.dark_red()
if help: if help:

View File

@ -1,76 +1,26 @@
from io import BytesIO
import discord import discord
import logging import logging
import re from io import BytesIO
from typing import List, Optional from typing import Optional
from pluralkit import db from pluralkit import db
from pluralkit.bot import utils, channel_logger from pluralkit.bot import utils, channel_logger
from pluralkit.bot.channel_logger import ChannelLogger from pluralkit.bot.channel_logger import ChannelLogger
from pluralkit.member import Member
from pluralkit.system import System
logger = logging.getLogger("pluralkit.bot.proxy") logger = logging.getLogger("pluralkit.bot.proxy")
class ProxyError(Exception):
pass
def fix_webhook(webhook: discord.Webhook) -> discord.Webhook: def fix_webhook(webhook: discord.Webhook) -> discord.Webhook:
# Workaround for https://github.com/Rapptz/discord.py/issues/1242 and similar issues (#1150) # Workaround for https://github.com/Rapptz/discord.py/issues/1242 and similar issues (#1150)
webhook._adapter.store_user = webhook._adapter._store_user webhook._adapter.store_user = webhook._adapter._store_user
webhook._adapter.http = None webhook._adapter.http = None
return webhook return webhook
def extract_leading_mentions(message_text):
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
if not match:
return message_text, ""
# Return the text after the mentions, and the mentions themselves
return message_text[match.span(0)[1]:].strip(), match.group(0)
def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
# Skip members with no defined proxy tags
if not member.prefix and not member.suffix:
return None
# DB defines empty prefix/suffixes as None, replace with empty strings to prevent errors
prefix = member.prefix or ""
suffix = member.suffix or ""
# Ignore mentions at the very start of the message, and match proxy tags after those
message_text, leading_mentions = extract_leading_mentions(message_text)
logger.debug(
"Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions,
prefix, suffix))
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
prefix_length = len(prefix)
suffix_length = len(suffix)
# If suffix_length is 0, the last bit of the slice will be "-0", and the slice will fail
if suffix_length > 0:
inner_string = message_text[prefix_length:-suffix_length]
else:
inner_string = message_text[prefix_length:]
# Add the mentions we stripped back
inner_string = leading_mentions + inner_string
return inner_string
def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
# Sort by specificity (members with both prefix and suffix go higher)
# This will make sure more "precise" proxy tags get tried first
members: List[db.ProxyMember] = sorted(members, key=lambda x: int(
bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
for member in members:
match = match_member_proxy_tags(member, message_text)
if match is not None: # Using "is not None" because an empty string is OK here too
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
return member, match
async def get_or_create_webhook_for_channel(conn, bot_user: discord.User, channel: discord.TextChannel): async def get_or_create_webhook_for_channel(conn, bot_user: discord.User, channel: discord.TextChannel):
# First, check if we have one saved in the DB # First, check if we have one saved in the DB
webhook_from_db = await db.get_webhook(conn, channel.id) webhook_from_db = await db.get_webhook(conn, channel.id)
@ -83,17 +33,20 @@ async def get_or_create_webhook_for_channel(conn, bot_user: discord.User, channe
hook._adapter.store_user = hook._adapter._store_user hook._adapter.store_user = hook._adapter._store_user
return fix_webhook(hook) return fix_webhook(hook)
# If not, we check to see if there already exists one we've missed try:
for existing_hook in await channel.webhooks(): # If not, we check to see if there already exists one we've missed
existing_hook_creator = existing_hook.user.id if existing_hook.user else None for existing_hook in await channel.webhooks():
is_mine = existing_hook.name == "PluralKit Proxy Webhook" and existing_hook_creator == bot_user.id existing_hook_creator = existing_hook.user.id if existing_hook.user else None
if is_mine: is_mine = existing_hook.name == "PluralKit Proxy Webhook" and existing_hook_creator == bot_user.id
# We found one we made, let's add that to the DB just to be sure if is_mine:
await db.add_webhook(conn, channel.id, existing_hook.id, existing_hook.token) # We found one we made, let's add that to the DB just to be sure
return fix_webhook(existing_hook) await db.add_webhook(conn, channel.id, existing_hook.id, existing_hook.token)
return fix_webhook(existing_hook)
# If not, we create one and save it # If not, we create one and save it
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook") created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
except discord.Forbidden:
raise ProxyError("PluralKit does not have the \"Manage Webhooks\" permission, and thus cannot proxy your message. Please contact a server administrator.")
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token) await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
return fix_webhook(created_webhook) return fix_webhook(created_webhook)
@ -113,28 +66,37 @@ async def make_attachment_file(message: discord.Message):
return discord.File(bio, first_attachment.filename) return discord.File(bio, first_attachment.filename)
async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember, async def send_proxy_message(conn, original_message: discord.Message, system: System, member: Member,
inner_text: str, logger: ChannelLogger, bot_user: discord.User): inner_text: str, logger: ChannelLogger, bot_user: discord.User):
# Send the message through the webhook # Send the message through the webhook
webhook = await get_or_create_webhook_for_channel(conn, bot_user, original_message.channel) webhook = await get_or_create_webhook_for_channel(conn, bot_user, original_message.channel)
# Bounds check the combined name to avoid silent erroring
full_username = "{} {}".format(member.name, system.tag or "").strip()
if len(full_username) < 2:
raise ProxyError("The webhook's name, `{}`, is shorter than two characters, and thus cannot be proxied. Please change the member name or use a longer system tag.".format(full_username))
if len(full_username) > 32:
raise ProxyError("The webhook's name, `{}`, is longer than 32 characters, and thus cannot be proxied. Please change the member name or use a shorter system tag.".format(full_username))
try: try:
sent_message = await webhook.send( sent_message = await webhook.send(
content=inner_text, content=inner_text,
username="{} {}".format(proxy_member.name, proxy_member.tag or "").strip(), username=full_username,
avatar_url=proxy_member.avatar_url, avatar_url=member.avatar_url,
file=await make_attachment_file(original_message), file=await make_attachment_file(original_message),
wait=True wait=True
) )
except discord.NotFound: except discord.NotFound:
# The webhook we got from the DB doesn't actually exist # The webhook we got from the DB doesn't actually exist
# This can happen if someone manually deletes it from the server
# If we delete it from the DB then call the function again, it'll re-create one for us # If we delete it from the DB then call the function again, it'll re-create one for us
# (lol, lazy)
await db.delete_webhook(conn, original_message.channel.id) await db.delete_webhook(conn, original_message.channel.id)
await do_proxy_message(conn, original_message, proxy_member, inner_text, logger) await send_proxy_message(conn, original_message, system, member, inner_text, logger, bot_user)
return return
# Save the proxied message in the database # Save the proxied message in the database
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id, await db.add_message(conn, sent_message.id, original_message.channel.id, member.id,
original_message.author.id) original_message.author.id)
await logger.log_message_proxied( await logger.log_message_proxied(
@ -145,11 +107,11 @@ async def do_proxy_message(conn, original_message: discord.Message, proxy_member
original_message.author.name, original_message.author.name,
original_message.author.discriminator, original_message.author.discriminator,
original_message.author.id, original_message.author.id,
proxy_member.name, member.name,
proxy_member.hid, member.hid,
proxy_member.avatar_url, member.avatar_url,
proxy_member.system_name, system.name,
proxy_member.system_hid, system.hid,
inner_text, inner_text,
sent_message.attachments[0].url if sent_message.attachments else None, sent_message.attachments[0].url if sent_message.attachments else None,
sent_message.created_at, sent_message.created_at,
@ -157,34 +119,44 @@ async def do_proxy_message(conn, original_message: discord.Message, proxy_member
) )
# And finally, gotta delete the original. # And finally, gotta delete the original.
await original_message.delete() try:
await original_message.delete()
except discord.Forbidden:
raise ProxyError("PluralKit does not have permission to delete user messages. Please contact a server administrator.")
async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger, bot_user: discord.User) -> bool: async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger, bot_user: discord.User) -> bool:
# Don't bother proxying in DMs with the bot # Don't bother proxying in DMs
if isinstance(message.channel, discord.abc.PrivateChannel): if isinstance(message.channel, discord.abc.PrivateChannel):
return False return False
# Return every member associated with the account # Get the system associated with the account, if possible
members = await db.get_members_by_account(conn, message.author.id) system = await System.get_by_account(conn, message.author.id)
proxy_match = match_proxy_tags(members, message.content) if not system:
if not proxy_match:
# No proxy tags match here, we done
return False return False
member, inner_text = proxy_match # Match on the members' proxy tags
proxy_match = await system.match_proxy(conn, message.content)
if not proxy_match:
return False
# Sanitize inner text for @everyones and such member, inner_message = proxy_match
inner_text = utils.sanitize(inner_text)
# Make sure no @everyones slip through, etc
inner_message = utils.sanitize(inner_message)
# If we don't have an inner text OR an attachment, we cancel because the hook can't send that # If we don't have an inner text OR an attachment, we cancel because the hook can't send that
# Strip so it counts a string of solely spaces as blank too # Strip so it counts a string of solely spaces as blank too
if not inner_text.strip() and not message.attachments: if not inner_message.strip() and not message.attachments:
return False return False
# So, we now have enough information to successfully proxy a message # So, we now have enough information to successfully proxy a message
async with conn.transaction(): async with conn.transaction():
await do_proxy_message(conn, message, member, inner_text, logger, bot_user) try:
await send_proxy_message(conn, message, system, member, inner_message, logger, bot_user)
except ProxyError as e:
await message.channel.send("\u274c {}".format(str(e)))
return True return True

View File

@ -54,24 +54,17 @@ class Member(namedtuple("Member",
return member return member
async def set_name(self, conn, system: "System", new_name: str): async def set_name(self, conn, new_name: str):
""" """
Set the name of a member. Requires the system to be passed in order to bounds check with the system tag. Set the name of a member.
:raises: MemberNameTooLongError, CustomEmojiError :raises: CustomEmojiError
""" """
# Custom emojis can't go in the member name # Custom emojis can't go in the member name
# Technically they *could* but they wouldn't render properly # Technically they *could*, but they wouldn't render properly
# so I'd rather explicitly ban them to in order to avoid confusion # so I'd rather explicitly ban them to in order to avoid confusion
# The textual form is longer than the length limit in most cases
# so we check this *before* the length check for better errors
if contains_custom_emoji(new_name): if contains_custom_emoji(new_name):
raise errors.CustomEmojiError() raise errors.CustomEmojiError()
# Explicit name length checking
if len(new_name) > system.get_member_name_limit():
raise errors.MemberNameTooLongError(tag_present=bool(system.tag))
await db.update_member_field(conn, self.id, "name", new_name) await db.update_member_field(conn, self.id, "name", new_name)
async def set_description(self, conn, new_description: Optional[str]): async def set_description(self, conn, new_description: Optional[str]):

View File

@ -1,9 +1,10 @@
import random import random
import re
import string import string
from datetime import datetime from datetime import datetime
from collections.__init__ import namedtuple from collections.__init__ import namedtuple
from typing import Optional, List from typing import Optional, List, Tuple
from pluralkit import db, errors from pluralkit import db, errors
from pluralkit.member import Member from pluralkit.member import Member
@ -63,11 +64,6 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
if contains_custom_emoji(new_tag): if contains_custom_emoji(new_tag):
raise errors.CustomEmojiError() raise errors.CustomEmojiError()
# Check name+tag length for all members
members_exceeding = await db.get_members_exceeding(conn, system_id=self.id, length=32 - len(new_tag) - 1)
if len(members_exceeding) > 0:
raise errors.TagTooLongWithMembersError([member.name for member in members_exceeding])
await db.update_system_field(conn, self.id, "tag", new_tag) await db.update_system_field(conn, self.id, "tag", new_tag)
async def set_avatar(self, conn, new_avatar_url: Optional[str]): async def set_avatar(self, conn, new_avatar_url: Optional[str]):
@ -133,12 +129,43 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
return None return None
def get_member_name_limit(self) -> int: def get_member_name_limit(self) -> int:
"""Returns the maximum length a member's name or nickname is allowed to be. Depends on the system tag.""" """Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
if self.tag: if self.tag:
return 32 - len(self.tag) - 1 return 32 - len(self.tag) - 1
else: else:
return 32 return 32
async def match_proxy(self, conn, message: str) -> Optional[Tuple[Member, str]]:
"""Tries to find a member with proxy tags matching the given message. Returns the member and the inner contents."""
members = await db.get_all_members(conn, self.id)
# Sort by specificity (members with both prefix and suffix defined go higher)
# This will make sure more "precise" proxy tags get tried first and match properly
members = sorted(members, key=lambda x: int(bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
for member in members:
proxy_prefix = member.prefix or ""
proxy_suffix = member.suffix or ""
# Check if the message matches these tags
if message.startswith(proxy_prefix) and message.endswith(proxy_suffix):
# If the message starts with a mention, "separate" that and match the bit after
mention_match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message)
leading_mentions = ""
if mention_match:
message = message[mention_match.span(0)[1]:].strip()
leading_mentions = mention_match.group(0)
# Extract the inner message (special case because -0 is invalid as an end slice)
if len(proxy_suffix) == 0:
inner_message = message[len(proxy_prefix):]
else:
inner_message = message[len(proxy_prefix):-len(proxy_suffix)]
# Add the stripped mentions back if there are any
inner_message = leading_mentions + inner_message
return member, inner_message
def to_json(self): def to_json(self):
return { return {
"id": self.hid, "id": self.hid,