bot: .net rewrite skeleton

This commit is contained in:
Ske 2019-04-19 20:48:37 +02:00
parent b5d1b87a72
commit e7fa5625b6
40 changed files with 770 additions and 3980 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
bin/
obj/
.env .env
.vscode/ .vscode/
.idea/ .idea/

6
Dockerfile Normal file
View File

@ -0,0 +1,6 @@
FROM mcr.microsoft.com/dotnet/core/sdk:2.2-alpine
WORKDIR /app
COPY PluralKit/ PluralKit.csproj /app/
RUN dotnet build
ENTRYPOINT ["dotnet", "run"]

16
PluralKit.csproj Normal file
View File

@ -0,0 +1,16 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.2</TargetFramework>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Dapper" Version="1.60.1" />
<PackageReference Include="Dapper.Contrib" Version="1.60.1" />
<PackageReference Include="Discord.Net" Version="2.0.1" />
<PackageReference Include="Npgsql" Version="4.0.4" />
<PackageReference Include="Npgsql.Json.NET" Version="4.0.4" />
</ItemGroup>
</Project>

125
PluralKit/Bot.cs Normal file
View File

@ -0,0 +1,125 @@
using System;
using System.Data;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;
using Dapper;
using Discord;
using Discord.Commands;
using Discord.WebSocket;
using Microsoft.Extensions.DependencyInjection;
using Npgsql;
using Npgsql.BackendMessages;
using Npgsql.PostgresTypes;
using Npgsql.TypeHandling;
using Npgsql.TypeMapping;
using NpgsqlTypes;
namespace PluralKit
{
class Initialize
{
static void Main() => new Initialize().MainAsync().GetAwaiter().GetResult();
private async Task MainAsync()
{
// Dapper by default tries to pass ulongs to Npgsql, which rejects them since PostgreSQL technically
// doesn't support unsigned types on its own.
// Instead we add a custom mapper to encode them as signed integers instead, converting them back and forth.
SqlMapper.RemoveTypeMap(typeof(ulong));
SqlMapper.AddTypeHandler<ulong>(new UlongEncodeAsLongHandler());
Dapper.DefaultTypeMap.MatchNamesWithUnderscores = true;
using (var services = BuildServiceProvider())
{
var connection = services.GetRequiredService<IDbConnection>() as NpgsqlConnection;
connection.ConnectionString = Environment.GetEnvironmentVariable("PK_DATABASE_URI");
await connection.OpenAsync();
var client = services.GetRequiredService<IDiscordClient>() as DiscordSocketClient;
await client.LoginAsync(TokenType.Bot, Environment.GetEnvironmentVariable("PK_TOKEN"));
await client.StartAsync();
await services.GetRequiredService<Bot>().Init();
await Task.Delay(-1);
}
}
public ServiceProvider BuildServiceProvider() => new ServiceCollection()
.AddSingleton<IDiscordClient, DiscordSocketClient>()
.AddSingleton<IDbConnection, NpgsqlConnection>()
.AddSingleton<Bot>()
.AddSingleton<CommandService>()
.AddSingleton<LogChannelService>()
.AddSingleton<ProxyService>()
.AddSingleton<SystemStore>()
.AddSingleton<MemberStore>()
.AddSingleton<MessageStore>()
.BuildServiceProvider();
}
class Bot
{
private IServiceProvider _services;
private DiscordSocketClient _client;
private CommandService _commands;
private IDbConnection _connection;
private ProxyService _proxy;
public Bot(IServiceProvider services, IDiscordClient client, CommandService commands, IDbConnection connection, ProxyService proxy)
{
this._services = services;
this._client = client as DiscordSocketClient;
this._commands = commands;
this._connection = connection;
this._proxy = proxy;
}
public async Task Init()
{
_commands.AddTypeReader<PKSystem>(new PKSystemTypeReader());
_commands.AddTypeReader<PKMember>(new PKMemberTypeReader());
_commands.CommandExecuted += CommandExecuted;
await _commands.AddModulesAsync(Assembly.GetEntryAssembly(), _services);
_client.MessageReceived += MessageReceived;
_client.ReactionAdded += _proxy.HandleReactionAddedAsync;
_client.MessageDeleted += _proxy.HandleMessageDeletedAsync;
}
private async Task CommandExecuted(Optional<CommandInfo> cmd, ICommandContext ctx, IResult _result)
{
if (!_result.IsSuccess) {
await ctx.Message.Channel.SendMessageAsync("\u274C " + _result.ErrorReason);
}
}
private async Task MessageReceived(SocketMessage _arg)
{
// Ignore system messages (member joined, message pinned, etc)
var arg = _arg as SocketUserMessage;
if (arg == null) return;
// Ignore bot messages
if (arg.Author.IsBot || arg.Author.IsWebhook) return;
int argPos = 0;
// Check if message starts with the command prefix
if (arg.HasStringPrefix("pk;", ref argPos) || arg.HasStringPrefix("pk!", ref argPos) || arg.HasMentionPrefix(_client.CurrentUser, ref argPos))
{
// If it does, fetch the sender's system (because most commands need that) into the context,
// and start command execution
var system = await _connection.QueryFirstAsync<PKSystem>("select systems.* from systems, accounts where accounts.uid = @Id and systems.id = accounts.system", new { Id = arg.Author.Id });
await _commands.ExecuteAsync(new PKCommandContext(_client, arg as SocketUserMessage, _connection, system), argPos, _services);
}
else
{
// If not, try proxying anyway
await _proxy.HandleMessageAsync(arg);
}
}
}
}

View File

@ -0,0 +1,73 @@
using System;
using System.Threading.Tasks;
using Dapper;
using Discord.Commands;
namespace PluralKit.Commands
{
[Group("system")]
public class SystemCommands : ContextParameterModuleBase<PKSystem>
{
public override string Prefix => "system";
public SystemStore Systems {get; set;}
public MemberStore Members {get; set;}
private RuntimeResult NO_SYSTEM_ERROR => PKResult.Error($"You do not have a system registered with PluralKit. To create one, type `pk;system new`. If you already have a system registered on another account, type `pk;link {Context.User.Mention}` from that account to link it here.");
private RuntimeResult OTHER_SYSTEM_CONTEXT_ERROR => PKResult.Error("You can only run this command on your own system.");
[Command("new")]
public async Task<RuntimeResult> New([Remainder] string systemName = null)
{
if (Context.ContextEntity != null) return OTHER_SYSTEM_CONTEXT_ERROR;
if (Context.SenderSystem != null) return PKResult.Error("You already have a system registered with PluralKit. To view it, type `pk;system`. If you'd like to delete your system and start anew, type `pk;system delete`, or if you'd like to unlink this account from it, type `pk;unlink.");
var system = await Systems.Create(systemName);
await ReplyAsync("Your system has been created. Type `pk;system` to view it, and type `pk;help` for more information about commands you can use now.");
return PKResult.Success();
}
[Command("name")]
public async Task<RuntimeResult> Name([Remainder] string newSystemName = null) {
if (Context.ContextEntity != null) return OTHER_SYSTEM_CONTEXT_ERROR;
if (Context.SenderSystem == null) return NO_SYSTEM_ERROR;
if (newSystemName != null && newSystemName.Length > 250) return PKResult.Error($"Your chosen system name is too long. ({newSystemName.Length} > 250 characters)");
Context.SenderSystem.Name = newSystemName;
await Systems.Save(Context.SenderSystem);
return PKResult.Success();
}
[Command("description")]
public async Task<RuntimeResult> Description([Remainder] string newDescription = null) {
if (Context.ContextEntity != null) return OTHER_SYSTEM_CONTEXT_ERROR;
if (Context.SenderSystem == null) return NO_SYSTEM_ERROR;
if (newDescription != null && newDescription.Length > 1000) return PKResult.Error($"Your chosen description is too long. ({newDescription.Length} > 250 characters)");
Context.SenderSystem.Description = newDescription;
await Systems.Save(Context.SenderSystem);
return PKResult.Success("uwu");
}
[Command("tag")]
public async Task<RuntimeResult> Tag([Remainder] string newTag = null) {
if (Context.ContextEntity != null) return OTHER_SYSTEM_CONTEXT_ERROR;
if (Context.SenderSystem == null) return NO_SYSTEM_ERROR;
Context.SenderSystem.Tag = newTag;
var unproxyableMembers = await Members.GetUnproxyableMembers(Context.SenderSystem);
//if (unproxyableMembers.Count > 0) {
throw new Exception("sdjsdflsdf");
//}
await Systems.Save(Context.SenderSystem);
return PKResult.Success("uwu");
}
public override async Task<PKSystem> ReadContextParameterAsync(string value)
{
var res = await new PKSystemTypeReader().ReadAsync(Context, value, _services);
return res.IsSuccess ? res.BestMatch as PKSystem : null;
}
}
}

34
PluralKit/Models.cs Normal file
View File

@ -0,0 +1,34 @@
using System;
using Dapper.Contrib.Extensions;
namespace PluralKit {
[Table("systems")]
public class PKSystem {
[Key]
public int Id { get; set; }
public string Hid { get; set; }
public string Name { get; set; }
public string Description { get; set; }
public string Tag { get; set; }
public string AvatarUrl { get; set; }
public string Token { get; set; }
public DateTime Created { get; set; }
public string UiTz { get; set; }
}
[Table("members")]
public class PKMember {
public int Id { get; set; }
public string Hid { get; set; }
public int System { get; set; }
public string Color { get; set; }
public string AvatarUrl { get; set; }
public string Name { get; set; }
public DateTime Date { get; set; }
public string Pronouns { get; set; }
public string Description { get; set; }
public string Prefix { get; set; }
public string Suffix { get; set; }
public DateTime Created { get; set; }
}
}

View File

@ -0,0 +1,50 @@
using System.Data;
using System.Threading.Tasks;
using Dapper;
using Discord;
namespace PluralKit {
class ServerDefinition {
public ulong Id;
public ulong LogChannel;
}
class LogChannelService {
private IDiscordClient _client;
private IDbConnection _connection;
public LogChannelService(IDiscordClient client, IDbConnection connection)
{
this._client = client;
this._connection = connection;
}
public async Task LogMessage(PKSystem system, PKMember member, IMessage message, IUser sender) {
var channel = await GetLogChannel((message.Channel as IGuildChannel).Guild);
if (channel == null) return;
var embed = new EmbedBuilder()
.WithAuthor($"#{message.Channel.Name}: {member.Name}", member.AvatarUrl)
.WithDescription(message.Content)
.WithFooter($"System ID: {system.Hid} | Member ID: {member.Hid} | Sender: ${sender.Username}#{sender.Discriminator} ({sender.Id}) | Message ID: ${message.Id}")
.WithTimestamp(message.Timestamp)
.Build();
await channel.SendMessageAsync(text: message.GetJumpUrl(), embed: embed);
}
public async Task<ITextChannel> GetLogChannel(IGuild guild) {
var server = await _connection.QueryFirstAsync<ServerDefinition>("select * from servers where id = @Id", new { Id = guild.Id });
if (server == null) return null;
return await _client.GetChannelAsync(server.LogChannel) as ITextChannel;
}
public async Task SetLogChannel(IGuild guild, ITextChannel newLogChannel) {
var def = new ServerDefinition {
Id = guild.Id,
LogChannel = newLogChannel.Id
};
await _connection.ExecuteAsync("insert into servers(id, log_channel) values (@Id, @LogChannel) on conflict (id) do update set log_channel = @LogChannel", def);
}
}
}

View File

@ -0,0 +1,151 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
using Dapper;
using Discord;
using Discord.Rest;
using Discord.Webhook;
using Discord.WebSocket;
namespace PluralKit
{
class ProxyDatabaseResult
{
public PKSystem System;
public PKMember Member;
}
class ProxyMatch {
public PKMember Member;
public PKSystem System;
public string InnerText;
public string ProxyName => Member.Name + (System.Tag.Length > 0 ? " " + System.Tag : "");
}
class ProxyService {
private IDiscordClient _client;
private IDbConnection _connection;
private LogChannelService _logger;
private MessageStore _messageStorage;
private ConcurrentDictionary<ulong, Lazy<Task<IWebhook>>> _webhooks;
public ProxyService(IDiscordClient client, IDbConnection connection, LogChannelService logger, MessageStore messageStorage)
{
this._client = client;
this._connection = connection;
this._logger = logger;
this._messageStorage = messageStorage;
_webhooks = new ConcurrentDictionary<ulong, Lazy<Task<IWebhook>>>();
}
private ProxyMatch GetProxyTagMatch(string message, IEnumerable<ProxyDatabaseResult> potentials) {
// TODO: add detection of leading @mention
// Sort by specificity (prefix+suffix first, prefix/suffix second)
var ordered = potentials.OrderByDescending((p) => (p.Member.Prefix != null ? 0 : 1) + (p.Member.Suffix != null ? 0 : 1));
foreach (var potential in ordered) {
var prefix = potential.Member.Prefix ?? "";
var suffix = potential.Member.Suffix ?? "";
if (message.StartsWith(prefix) && message.EndsWith(suffix)) {
var inner = message.Substring(prefix.Length, message.Length - prefix.Length - suffix.Length);
return new ProxyMatch { Member = potential.Member, System = potential.System, InnerText = inner };
}
}
return null;
}
public async Task HandleMessageAsync(IMessage message) {
var results = await _connection.QueryAsync<PKMember, PKSystem, ProxyDatabaseResult>("select members.*, systems.* from members, systems, accounts where members.system = systems.id and accounts.system = systems.id and accounts.uid = @Uid", (member, system) => new ProxyDatabaseResult { Member = member, System = system }, new { Uid = message.Author.Id });
// Find a member with proxy tags matching the message
var match = GetProxyTagMatch(message.Content, results);
if (match == null) return;
// Fetch a webhook for this channel, and send the proxied message
var webhook = await GetWebhookByChannelCaching(message.Channel as ITextChannel);
var hookMessage = await ExecuteWebhook(webhook, match.InnerText, match.ProxyName, match.Member.AvatarUrl, message.Attachments.FirstOrDefault());
// Store the message in the database, and log it in the log channel (if applicable)
await _messageStorage.Store(message.Author.Id, hookMessage.Id, hookMessage.Channel.Id, match.Member);
await _logger.LogMessage(match.System, match.Member, hookMessage, message.Author);
// Wait a second or so before deleting the original message
await Task.Delay(1000);
await message.DeleteAsync();
}
private async Task<IMessage> ExecuteWebhook(IWebhook webhook, string text, string username, string avatarUrl, IAttachment attachment) {
var client = new DiscordWebhookClient(webhook);
ulong messageId;
if (attachment != null) {
using (var stream = await WebRequest.CreateHttp(attachment.Url).GetRequestStreamAsync()) {
messageId = await client.SendFileAsync(stream, filename: attachment.Filename, text: text, username: username, avatarUrl: avatarUrl);
}
} else {
messageId = await client.SendMessageAsync(text, username: username, avatarUrl: avatarUrl);
}
return await webhook.Channel.GetMessageAsync(messageId);
}
private async Task<IWebhook> GetWebhookByChannelCaching(ITextChannel channel) {
// We cache the webhook through a Lazy<Task<T>>, this way we make sure to only create one webhook per channel
// TODO: make sure this is sharding-safe. Intuition says yes, since one channel is guaranteed to only be handled by one shard, but best to make sure
var webhookFactory = _webhooks.GetOrAdd(channel.Id, new Lazy<Task<IWebhook>>(() => FindWebhookByChannel(channel)));
return await webhookFactory.Value;
}
private async Task<IWebhook> FindWebhookByChannel(ITextChannel channel) {
IWebhook webhook;
webhook = (await channel.GetWebhooksAsync()).FirstOrDefault(IsWebhookMine);
if (webhook != null) return webhook;
webhook = await channel.CreateWebhookAsync("PluralKit Proxy Webhook");
return webhook;
}
private bool IsWebhookMine(IWebhook arg)
{
return arg.Creator.Id == this._client.CurrentUser.Id && arg.Name == "PluralKit Proxy Webhook";
}
public async Task HandleReactionAddedAsync(Cacheable<IUserMessage, ulong> message, ISocketMessageChannel channel, SocketReaction reaction)
{
// Make sure it's the right emoji (red X)
if (reaction.Emote.Name != "\u274C") return;
// Find the message in the database
var storedMessage = await _messageStorage.Get(message.Id);
if (storedMessage == null) return; // (if we can't, that's ok, no worries)
// Make sure it's the actual sender of that message deleting the message
if (storedMessage.SenderId != reaction.UserId) return;
try {
// Then, fetch the Discord message and delete that
// TODO: this could be faster if we didn't bother fetching it and just deleted it directly
// somehow through REST?
await (await message.GetOrDownloadAsync()).DeleteAsync();
} catch (NullReferenceException) {
// Message was deleted before we got to it... cool, no problem, lmao
}
// Finally, delete it from our database.
await _messageStorage.Delete(message.Id);
}
public async Task HandleMessageDeletedAsync(Cacheable<IMessage, ulong> message, ISocketMessageChannel channel)
{
await _messageStorage.Delete(message.Id);
}
}
}

130
PluralKit/Stores.cs Normal file
View File

@ -0,0 +1,130 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Dapper.Contrib.Extensions;
namespace PluralKit {
public class SystemStore {
private IDbConnection conn;
public SystemStore(IDbConnection conn) {
this.conn = conn;
}
public async Task<PKSystem> Create(string systemName = null) {
// TODO: handle HID collision case
var hid = HidUtils.GenerateHid();
return await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (@Hid, @Name) returning *", new { Hid = hid, Name = systemName });
}
public async Task<PKSystem> GetByAccount(ulong accountId) {
return await conn.QuerySingleAsync<PKSystem>("select systems.* from systems, accounts where accounts.system = system.id and accounts.uid = @Id", new { Id = accountId });
}
public async Task<PKSystem> GetByHid(string hid) {
return await conn.QuerySingleAsync<PKSystem>("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKSystem> GetByToken(string token) {
return await conn.QuerySingleAsync<PKSystem>("select * from systems where token = @Token", new { Token = token });
}
public async Task Save(PKSystem system) {
await conn.UpdateAsync(system);
}
public async Task Delete(PKSystem system) {
await conn.DeleteAsync(system);
}
}
public class MemberStore {
private IDbConnection conn;
public MemberStore(IDbConnection conn) {
this.conn = conn;
}
public async Task<PKMember> Create(PKSystem system, string name) {
// TODO: handle collision
var hid = HidUtils.GenerateHid();
return await conn.QuerySingleAsync("insert into members (hid, system, name) values (@Hid, @SystemId, @Name) returning *", new {
Hid = hid,
SystemID = system.Id,
Name = name
});
}
public async Task<PKMember> GetByHid(string hid) {
return await conn.QuerySingleAsync("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKMember> GetByName(string name) {
return await conn.QuerySingleAsync("select * from members where lower(name) = lower(@Name)", new { Name = name });
}
public async Task<PKMember> GetByNameConstrained(PKSystem system, string name) {
return await conn.QuerySingleAsync("select * from members where lower(name) = @Name and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public async Task<ICollection<PKMember>> GetUnproxyableMembers(PKSystem system) {
return (await GetBySystem(system))
.Where((m) => {
var proxiedName = $"{m.Name} {system.Tag}";
return proxiedName.Length > 32 || proxiedName.Length < 2;
}).ToList();
}
public async Task<IEnumerable<PKMember>> GetBySystem(PKSystem system) {
return await conn.QueryAsync<PKMember>("select * from members where system = @SystemID", new { SystemID = system.Id });
}
public async Task Save(PKMember member) {
await conn.UpdateAsync(member);
}
public async Task Delete(PKMember member) {
await conn.DeleteAsync(member);
}
}
public class MessageStore {
public class StoredMessage {
public ulong Mid;
public ulong ChannelId;
public ulong SenderId;
public PKMember Member;
public PKSystem System;
}
private IDbConnection _connection;
public MessageStore(IDbConnection connection) {
this._connection = connection;
}
public async Task Store(ulong senderId, ulong messageId, ulong channelId, PKMember member) {
await _connection.ExecuteAsync("insert into messages(mid, channel, member, sender) values(@MessageId, @ChannelId, @MemberId, @SenderId)", new {
MessageId = messageId,
ChannelId = channelId,
MemberId = member.Id,
SenderId = senderId
});
}
public async Task<StoredMessage> Get(ulong id) {
return (await _connection.QueryAsync<StoredMessage, PKMember, PKSystem, StoredMessage>("select * from messages, members, systems where mid = @Id and messages.member = members.id and systems.id = members.system", (msg, member, system) => {
msg.System = system;
msg.Member = member;
return msg;
}, new { Id = id })).First();
}
public async Task Delete(ulong id) {
await _connection.ExecuteAsync("delete from messages where mid = @Id", new { Id = id });
}
}
}

176
PluralKit/Utils.cs Normal file
View File

@ -0,0 +1,176 @@
using System;
using System.Data;
using System.Threading.Tasks;
using Dapper;
using Discord;
using Discord.Commands;
using Discord.Commands.Builders;
using Discord.WebSocket;
using Microsoft.Extensions.DependencyInjection;
namespace PluralKit
{
class UlongEncodeAsLongHandler : SqlMapper.TypeHandler<ulong>
{
public override ulong Parse(object value)
{
// Cast to long to unbox, then to ulong (???)
return (ulong)(long)value;
}
public override void SetValue(IDbDataParameter parameter, ulong value)
{
parameter.Value = (long)value;
}
}
class PKSystemTypeReader : TypeReader
{
public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services)
{
var client = services.GetService<IDiscordClient>();
var conn = services.GetService<IDbConnection>();
// System references can take three forms:
// - The direct user ID of an account connected to the system
// - A @mention of an account connected to the system (<@uid>)
// - A system hid
// First, try direct user ID parsing
if (ulong.TryParse(input, out var idFromNumber)) return await FindSystemByAccountHelper(idFromNumber, client, conn);
// Then, try mention parsing.
if (MentionUtils.TryParseUser(input, out var idFromMention)) return await FindSystemByAccountHelper(idFromMention, client, conn);
// Finally, try HID parsing
var res = await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where hid = @Hid", new { Hid = input });
if (res != null) return TypeReaderResult.FromSuccess(res);
return TypeReaderResult.FromError(CommandError.ObjectNotFound, $"System with ID `${input}` not found.");
}
async Task<TypeReaderResult> FindSystemByAccountHelper(ulong id, IDiscordClient client, IDbConnection conn)
{
var foundByAccountId = await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from accounts, systems where accounts.system = system.id and accounts.id = @Id", new { Id = id });
if (foundByAccountId != null) return TypeReaderResult.FromSuccess(foundByAccountId);
// We didn't find any, so we try to resolve the user ID to find the associated account,
// so we can print their username.
var user = await client.GetUserAsync(id);
// Return descriptive errors based on whether we found the user or not.
if (user == null) return TypeReaderResult.FromError(CommandError.ObjectNotFound, $"System or account with ID `${id}` not found.");
return TypeReaderResult.FromError(CommandError.ObjectNotFound, $"Account **${user.Username}#${user.Discriminator}** not found.");
}
}
class PKMemberTypeReader : TypeReader
{
public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services)
{
var conn = services.GetService(typeof(IDbConnection)) as IDbConnection;
// If the sender of the command is in a system themselves,
// then try searching by the member's name
if (context is PKCommandContext ctx && ctx.SenderSystem != null)
{
var foundByName = await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where system = @System and lower(name) = lower(@Name)", new { System = ctx.SenderSystem.Id, Name = input });
if (foundByName != null) return TypeReaderResult.FromSuccess(foundByName);
}
// Otherwise, if sender isn't in a system, or no member found by that name,
// do a standard by-hid search.
var foundByHid = await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where hid = @Hid", new { Hid = input });
if (foundByHid != null) return TypeReaderResult.FromSuccess(foundByHid);
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Member not found.");
}
}
/// Subclass of ICommandContext with PK-specific additional fields and functionality
public class PKCommandContext : SocketCommandContext, ICommandContext
{
public IDbConnection Connection { get; }
public PKSystem SenderSystem { get; }
public PKCommandContext(DiscordSocketClient client, SocketUserMessage msg, IDbConnection connection, PKSystem system) : base(client, msg)
{
Connection = connection;
SenderSystem = system;
}
}
public class ContextualContext<T> : PKCommandContext
{
public T ContextEntity { get; internal set; }
public ContextualContext(PKCommandContext ctx, T contextEntity): base(ctx.Client, ctx.Message, ctx.Connection, ctx.SenderSystem)
{
this.ContextEntity = contextEntity;
}
}
public abstract class ContextParameterModuleBase<T> : ModuleBase<ContextualContext<T>>
{
public IServiceProvider _services { get; set; }
public CommandService _commands { get; set; }
public abstract string Prefix { get; }
public abstract Task<T> ReadContextParameterAsync(string value);
protected override void OnModuleBuilding(CommandService commandService, ModuleBuilder builder) {
// We create a catch-all command that intercepts the first argument, tries to parse it as
// the context parameter, then runs the command service AGAIN with that given in a wrapped
// context, with the context argument removed so it delegates to the subcommand executor
builder.AddCommand("", async (ctx, param, services, info) => {
var pkCtx = ctx as PKCommandContext;
var res = await ReadContextParameterAsync(param[0] as string);
await commandService.ExecuteAsync(new ContextualContext<T>(pkCtx, res), Prefix + " " + param[1] as string, services);
}, (cb) => {
cb.WithPriority(-9999);
cb.AddPrecondition(new ContextParameterFallbackPreconditionAttribute());
cb.AddParameter<string>("contextValue", (pb) => pb.WithDefault(""));
cb.AddParameter<string>("rest", (pb) => pb.WithDefault("").WithIsRemainder(true));
});
}
}
public class ContextParameterFallbackPreconditionAttribute : PreconditionAttribute
{
public ContextParameterFallbackPreconditionAttribute()
{
}
public override async Task<PreconditionResult> CheckPermissionsAsync(ICommandContext context, CommandInfo command, IServiceProvider services)
{
if (context.GetType().Name != "ContextualContext`1") {
return PreconditionResult.FromSuccess();
} else {
return PreconditionResult.FromError("");
}
}
}
public class HidUtils
{
public static string GenerateHid()
{
var rnd = new Random();
var charset = "abcdefghijklmnopqrstuvwxyz";
string hid = "";
for (int i = 0; i < 5; i++)
{
hid += charset[rnd.Next(charset.Length)];
}
return hid;
}
}
public class PKResult : RuntimeResult
{
public PKResult(CommandError? error, string reason) : base(error, reason)
{
}
public static RuntimeResult Error(string reason) => new PKResult(CommandError.Unsuccessful, reason);
public static RuntimeResult Success(string reason = null) => new PKResult(null, reason);
}
}

View File

@ -1,38 +1,11 @@
version: '3' version: "3"
services: services:
bot: bot:
build: src/ build: .
entrypoint:
- python
- bot_main.py
volumes:
- "./pluralkit.conf:/app/pluralkit.conf:ro"
environment: environment:
- "DATABASE_URI=postgres://postgres:postgres@db:5432/postgres" - PK_TOKEN
depends_on: - "PK_DATABASE_URI=Host=db;Username=postgres;Password=postgres;Database=postgres"
- db links:
restart: always - db
api:
build: src/
entrypoint:
- python
- api_main.py
depends_on:
- db
restart: always
ports:
- "2939:8080"
environment:
- "DATABASE_URI=postgres://postgres:postgres@db:5432/postgres"
- "CLIENT_ID"
- "INVITE_CLIENT_ID_OVERRIDE"
- "CLIENT_SECRET"
- "REDIRECT_URI"
db: db:
image: postgres:alpine image: postgres:alpine
volumes:
- "db_data:/var/lib/postgresql/data"
restart: always
volumes:
db_data:

View File

@ -1,5 +0,0 @@
{
"database_uri": "postgres://username:password@hostname:port/database_name",
"token": "BOT_TOKEN_GOES_HERE",
"log_channel": null
}

View File

@ -1,10 +0,0 @@
FROM python:3.6-alpine
RUN apk --no-cache add build-base libffi-dev
WORKDIR /app
ADD requirements.txt /app
RUN pip install --trusted-host pypi.python.org -r requirements.txt
ADD . /app

View File

@ -1,254 +0,0 @@
import json
import logging
import os
from aiohttp import web, ClientSession
from pluralkit import db, utils
from pluralkit.errors import PluralKitError
from pluralkit.member import Member
from pluralkit.system import System
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logger = logging.getLogger("pluralkit.api")
def require_system(f):
async def inner(request):
if "system" not in request:
raise web.HTTPUnauthorized()
return await f(request)
return inner
@web.middleware
async def error_middleware(request, handler):
try:
return await handler(request)
except json.JSONDecodeError:
raise web.HTTPBadRequest()
except PluralKitError as e:
return web.json_response({"error": e.message}, status=400)
@web.middleware
async def db_middleware(request, handler):
async with request.app["pool"].acquire() as conn:
request["conn"] = conn
return await handler(request)
@web.middleware
async def auth_middleware(request, handler):
token = request.headers.get("X-Token") or request.query.get("token")
if token:
system = await System.get_by_token(request["conn"], token)
if system:
request["system"] = system
return await handler(request)
@web.middleware
async def cors_middleware(request, handler):
try:
resp = await handler(request)
except web.HTTPException as r:
resp = r
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Methods"] = "GET, POST, PATCH"
resp.headers["Access-Control-Allow-Headers"] = "X-Token"
return resp
class Handlers:
@require_system
async def get_system(request):
return web.json_response(request["system"].to_json())
async def get_other_system(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
return web.json_response(system.to_json())
async def get_system_members(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
members = await system.get_members(request["conn"])
return web.json_response([m.to_json() for m in members])
async def get_system_switches(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
switches = await system.get_switches(request["conn"], 9999)
cache = {}
async def hid_getter(member_id):
if not member_id in cache:
cache[member_id] = await Member.get_member_by_id(request["conn"], member_id)
return cache[member_id].hid
return web.json_response([await s.to_json(hid_getter) for s in switches])
async def get_system_fronters(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
members, stamp = await utils.get_fronters(request["conn"], system.id)
if not stamp:
# No switch has been registered at all
raise web.HTTPNotFound(body="null")
data = {
"timestamp": stamp.isoformat(),
"members": [member.to_json() for member in members]
}
return web.json_response(data)
@require_system
async def patch_system(request):
req = await request.json()
if "name" in req:
await request["system"].set_name(request["conn"], req["name"])
if "description" in req:
await request["system"].set_description(request["conn"], req["description"])
if "tag" in req:
await request["system"].set_tag(request["conn"], req["tag"])
if "avatar_url" in req:
await request["system"].set_avatar(request["conn"], req["name"])
if "tz" in req:
await request["system"].set_time_zone(request["conn"], req["tz"])
return web.json_response((await System.get_by_id(request["conn"], request["system"].id)).to_json())
async def get_member(request):
member_id = request.match_info.get("member")
member = await Member.get_member_by_hid(request["conn"], None, member_id)
if not member:
raise web.HTTPNotFound(body="{}")
system = await System.get_by_id(request["conn"], member.system)
member_json = member.to_json()
member_json["system"] = system.to_json()
return web.json_response(member_json)
@require_system
async def post_member(request):
req = await request.json()
member = await request["system"].create_member(request["conn"], req["name"])
return web.json_response(member.to_json())
@require_system
async def patch_member(request):
member_id = request.match_info.get("member")
member = await Member.get_member_by_hid(request["conn"], None, member_id)
if not member:
raise web.HTTPNotFound()
if member.system != request["system"].id:
raise web.HTTPUnauthorized()
req = await request.json()
if "name" in req:
await member.set_name(request["conn"], req["name"])
if "description" in req:
await member.set_description(request["conn"], req["description"])
if "avatar_url" in req:
await member.set_avatar_url(request["conn"], req["avatar_url"])
if "color" in req:
await member.set_color(request["conn"], req["color"])
if "birthday" in req:
await member.set_birthdate(request["conn"], req["birthday"])
if "pronouns" in req:
await member.set_pronouns(request["conn"], req["pronouns"])
if "prefix" in req or "suffix" in req:
await member.set_proxy_tags(request["conn"], req.get("prefix", member.prefix), req.get("suffix", member.suffix))
return web.json_response((await Member.get_member_by_id(request["conn"], member.id)).to_json())
@require_system
async def delete_member(request):
member_id = request.match_info.get("member")
member = await Member.get_member_by_hid(request["conn"], None, member_id)
if not member:
raise web.HTTPNotFound()
if member.system != request["system"].id:
raise web.HTTPUnauthorized()
await member.delete(request["conn"])
@require_system
async def post_switch(request):
req = await request.json()
if isinstance(req, str):
req = [req]
if req is None:
req = []
if not isinstance(req, list):
raise web.HTTPBadRequest()
members = [await Member.get_member_by_hid(request["conn"], request["system"].id, hid) for hid in req]
if not all(members):
raise web.HTTPNotFound(body=json.dumps({"error": "One or more members not found."}))
switch = await request["system"].add_switch(request["conn"], members)
hids = {member.id: member.hid for member in members}
async def hid_getter(mid):
return hids[mid]
return web.json_response(await switch.to_json(hid_getter))
async def discord_oauth(request):
code = await request.text()
async with ClientSession() as sess:
data = {
'client_id': os.environ["CLIENT_ID"],
'client_secret': os.environ["CLIENT_SECRET"],
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': os.environ["REDIRECT_URI"],
'scope': 'identify'
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
res = await sess.post("https://discordapp.com/api/v6/oauth2/token", data=data, headers=headers)
if res.status != 200:
raise web.HTTPBadRequest()
access_token = (await res.json())["access_token"]
res = await sess.get("https://discordapp.com/api/v6/users/@me", headers={"Authorization": "Bearer " + access_token})
user_id = int((await res.json())["id"])
system = await System.get_by_account(request["conn"], user_id)
if not system:
raise web.HTTPUnauthorized()
return web.Response(text=await system.get_token(request["conn"]))
async def run():
app = web.Application(middlewares=[cors_middleware, db_middleware, auth_middleware, error_middleware])
def cors_fallback(req):
return web.Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "x-token", "Access-Control-Allow-Methods": "GET, POST, PATCH"}, status=404 if req.method != "OPTIONS" else 200)
app.add_routes([
web.get("/s", Handlers.get_system),
web.post("/s/switches", Handlers.post_switch),
web.get("/s/{system}", Handlers.get_other_system),
web.get("/s/{system}/members", Handlers.get_system_members),
web.get("/s/{system}/switches", Handlers.get_system_switches),
web.get("/s/{system}/fronters", Handlers.get_system_fronters),
web.patch("/s", Handlers.patch_system),
web.get("/m/{member}", Handlers.get_member),
web.post("/m", Handlers.post_member),
web.patch("/m/{member}", Handlers.patch_member),
web.delete("/m/{member}", Handlers.delete_member),
web.post("/discord_oauth", Handlers.discord_oauth),
web.route("*", "/{tail:.*}", cors_fallback)
])
app["pool"] = await db.connect(
os.environ["DATABASE_URI"]
)
return app
web.run_app(run())

View File

@ -1,12 +0,0 @@
import asyncio
import sys
try:
# uvloop doesn't work on Windows, therefore an optional dependency
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
from pluralkit import bot
bot.run(bot.Config.from_file_and_env(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf"))

View File

@ -1,148 +0,0 @@
import asyncio
import sys
import asyncpg
from collections import namedtuple
import discord
import logging
import json
import os
import traceback
from pluralkit import db
from pluralkit.bot import commands, proxy, channel_logger, embeds
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
class Config:
required_fields = ["database_uri", "token"]
fields = ["database_uri", "token", "log_channel"]
database_uri: str
token: str
log_channel: str
def __init__(self, database_uri: str, token: str, log_channel: str = None):
self.database_uri = database_uri
self.token = token
self.log_channel = log_channel
@staticmethod
def from_file_and_env(filename: str) -> "Config":
try:
with open(filename, "r") as f:
config = json.load(f)
except IOError as e:
# If all the required fields are specified as environment variables, it's OK to
# not raise the IOError, we can just construct the dict from these
if all([rf.upper() in os.environ for rf in Config.required_fields]):
config = {}
else:
# If they aren't, though, then rethrow
raise e
# Override with environment variables
for f in Config.fields:
if f.upper() in os.environ:
config[f] = os.environ[f.upper()]
# If we currently don't have all the required fields, then raise
if not all([rf in config for rf in Config.required_fields]):
raise RuntimeError("Some required config fields were missing: " + ", ".join(filter(lambda rf: rf not in config, Config.required_fields)))
return Config(**config)
def connect_to_database(uri: str) -> asyncpg.pool.Pool:
return asyncio.get_event_loop().run_until_complete(db.connect(uri))
def run(config: Config):
pool = connect_to_database(config.database_uri)
async def create_tables():
async with pool.acquire() as conn:
await db.create_tables(conn)
asyncio.get_event_loop().run_until_complete(create_tables())
client = discord.AutoShardedClient()
logger = channel_logger.ChannelLogger(client)
@client.event
async def on_ready():
print("PluralKit started.")
print("User: {}#{} (ID: {})".format(client.user.name, client.user.discriminator, client.user.id))
print("{} servers".format(len(client.guilds)))
print("{} shards".format(client.shard_count or 1))
await client.change_presence(activity=discord.Game(name="pk;help \u2014 in {} servers".format(len(client.guilds))))
@client.event
async def on_message(message: discord.Message):
# Ignore messages from bots
if message.author.bot:
return
# Grab a database connection from the pool
async with pool.acquire() as conn:
# First pass: do command handling
did_run_command = await commands.command_dispatch(client, message, conn)
if did_run_command:
return
# Second pass: do proxy matching
await proxy.try_proxy_message(conn, message, logger, client.user)
@client.event
async def on_raw_message_delete(payload: discord.RawMessageDeleteEvent):
async with pool.acquire() as conn:
await proxy.handle_deleted_message(conn, client, payload.message_id, None, logger)
@client.event
async def on_raw_bulk_message_delete(payload: discord.RawBulkMessageDeleteEvent):
async with pool.acquire() as conn:
for message_id in payload.message_ids:
await proxy.handle_deleted_message(conn, client, message_id, None, logger)
@client.event
async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
if payload.emoji.name == "\u274c": # Red X
async with pool.acquire() as conn:
await proxy.try_delete_by_reaction(conn, client, payload.message_id, payload.user_id, logger)
if payload.emoji.name in "\u2753\u2754": # Question mark
async with pool.acquire() as conn:
await proxy.do_query_message(conn, client, payload.user_id, payload.message_id)
@client.event
async def on_error(event_name, *args, **kwargs):
# Print it to stderr
logging.getLogger("pluralkit").exception("Exception while handling event {}".format(event_name))
# Then log it to the given log channel
# TODO: replace this with Sentry or something
if not config.log_channel:
return
log_channel = client.get_channel(int(config.log_channel))
# If this is a message event, we can attach additional information in an event
# ie. username, channel, content, etc
if args and isinstance(args[0], discord.Message):
message: discord.Message = args[0]
embed = embeds.exception_log(
message.content,
message.author.name,
message.author.discriminator,
message.author.id,
message.guild.id if message.guild else None,
message.channel.id
)
else:
# If not, just post the string itself
embed = None
traceback_str = "```python\n{}```".format(traceback.format_exc())
if len(traceback.format_exc()) >= (2000 - len("```python\n```")):
traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):])
await log_channel.send(content=traceback_str, embed=embed)
client.run(config.token)

View File

@ -1,104 +0,0 @@
import discord
import logging
from datetime import datetime
from pluralkit import db
def embed_set_author_name(embed: discord.Embed, channel_name: str, member_name: str, system_name: str, avatar_url: str):
name = "#{}: {}".format(channel_name, member_name)
if system_name:
name += " ({})".format(system_name)
embed.set_author(name=name, icon_url=avatar_url or discord.Embed.Empty)
class ChannelLogger:
def __init__(self, client: discord.Client):
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
self.client = client
async def get_log_channel(self, conn, server_id: int):
server_info = await db.get_server_info(conn, server_id)
if not server_info:
return None
log_channel = server_info["log_channel"]
if not log_channel:
return None
return self.client.get_channel(log_channel)
async def send_to_log_channel(self, log_channel: discord.TextChannel, embed: discord.Embed, text: str = None):
try:
await log_channel.send(content=text, embed=embed)
except discord.Forbidden:
# TODO: spew big error
self.logger.warning(
"Did not have permission to send message to logging channel (server={}, channel={})".format(
log_channel.guild.id, log_channel.id))
async def log_message_proxied(self, conn,
server_id: int,
channel_name: str,
channel_id: int,
sender_name: str,
sender_disc: int,
sender_id: int,
member_name: str,
member_hid: str,
member_avatar_url: str,
system_name: str,
system_hid: str,
message_text: str,
message_image: str,
message_timestamp: datetime,
message_id: int):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message_text
embed.timestamp = message_timestamp
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
embed.set_footer(
text="System ID: {} | Member ID: {} | Sender: {}#{} ({}) | Message ID: {}".format(system_hid, member_hid,
sender_name, sender_disc,
sender_id, message_id))
if message_image:
embed.set_thumbnail(url=message_image)
await self.send_to_log_channel(log_channel, embed, message_link)
async def log_message_deleted(self, conn,
server_id: int,
channel_name: str,
member_name: str,
member_hid: str,
member_avatar_url: str,
system_name: str,
system_hid: str,
message_text: str,
message_id: int):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message_text or "*(unknown, message deleted by moderator)*"
embed.timestamp = datetime.utcnow()
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
embed.set_footer(
text="System ID: {} | Member ID: {} | Message ID: {}".format(system_hid, member_hid, message_id))
await self.send_to_log_channel(log_channel, embed)

View File

@ -1,243 +0,0 @@
import asyncio
from datetime import datetime
import discord
import re
from typing import Tuple, Optional, Union
from pluralkit import db
from pluralkit.bot import embeds, utils
from pluralkit.errors import PluralKitError
from pluralkit.member import Member
from pluralkit.system import System
def next_arg(arg_string: str) -> Tuple[str, Optional[str]]:
# A basic quoted-arg parser
for quote in "“‟”":
arg_string = arg_string.replace(quote, "\"")
if arg_string.startswith("\""):
end_quote = arg_string[1:].find("\"") + 1
if end_quote > 0:
return arg_string[1:end_quote], arg_string[end_quote + 1:].strip()
else:
return arg_string[1:], None
next_space = arg_string.find(" ")
if next_space >= 0:
return arg_string[:next_space].strip(), arg_string[next_space:].strip()
else:
return arg_string.strip(), None
class CommandError(Exception):
def __init__(self, text: str, help: Tuple[str, str] = None):
self.text = text
self.help = help
def format(self):
return "\u274c " + self.text, embeds.error("", self.help) if self.help else None
class CommandContext:
client: discord.Client
message: discord.Message
def __init__(self, client: discord.Client, message: discord.Message, conn, args: str, system: Optional[System]):
self.client = client
self.message = message
self.conn = conn
self.args = args
self._system = system
async def get_system(self) -> Optional[System]:
return self._system
async def ensure_system(self) -> System:
system = await self.get_system()
if not system:
raise CommandError("No system registered to this account. Use `pk;system new` to register one.")
return system
def has_next(self) -> bool:
return bool(self.args)
def format_time(self, dt: datetime):
if self._system:
return self._system.format_time(dt)
return dt.isoformat(sep=" ", timespec="seconds") + " UTC"
def pop_str(self, error: CommandError = None) -> Optional[str]:
if not self.args:
if error:
raise error
return None
popped, self.args = next_arg(self.args)
return popped
def peek_str(self) -> Optional[str]:
if not self.args:
return None
popped, _ = next_arg(self.args)
return popped
def match(self, next) -> bool:
peeked = self.peek_str()
if peeked and peeked.lower() == next.lower():
self.pop_str()
return True
return False
async def pop_system(self, error: CommandError = None) -> System:
name = self.pop_str(error)
system = await utils.get_system_fuzzy(self.conn, self.client, name)
if not system:
raise CommandError("Unable to find system '{}'.".format(name))
return system
async def pop_member(self, error: CommandError = None, system_only: bool = True) -> Member:
name = self.pop_str(error)
if system_only:
system = await self.ensure_system()
else:
system = await self.get_system()
member = await utils.get_member_fuzzy(self.conn, system.id if system else None, name, system_only)
if not member:
raise CommandError("Unable to find member '{}'{}.".format(name, " in your system" if system_only else ""))
return member
def remaining(self):
return self.args
async def reply(self, content=None, embed=None):
return await self.message.channel.send(content=content, embed=embed)
async def reply_ok(self, content=None, embed=None):
return await self.reply(content="\u2705 {}".format(content or ""), embed=embed)
async def reply_warn(self, content=None, embed=None):
return await self.reply(content="\u26a0 {}".format(content or ""), embed=embed)
async def reply_ok_dm(self, content: str):
if isinstance(self.message.channel, discord.DMChannel):
await self.reply_ok(content="\u2705 {}".format(content or ""))
else:
await self.message.author.send(content="\u2705 {}".format(content or ""))
await self.reply_ok("DM'd!")
async def confirm_react(self, user: Union[discord.Member, discord.User], message: discord.Message):
await message.add_reaction("\u2705") # Checkmark
await message.add_reaction("\u274c") # Red X
try:
reaction, _ = await self.client.wait_for("reaction_add",
check=lambda r, u: u.id == user.id and r.emoji in ["\u2705",
"\u274c"],
timeout=60.0 * 5)
return reaction.emoji == "\u2705"
except asyncio.TimeoutError:
raise CommandError("Timed out - try again.")
async def confirm_text(self, user: discord.Member, channel: discord.TextChannel, confirm_text: str, message: str):
await self.reply(message)
try:
message = await self.client.wait_for("message",
check=lambda m: m.channel.id == channel.id and m.author.id == user.id,
timeout=60.0 * 5)
return message.content.lower() == confirm_text.lower()
except asyncio.TimeoutError:
raise CommandError("Timed out - try again.")
import pluralkit.bot.commands.api_commands
import pluralkit.bot.commands.import_commands
import pluralkit.bot.commands.member_commands
import pluralkit.bot.commands.message_commands
import pluralkit.bot.commands.misc_commands
import pluralkit.bot.commands.mod_commands
import pluralkit.bot.commands.switch_commands
import pluralkit.bot.commands.system_commands
async def command_root(ctx: CommandContext):
if ctx.match("system") or ctx.match("s"):
await system_commands.system_root(ctx)
elif ctx.match("member") or ctx.match("m"):
await member_commands.member_root(ctx)
elif ctx.match("link"):
await system_commands.account_link(ctx)
elif ctx.match("unlink"):
await system_commands.account_unlink(ctx)
elif ctx.match("message"):
await message_commands.message_info(ctx)
elif ctx.match("log"):
await mod_commands.set_log(ctx)
elif ctx.match("invite"):
await misc_commands.invite_link(ctx)
elif ctx.match("export"):
await misc_commands.export(ctx)
elif ctx.match("switch") or ctx.match("sw"):
await switch_commands.switch_root(ctx)
elif ctx.match("token"):
await api_commands.token_root(ctx)
elif ctx.match("import"):
await import_commands.import_root(ctx)
elif ctx.match("help"):
await misc_commands.help_root(ctx)
elif ctx.match("tell"):
await misc_commands.tell(ctx)
elif ctx.match("fire"):
await misc_commands.pkfire(ctx)
elif ctx.match("thunder"):
await misc_commands.pkthunder(ctx)
elif ctx.match("freeze"):
await misc_commands.pkfreeze(ctx)
elif ctx.match("starstorm"):
await misc_commands.pkstarstorm(ctx)
elif ctx.match("commands"):
await misc_commands.command_list(ctx)
else:
raise CommandError("Unknown command {}. For a list of commands, type `pk;commands`.".format(ctx.pop_str()))
async def run_command(ctx: CommandContext, func):
# lol nested try
try:
try:
await func(ctx)
except PluralKitError as e:
raise CommandError(e.message, e.help_page)
except CommandError as e:
content, embed = e.format()
await ctx.reply(content=content, embed=embed)
async def command_dispatch(client: discord.Client, message: discord.Message, conn) -> bool:
prefix = "^(pk(;|!)|<@{}> )".format(client.user.id)
regex = re.compile(prefix, re.IGNORECASE)
cmd = message.content
match = regex.match(cmd)
if match:
remaining_string = cmd[match.span()[1]:].strip()
ctx = CommandContext(
client=client,
message=message,
conn=conn,
args=remaining_string,
system=await System.get_by_account(conn, message.author.id)
)
await run_command(ctx, command_root)
return True
return False

View File

@ -1,35 +0,0 @@
from pluralkit.bot.commands import CommandContext
disclaimer = "\u26A0 Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure. If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`."
async def token_root(ctx: CommandContext):
if ctx.match("refresh") or ctx.match("expire") or ctx.match("invalidate") or ctx.match("update"):
await token_refresh(ctx)
else:
await token_get(ctx)
async def token_get(ctx: CommandContext):
system = await ctx.ensure_system()
if system.token:
token = system.token
else:
token = await system.refresh_token(ctx.conn)
token_message = "{}\n\u2705 Here's your API token:".format(disclaimer)
if token:
await ctx.reply_ok("DM'd!")
await ctx.message.author.send(token_message)
await ctx.message.author.send(token)
return
async def token_refresh(ctx: CommandContext):
system = await ctx.ensure_system()
token = await system.refresh_token(ctx.conn)
token_message = "Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\n{}\n\u2705 Here's your new API token:".format(disclaimer)
if token:
await ctx.message.author.send(token_message)
await ctx.message.author.send(token)

View File

@ -1,49 +0,0 @@
import aiohttp
import asyncio
import io
import json
import os
from datetime import datetime
from pluralkit.errors import TupperboxImportError
from pluralkit.bot.commands import *
async def import_root(ctx: CommandContext):
# Only one import method rn, so why not default to Tupperbox?
await import_tupperbox(ctx)
async def import_tupperbox(ctx: CommandContext):
await ctx.reply("To import from Tupperbox, reply to this message with a `tuppers.json` file imported from Tupperbox.\n\nTo obtain such a file, type `tul!export` (or your server's equivalent).")
def predicate(msg):
if msg.author.id != ctx.message.author.id:
return False
if msg.attachments:
if msg.attachments[0].filename.endswith(".json"):
return True
return False
try:
message = await ctx.client.wait_for("message", check=predicate, timeout=60*5)
except asyncio.TimeoutError:
raise CommandError("Timed out. Try running `pk;import` again.")
s = io.BytesIO()
await message.attachments[0].save(s, use_cached=False)
data = json.load(s)
system = await ctx.get_system()
if not system:
system = await System.create_system(ctx.conn, account_id=ctx.message.author.id)
result = await system.import_from_tupperbox(ctx.conn, data)
tag_note = ""
if len(result.tags) > 1:
tag_note = "\n\nPluralKit's tags work on a per-system basis. Since your Tupperbox members have more than one unique tag, PluralKit has not imported the tags. Set your system tag manually with `pk;system tag <tag>`."
await ctx.reply_ok("Updated {} member{}, created {} member{}. Type `pk;system list` to check!{}".format(
len(result.updated), "s" if len(result.updated) != 1 else "",
len(result.created), "s" if len(result.created) != 1 else "",
tag_note
))

View File

@ -1,192 +0,0 @@
import pluralkit.bot.embeds
from pluralkit.bot import help
from pluralkit.bot.commands import *
from pluralkit.errors import PluralKitError
async def member_root(ctx: CommandContext):
if ctx.match("new") or ctx.match("create") or ctx.match("add") or ctx.match("register"):
await new_member(ctx)
elif ctx.match("set"):
await member_set(ctx)
# TODO "pk;member list"
elif not ctx.has_next():
raise CommandError("Must pass a subcommand. For a list of subcommands, type `pk;help member`.")
else:
await specific_member_root(ctx)
async def specific_member_root(ctx: CommandContext):
member = await ctx.pop_member(system_only=False)
if ctx.has_next():
# Following commands operate on members only in the caller's own system
# error if not, to make sure you can't destructively edit someone else's member
system = await ctx.ensure_system()
if not member.system == system.id:
raise CommandError("Member must be in your own system.")
if ctx.match("name") or ctx.match("rename"):
await member_name(ctx, member)
elif ctx.match("description") or ctx.match("desc"):
await member_description(ctx, member)
elif ctx.match("avatar") or ctx.match("icon"):
await member_avatar(ctx, member)
elif ctx.match("proxy") or ctx.match("tags"):
await member_proxy(ctx, member)
elif ctx.match("pronouns") or ctx.match("pronoun"):
await member_pronouns(ctx, member)
elif ctx.match("color") or ctx.match("colour"):
await member_color(ctx, member)
elif ctx.match("birthday") or ctx.match("birthdate") or ctx.match("bday"):
await member_birthdate(ctx, member)
elif ctx.match("delete") or ctx.match("remove") or ctx.match("destroy") or ctx.match("erase"):
await member_delete(ctx, member)
else:
raise CommandError(
"Unknown subcommand {}. For a list of all commands, type `pk;help member`".format(ctx.pop_str()))
else:
# Basic lookup
await member_info(ctx, member)
async def member_info(ctx: CommandContext, member: Member):
await ctx.reply(embed=await pluralkit.bot.embeds.member_card(ctx.conn, member))
async def new_member(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
raise CommandError("You must pass a name for the new member.")
new_name = ctx.remaining()
existing_member = await Member.get_member_by_name(ctx.conn, system.id, new_name)
if existing_member:
msg = await ctx.reply_warn(
"There is already a member with this name, with the ID `{}`. Do you want to create a duplicate member anyway?".format(
existing_member.hid))
if not await ctx.confirm_react(ctx.message.author, msg):
raise CommandError("Member creation cancelled.")
try:
member = await system.create_member(ctx.conn, new_name)
except PluralKitError as e:
raise CommandError(e.message)
await ctx.reply_ok(
"Member \"{}\" (`{}`) registered! Type `pk;help member` for a list of commands to edit this member.".format(new_name, member.hid))
async def member_set(ctx: CommandContext):
raise CommandError(
"`pk;member set` has been retired. Please use the new member modifying commands. Type `pk;help member` for a list.")
async def member_name(ctx: CommandContext, member: Member):
system = await ctx.ensure_system()
new_name = ctx.pop_str(CommandError("You must pass a new member name."))
# Warn if there's a member by the same name already
existing_member = await Member.get_member_by_name(ctx.conn, system.id, new_name)
if existing_member and existing_member.id != member.id:
msg = await ctx.reply_warn(
"There is already another member with this name, with the ID `{}`. Do you want to rename this member anyway? This will result in two members with the same name.".format(
existing_member.hid))
if not await ctx.confirm_react(ctx.message.author, msg):
raise CommandError("Member renaming cancelled.")
await member.set_name(ctx.conn, new_name)
await ctx.reply_ok("Member name updated.")
if len(new_name) < 2 and not system.tag:
await ctx.reply_warn(
"This member's new name is under 2 characters, and thus cannot be proxied. To prevent this, use a longer member name, or add a system tag.")
elif len(new_name) > 32:
exceeds_by = len(new_name) - 32
await ctx.reply_warn(
"This member's new name is longer than 32 characters, and thus cannot be proxied. To prevent this, shorten the member name by {} characters.".format(
exceeds_by))
elif len(new_name) > system.get_member_name_limit():
exceeds_by = len(new_name) - system.get_member_name_limit()
await ctx.reply_warn(
"This member's new name, when combined with the system tag `{}`, is longer than 32 characters, and thus cannot be proxied. To prevent this, shorten the name or system tag by at least {} characters.".format(
system.tag, exceeds_by))
async def member_description(ctx: CommandContext, member: Member):
new_description = ctx.remaining() or None
await member.set_description(ctx.conn, new_description)
await ctx.reply_ok("Member description {}.".format("updated" if new_description else "cleared"))
async def member_avatar(ctx: CommandContext, member: Member):
new_avatar_url = ctx.remaining() or None
if new_avatar_url:
user = await utils.parse_mention(ctx.client, new_avatar_url)
if user:
new_avatar_url = user.avatar_url_as(format="png")
await member.set_avatar(ctx.conn, new_avatar_url)
await ctx.reply_ok("Member avatar {}.".format("updated" if new_avatar_url else "cleared"))
async def member_color(ctx: CommandContext, member: Member):
new_color = ctx.remaining() or None
await member.set_color(ctx.conn, new_color)
await ctx.reply_ok("Member color {}.".format("updated" if new_color else "cleared"))
async def member_pronouns(ctx: CommandContext, member: Member):
new_pronouns = ctx.remaining() or None
await member.set_pronouns(ctx.conn, new_pronouns)
await ctx.reply_ok("Member pronouns {}.".format("updated" if new_pronouns else "cleared"))
async def member_birthdate(ctx: CommandContext, member: Member):
new_birthdate = ctx.remaining() or None
await member.set_birthdate(ctx.conn, new_birthdate)
await ctx.reply_ok("Member birthdate {}.".format("updated" if new_birthdate else "cleared"))
async def member_proxy(ctx: CommandContext, member: Member):
if not ctx.has_next():
prefix, suffix = None, None
else:
# Sanity checking
example = ctx.remaining()
if "text" not in example:
raise CommandError("Example proxy message must contain the string 'text'. For help, type `pk;help proxy`.")
if example.count("text") != 1:
raise CommandError("Example proxy message must contain the string 'text' exactly once. For help, type `pk;help proxy`.")
# Extract prefix and suffix
prefix = example[:example.index("text")].strip()
suffix = example[example.index("text") + 4:].strip()
# DB stores empty strings as None, make that work
if not prefix:
prefix = None
if not suffix:
suffix = None
async with ctx.conn.transaction():
await member.set_proxy_tags(ctx.conn, prefix, suffix)
await ctx.reply_ok(
"Proxy settings updated." if prefix or suffix else "Proxy settings cleared. If you meant to set your proxy tags, type `pk;help proxy` for help.")
async def member_delete(ctx: CommandContext, member: Member):
delete_confirm_msg = "Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(
member.name, member.hid)
if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, member.hid, delete_confirm_msg):
raise CommandError("Member deletion cancelled.")
await member.delete(ctx.conn)
await ctx.reply_ok("Member deleted.")

View File

@ -1,18 +0,0 @@
from pluralkit.bot.commands import *
async def message_info(ctx: CommandContext):
mid_str = ctx.pop_str(CommandError("You must pass a message ID."))
try:
mid = int(mid_str)
except ValueError:
raise CommandError("You must pass a valid number as a message ID.")
# Find the message in the DB
message = await db.get_message(ctx.conn, mid)
if not message:
raise CommandError(
"Message with ID '{}' not found. Are you sure it's a message proxied by PluralKit?".format(mid))
await ctx.reply(embed=await embeds.message_card(ctx.client, message))

View File

@ -1,191 +0,0 @@
import io
import json
import os
from discord.utils import oauth_url
from pluralkit.bot import help
from pluralkit.bot.commands import *
from pluralkit.bot.embeds import help_footer_embed
prefix = "pk;" # TODO: configurable
def make_footer_embed():
embed = discord.Embed()
embed.set_footer(text=help.helpfile["footer"])
return embed
def make_command_embed(command):
embed = make_footer_embed()
embed.title = prefix + command["usage"]
embed.description = (command["description"] + "\n" + command.get("longdesc", "")).strip()
if "aliases" in command:
embed.add_field(name="Aliases" if len(command["aliases"]) > 1 else "Alias", value="\n".join([prefix + cmd for cmd in command["aliases"]]), inline=False)
embed.add_field(name="Usage", value=prefix + command["usage"], inline=False)
if "examples" in command:
embed.add_field(name="Examples" if len(command["examples"]) > 1 else "Example", value="\n".join([prefix + cmd for cmd in command["examples"]]), inline=False)
if "subcommands" in command:
embed.add_field(name="Subcommands", value="\n".join([command["name"] + " " + sc["name"] for sc in command["subcommands"]]), inline=False)
return embed
def find_command(command_list, name):
for command in command_list:
if command["name"].lower().strip() == name.lower().strip():
return command
async def help_root(ctx: CommandContext):
for page_name, page_content in help.helpfile["pages"].items():
if ctx.match(page_name):
return await help_page(ctx, page_content)
if not ctx.has_next():
return await help_page(ctx, help.helpfile["pages"]["root"])
return await help_command(ctx, ctx.remaining())
async def help_page(ctx, sections):
msg = ""
for section in sections:
msg += "__**{}**__\n{}\n\n".format(section["name"], section["content"])
return await ctx.reply(content=msg, embed=make_footer_embed())
async def help_command(ctx, command_name):
name_parts = command_name.replace(prefix, "").split(" ")
command = find_command(help.helpfile["commands"], name_parts[0])
name_parts = name_parts[1:]
if not command:
raise CommandError("Could not find command '{}'.".format(command_name))
while len(name_parts) > 0:
found_command = find_command(command["subcommands"], name_parts[0])
if not found_command:
break
command = found_command
name_parts = name_parts[1:]
return await ctx.reply(embed=make_command_embed(command))
async def command_list(ctx):
cmds = []
categories = {}
def make_command_list(lst):
for cmd in lst:
if not cmd["category"] in categories:
categories[cmd["category"]] = []
categories[cmd["category"]].append("**{}{}** - {}".format(prefix, cmd["usage"], cmd["description"]))
if "subcommands" in cmd:
make_command_list(cmd["subcommands"])
make_command_list(help.helpfile["commands"])
embed = discord.Embed()
embed.title = "PluralKit Commands"
embed.description = "Type `pk;help <command>` for more information."
for cat_name, cat_cmds in categories.items():
embed.add_field(name=cat_name, value="\n".join(cat_cmds))
await ctx.reply(embed=embed)
async def invite_link(ctx: CommandContext):
client_id = (await ctx.client.application_info()).id
permissions = discord.Permissions()
# So the bot can actually add the webhooks it needs to do the proxy functionality
permissions.manage_webhooks = True
# So the bot can respond with status, error, and success messages
permissions.send_messages = True
# So the bot can delete channels
permissions.manage_messages = True
# So the bot can respond with extended embeds, ex. member cards
permissions.embed_links = True
# So the bot can send images too
permissions.attach_files = True
# (unsure if it needs this, actually, might be necessary for message lookup)
permissions.read_message_history = True
# So the bot can add reactions for confirm/deny prompts
permissions.add_reactions = True
url = oauth_url(client_id, permissions)
await ctx.reply_ok("Use this link to add PluralKit to your server: {}".format(url))
async def export(ctx: CommandContext):
working_msg = await ctx.message.channel.send("Working...")
system = await ctx.ensure_system()
members = await system.get_members(ctx.conn)
accounts = await system.get_linked_account_ids(ctx.conn)
switches = await system.get_switches(ctx.conn, 999999)
data = {
"name": system.name,
"id": system.hid,
"description": system.description,
"tag": system.tag,
"avatar_url": system.avatar_url,
"created": system.created.isoformat(),
"members": [
{
"name": member.name,
"id": member.hid,
"color": member.color,
"avatar_url": member.avatar_url,
"birthday": member.birthday.isoformat() if member.birthday else None,
"pronouns": member.pronouns,
"description": member.description,
"prefix": member.prefix,
"suffix": member.suffix,
"created": member.created.isoformat(),
"message_count": await member.message_count(ctx.conn)
} for member in members
],
"accounts": [str(uid) for uid in accounts],
"switches": [
{
"timestamp": switch.timestamp.isoformat(),
"members": [member.hid for member in await switch.fetch_members(ctx.conn)]
} for switch in switches
] # TODO: messages
}
await working_msg.delete()
f = io.BytesIO(json.dumps(data).encode("utf-8"))
await ctx.reply_ok("DM'd!")
await ctx.message.author.send(content="Here you go!", file=discord.File(fp=f, filename="pluralkit_system.json"))
async def tell(ctx: CommandContext):
# Dev command only
# This is used to tell members of servers I'm not in when something is broken so they can contact me with debug info
if ctx.message.author.id != 102083498529026048:
# Just silently fail, not really a public use command
return
channel = ctx.pop_str()
message = ctx.remaining()
# lol error handling
await ctx.client.get_channel(int(channel)).send(content="[dev message] " + message)
await ctx.reply_ok("Sent!")
# Easter eggs lmao because why not
async def pkfire(ctx: CommandContext):
await ctx.message.channel.send("*A giant lightning bolt promptly erupts into a pillar of fire as it hits your opponent.*")
async def pkthunder(ctx: CommandContext):
await ctx.message.channel.send("*A giant ball of lightning is conjured and fired directly at your opponent, vanquishing them.*")
async def pkfreeze(ctx: CommandContext):
await ctx.message.channel.send("*A giant crystal ball of ice is charged and hurled toward your opponent, bursting open and freezing them solid on contact.*")
async def pkstarstorm(ctx: CommandContext):
await ctx.message.channel.send("*Vibrant colours burst forth from the sky as meteors rain down upon your opponent.*")

View File

@ -1,21 +0,0 @@
from pluralkit.bot.commands import *
async def set_log(ctx: CommandContext):
if not ctx.message.author.guild_permissions.administrator:
raise CommandError("You must be a server administrator to use this command.")
server = ctx.message.guild
if not server:
raise CommandError("This command can not be run in a DM.")
if not ctx.has_next():
channel_id = None
else:
channel = utils.parse_channel_mention(ctx.pop_str(), server=server)
if not channel:
raise CommandError("Channel not found.")
channel_id = channel.id
await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id)
await ctx.reply_ok("Updated logging channel." if channel_id else "Cleared logging channel.")

View File

@ -1,156 +0,0 @@
from datetime import datetime
from typing import List
import dateparser
import pytz
from pluralkit.bot.commands import *
from pluralkit.member import Member
from pluralkit.utils import display_relative
async def switch_root(ctx: CommandContext):
if not ctx.has_next():
raise CommandError("You must use a subcommand. For a list of subcommands, type `pk;help member`.")
if ctx.match("out"):
await switch_out(ctx)
elif ctx.match("move"):
await switch_move(ctx)
elif ctx.match("delete") or ctx.match("remove") or ctx.match("erase") or ctx.match("cancel"):
await switch_delete(ctx)
else:
await switch_member(ctx)
async def switch_member(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
raise CommandError("You must pass at least one member name or ID to register a switch to.")
members: List[Member] = []
while ctx.has_next():
members.append(await ctx.pop_member())
# Log the switch
await system.add_switch(ctx.conn, members)
if len(members) == 1:
await ctx.reply_ok("Switch registered. Current fronter is now {}.".format(members[0].name))
else:
await ctx.reply_ok(
"Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members])))
async def switch_out(ctx: CommandContext):
system = await ctx.ensure_system()
switch = await system.get_latest_switch(ctx.conn)
if switch and not switch.members:
raise CommandError("There's already no one in front.")
# Log it, and don't log any members
await system.add_switch(ctx.conn, [])
await ctx.reply_ok("Switch-out registered.")
async def switch_delete(ctx: CommandContext):
system = await ctx.ensure_system()
last_two_switches = await system.get_switches(ctx.conn, 2)
if not last_two_switches:
raise CommandError("You do not have a logged switch to delete.")
last_switch = last_two_switches[0]
next_last_switch = last_two_switches[1] if len(last_two_switches) > 1 else None
last_switch_members = ", ".join([member.name for member in await last_switch.fetch_members(ctx.conn)])
last_switch_time = display_relative(last_switch.timestamp)
if next_last_switch:
next_last_switch_members = ", ".join([member.name for member in await next_last_switch.fetch_members(ctx.conn)])
next_last_switch_time = display_relative(next_last_switch.timestamp)
msg = await ctx.reply_warn("This will delete the latest switch ({}, {} ago). The next latest switch is {} ({} ago). Is this okay?".format(last_switch_members, last_switch_time, next_last_switch_members, next_last_switch_time))
else:
msg = await ctx.reply_warn("This will delete the latest switch ({}, {} ago). You have no other switches logged. Is this okay?".format(last_switch_members, last_switch_time))
if not await ctx.confirm_react(ctx.message.author, msg):
raise CommandError("Switch deletion cancelled.")
await last_switch.delete(ctx.conn)
if next_last_switch:
# lol block scope amirite
# but yeah this is fine
await ctx.reply_ok("Switch deleted. Next latest switch is now {} ({} ago).".format(next_last_switch_members, next_last_switch_time))
else:
await ctx.reply_ok("Switch deleted. You now have no logged switches.")
async def switch_move(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
raise CommandError("You must pass a time to move the switch to.")
# Parse the time to move to
new_time = dateparser.parse(ctx.remaining(), languages=["en"], settings={
# Tell it to default to the system's given time zone
# If no time zone was given *explicitly in the string* it'll return as naive
"TIMEZONE": system.ui_tz
})
if not new_time:
raise CommandError("'{}' can't be parsed as a valid time.".format(ctx.remaining()))
tz = pytz.timezone(system.ui_tz)
# So we default to putting the system's time zone in the tzinfo
if not new_time.tzinfo:
new_time = tz.localize(new_time)
# Now that we have a system-time datetime, convert this to UTC and make it naive since that's what we deal with
new_time = pytz.utc.normalize(new_time).replace(tzinfo=None)
# Make sure the time isn't in the future
if new_time > datetime.utcnow():
raise CommandError("Can't move switch to a time in the future.")
# Make sure it all runs in a big transaction for atomicity
async with ctx.conn.transaction():
# Get the last two switches to make sure the switch to move isn't before the second-last switch
last_two_switches = await system.get_switches(ctx.conn, 2)
if len(last_two_switches) == 0:
raise CommandError("There are no registered switches for this system.")
last_switch = last_two_switches[0]
if len(last_two_switches) > 1:
second_last_switch = last_two_switches[1]
if new_time < second_last_switch.timestamp:
time_str = display_relative(second_last_switch.timestamp)
raise CommandError(
"Can't move switch to before last switch time ({} ago), as it would cause conflicts.".format(time_str))
# Display the confirmation message w/ humanized times
last_fronters = await last_switch.fetch_members(ctx.conn)
members = ", ".join([member.name for member in last_fronters]) or "nobody"
last_absolute = ctx.format_time(last_switch.timestamp)
last_relative = display_relative(last_switch.timestamp)
new_absolute = ctx.format_time(new_time)
new_relative = display_relative(new_time)
# Confirm with user
switch_confirm_message = await ctx.reply(
"This will move the latest switch ({}) from {} ({} ago) to {} ({} ago). Is this OK?".format(members,
last_absolute,
last_relative,
new_absolute,
new_relative))
if not await ctx.confirm_react(ctx.message.author, switch_confirm_message):
raise CommandError("Switch move cancelled.")
# Actually move the switch
await last_switch.move(ctx.conn, new_time)
await ctx.reply_ok("Switch moved.")

View File

@ -1,443 +0,0 @@
from datetime import datetime, timedelta
import aiohttp
import dateparser
import humanize
import math
import timezonefinder
import pytz
import pluralkit.bot.embeds
from pluralkit.bot.commands import *
from pluralkit.errors import ExistingSystemError, UnlinkingLastAccountError, AccountAlreadyLinkedError
from pluralkit.utils import display_relative
# This needs to load from the timezone file so we're preloading this so we
# don't have to do it on every invocation
tzf = timezonefinder.TimezoneFinder()
async def system_root(ctx: CommandContext):
# Commands that operate without a specified system (usually defaults to the executor's own system)
if ctx.match("name") or ctx.match("rename"):
await system_name(ctx)
elif ctx.match("description") or ctx.match("desc"):
await system_description(ctx)
elif ctx.match("avatar") or ctx.match("icon"):
await system_avatar(ctx)
elif ctx.match("tag"):
await system_tag(ctx)
elif ctx.match("new") or ctx.match("register") or ctx.match("create") or ctx.match("init"):
await system_new(ctx)
elif ctx.match("delete") or ctx.match("remove") or ctx.match("destroy") or ctx.match("erase"):
await system_delete(ctx)
elif ctx.match("front") or ctx.match("fronter") or ctx.match("fronters"):
await system_fronter(ctx, await ctx.ensure_system())
elif ctx.match("fronthistory"):
await system_fronthistory(ctx, await ctx.ensure_system())
elif ctx.match("frontpercent") or ctx.match("frontbreakdown") or ctx.match("frontpercentage"):
await system_frontpercent(ctx, await ctx.ensure_system())
elif ctx.match("timezone") or ctx.match("tz"):
await system_timezone(ctx)
elif ctx.match("set"):
await system_set(ctx)
elif ctx.match("list") or ctx.match("members"):
await system_list(ctx, await ctx.ensure_system())
elif not ctx.has_next():
# (no argument, command ends here, default to showing own system)
await system_info(ctx, await ctx.ensure_system())
else:
# If nothing matches, the next argument is likely a system name/ID, so delegate
# to the specific system root
await specified_system_root(ctx)
async def specified_system_root(ctx: CommandContext):
# Commands that operate on a specified system (ie. not necessarily the command executor's)
system_name = ctx.pop_str()
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, system_name)
if not system:
raise CommandError(
"Unable to find system `{}`. If you meant to run a command, type `pk;help system` for a list of system commands.".format(
system_name))
if ctx.match("front") or ctx.match("fronter"):
await system_fronter(ctx, system)
elif ctx.match("fronthistory"):
await system_fronthistory(ctx, system)
elif ctx.match("frontpercent") or ctx.match("frontbreakdown") or ctx.match("frontpercentage"):
await system_frontpercent(ctx, system)
elif ctx.match("list") or ctx.match("members"):
await system_list(ctx, system)
else:
await system_info(ctx, system)
async def system_info(ctx: CommandContext, system: System):
this_system = await ctx.get_system()
await ctx.reply(embed=await pluralkit.bot.embeds.system_card(ctx.conn, ctx.client, system, this_system and this_system.id == system.id))
async def system_new(ctx: CommandContext):
new_name = ctx.remaining() or None
try:
await System.create_system(ctx.conn, ctx.message.author.id, new_name)
except ExistingSystemError as e:
raise CommandError(e.message)
await ctx.reply_ok("System registered! To begin adding members, use `pk;member new <name>`.")
async def system_set(ctx: CommandContext):
raise CommandError(
"`pk;system set` has been retired. Please use the new system modifying commands. Type `pk;help system` for a list.")
async def system_name(ctx: CommandContext):
system = await ctx.ensure_system()
new_name = ctx.remaining() or None
await system.set_name(ctx.conn, new_name)
await ctx.reply_ok("System name {}.".format("updated" if new_name else "cleared"))
async def system_description(ctx: CommandContext):
system = await ctx.ensure_system()
new_description = ctx.remaining() or None
await system.set_description(ctx.conn, new_description)
await ctx.reply_ok("System description {}.".format("updated" if new_description else "cleared"))
async def system_timezone(ctx: CommandContext):
system = await ctx.ensure_system()
city_query = ctx.remaining() or None
msg = await ctx.reply("\U0001F50D Searching '{}' (may take a while)...".format(city_query))
# Look up the city on Overpass (OpenStreetMap)
async with aiohttp.ClientSession() as sess:
# OverpassQL is weird, but this basically searches for every node of type city with name [input].
async with sess.get("https://nominatim.openstreetmap.org/search?city=novosibirsk&format=json&limit=1", params={"city": city_query, "format": "json", "limit": "1"}) as r:
if r.status != 200:
raise CommandError("OSM Nominatim API returned error. Try again.")
data = await r.json()
# If we didn't find a city, complain
if not data:
raise CommandError("City '{}' not found.".format(city_query))
# Take the lat/long given by Overpass and put it into timezonefinder
lat, lng = (float(data[0]["lat"]), float(data[0]["lon"]))
timezone_name = tzf.timezone_at(lng=lng, lat=lat)
# Also delete the original searching message
await msg.delete()
if not timezone_name:
raise CommandError("Time zone for city '{}' not found. This should never happen.".format(data[0]["display_name"]))
# This should hopefully result in a valid time zone name
# (if not, something went wrong)
tz = await system.set_time_zone(ctx.conn, timezone_name)
offset = tz.utcoffset(datetime.utcnow())
offset_str = "UTC{:+02d}:{:02d}".format(int(offset.total_seconds() // 3600), int(offset.total_seconds() // 60 % 60))
await ctx.reply_ok("System time zone set to {} ({}, {}).\n*Data from OpenStreetMap, queried using Nominatim.*".format(tz.tzname(datetime.utcnow()), offset_str, tz.zone))
async def system_tag(ctx: CommandContext):
system = await ctx.ensure_system()
new_tag = ctx.remaining() or None
await system.set_tag(ctx.conn, new_tag)
await ctx.reply_ok("System tag {}.".format("updated" if new_tag else "cleared"))
# System class is immutable, update the tag so get_member_name_limit works
system = system._replace(tag=new_tag)
members = await system.get_members(ctx.conn)
# Certain members might not be able to be proxied with this new tag, show a warning for those
members_exceeding = [member for member in members if
len(member.name) > system.get_member_name_limit()]
if members_exceeding:
member_names = ", ".join([member.name for member in members_exceeding])
await ctx.reply_warn(
"Due to the length of this tag, the following members will not be able to be proxied: {}. Please use a shorter tag to prevent this.".format(
member_names))
# Edge case: members with name length 1 and no new tag
if not new_tag:
one_length_members = [member for member in members if len(member.name) == 1]
if one_length_members:
member_names = ", ".join([member.name for member in one_length_members])
await ctx.reply_warn(
"Without a system tag, you will not be able to proxy members with a one-character name: {}. To prevent this, please add a system tag or lengthen their name.".format(
member_names))
async def system_avatar(ctx: CommandContext):
system = await ctx.ensure_system()
new_avatar_url = ctx.remaining() or None
if new_avatar_url:
user = await utils.parse_mention(ctx.client, new_avatar_url)
if user:
new_avatar_url = user.avatar_url_as(format="png")
await system.set_avatar(ctx.conn, new_avatar_url)
await ctx.reply_ok("System avatar {}.".format("updated" if new_avatar_url else "cleared"))
async def account_link(ctx: CommandContext):
system = await ctx.ensure_system()
account_name = ctx.pop_str(CommandError(
"You must pass an account to link this system to. You can either use a \\@mention, or a raw account ID."))
# Do the sanity checking here too (despite it being done in System.link_account)
# Because we want it to be done before the confirmation dialog is shown
# Find account to link
linkee = await utils.parse_mention(ctx.client, account_name)
if not linkee:
raise CommandError("Account `{}` not found.".format(account_name))
# Make sure account doesn't already have a system
account_system = await System.get_by_account(ctx.conn, linkee.id)
if account_system:
raise CommandError(AccountAlreadyLinkedError(account_system).message)
msg = await ctx.reply(
"{}, please confirm the link by clicking the \u2705 reaction on this message.".format(linkee.mention))
if not await ctx.confirm_react(linkee, msg):
raise CommandError("Account link cancelled.")
await system.link_account(ctx.conn, linkee.id)
await ctx.reply_ok("Account linked to system.")
async def account_unlink(ctx: CommandContext):
system = await ctx.ensure_system()
msg = await ctx.reply("Are you sure you want to unlink this account from your system?")
if not await ctx.confirm_react(ctx.message.author, msg):
raise CommandError("Account unlink cancelled.")
try:
await system.unlink_account(ctx.conn, ctx.message.author.id)
except UnlinkingLastAccountError as e:
raise CommandError(e.message)
await ctx.reply_ok("Account unlinked.")
async def system_fronter(ctx: CommandContext, system: System):
embed = await embeds.front_status(ctx, await system.get_latest_switch(ctx.conn))
await ctx.reply(embed=embed)
async def system_fronthistory(ctx: CommandContext, system: System):
lines = []
front_history = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=10)
if not front_history:
raise CommandError("You have no logged switches. Use `pk;switch´ to start logging.")
for i, (timestamp, members) in enumerate(front_history):
# Special case when no one's fronting
if len(members) == 0:
name = "(no fronter)"
else:
name = ", ".join([member.name for member in members])
# Make proper date string
time_text = ctx.format_time(timestamp)
rel_text = display_relative(timestamp)
delta_text = ""
if i > 0:
last_switch_time = front_history[i - 1][0]
delta_text = ", for {}".format(display_relative(timestamp - last_switch_time))
lines.append("**{}** ({}, {} ago{})".format(name, time_text, rel_text, delta_text))
embed = embeds.status("\n".join(lines) or "(none)")
embed.title = "Past switches"
await ctx.reply(embed=embed)
async def system_delete(ctx: CommandContext):
system = await ctx.ensure_system()
delete_confirm_msg = "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(
system.hid)
if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, system.hid, delete_confirm_msg):
raise CommandError("System deletion cancelled.")
await system.delete(ctx.conn)
await ctx.reply_ok("System deleted.")
async def system_frontpercent(ctx: CommandContext, system: System):
# Parse the time limit (will go this far back)
if ctx.remaining():
before = dateparser.parse(ctx.remaining(), languages=["en"], settings={
"TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False
})
if not before:
raise CommandError("Could not parse '{}' as a valid time.".format(ctx.remaining()))
# If time is in the future, just kinda discard
if before and before > datetime.utcnow():
before = None
else:
before = datetime.utcnow() - timedelta(days=30)
# Fetch list of switches
all_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 99999)
if not all_switches:
raise CommandError("No switches registered to this system.")
# Cull the switches *ending* before the limit, if given
# We'll need to find the first switch starting before the limit, then cut off every switch *before* that
if before:
for last_stamp, _ in all_switches:
if last_stamp < before:
break
all_switches = [(stamp, members) for stamp, members in all_switches if stamp >= last_stamp]
start_times = [stamp for stamp, _ in all_switches]
end_times = [datetime.utcnow()] + start_times
switch_members = [members for _, members in all_switches]
# Gonna save a list of members by ID for future lookup too
members_by_id = {}
# Using the ID as a key here because it's a simple number that can be hashed and used as a key
member_times = {}
for start_time, end_time, members in zip(start_times, end_times, switch_members):
# Cut off parts of the switch that occurs before the time limit (will only happen if this is the last switch)
if before and start_time < before:
start_time = before
# Calculate length of the switch
switch_length = end_time - start_time
def add_switch(id, length):
if id not in member_times:
member_times[id] = length
else:
member_times[id] += length
for member in members:
# Add the switch length to the currently registered time for that member
add_switch(member.id, switch_length)
# Also save the member in the ID map for future reference
members_by_id[member.id] = member
# Also register a no-fronter switch with the key None
if not members:
add_switch(None, switch_length)
# Find the total timespan of the range
span_start = max(start_times[-1], before) if before else start_times[-1]
total_time = datetime.utcnow() - span_start
embed = embeds.status("")
for member_id, front_time in sorted(member_times.items(), key=lambda x: x[1], reverse=True):
member = members_by_id[member_id] if member_id else None
# Calculate percent
fraction = front_time / total_time
percent = round(fraction * 100)
embed.add_field(name=member.name if member else "(no fronter)",
value="{}% ({})".format(percent, humanize.naturaldelta(front_time)))
embed.set_footer(text="Since {} ({} ago)".format(ctx.format_time(span_start),
display_relative(span_start)))
await ctx.reply(embed=embed)
async def system_list(ctx: CommandContext, system: System):
# TODO: refactor this
all_members = sorted(await system.get_members(ctx.conn), key=lambda m: m.name.lower())
if ctx.match("full"):
page_size = 8
if len(all_members) <= page_size:
# If we have less than 8 members, don't bother paginating
await ctx.reply(embed=embeds.member_list_full(system, all_members, 0, page_size))
else:
current_page = 0
msg: discord.Message = None
while True:
page_count = math.ceil(len(all_members) / page_size)
embed = embeds.member_list_full(system, all_members, current_page, page_size)
# Add reactions for moving back and forth
if not msg:
msg = await ctx.reply(embed=embed)
await msg.add_reaction("\u2B05")
await msg.add_reaction("\u27A1")
else:
await msg.edit(embed=embed)
def check(reaction, user):
return user.id == ctx.message.author.id and reaction.emoji in ["\u2B05", "\u27A1"]
try:
reaction, _ = await ctx.client.wait_for("reaction_add", timeout=5*60, check=check)
except asyncio.TimeoutError:
return
if reaction.emoji == "\u2B05":
current_page = (current_page - 1) % page_count
elif reaction.emoji == "\u27A1":
current_page = (current_page + 1) % page_count
# If we can, remove the original reaction from the member
# Don't bother checking permission if we're in DMs (wouldn't work anyway)
if ctx.message.guild:
if ctx.message.channel.permissions_for(ctx.message.guild.get_member(ctx.client.user.id)).manage_messages:
await reaction.remove(ctx.message.author)
else:
#Basically same code as above
#25 members at a time seems handy
page_size = 25
if len(all_members) <= page_size:
# If we have less than 25 members, don't bother paginating
await ctx.reply(embed=embeds.member_list_short(system, all_members, 0, page_size))
else:
current_page = 0
msg: discord.Message = None
while True:
page_count = math.ceil(len(all_members) / page_size)
embed = embeds.member_list_short(system, all_members, current_page, page_size)
if not msg:
msg = await ctx.reply(embed=embed)
await msg.add_reaction("\u2B05")
await msg.add_reaction("\u27A1")
else:
await msg.edit(embed=embed)
def check(reaction, user):
return user.id == ctx.message.author.id and reaction.emoji in ["\u2B05", "\u27A1"]
try:
reaction, _ = await ctx.client.wait_for("reaction_add", timeout=5*60, check=check)
except asyncio.TimeoutError:
return
if reaction.emoji == "\u2B05":
current_page = (current_page - 1) % page_count
elif reaction.emoji == "\u27A1":
current_page = (current_page + 1) % page_count
if ctx.message.guild:
if ctx.message.channel.permissions_for(ctx.message.guild.get_member(ctx.client.user.id)).manage_messages:
await reaction.remove(ctx.message.author)

View File

@ -1,285 +0,0 @@
import discord
import math
import humanize
from typing import Tuple, List
from pluralkit import db
from pluralkit.bot.utils import escape
from pluralkit.member import Member
from pluralkit.switch import Switch
from pluralkit.system import System
from pluralkit.utils import get_fronters, display_relative
def truncate_field_name(s: str) -> str:
return s[:256]
def truncate_field_body(s: str) -> str:
if len(s) > 1024:
return s[:1024-3] + "..."
return s
def truncate_description(s: str) -> str:
return s[:2048]
def truncate_description_list(s: str) -> str:
if len(s) > 512:
return s[:512-45] + "..."
return s
def truncate_title(s: str) -> str:
return s[:256]
def success(text: str) -> discord.Embed:
embed = discord.Embed()
embed.description = truncate_description(text)
embed.colour = discord.Colour.green()
return embed
def error(text: str, help: Tuple[str, str] = None) -> discord.Embed:
embed = discord.Embed()
embed.description = truncate_description(text)
embed.colour = discord.Colour.dark_red()
if help:
help_title, help_text = help
embed.add_field(name=truncate_field_name(help_title), value=truncate_field_body(help_text))
return embed
def status(text: str) -> discord.Embed:
embed = discord.Embed()
embed.description = truncate_description(text)
embed.colour = discord.Colour.blue()
return embed
def exception_log(message_content, author_name, author_discriminator, author_id, server_id,
channel_id) -> discord.Embed:
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.title = truncate_title(message_content)
embed.set_footer(text="Sender: {}#{} ({}) | Server: {} | Channel: {}".format(
author_name, author_discriminator, author_id,
server_id if server_id else "(DMs)",
channel_id
))
return embed
async def system_card(conn, client: discord.Client, system: System, is_own_system: bool = True) -> discord.Embed:
card = discord.Embed()
card.colour = discord.Colour.blue()
if system.name:
card.title = truncate_title(system.name)
if system.avatar_url:
card.set_thumbnail(url=system.avatar_url)
if system.tag:
card.add_field(name="Tag", value=truncate_field_body(system.tag))
fronters, switch_time = await get_fronters(conn, system.id)
if fronters:
names = ", ".join([member.name for member in fronters])
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time))
card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters",
value=truncate_field_body(fronter_val))
account_names = []
for account_id in await system.get_linked_account_ids(conn):
try:
account = await client.get_user_info(account_id)
account_names.append("<@{}> ({}#{})".format(account_id, account.name, account.discriminator))
except discord.NotFound:
account_names.append("(deleted account {})".format(account_id))
card.add_field(name="Linked accounts", value=truncate_field_body("\n".join(account_names)))
if system.description:
card.add_field(name="Description",
value=truncate_field_body(system.description), inline=False)
card.add_field(name="Members", value="*See `pk;system {0} list`for the short list, or `pk;system {0} list full` for the detailed list*".format(system.hid) if not is_own_system else "*See `pk;system list` for the short list, or `pk;system list full` for the detailed list*")
card.set_footer(text="System ID: {}".format(system.hid))
return card
async def member_card(conn, member: Member) -> discord.Embed:
system = await member.fetch_system(conn)
card = discord.Embed()
card.colour = discord.Colour.blue()
name_and_system = member.name
if system.name:
name_and_system += " ({})".format(system.name)
card.set_author(name=truncate_field_name(name_and_system), icon_url=member.avatar_url or discord.Embed.Empty)
if member.avatar_url:
card.set_thumbnail(url=member.avatar_url)
if member.color:
card.colour = int(member.color, 16)
if member.birthday:
card.add_field(name="Birthdate", value=member.birthday_string())
if member.pronouns:
card.add_field(name="Pronouns", value=truncate_field_body(member.pronouns))
message_count = await member.message_count(conn)
if message_count > 0:
card.add_field(name="Message Count", value=str(message_count), inline=True)
if member.prefix or member.suffix:
prefix = member.prefix or ""
suffix = member.suffix or ""
card.add_field(name="Proxy Tags",
value=truncate_field_body("{}text{}".format(prefix, suffix)))
if member.description:
card.add_field(name="Description",
value=truncate_field_body(member.description), inline=False)
card.set_footer(text="System ID: {} | Member ID: {}".format(system.hid, member.hid))
return card
async def front_status(ctx: "CommandContext", switch: Switch) -> discord.Embed:
if switch:
embed = status("")
fronter_names = [member.name for member in await switch.fetch_members(ctx.conn)]
if len(fronter_names) == 0:
embed.add_field(name="Current fronter", value="(no fronter)")
elif len(fronter_names) == 1:
embed.add_field(name="Current fronter", value=truncate_field_body(fronter_names[0]))
else:
embed.add_field(name="Current fronters", value=truncate_field_body(", ".join(fronter_names)))
if switch.timestamp:
embed.add_field(name="Since",
value="{} ({})".format(ctx.format_time(switch.timestamp),
display_relative(switch.timestamp)))
else:
embed = error("No switches logged.")
return embed
async def get_message_contents(client: discord.Client, channel_id: int, message_id: int):
channel = client.get_channel(channel_id)
if channel:
try:
original_message = await channel.get_message(message_id)
return original_message.content or None
except (discord.errors.Forbidden, discord.errors.NotFound):
pass
return None
async def message_card(client: discord.Client, message: db.MessageInfo, include_pronouns: bool = False):
# Get the original sender of the messages
try:
original_sender = await client.get_user_info(message.sender)
except discord.NotFound:
# Account was since deleted - rare but we're handling it anyway
original_sender = None
embed = discord.Embed()
embed.timestamp = discord.utils.snowflake_time(message.mid)
embed.colour = discord.Colour.blue()
if message.system_name:
system_value = "{} (`{}`)".format(message.system_name, message.system_hid)
else:
system_value = "`{}`".format(message.system_hid)
embed.add_field(name="System", value=system_value)
if include_pronouns and message.pronouns:
embed.add_field(name="Member", value="{} (`{}`)\n*(pronouns: **{}**)*".format(message.name, message.hid, message.pronouns))
else:
embed.add_field(name="Member", value="{} (`{}`)".format(message.name, message.hid))
if original_sender:
sender_name = "<@{}> ({}#{})".format(message.sender, original_sender.name, original_sender.discriminator)
else:
sender_name = "(deleted account {})".format(message.sender)
embed.add_field(name="Sent by", value=sender_name)
message_content = await get_message_contents(client, message.channel, message.mid)
embed.description = message_content or "(unknown, message deleted)"
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
return embed
def help_footer_embed() -> discord.Embed:
embed = discord.Embed()
embed.set_footer(text="By @Ske#6201 | GitHub: https://github.com/xSke/PluralKit/")
return embed
# TODO: merge these somehow, they're very similar
def member_list_short(system: System, all_members: List[Member], current_page: int, page_size: int):
page_count = int(math.ceil(len(all_members) / page_size))
title = ""
if len(all_members) > page_size:
title += "[{}/{}] ".format(current_page + 1, page_count)
if system.name:
title += "Members of {} (`{}`)".format(system.name, system.hid)
else:
title += "Members of `{}`".format(system.hid)
embed = discord.Embed()
embed.title = title
desc = ""
for member in all_members[current_page*page_size:current_page*page_size+page_size]:
if member.prefix or member.suffix:
desc += "[`{}`] **{}** *({}text{})*\n".format(member.hid, member.name, member.prefix or "", member.suffix or "")
else:
desc += "[`{}`] **{}**\n".format(member.hid, member.name)
embed.description = desc
return embed
def member_list_full(system: System, all_members: List[Member], current_page: int, page_size: int):
page_count = int(math.ceil(len(all_members) / page_size))
title = ""
if len(all_members) > page_size:
title += "[{}/{}] ".format(current_page + 1, page_count)
if system.name:
title += "Members of {} (`{}`)".format(system.name, system.hid)
else:
title += "Members of `{}`".format(system.hid)
embed = discord.Embed()
embed.title = title
for member in all_members[current_page*page_size:current_page*page_size+page_size]:
member_description = "**ID**: {}\n".format(member.hid)
if member.birthday:
member_description += "**Birthday:** {}\n".format(member.birthday_string())
if member.pronouns:
member_description += "**Pronouns:** {}\n".format(member.pronouns)
if member.description:
if len(member.description) > 512:
member_description += "\n" + truncate_description_list(member.description) + "\n" + "Type `pk;member {}` for full description.".format(member.hid)
else:
member_description += "\n" + member.description
embed.add_field(name=member.name, value=truncate_field_body(member_description) or "\u200B", inline=False)
return embed

View File

@ -1,336 +0,0 @@
{
"commands": [
{
"name": "system",
"aliases": ["s"],
"usage": "system [id]",
"description": "Shows information about a system.",
"longdesc": "The given ID can either be a 5-character ID, a Discord account @mention, or a Discord account ID. Leave blank to show your own system.",
"examples": ["system", "system abcde", "system @Foo#1234", "system 102083498529026048"],
"category": "System",
"subcommands": [
{
"name": "new",
"aliases": ["system register", "system create", "system init"],
"usage": "system new [name]",
"category": "System",
"description": "Creates a new system registered to your account."
},
{
"name": "name",
"alises": ["system rename"],
"usage": "system name [name]",
"category": "System",
"description": "Changes the name of your system."
},
{
"name": "description",
"aliases": ["system desc"],
"usage": "system description [description]",
"category": "System",
"description": "Changes the description of your system."
},
{
"name": "avatar",
"aliases": ["system icon"],
"usage": "system avatar [avatar url]",
"category": "System",
"description": "Changes the avatar of your system.",
"longdesc": "**NB:** Avatar URLs must be a *direct* link to an image (ending in .jpg, .gif or .png), AND must be under the size of 1000x1000 (in both dimensions), AND must be smaller than 1 MB. If the avatar doesn't show up properly, it is likely one or more of these rules aren't followed. If you need somewhere to host an image, you can upload it to Discord or Imgur and copy the *direct* link from there.",
"examples": ["system avatar https://i.imgur.com/HmK2Wgo.png"]
},
{
"name": "tag",
"usage": "system tag [tag]",
"category": "System",
"description": "Changes the system tag of your system.",
"longdesc": "The system tag is a snippet of text added to the end of your member's names when proxying. Many servers require the use of a system tag for identification. Leave blank to clear.\n\n**NB:** You may use standard Discord emojis, but server/Nitro emojis won't work.",
"examples": ["system tag |ABC", "system tag 💮", "system tag"]
},
{
"name": "timezone",
"usage": "system timezone [location]",
"category": "System",
"description": "Changes the time zone of your system.",
"longdesc": "This affects all dates or times displayed in PluralKit. Leave blank to clear.\n\n**NB:** You need to specify a location (eg. the nearest major city to you). This allows PluralKit to dynamically adjust for time zone or DST changes.",
"examples": ["system timezone New York", "system timezone Wichita Falls", "system timezone"]
},
{
"name": "delete",
"aliases": ["system remove", "system destroy", "system erase"],
"usage": "system delete",
"category": "System",
"description": "Deletes your system.",
"longdesc": "The command will ask for confirmation.\n\n**This is irreversible, and will delete all information associated with your system, members, proxied messages, and accounts.**"
},
{
"name": "fronter",
"aliases": ["system front", "system fronters"],
"usage": "system [id] fronter",
"category": "System",
"description": "Shows the current fronter of a system."
},
{
"name": "fronthistory",
"usage": "system [id] fronthistory",
"category": "System",
"description": "Shows the last 10 switches of a system."
},
{
"name": "frontpercent",
"aliases": ["system frontbreakdown", "system frontpercentage"],
"usage": "system [id] fronthistory [timeframe]",
"category": "System",
"description": "Shows the aggregated front history of a system within a given time frame.",
"longdesc": "Percentages may add up to over 100% when multiple members cofront. Time frame will default to 1 month.",
"examples": ["system fronthistory 1 month", "system fronthistory 2 weeks", "system @Foo#1234 fronthistory 4 days"]
},
{
"name": "list",
"aliases": ["system members"],
"usage": "system [id] list [full]",
"category": "System",
"description": "Shows a paginated list of a system's members. Add 'full' for more details.",
"examples": ["system list", "system list full", "system 102083498529026048 list"]
}
]
},
{
"name": "link",
"usage": "link <account>",
"category": "System",
"description": "Links this system to a different account.",
"longdesc": "This means you can manage the system from both accounts. The other account will need to verify the link by reacting to a message.",
"examples": ["link @Foo#1234", "link 102083498529026048"]
},
{
"name": "unlink",
"usage": "unlink",
"category": "System",
"description": "Unlinks this account from its system.",
"longdesc": "You can't unlink the only account in a system."
},
{
"name": "member",
"aliases": ["m"],
"usage": "member <name>",
"category": "Member",
"description": "Shows information about a member.",
"longdesc": "The given member name can either be the name of a member in your own system or a 5-character member ID (in any system).",
"examples": ["member John", "member abcde"],
"subcommands": [
{
"name": "new",
"aliases": ["member add", "member create", "member register"],
"usage": "member new <name>",
"category": "Member",
"description": "Creates a new system member.",
"exmaples": ["member new Jack"]
},
{
"name": "rename",
"usage": "member <name> rename <name>",
"category": "Member",
"description": "Changes the name of a member.",
"examples": ["member Jack rename Jill"]
},
{
"name": "description",
"aliases": ["member desc"],
"usage": "member <name> description [description]",
"category": "Member",
"description": "Changes the description of a member.",
"examples": ["member Jack description Very cool guy."]
},
{
"name": "avatar",
"aliases": ["member icon"],
"usage": "member <name> avatar [avatarurl]",
"category": "Member",
"description": "Changes the avatar of a member.",
"longdesc": "**NB:** Avatar URLs must be a *direct* link to an image (ending in .jpg, .gif or .png), AND must be under the size of 1000x1000 (in both dimensions), AND must be smaller than 1 MB. If the avatar doesn't show up properly, it is likely one or more of these rules aren't followed. If you need somewhere to host an image, you can upload it to Discord or Imgur and copy the *direct* link from there.",
"examples": ["member Jack avatar https://i.imgur.com/HmK2Wgo.png"]
},
{
"name": "proxy",
"aliases": ["member tags"],
"usage": "member <name> proxy [tags]",
"category": "Member",
"description": "Changes the proxy tags of a member.",
"longdesc": "The proxy tags describe how to proxy this member through Discord. You must pass an \"example proxy\" of the word \"text\", ie. how you'd proxy the word \"text\". For example, if you want square brackets for this member, pass `[text]`. Emojis are allowed.",
"examples": ["member Jack proxy [text]", "member Jill proxy J:text", "member Jones proxy 🍒text"]
},
{
"name": "pronouns",
"aliases": ["member pronoun"],
"usage": "member <name> pronouns [pronouns]",
"category": "Member",
"description": "Changes the pronouns of a member.",
"longdesc": "These will be displayed on their profile. This is a free text field, put whatever you'd like :)",
"examples": ["member Jack pronouns he/him", "member Jill pronouns she/her or they/them", "member Jones pronouns use whatever lol"]
},
{
"name": "color",
"aliases": ["member colour"],
"usage": "member <name> color [color]",
"category": "Member",
"description": "Changes the color of a member.",
"longdesc": "This will displayed on their profile. Colors must be in hex format (eg. #ff0000).\n\n**NB:** Due to a Discord limitation, the colors don't affect proxied message names.",
"examples": ["member Jack color #ff0000", "member Jill color #abcdef"]
},
{
"name": "birthday",
"aliases": ["member bday", "member birthdate"],
"usage": "member <name> birthday [birthday]",
"category": "Member",
"description": "Changes the birthday of a member.",
"longdesc": "This must be in YYYY-MM-DD format, or just MM-DD if you don't want to specify a year.",
"examples": ["member Jack birthday 1997-03-27", "member Jill birthday 2018-01-03", "member Jones birthday 12-21"]
},
{
"name": "delete",
"aliases": ["member remove", "member destroy", "member erase"],
"usage": "member <name> delete",
"category": "Member",
"description": "Deletes a member.",
"longdesc": "This command will ask for confirmation.\n\n**This is irreversible, and will delete all data associated with this member.**"
}
]
},
{
"name": "switch",
"aliases": ["sw"],
"usage": "switch <member> [member...]",
"category": "Switching",
"description": "Registers a switch with the given members.",
"longdesc": "You may specify multiple members to indicate cofronting.",
"examples": ["switch Jack", "switch Jack Jill"],
"subcommands": [
{
"name": "move",
"usage": "switch move <time>",
"category": "Switching",
"description": "Moves the latest switch back or forwards in time.",
"longdesc": "You can't move a switch into the future, and you can't move a switch further back than the second-latest switch (which would reorder front history).",
"examples": ["switch move 1 day ago", "switch move 4:30 pm"]
},
{
"name": "delete",
"usage": "switch delete",
"category": "Switching",
"description": "Deletes the latest switch. Will ask for confirmation."
},
{
"name": "out",
"usage": "switch out",
"category": "Switching",
"description": "Will register a 'switch-out' - a switch with no associated members."
}
]
},
{
"name": "log",
"usage": "log <channel>",
"category": "Utility",
"description": "Sets a channel to log all proxied messages.",
"longdesc": "This command is restricted to the server administrators (ie. users with the Administrator role).",
"examples": "log #pluralkit-log"
},
{
"name": "message",
"usage": "message <messageid>",
"category": "Utility",
"description": "Looks up information about a message by its message ID.",
"longdesc": " You can obtain a message ID by turning on Developer Mode in Discord's settings, and rightclicking/longpressing on a message.\n\n**Tip:** Reacting to a message with ❓ will DM you this information too.",
"examples": "message 561614629802082304"
},
{
"name": "invite",
"usage": "invite",
"category": "Utility",
"description": "Sends the bot invite link for PluralKit."
},
{
"name": "import",
"usage": "import",
"category": "Utility",
"description": "Imports a .json file from Tupperbox.",
"longdesc": "You will need to type the command, *then* send a new message containing the .json file as an attachment."
},
{
"name": "export",
"usage": "export",
"category": "Utility",
"description": "Exports your system to a .json file.",
"longdesc": "This will respond with a .json file containing your system and member data, useful for importing elsewhere."
},
{
"name": "token",
"usage": "token",
"category": "API",
"description": "DMs you a token for using the PluralKit API.",
"subcommands": [
{
"name": "refresh",
"usage": "token refresh",
"category": "API",
"description": "Refreshes your API token.",
"longdesc": "This will invalide the old token and DM you a new one. Do this if your token leaks in any way."
}
]
},
{
"name": "help",
"usage": "help [command]",
"category": "Help",
"description": "Displays help for a given command.",
"examples": ["help", "help system", "help member avatar", "help switch move"],
"subcommands": [
{
"name": "proxy",
"usage": "help proxy",
"category": "Help",
"description": "Displays a short guide to the proxy functionality."
}
]
},
{
"name": "commands",
"usage": "commands",
"category": "Help",
"description": "Displays a paginated list of commands",
"examples": ["commands", "commands"]
}
],
"pages": {
"root": [
{
"name": "PluralKit",
"content": "PluralKit is a bot designed for plural communities on Discord. It allows you to register systems, maintain system information, set up message proxying, log switches, and more.\n\n**Who's this for? What are systems?**\nPut simply, a system is a person that shares their body with at least 1 other sentient \"self\". This may be a result of having a dissociative disorder like DID/OSDD or a practice known as Tulpamancy, but people that aren't tulpamancers or undiagnosed and have headmates are also systems.\n\n**Why are people's names saying [BOT] next to them? What's going on?**\nThese people are not actually bots, this is simply a caveat to the message proxying feature of PluralKit.\nType `pk;help proxy` for an in-depth explanation."
},
{
"name": "Getting started",
"content": "To get started using the bot, try running the following commands.\n**1**. `pk;system new` - Create a system if you haven't already\n**2**. `pk;member add John` - Add a new member to your system\n**3**. `pk;member John proxy [text]` - Set up square brackets as proxy tags\n**4**. You're done!\n**5**. Optionally, you may set an avatar from the URL of an image with:\n`pk;member John avatar [link to image]`\n\nType `pk;help member` for more information."
},
{
"name": "Useful tips",
"content": "React with ❌ on a proxied message to delete it (if you sent it!).\nReact with ❓ on a proxied message to look up information about it, like who sent it."
},
{
"name": "More information",
"content": "For a full list of commands, type `pk;commands`.\nFor a more in-depth explanation of message proxying, type `pk;help proxy`.\nIf you're an existing user of the Tupperbox proxy bot, type `pk;import` to import your data from there."
},
{
"name": "Support server",
"content": "We also have a Discord server for support, discussion, suggestions, announcements, etc: <https://discord.gg/PczBt78>"
}
],
"proxy": [
{
"name": "Proxying",
"content": "Proxying through PluralKit lets system members have their own faux-account with their name and avatar.\nYou'll type a message from your account in *proxy tags*, and PluralKit will recognize those tags and repost the message with the proper details, with the minor caveat of having the **[BOT]** icon next to the name (this is a Discord limitation and cannot be circumvented).\n\nTo set up a member's proxy tag, use the `pk;member <name> proxy [example match]` command.\n\nYou'll need to give the bot an \"example match\" containing the word `text`. Imagine you're proxying the word \"text\", and add that to the end of the command. For example: `pk;member John proxy [text]`. That will set the member John up to use square brackets as proxy tags. Now saying something like `[hello world]` will proxy the text \"hello world\" with John's name and avatar. You can also use other symbols, letters, numbers, et cetera, as prefixes, suffixes, or both. `J:text`, `$text` and `text]` are also examples of valid example matches."
}
]
},
"footer": "By @Ske#6201 | GitHub: https://github.com/xSke/PluralKit/"
}

View File

@ -1,6 +0,0 @@
import json
import os.path
helpfile = None
with open(os.path.dirname(__file__) + "/help.json", "r") as f:
helpfile = json.load(f)

View File

@ -1,254 +0,0 @@
import asyncio
import re
import discord
from io import BytesIO
from typing import Optional
from pluralkit import db
from pluralkit.bot import utils, channel_logger, embeds
from pluralkit.bot.channel_logger import ChannelLogger
from pluralkit.member import Member
from pluralkit.system import System
class ProxyError(Exception):
pass
async def get_or_create_webhook_for_channel(conn, bot_user: discord.User, channel: discord.TextChannel):
# First, check if we have one saved in the DB
webhook_from_db = await db.get_webhook(conn, channel.id)
if webhook_from_db:
webhook_id, webhook_token = webhook_from_db
session = channel._state.http._session
hook = discord.Webhook.partial(webhook_id, webhook_token, adapter=discord.AsyncWebhookAdapter(session))
return hook
try:
# If not, we check to see if there already exists one we've missed
for existing_hook in await channel.webhooks():
existing_hook_creator = existing_hook.user.id if existing_hook.user else None
is_mine = existing_hook.name == "PluralKit Proxy Webhook" and existing_hook_creator == bot_user.id
if is_mine:
# We found one we made, let's add that to the DB just to be sure
await db.add_webhook(conn, channel.id, existing_hook.id, existing_hook.token)
return existing_hook
# If not, we create one and save it
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
except discord.Forbidden:
raise ProxyError(
"PluralKit does not have the \"Manage Webhooks\" permission, and thus cannot proxy your message. Please contact a server administrator.")
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
return created_webhook
async def make_attachment_file(message: discord.Message):
if not message.attachments:
return None
first_attachment = message.attachments[0]
# Copy the file data to the buffer
# TODO: do this without buffering... somehow
bio = BytesIO()
await first_attachment.save(bio)
return discord.File(bio, first_attachment.filename)
def fix_clyde(name: str) -> str:
# Discord doesn't allow any webhook username to contain the word "Clyde"
# So replace "Clyde" with "C lyde" (except with a hair space, hence \u200A)
# Zero-width spacers are ignored by Discord and will still trigger the error
return re.sub("(c)(lyde)", "\\1\u200A\\2", name, flags=re.IGNORECASE)
async def send_proxy_message(conn, original_message: discord.Message, system: System, member: Member,
inner_text: str, logger: ChannelLogger, bot_user: discord.User):
# Send the message through the webhook
webhook = await get_or_create_webhook_for_channel(conn, bot_user, original_message.channel)
# Bounds check the combined name to avoid silent erroring
full_username = "{} {}".format(member.name, system.tag or "").strip()
full_username = fix_clyde(full_username)
if len(full_username) < 2:
raise ProxyError(
"The webhook's name, `{}`, is shorter than two characters, and thus cannot be proxied. Please change the member name or use a longer system tag.".format(
full_username))
if len(full_username) > 32:
raise ProxyError(
"The webhook's name, `{}`, is longer than 32 characters, and thus cannot be proxied. Please change the member name or use a shorter system tag.".format(
full_username))
try:
sent_message = await webhook.send(
content=inner_text,
username=full_username,
avatar_url=member.avatar_url,
file=await make_attachment_file(original_message),
wait=True
)
except discord.NotFound:
# The webhook we got from the DB doesn't actually exist
# This can happen if someone manually deletes it from the server
# If we delete it from the DB then call the function again, it'll re-create one for us
# (lol, lazy)
await db.delete_webhook(conn, original_message.channel.id)
await send_proxy_message(conn, original_message, system, member, inner_text, logger, bot_user)
return
# Save the proxied message in the database
await db.add_message(conn, sent_message.id, original_message.channel.id, member.id,
original_message.author.id)
# Log it in the log channel if possible
await logger.log_message_proxied(
conn,
original_message.channel.guild.id,
original_message.channel.name,
original_message.channel.id,
original_message.author.name,
original_message.author.discriminator,
original_message.author.id,
member.name,
member.hid,
member.avatar_url,
system.name,
system.hid,
inner_text,
sent_message.attachments[0].url if sent_message.attachments else None,
sent_message.created_at,
sent_message.id
)
# And finally, gotta delete the original.
# We wait half a second or so because if the client receives the message deletion
# event before the message actually gets confirmed sent on their end, the message
# doesn't properly get deleted for them, leading to duplication
try:
await asyncio.sleep(0.5)
await original_message.delete()
except discord.Forbidden:
raise ProxyError(
"PluralKit does not have permission to delete user messages. Please contact a server administrator.")
except discord.NotFound:
# Sometimes some other thing will delete the original message before PK gets to it
# This is not a problem - message gets deleted anyway :)
# Usually happens when Tupperware and PK conflict
pass
async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger, bot_user: discord.User) -> bool:
# Don't bother proxying in DMs
if isinstance(message.channel, discord.abc.PrivateChannel):
return False
# Get the system associated with the account, if possible
system = await System.get_by_account(conn, message.author.id)
if not system:
return False
# Match on the members' proxy tags
proxy_match = await system.match_proxy(conn, message.content)
if not proxy_match:
return False
member, inner_message = proxy_match
# Make sure no @everyones slip through
# Webhooks implicitly have permission to mention @everyone so we have to enforce that manually
inner_message = utils.sanitize(inner_message)
# If we don't have an inner text OR an attachment, we cancel because the hook can't send that
# Strip so it counts a string of solely spaces as blank too
if not inner_message.strip() and not message.attachments:
return False
# So, we now have enough information to successfully proxy a message
async with conn.transaction():
try:
await send_proxy_message(conn, message, system, member, inner_message, logger, bot_user)
except ProxyError as e:
# First, try to send the error in the channel it was triggered in
# Failing that, send the error in a DM.
# Failing *that*... give up, I guess.
try:
await message.channel.send("\u274c {}".format(str(e)))
except discord.Forbidden:
try:
await message.author.send("\u274c {}".format(str(e)))
except discord.Forbidden:
pass
return True
async def handle_deleted_message(conn, client: discord.Client, message_id: int,
message_content: Optional[str], logger: channel_logger.ChannelLogger) -> bool:
msg = await db.get_message(conn, message_id)
if not msg:
return False
channel = client.get_channel(msg.channel)
if not channel:
# Weird edge case, but channel *could* be deleted at this point (can't think of any scenarios it would be tho)
return False
await db.delete_message(conn, message_id)
await logger.log_message_deleted(
conn,
channel.guild.id,
channel.name,
msg.name,
msg.hid,
msg.avatar_url,
msg.system_name,
msg.system_hid,
message_content,
message_id
)
return True
async def try_delete_by_reaction(conn, client: discord.Client, message_id: int, reaction_user: int,
logger: channel_logger.ChannelLogger) -> bool:
# Find the message by the given message id or reaction user
msg = await db.get_message_by_sender_and_id(conn, message_id, reaction_user)
if not msg:
# Either the wrong user reacted or the message isn't a proxy message
# In either case - not our problem
return False
# Find the original message
original_message = await client.get_channel(msg.channel).get_message(message_id)
if not original_message:
# Message got deleted, possibly race condition, eh
return False
# Then delete the original message
await original_message.delete()
await handle_deleted_message(conn, client, message_id, original_message.content, logger)
async def do_query_message(conn, client: discord.Client, queryer_id: int, message_id: int) -> bool:
# Find the message that was queried
msg = await db.get_message(conn, message_id)
if not msg:
return False
# Then DM the queryer the message embed
card = await embeds.message_card(client, msg, include_pronouns=True)
user = client.get_user(queryer_id)
if not user:
# We couldn't find this user in the cache - bail
return False
# Send the card to the user
try:
await user.send(embed=card)
except discord.Forbidden:
# User doesn't have DMs enabled, not much we can do about that
pass

View File

@ -1,87 +0,0 @@
import discord
import logging
import re
from typing import Optional
from pluralkit import db
from pluralkit.member import Member
from pluralkit.system import System
logger = logging.getLogger("pluralkit.utils")
def escape(s):
return s.replace("`", "\\`")
def bounds_check_member_name(new_name, system_tag):
if len(new_name) > 32:
return "Name cannot be longer than 32 characters."
if system_tag:
if len("{} {}".format(new_name, system_tag)) > 32:
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(
system_tag)
async def parse_mention(client: discord.Client, mention: str) -> Optional[discord.User]:
# First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention)
if match:
try:
return await client.get_user_info(int(match.group(1)))
except discord.NotFound:
return None
# Then try with just ID
try:
return await client.get_user_info(int(mention))
except (ValueError, discord.NotFound):
return None
def parse_channel_mention(mention: str, server: discord.Guild) -> Optional[discord.TextChannel]:
match = re.fullmatch("<#(\\d+)>", mention)
if match:
return server.get_channel(int(match.group(1)))
try:
return server.get_channel(int(mention))
except ValueError:
return None
async def get_system_fuzzy(conn, client: discord.Client, key) -> Optional[System]:
if isinstance(key, discord.User):
return await db.get_system_by_account(conn, account_id=key.id)
if isinstance(key, str) and len(key) == 5:
return await db.get_system_by_hid(conn, system_hid=key)
account = await parse_mention(client, key)
if account:
system = await db.get_system_by_account(conn, account_id=account.id)
if system:
return system
return None
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> Member:
# First search by hid
if system_only:
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
else:
member = await db.get_member_by_hid(conn, member_hid=key)
if member is not None:
return member
# Then search by name, if we have a system
if system_id:
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
if member is not None:
return member
def sanitize(text):
# Insert a zero-width space in @everyone so it doesn't trigger
return text.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")

View File

@ -1,383 +0,0 @@
from collections import namedtuple
from datetime import datetime
import logging
from typing import List, Optional
import time
import asyncpg
import asyncpg.exceptions
from discord.utils import snowflake_time
from pluralkit.system import System
from pluralkit.member import Member
logger = logging.getLogger("pluralkit.db")
async def connect(uri):
while True:
try:
return await asyncpg.create_pool(uri)
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
logger.exception("Failed to connect to database, retrying in 5 seconds...")
time.sleep(5)
def db_wrap(func):
async def inner(*args, **kwargs):
before = time.perf_counter()
try:
res = await func(*args, **kwargs)
after = time.perf_counter()
logger.debug(" - DB call {} took {:.2f} ms".format(func.__name__, (after - before) * 1000))
return res
except asyncpg.exceptions.PostgresError:
logger.exception("Error from database query {}".format(func.__name__))
return inner
@db_wrap
async def create_system(conn, system_name: str, system_hid: str) -> System:
logger.debug("Creating system (name={}, hid={})".format(
system_name, system_hid))
row = await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
return System(**row) if row else None
@db_wrap
async def remove_system(conn, system_id: int):
logger.debug("Deleting system (id={})".format(system_id))
await conn.execute("delete from systems where id = $1", system_id)
@db_wrap
async def create_member(conn, system_id: int, member_name: str, member_hid: str) -> Member:
logger.debug("Creating member (system={}, name={}, hid={})".format(
system_id, member_name, member_hid))
row = await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
return Member(**row) if row else None
@db_wrap
async def delete_member(conn, member_id: int):
logger.debug("Deleting member (id={})".format(member_id))
await conn.execute("delete from members where id = $1", member_id)
@db_wrap
async def link_account(conn, system_id: int, account_id: int):
logger.debug("Linking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("insert into accounts (uid, system) values ($1, $2)", account_id, system_id)
@db_wrap
async def unlink_account(conn, system_id: int, account_id: int):
logger.debug("Unlinking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("delete from accounts where uid = $1 and system = $2", account_id, system_id)
@db_wrap
async def get_linked_accounts(conn, system_id: int) -> List[int]:
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
@db_wrap
async def get_system_by_account(conn, account_id: int) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
return System(**row) if row else None
@db_wrap
async def get_system_by_token(conn, token: str) -> Optional[System]:
row = await conn.fetchrow("select * from systems where token = $1", token)
return System(**row) if row else None
@db_wrap
async def get_system_by_hid(conn, system_hid: str) -> System:
row = await conn.fetchrow("select * from systems where hid = $1", system_hid)
return System(**row) if row else None
@db_wrap
async def get_system(conn, system_id: int) -> System:
row = await conn.fetchrow("select * from systems where id = $1", system_id)
return System(**row) if row else None
@db_wrap
async def get_member_by_name(conn, system_id: int, member_name: str) -> Member:
row = await conn.fetchrow("select * from members where system = $1 and lower(name) = lower($2)", system_id, member_name)
return Member(**row) if row else None
@db_wrap
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str) -> Member:
row = await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
return Member(**row) if row else None
@db_wrap
async def get_member_by_hid(conn, member_hid: str) -> Member:
row = await conn.fetchrow("select * from members where hid = $1", member_hid)
return Member(**row) if row else None
@db_wrap
async def get_member(conn, member_id: int) -> Member:
row = await conn.fetchrow("select * from members where id = $1", member_id)
return Member(**row) if row else None
@db_wrap
async def get_members(conn, members: list) -> List[Member]:
rows = await conn.fetch("select * from members where id = any($1)", members)
return [Member(**row) for row in rows]
@db_wrap
async def update_system_field(conn, system_id: int, field: str, value):
logger.debug("Updating system field (id={}, {}={})".format(
system_id, field, value))
await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id)
@db_wrap
async def update_member_field(conn, member_id: int, field: str, value):
logger.debug("Updating member field (id={}, {}={})".format(
member_id, field, value))
await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id)
@db_wrap
async def get_all_members(conn, system_id: int) -> List[Member]:
rows = await conn.fetch("select * from members where system = $1", system_id)
return [Member(**row) for row in rows]
@db_wrap
async def get_members_exceeding(conn, system_id: int, length: int) -> List[Member]:
rows = await conn.fetch("select * from members where system = $1 and length(name) > $2", system_id, length)
return [Member(**row) for row in rows]
@db_wrap
async def get_webhook(conn, channel_id: int) -> (str, str):
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", channel_id)
return (str(row["webhook"]), row["token"]) if row else None
@db_wrap
async def add_webhook(conn, channel_id: int, webhook_id: int, webhook_token: str):
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", channel_id, webhook_id, webhook_token)
@db_wrap
async def delete_webhook(conn, channel_id: int):
await conn.execute("delete from webhooks where channel = $1", channel_id)
@db_wrap
async def add_message(conn, message_id: int, channel_id: int, member_id: int, sender_id: int):
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", message_id, channel_id, member_id, sender_id)
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
id: int
hid: str
prefix: str
suffix: str
color: str
name: str
avatar_url: str
tag: str
system_name: str
system_hid: str
@db_wrap
async def get_members_by_account(conn, account_id: int) -> List[ProxyMember]:
# Returns a "chimera" object
rows = await conn.fetch("""select
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
systems.tag, systems.name as system_name, systems.hid as system_hid
from
systems, members, accounts
where
accounts.uid = $1
and systems.id = accounts.system
and members.system = systems.id""", account_id)
return [ProxyMember(**row) for row in rows]
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender", "name", "hid", "avatar_url", "system_name", "system_hid", "pronouns"])):
mid: int
channel: int
member: int
sender: int
name: str
hid: str
avatar_url: str
system_name: str
system_hid: str
pronouns: str
def to_json(self):
return {
"id": str(self.mid),
"channel": str(self.channel),
"member": self.hid,
"system": self.system_hid,
"message_sender": str(self.sender),
"timestamp": snowflake_time(self.mid).isoformat()
}
@db_wrap
async def get_message_by_sender_and_id(conn, message_id: int, sender_id: int) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url, members.pronouns,
systems.name as system_name, systems.hid as system_hid
from
messages, members, systems
where
messages.member = members.id
and members.system = systems.id
and mid = $1
and sender = $2""", message_id, sender_id)
return MessageInfo(**row) if row else None
@db_wrap
async def get_message(conn, message_id: int) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url, members.pronouns,
systems.name as system_name, systems.hid as system_hid
from
messages, members, systems
where
messages.member = members.id
and members.system = systems.id
and mid = $1""", message_id)
return MessageInfo(**row) if row else None
@db_wrap
async def delete_message(conn, message_id: int):
logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", message_id)
@db_wrap
async def get_member_message_count(conn, member_id: int) -> int:
return await conn.fetchval("select count(*) from messages where member = $1", member_id)
@db_wrap
async def front_history(conn, system_id: int, count: int):
return await conn.fetch("""select
switches.*,
array(
select member from switch_members
where switch_members.switch = switches.id
order by switch_members.id asc
) as members
from switches
where switches.system = $1
order by switches.timestamp desc
limit $2""", system_id, count)
@db_wrap
async def add_switch(conn, system_id: int):
logger.debug("Adding switch (system={})".format(system_id))
res = await conn.fetchrow("insert into switches (system) values ($1) returning *", system_id)
return res["id"]
@db_wrap
async def move_switch(conn, system_id: int, switch_id: int, new_time: datetime):
logger.debug("Moving latest switch (system={}, id={}, new_time={})".format(system_id, switch_id, new_time))
await conn.execute("update switches set timestamp = $1 where system = $2 and id = $3", new_time, system_id, switch_id)
@db_wrap
async def add_switch_member(conn, switch_id: int, member_id: int):
logger.debug("Adding switch member (switch={}, member={})".format(switch_id, member_id))
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
@db_wrap
async def delete_switch(conn, switch_id: int):
logger.debug("Deleting switch (id={})".format(switch_id))
await conn.execute("delete from switches where id = $1", switch_id)
@db_wrap
async def get_server_info(conn, server_id: int):
return await conn.fetchrow("select * from servers where id = $1", server_id)
@db_wrap
async def update_server(conn, server_id: int, logging_channel_id: int):
logging_channel_id = logging_channel_id if logging_channel_id else None
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", server_id, logging_channel_id)
@db_wrap
async def member_count(conn) -> int:
return await conn.fetchval("select count(*) from members")
@db_wrap
async def system_count(conn) -> int:
return await conn.fetchval("select count(*) from systems")
@db_wrap
async def message_count(conn) -> int:
return await conn.fetchval("select count(*) from messages")
@db_wrap
async def account_count(conn) -> int:
return await conn.fetchval("select count(*) from accounts")
async def create_tables(conn):
await conn.execute("""create table if not exists systems (
id serial primary key,
hid char(5) unique not null,
name text,
description text,
tag text,
avatar_url text,
token text,
created timestamp not null default (current_timestamp at time zone 'utc'),
ui_tz text not null default 'UTC'
)""")
await conn.execute("""create table if not exists members (
id serial primary key,
hid char(5) unique not null,
system serial not null references systems(id) on delete cascade,
color char(6),
avatar_url text,
name text not null,
birthday date,
pronouns text,
description text,
prefix text,
suffix text,
created timestamp not null default (current_timestamp at time zone 'utc')
)""")
await conn.execute("""create table if not exists accounts (
uid bigint primary key,
system serial not null references systems(id) on delete cascade
)""")
await conn.execute("""create table if not exists messages (
mid bigint primary key,
channel bigint not null,
member serial not null references members(id) on delete cascade,
sender bigint not null
)""")
await conn.execute("""create table if not exists switches (
id serial primary key,
system serial not null references systems(id) on delete cascade,
timestamp timestamp not null default (current_timestamp at time zone 'utc')
)""")
await conn.execute("""create table if not exists switch_members (
id serial primary key,
switch serial not null references switches(id) on delete cascade,
member serial not null references members(id) on delete cascade
)""")
await conn.execute("""create table if not exists webhooks (
channel bigint primary key,
webhook bigint not null,
token text not null
)""")
await conn.execute("""create table if not exists servers (
id bigint primary key,
log_channel bigint
)""")

View File

@ -1,104 +0,0 @@
from typing import Tuple
class PluralKitError(Exception):
def __init__(self, message):
self.message = message
self.help_page = None
def with_help(self, help_page: Tuple[str, str]):
self.help_page = help_page
class ExistingSystemError(PluralKitError):
def __init__(self):
super().__init__(
"You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
class DescriptionTooLongError(PluralKitError):
def __init__(self):
super().__init__("You can't have a description longer than 1024 characters.")
class TagTooLongError(PluralKitError):
def __init__(self):
super().__init__("You can't have a system tag longer than 32 characters.")
class TagTooLongWithMembersError(PluralKitError):
def __init__(self, member_names):
super().__init__(
"The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(
", ".join(member_names)))
self.member_names = member_names
class CustomEmojiError(PluralKitError):
def __init__(self):
super().__init__(
"Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.")
class InvalidAvatarURLError(PluralKitError):
def __init__(self):
super().__init__("Invalid image URL.")
class AccountInOwnSystemError(PluralKitError):
def __init__(self):
super().__init__("That account is already linked to your own system.")
class AccountAlreadyLinkedError(PluralKitError):
def __init__(self, existing_system):
super().__init__("The mentioned account is already linked to a system (`{}`)".format(existing_system.hid))
self.existing_system = existing_system
class UnlinkingLastAccountError(PluralKitError):
def __init__(self):
super().__init__("This is the only account on your system, so you can't unlink it.")
class MemberNameTooLongError(PluralKitError):
def __init__(self, tag_present: bool):
if tag_present:
super().__init__(
"The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.")
else:
super().__init__("The maximum length of a member name is 32 characters.")
class InvalidColorError(PluralKitError):
def __init__(self):
super().__init__("Color must be a valid hex color. (eg. #ff0000)")
class InvalidDateStringError(PluralKitError):
def __init__(self):
super().__init__("Invalid date string. Date must be in ISO-8601 format (YYYY-MM-DD, eg. 1999-07-25).")
class MembersAlreadyFrontingError(PluralKitError):
def __init__(self, members: "List[Member]"):
if len(members) == 0:
super().__init__("There are already no members fronting.")
elif len(members) == 1:
super().__init__("Member {} is already fronting.".format(members[0].name))
else:
super().__init__("Members {} are already fronting.".format(", ".join([member.name for member in members])))
class DuplicateSwitchMembersError(PluralKitError):
def __init__(self):
super().__init__("Duplicate members in member list.")
class InvalidTimeZoneError(PluralKitError):
def __init__(self, tz_name: str):
super().__init__("Invalid time zone designation \"{}\".\n\nFor a list of valid time zone designations, see the `TZ database name` column here: <https://en.wikipedia.org/wiki/List_of_tz_database_time_zones#List>.".format(tz_name))
class TupperboxImportError(PluralKitError):
def __init__(self):
super().__init__("Invalid Tupperbox file.")

View File

@ -1,177 +0,0 @@
import re
from datetime import date, datetime
from collections.__init__ import namedtuple
from typing import Optional, Union
from pluralkit import db, errors
from pluralkit.utils import validate_avatar_url_or_raise, contains_custom_emoji
class Member(namedtuple("Member",
["id", "hid", "system", "color", "avatar_url", "name", "birthday", "pronouns", "description",
"prefix", "suffix", "created"])):
"""An immutable representation of a system member fetched from the database."""
id: int
hid: str
system: int
color: str
avatar_url: str
name: str
birthday: date
pronouns: str
description: str
prefix: str
suffix: str
created: datetime
def to_json(self):
return {
"id": self.hid,
"name": self.name,
"color": self.color,
"avatar_url": self.avatar_url,
"birthday": self.birthday.isoformat() if self.birthday else None,
"pronouns": self.pronouns,
"description": self.description,
"prefix": self.prefix,
"suffix": self.suffix
}
@staticmethod
async def get_member_by_id(conn, member_id: int) -> Optional["Member"]:
"""Fetch a member with the given internal member ID from the database."""
return await db.get_member(conn, member_id)
@staticmethod
async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]":
"""Fetch a member by the given name in the given system from the database."""
member = await db.get_member_by_name(conn, system_id, member_name)
return member
@staticmethod
async def get_member_by_hid(conn, system_id: Optional[int], member_hid: str) -> "Optional[Member]":
"""Fetch a member by the given hid from the database. If @`system_id` is present, will only return members from that system."""
if system_id:
member = await db.get_member_by_hid_in_system(conn, system_id, member_hid)
else:
member = await db.get_member_by_hid(conn, member_hid)
return member
@staticmethod
async def get_member_fuzzy(conn, system_id: int, name: str) -> "Optional[Member]":
by_hid = await Member.get_member_by_hid(conn, system_id, name)
if by_hid:
return by_hid
by_name = await Member.get_member_by_name(conn, system_id, name)
return by_name
async def set_name(self, conn, new_name: str):
"""
Set the name of a member.
:raises: CustomEmojiError
"""
# Custom emojis can't go in the member name
# Technically they *could*, but they wouldn't render properly
# so I'd rather explicitly ban them to in order to avoid confusion
if contains_custom_emoji(new_name):
raise errors.CustomEmojiError()
await db.update_member_field(conn, self.id, "name", new_name)
async def set_description(self, conn, new_description: Optional[str]):
"""
Set or clear the description of a member.
:raises: DescriptionTooLongError
"""
# Explicit length checking
if new_description and len(new_description) > 1024:
raise errors.DescriptionTooLongError()
await db.update_member_field(conn, self.id, "description", new_description)
async def set_avatar(self, conn, new_avatar_url: Optional[str]):
"""
Set or clear the avatar of a member.
:raises: InvalidAvatarURLError
"""
if new_avatar_url:
validate_avatar_url_or_raise(new_avatar_url)
await db.update_member_field(conn, self.id, "avatar_url", new_avatar_url)
async def set_color(self, conn, new_color: Optional[str]):
"""
Set or clear the associated color of a member.
:raises: InvalidColorError
"""
cleaned_color = None
if new_color:
match = re.fullmatch("#?([0-9A-Fa-f]{6})", new_color)
if not match:
raise errors.InvalidColorError()
cleaned_color = match.group(1).lower()
await db.update_member_field(conn, self.id, "color", cleaned_color)
async def set_birthdate(self, conn, new_date: Union[date, str]):
"""
Set or clear the birthdate of a member. To hide the birth year, pass a year of 0001.
If passed a string, will attempt to parse the string as a date.
:raises: InvalidDateStringError
"""
if isinstance(new_date, str):
date_str = new_date
try:
new_date = datetime.strptime(date_str, "%Y-%m-%d").date()
except ValueError:
try:
# Try again, adding 0001 as a placeholder year
# This is considered a "null year" and will be omitted from the info card
# Useful if you want your birthday to be displayed yearless.
new_date = datetime.strptime("0001-" + date_str, "%Y-%m-%d").date()
except ValueError:
raise errors.InvalidDateStringError()
await db.update_member_field(conn, self.id, "birthday", new_date)
async def set_pronouns(self, conn, new_pronouns: str):
"""Set or clear the associated pronouns with a member."""
await db.update_member_field(conn, self.id, "pronouns", new_pronouns)
async def set_proxy_tags(self, conn, prefix: Optional[str], suffix: Optional[str]):
"""
Set the proxy tags for a member. Having no prefix *and* no suffix will disable proxying.
"""
# Make sure empty strings or other falsey values are actually None
prefix = prefix or None
suffix = suffix or None
async with conn.transaction():
await db.update_member_field(conn, member_id=self.id, field="prefix", value=prefix)
await db.update_member_field(conn, member_id=self.id, field="suffix", value=suffix)
async def delete(self, conn):
"""Delete this member from the database."""
await db.delete_member(conn, self.id)
async def fetch_system(self, conn) -> "System":
"""Fetch the member's system from the database"""
return await db.get_system(conn, self.system)
async def message_count(self, conn) -> int:
return await db.get_member_message_count(conn, self.id)
def birthday_string(self) -> Optional[str]:
if not self.birthday:
return None
if self.birthday.year == 1:
return self.birthday.strftime("%b %d")
return self.birthday.strftime("%b %d, %Y")

View File

@ -1,28 +0,0 @@
from collections import namedtuple
from datetime import datetime
from typing import List
from pluralkit import db
from pluralkit.member import Member
class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])):
id: int
system: int
timestamp: datetime
members: List[int]
async def fetch_members(self, conn) -> List[Member]:
return await db.get_members(conn, self.members)
async def delete(self, conn):
await db.delete_switch(conn, self.id)
async def move(self, conn, new_timestamp):
await db.move_switch(conn, self.system, self.id, new_timestamp)
async def to_json(self, hid_getter):
return {
"timestamp": self.timestamp.isoformat(),
"members": [await hid_getter(m) for m in self.members]
}

View File

@ -1,322 +0,0 @@
import random
import re
import string
from collections.__init__ import namedtuple
from datetime import datetime
from typing import Optional, List, Tuple
import pytz
from pluralkit import db, errors
from pluralkit.member import Member
from pluralkit.switch import Switch
from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise
class TupperboxImportResult(namedtuple("TupperboxImportResult", ["updated", "created", "tags"])):
pass
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "token", "created", "ui_tz"])):
id: int
hid: str
name: str
description: str
tag: str
avatar_url: str
token: str
created: datetime
# pytz-compatible time zone name, usually Olson-style (eg. Europe/Amsterdam)
ui_tz: str
@staticmethod
async def get_by_id(conn, system_id: int) -> Optional["System"]:
return await db.get_system(conn, system_id)
@staticmethod
async def get_by_account(conn, account_id: int) -> Optional["System"]:
return await db.get_system_by_account(conn, account_id)
@staticmethod
async def get_by_token(conn, token: str) -> Optional["System"]:
return await db.get_system_by_token(conn, token)
@staticmethod
async def get_by_hid(conn, hid: str) -> Optional["System"]:
return await db.get_system_by_hid(conn, hid)
@staticmethod
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
async with conn.transaction():
existing_system = await System.get_by_account(conn, account_id)
if existing_system:
raise errors.ExistingSystemError()
new_hid = generate_hid()
async with conn.transaction():
new_system = await db.create_system(conn, system_name, new_hid)
await db.link_account(conn, new_system.id, account_id)
return new_system
async def set_name(self, conn, new_name: Optional[str]):
await db.update_system_field(conn, self.id, "name", new_name)
async def set_description(self, conn, new_description: Optional[str]):
# Explicit length error
if new_description and len(new_description) > 1024:
raise errors.DescriptionTooLongError()
await db.update_system_field(conn, self.id, "description", new_description)
async def set_tag(self, conn, new_tag: Optional[str]):
if new_tag:
# Explicit length error
if len(new_tag) > 32:
raise errors.TagTooLongError()
if contains_custom_emoji(new_tag):
raise errors.CustomEmojiError()
await db.update_system_field(conn, self.id, "tag", new_tag)
async def set_avatar(self, conn, new_avatar_url: Optional[str]):
if new_avatar_url:
validate_avatar_url_or_raise(new_avatar_url)
await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url)
async def link_account(self, conn, new_account_id: int):
async with conn.transaction():
existing_system = await System.get_by_account(conn, new_account_id)
if existing_system:
if existing_system.id == self.id:
raise errors.AccountInOwnSystemError()
raise errors.AccountAlreadyLinkedError(existing_system)
await db.link_account(conn, self.id, new_account_id)
async def unlink_account(self, conn, account_id: int):
async with conn.transaction():
linked_accounts = await db.get_linked_accounts(conn, self.id)
if len(linked_accounts) == 1:
raise errors.UnlinkingLastAccountError()
await db.unlink_account(conn, self.id, account_id)
async def get_linked_account_ids(self, conn) -> List[int]:
return await db.get_linked_accounts(conn, self.id)
async def delete(self, conn):
await db.remove_system(conn, self.id)
async def refresh_token(self, conn) -> str:
new_token = "".join(random.choices(string.ascii_letters + string.digits, k=64))
await db.update_system_field(conn, self.id, "token", new_token)
return new_token
async def get_token(self, conn) -> str:
if self.token:
return self.token
return await self.refresh_token(conn)
async def create_member(self, conn, member_name: str) -> Member:
# TODO: figure out what to do if this errors out on collision on generate_hid
new_hid = generate_hid()
if len(member_name) > self.get_member_name_limit():
raise errors.MemberNameTooLongError(tag_present=bool(self.tag))
member = await db.create_member(conn, self.id, member_name, new_hid)
return member
async def get_members(self, conn) -> List[Member]:
return await db.get_all_members(conn, self.id)
async def get_switches(self, conn, count) -> List[Switch]:
"""Returns the latest `count` switches logged for this system, ordered latest to earliest."""
return [Switch(**s) for s in await db.front_history(conn, self.id, count)]
async def get_latest_switch(self, conn) -> Optional[Switch]:
"""Returns the latest switch logged for this system, or None if no switches have been logged"""
switches = await self.get_switches(conn, 1)
if switches:
return switches[0]
else:
return None
async def add_switch(self, conn, members: List[Member]) -> Switch:
"""
Logs a new switch for a system.
:raises: MembersAlreadyFrontingError, DuplicateSwitchMembersError
"""
new_ids = [member.id for member in members]
last_switch = await self.get_latest_switch(conn)
# If we have a switch logged before, make sure this isn't a dupe switch
if last_switch:
last_switch_members = await last_switch.fetch_members(conn)
last_ids = [member.id for member in last_switch_members]
# We don't compare by set() here because swapping multiple is a valid operation
if last_ids == new_ids:
raise errors.MembersAlreadyFrontingError(members)
# Check for dupes
if len(set(new_ids)) != len(new_ids):
raise errors.DuplicateSwitchMembersError()
async with conn.transaction():
switch_id = await db.add_switch(conn, self.id)
# TODO: batch query here
for member in members:
await db.add_switch_member(conn, switch_id, member.id)
return await self.get_latest_switch(conn)
def get_member_name_limit(self) -> int:
"""Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
if self.tag:
return 32 - len(self.tag) - 1
else:
return 32
async def match_proxy(self, conn, message: str) -> Optional[Tuple[Member, str]]:
"""Tries to find a member with proxy tags matching the given message. Returns the member and the inner contents."""
members = await db.get_all_members(conn, self.id)
# Sort by specificity (members with both prefix and suffix defined go higher)
# This will make sure more "precise" proxy tags get tried first and match properly
members = sorted(members, key=lambda x: int(bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
for member in members:
proxy_prefix = member.prefix or ""
proxy_suffix = member.suffix or ""
if not proxy_prefix and not proxy_suffix:
# If the member has neither a prefix or a suffix, cancel early
# Otherwise it'd match any message no matter what
continue
# Check if the message matches these tags
if message.startswith(proxy_prefix) and message.endswith(proxy_suffix):
# If the message starts with a mention, "separate" that and match the bit after
mention_match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message)
leading_mentions = ""
if mention_match:
message = message[mention_match.span(0)[1]:].strip()
leading_mentions = mention_match.group(0)
# Extract the inner message (special case because -0 is invalid as an end slice)
if len(proxy_suffix) == 0:
inner_message = message[len(proxy_prefix):]
else:
inner_message = message[len(proxy_prefix):-len(proxy_suffix)]
# Add the stripped mentions back if there are any
inner_message = leading_mentions + inner_message
return member, inner_message
def format_time(self, dt: datetime) -> str:
"""
Localizes the given `datetime` to a string based on the system's preferred time zone.
Assumes `dt` is a naïve `datetime` instance set to UTC, which is consistent with the rest of PluralKit.
"""
tz = pytz.timezone(self.ui_tz)
# Set to aware (UTC), convert to tz, set to naive (tz), then format and append name
return tz.normalize(pytz.utc.localize(dt)).replace(tzinfo=None).isoformat(sep=" ", timespec="seconds") + " " + tz.tzname(dt)
async def set_time_zone(self, conn, tz_name: str) -> pytz.tzinfo:
"""
Sets the system time zone to the time zone represented by the given string.
If `tz_name` is None or an empty string, will default to UTC.
If `tz_name` does not represent a valid time zone string, will raise InvalidTimeZoneError.
:raises: InvalidTimeZoneError
:returns: The `pytz.tzinfo` instance of the newly set time zone.
"""
try:
tz = pytz.timezone(tz_name or "UTC")
except pytz.UnknownTimeZoneError:
raise errors.InvalidTimeZoneError(tz_name)
await db.update_system_field(conn, self.id, "ui_tz", tz.zone)
return tz
async def import_from_tupperbox(self, conn, data: dict):
"""
Imports from a Tupperbox JSON data file.
:raises: TupperboxImportError
"""
if not "tuppers" in data:
raise errors.TupperboxImportError()
if not isinstance(data["tuppers"], list):
raise errors.TupperboxImportError()
all_tags = set()
created_members = set()
updated_members = set()
for tupper in data["tuppers"]:
# Sanity check tupper fields
for field in ["name", "avatar_url", "brackets", "birthday", "description", "tag"]:
if field not in tupper:
raise errors.TupperboxImportError()
# Find member by name, create if not exists
member_name = str(tupper["name"])
member = await Member.get_member_by_name(conn, self.id, member_name)
if not member:
# And keep track of created members
created_members.add(member_name)
member = await self.create_member(conn, member_name)
else:
# Keep track of updated members
updated_members.add(member_name)
# Set avatar
await member.set_avatar(conn, str(tupper["avatar_url"]))
# Set proxy tags
if not (isinstance(tupper["brackets"], list) and len(tupper["brackets"]) >= 2):
raise errors.TupperboxImportError()
await member.set_proxy_tags(conn, str(tupper["brackets"][0]), str(tupper["brackets"][1]))
# Set birthdate (input is in ISO-8601, first 10 characters is the date)
if tupper["birthday"]:
try:
await member.set_birthdate(conn, str(tupper["birthday"][:10]))
except errors.InvalidDateStringError:
pass
# Set description
await member.set_description(conn, tupper["description"])
# Keep track of tag
all_tags.add(tupper["tag"])
# Since Tupperbox does tags on a per-member basis, we only apply a system tag if
# every member has the same tag (surprisingly common)
# If not, we just do nothing. (This will be reported in the caller function through the returned result)
if len(all_tags) == 1:
tag = list(all_tags)[0]
await self.set_tag(conn, tag)
return TupperboxImportResult(updated=updated_members, created=created_members, tags=all_tags)
def to_json(self):
return {
"id": self.hid,
"name": self.name,
"description": self.description,
"tag": self.tag,
"avatar_url": self.avatar_url,
"tz": self.ui_tz
}

View File

@ -1,73 +0,0 @@
import humanize
import re
import random
import string
from datetime import datetime, timezone, timedelta
from typing import List, Tuple, Union
from urllib.parse import urlparse
from pluralkit import db
from pluralkit.errors import InvalidAvatarURLError
def display_relative(time: Union[datetime, timedelta]) -> str:
if isinstance(time, datetime):
time = datetime.utcnow() - time
return humanize.naturaldelta(time)
async def get_fronter_ids(conn, system_id) -> (List[int], datetime):
switches = await db.front_history(conn, system_id=system_id, count=1)
if not switches:
return [], None
if not switches[0]["members"]:
return [], switches[0]["timestamp"]
return switches[0]["members"], switches[0]["timestamp"]
async def get_fronters(conn, system_id) -> (List["Member"], datetime):
member_ids, timestamp = await get_fronter_ids(conn, system_id)
# Collect in dict and then look up as list, to preserve return order
members = {member.id: member for member in await db.get_members(conn, member_ids)}
return [members[member_id] for member_id in member_ids], timestamp
async def get_front_history(conn, system_id, count) -> List[Tuple[datetime, List["pluMember"]]]:
# Get history from DB
switches = await db.front_history(conn, system_id=system_id, count=count)
if not switches:
return []
# Get all unique IDs referenced
all_member_ids = {id for switch in switches for id in switch["members"]}
# And look them up in the database into a dict
all_members = {member.id: member for member in await db.get_members(conn, list(all_member_ids))}
# Collect in array and return
out = []
for switch in switches:
timestamp = switch["timestamp"]
members = [all_members[id] for id in switch["members"]]
out.append((timestamp, members))
return out
def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5))
def contains_custom_emoji(value):
return bool(re.search("<a?:\w+:\d+>", value))
def validate_avatar_url_or_raise(url):
u = urlparse(url)
if not (u.scheme in ["http", "https"] and u.netloc and u.path):
raise InvalidAvatarURLError()
# TODO: check file type and size of image

View File

@ -1,10 +0,0 @@
aiodns
aiohttp==3.3.0
asyncpg
dateparser
https://github.com/Rapptz/discord.py/archive/aceec2009a7c819d2236884fa9ccc5ce58a92bea.zip#egg=discord.py
humanize
uvloop; sys.platform != 'win32' and sys.platform != 'cygwin' and sys.platform != 'cli'
ciso8601
pytz
timezonefinder