Compare commits

...

8 Commits

Author SHA1 Message Date
patrick daa57cae92 Merge pull request 'patrick/sqlite' (#11) from patrick/sqlite into master
ci/woodpecker/push/woodpecker Pipeline was successful Details
Reviewed-on: #11
2023-08-08 20:07:40 -07:00
Patrick Cleavelin 132b7b99cc fully use database remove awful json, also properly compile sqlite in CI
ci/woodpecker/push/woodpecker Pipeline was successful Details
ci/woodpecker/tag/woodpecker Pipeline was successful Details
ci/woodpecker/pr/woodpecker Pipeline was successful Details
2023-08-08 19:55:46 -07:00
Patrick Cleavelin 2e1d41b2cd switch to db for auth 2023-08-06 19:36:08 -05:00
Patrick Cleavelin 52d7cc7ded support adding and removing intros
ci/woodpecker/push/woodpecker Pipeline failed Details
2023-08-06 17:12:26 -05:00
Patrick Cleavelin 9f426407a9 make db call simpler 2023-08-06 16:28:44 -05:00
Patrick Cleavelin 5da57545e2 don't refresh on adding intro, swap widget 2023-08-05 15:21:15 -05:00
Patrick Cleavelin ff7e608f9a replace intro selector with database info (can't select yet) 2023-08-05 12:47:55 -05:00
Patrick Cleavelin 969a97cab7 get started with an actual database
ci/woodpecker/push/woodpecker Pipeline failed Details
2023-08-04 15:16:32 -05:00
9 changed files with 958 additions and 882 deletions

95
Cargo.lock generated
View File

@ -62,6 +62,17 @@ dependencies = [
"subtle",
]
[[package]]
name = "ahash"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "1.0.2"
@ -71,6 +82,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "android-tzdata"
version = "0.1.1"
@ -313,6 +330,7 @@ dependencies = [
"iana-time-zone",
"js-sys",
"num-traits 0.2.16",
"serde",
"time 0.1.45",
"wasm-bindgen",
"winapi",
@ -475,6 +493,12 @@ version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "encoding_rs"
version = "0.8.32"
@ -514,6 +538,18 @@ dependencies = [
"libc",
]
[[package]]
name = "fallible-iterator"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fastrand"
version = "2.0.0"
@ -744,6 +780,19 @@ name = "hashbrown"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
dependencies = [
"ahash",
"allocator-api2",
]
[[package]]
name = "hashlink"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "312f66718a2d7789ffef4f4b7b213138ed9f1eb3aa1d0d82fc99f88fb3ffd26f"
dependencies = [
"hashbrown 0.14.0",
]
[[package]]
name = "headers"
@ -925,6 +974,24 @@ version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6"
[[package]]
name = "iter_tools"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531cafdc99b3b3252bb32f5620e61d56b19415efc19900b12d1b2e7483854897"
dependencies = [
"itertools",
]
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.9"
@ -952,6 +1019,17 @@ version = "0.2.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "libsqlite3-sys"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.3"
@ -1020,7 +1098,9 @@ dependencies = [
"chrono",
"dotenv",
"futures",
"iter_tools",
"reqwest",
"rusqlite",
"serde",
"serde_json",
"serenity",
@ -1516,6 +1596,21 @@ dependencies = [
"winapi",
]
[[package]]
name = "rusqlite"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
dependencies = [
"bitflags 2.3.3",
"chrono",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",
"libsqlite3-sys",
"smallvec",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"

View File

@ -9,9 +9,10 @@ edition = "2021"
async-trait = "0.1.72"
axum = { version = "0.6.9", features = ["headers", "multipart"] }
axum-extra = { version = "0.7.5", features = ["cookie-private", "cookie"] }
chrono = "0.4.23"
chrono = { version = "0.4.23", features = ["serde"] }
dotenv = "0.15.0"
futures = "0.3.26"
iter_tools = "0.1.4"
reqwest = "0.11.14"
serde = "1.0.152"
serde_json = "1.0.93"
@ -30,3 +31,9 @@ features = ["client", "gateway", "rustls_backend", "model", "cache", "voice"]
[dependencies.songbird]
version = "0.3.2"
features = [ "builtin-queue", "yt-dlp" ]
[target.'cfg(unix)'.dependencies]
rusqlite = { version = "0.29.0", features = ["chrono"] }
[target.'cfg(windows)'.dependencies]
rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] }

View File

@ -8,7 +8,7 @@
outputs = { self, nixpkgs, rust-overlay, flake-utils, ... }:
flake-utils.lib.eachDefaultSystem (system:
let
tag = "v0.1.5_2-alpha";
tag = "v0.2.0-alpha";
overlays = [ (import rust-overlay) ];
pkgs = import nixpkgs {
inherit system overlays;
@ -35,6 +35,7 @@
pkg-config
gcc
openssl
sqlite
pkg-config
python3
ffmpeg
@ -50,7 +51,7 @@
name = "memejoin-rs";
src = self;
buildInputs = [ openssl.dev ];
nativeBuildInputs = [ local-rust pkg-config openssl openssl.dev cmake gcc libopus ];
nativeBuildInputs = [ local-rust pkg-config openssl openssl.dev cmake gcc libopus sqlite ];
cargoLock = {
lockFile = ./Cargo.lock;
@ -62,7 +63,7 @@
name = "memejoin-rs";
copyToRoot = buildEnv {
name = "image-root";
paths = [ default cacert openssl openssl.dev ffmpeg libopus youtube-dl yt-dlp ];
paths = [ default cacert openssl openssl.dev ffmpeg libopus youtube-dl yt-dlp sqlite ];
};
runAsRoot = ''
#!${runtimeShell}

423
src/db/mod.rs Normal file
View File

@ -0,0 +1,423 @@
use std::path::Path;
use chrono::NaiveDateTime;
use rusqlite::{Connection, OptionalExtension, Result};
use serde::{Deserialize, Serialize};
use tracing::warn;
use crate::auth;
pub struct Database {
conn: Connection,
}
impl Database {
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
Ok(Self {
conn: Connection::open(path)?,
})
}
pub(crate) fn get_guilds(&self) -> Result<Vec<Guild>> {
let mut query = self.conn.prepare(
"
SELECT
id, name, sound_delay
FROM Guild
",
)?;
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
let guilds = query
.query_map([], |row| {
Ok(Guild {
id: row.get(0)?,
name: row.get(1)?,
sound_delay: row.get(2)?,
})
})?
.into_iter()
.collect::<Result<Vec<Guild>>>();
guilds
}
pub(crate) fn get_user_from_api_key(&self, api_key: &str) -> Result<User> {
self.conn.query_row(
"
SELECT
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
FROM User
WHERE api_key = ?1
",
[api_key],
|row| {
Ok(User {
name: row.get(0)?,
api_key: row.get(1)?,
api_key_expires_at: row.get(2)?,
discord_token: row.get(3)?,
discord_token_expires_at: row.get(4)?,
})
},
)
}
pub(crate) fn get_user(&self, username: &str) -> Result<Option<User>> {
self.conn
.query_row(
"
SELECT
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
FROM User
WHERE name = ?1
",
[username],
|row| {
Ok(User {
name: row.get(0)?,
api_key: row.get(1)?,
api_key_expires_at: row.get(2)?,
discord_token: row.get(3)?,
discord_token_expires_at: row.get(4)?,
})
},
)
.optional()
}
pub fn get_user_guilds(&self, username: &str) -> Result<Vec<Guild>> {
let mut query = self.conn.prepare(
"
SELECT
id, name, sound_delay
FROM Guild
LEFT JOIN UserGuild ON UserGuild.guild_id = Guild.id
WHERE UserGuild.username = :username
",
)?;
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
let guilds = query
.query_map(&[(":username", username)], |row| {
Ok(Guild {
id: row.get(0)?,
name: row.get(1)?,
sound_delay: row.get(2)?,
})
})?
.into_iter()
.collect::<Result<Vec<Guild>>>();
guilds
}
pub fn get_guild_intros(&self, guild_id: u64) -> Result<Vec<Intro>> {
let mut query = self.conn.prepare(
"
SELECT
Intro.id,
Intro.name,
Intro.filename
FROM Intro
WHERE
Intro.guild_id = :guild_id
",
)?;
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
let intros = query
.query_map(
&[
// :vomit:
(":guild_id", &guild_id.to_string()),
],
|row| {
Ok(Intro {
id: row.get(0)?,
name: row.get(1)?,
filename: row.get(2)?,
})
},
)?
.into_iter()
.collect::<Result<Vec<Intro>>>();
intros
}
pub fn get_all_user_intros(&self, guild_id: u64) -> Result<Vec<UserIntro>> {
let mut query = self.conn.prepare(
"
SELECT
Intro.id,
Intro.name,
Intro.filename,
UI.channel_name,
UI.username
FROM Intro
LEFT JOIN UserIntro UI ON UI.intro_id = Intro.id
WHERE
UI.guild_id = :guild_id
ORDER BY UI.username DESC, UI.channel_name DESC, UI.intro_id;
",
)?;
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
let intros = query
.query_map(
&[
// :vomit:
(":guild_id", &guild_id.to_string()),
],
|row| {
Ok(UserIntro {
intro: Intro {
id: row.get(0)?,
name: row.get(1)?,
filename: row.get(2)?,
},
channel_name: row.get(3)?,
username: row.get(4)?,
})
},
)?
.into_iter()
.collect::<Result<Vec<UserIntro>>>();
intros
}
pub(crate) fn get_user_permissions(
&self,
username: &str,
guild_id: u64,
) -> Result<auth::Permissions> {
self.conn.query_row(
"
SELECT
permissions
FROM UserPermission
WHERE
username = ?1
AND guild_id = ?2
",
[username, &guild_id.to_string()],
|row| Ok(auth::Permissions(row.get(0)?)),
)
}
pub(crate) fn get_guild_channels(&self, guild_id: u64) -> Result<Vec<String>> {
let mut query = self.conn.prepare(
"
SELECT
Channel.name
FROM Channel
WHERE
Channel.guild_id = :guild_id
ORDER BY Channel.name DESC
",
)?;
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
let intros = query
.query_map(
&[
// :vomit:
(":guild_id", &guild_id.to_string()),
],
|row| Ok(row.get(0)?),
)?
.into_iter()
.collect::<Result<Vec<String>>>();
intros
}
pub(crate) fn get_user_channel_intros(
&self,
username: &str,
guild_id: u64,
channel_name: &str,
) -> Result<Vec<Intro>> {
let all_user_intros = self.get_all_user_intros(guild_id)?.into_iter();
let intros = all_user_intros
.filter(|intro| &intro.username == &username && &intro.channel_name == channel_name)
.map(|intro| intro.intro)
.collect();
Ok(intros)
}
pub fn insert_user(
&self,
username: &str,
api_key: &str,
api_key_expires_at: NaiveDateTime,
discord_token: &str,
discord_token_expires_at: NaiveDateTime,
) -> Result<()> {
let affected = self.conn.execute(
"INSERT INTO
User (username, api_key, api_key_expires_at, discord_token, discord_token_expires_at)
VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT(username) DO UPDATE SET api_key = ?2, api_key_expires_at = ?3, discord_token = ?4, discord_token_expires_at = ?5",
&[
username,
api_key,
&api_key_expires_at.to_string(),
discord_token,
&discord_token_expires_at.to_string(),
],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert new user");
}
Ok(())
}
pub fn insert_intro(
&self,
name: &str,
volume: i32,
guild_id: u64,
filename: &str,
) -> Result<()> {
let affected = self.conn.execute(
"INSERT INTO
Intro (name, volume, guild_id, filename)
VALUES (?1, ?2, ?3, ?4)",
&[name, &volume.to_string(), &guild_id.to_string(), filename],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert intro");
}
Ok(())
}
pub fn insert_user_guild(&self, username: &str, guild_id: u64) -> Result<()> {
let affected = self.conn.execute(
"INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)",
&[username, &guild_id.to_string()],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert user guild");
}
Ok(())
}
pub fn insert_user_intro(
&self,
username: &str,
guild_id: u64,
channel_name: &str,
intro_id: i32,
) -> Result<()> {
let affected = self.conn.execute(
"INSERT INTO UserIntro (username, guild_id, channel_name, intro_id) VALUES (?1, ?2, ?3, ?4)",
&[
username,
&guild_id.to_string(),
channel_name,
&intro_id.to_string(),
],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert user intro");
}
Ok(())
}
pub(crate) fn insert_user_permission(
&self,
username: &str,
guild_id: u64,
permissions: auth::Permissions,
) -> Result<()> {
let affected = self.conn.execute(
"
INSERT INTO
UserPermission (username, guild_id, permissions)
VALUES (?1, ?2, ?3)
ON CONFLICT(username, guild_id) DO UPDATE SET permissions = ?3",
&[username, &guild_id.to_string(), &permissions.0.to_string()],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert user permissions");
}
Ok(())
}
pub fn delete_user_intro(
&self,
username: &str,
guild_id: u64,
channel_name: &str,
intro_id: i32,
) -> Result<()> {
let affected = self.conn.execute(
"DELETE FROM
UserIntro
WHERE
username = ?1
AND guild_id = ?2
AND channel_name = ?3
AND intro_id = ?4",
&[
username,
&guild_id.to_string(),
channel_name,
&intro_id.to_string(),
],
)?;
if affected < 1 {
warn!("no rows affected when attempting to delete user intro");
}
Ok(())
}
}
pub struct Guild {
pub id: u64,
pub name: String,
pub sound_delay: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub name: String,
pub api_key: String,
pub api_key_expires_at: NaiveDateTime,
pub discord_token: String,
pub discord_token_expires_at: NaiveDateTime,
}
pub struct Intro {
pub id: i32,
pub name: String,
pub filename: String,
}
pub struct UserIntro {
pub intro: Intro,
pub channel_name: String,
pub username: String,
}

84
src/db/schema.sql Normal file
View File

@ -0,0 +1,84 @@
BEGIN;
create table User
(
username TEXT not null
constraint User_pk
primary key,
api_key TEXT not null,
api_key_expires_at DATETIME not null,
discord_token TEXT not null,
discord_token_expires_at DATETIME not null
);
create table Intro
(
id integer not null
constraint Intro_pk
primary key autoincrement,
name TEXT not null,
volume integer not null,
guild_id integer not null
constraint Intro_Guild_guild_id_fk
references Guild ("id"),
filename TEXT not null
);
create table Guild
(
id integer not null
primary key,
name TEXT not null,
sound_delay integer not null
);
create table Channel
(
name TEXT
primary key,
guild_id integer
constraint Channel_Guild_id_fk
references Guild (id)
);
create table UserGuild
(
username TEXT not null
constraint UserGuild_User_username_fk
references User,
guild_id integer not null
constraint UserGuild_Guild_id_fk
references Guild (id),
primary key ("username", "guild_id")
);
create table UserIntro
(
username text not null
constraint UserIntro_User_username_fk
references User,
intro_id integer not null
constraint UserIntro_Intro_id_fk
references Intro,
guild_id integer not null
constraint UserIntro_Guild_guild_id_fk
references Guild ("id"),
channel_name text not null
constraint UserIntro_Channel_channel_name_fk
references Channel ("name"),
primary key ("username", "intro_id", "guild_id", "channel_name")
);
create table UserPermission
(
username TEXT not null
constraint UserPermission_User_username_fk
references User,
guild_id integer not null
constraint User_Guild_guild_id_fk
references Guild ("id"),
permissions integer not null,
primary key ("username", "guild_id")
);
COMMIT;

View File

@ -3,26 +3,23 @@
#![feature(async_closure)]
mod auth;
mod db;
mod htmx;
mod media;
mod page;
mod routes;
pub mod settings;
use axum::http::{HeaderValue, Method};
use axum::routing::{delete, get, post};
use axum::http::Method;
use axum::routing::{get, post};
use axum::Router;
use futures::StreamExt;
use settings::ApiState;
use songbird::tracks::TrackQueue;
use std::collections::HashMap;
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use tower_http::cors::{Any, CorsLayer};
use serde::Deserialize;
use serenity::async_trait;
use serenity::model::prelude::{Channel, ChannelId, GuildId, Member, Ready};
use serenity::model::voice::VoiceState;
@ -31,7 +28,7 @@ use serenity::prelude::*;
use songbird::SerenityInit;
use tracing::*;
use crate::settings::{Intro, Settings};
use crate::settings::Settings;
enum HandlerMessage {
Ready(Context),
@ -119,7 +116,7 @@ impl EventHandler for Handler {
}
}
fn spawn_api(settings: Arc<Mutex<Settings>>) {
fn spawn_api(db: Arc<tokio::sync::Mutex<db::Database>>) {
let secrets = auth::DiscordSecret {
client_id: env::var("DISCORD_CLIENT_ID").expect("expected DISCORD_CLIENT_ID env var"),
client_secret: env::var("DISCORD_CLIENT_SECRET")
@ -128,7 +125,7 @@ fn spawn_api(settings: Arc<Mutex<Settings>>) {
let origin = env::var("APP_ORIGIN").expect("expected APP_ORIGIN");
let state = ApiState {
settings,
db,
secrets,
origin: origin.clone(),
};
@ -154,23 +151,8 @@ fn spawn_api(settings: Arc<Mutex<Settings>>) {
post(routes::v2_upload_guild_intro),
)
.route("/health", get(routes::health))
.route("/me", get(routes::me))
.route("/intros/:guild", get(routes::intros))
.route("/intros/:guild/add", get(routes::add_guild_intro))
.route("/intros/:guild/upload", post(routes::upload_guild_intro))
.route("/intros/:guild/delete", delete(routes::delete_guild_intro))
.route(
"/intros/:guild/:channel/:intro",
post(routes::add_intro_to_user),
)
.route(
"/intros/:guild/:channel/:intro/remove",
post(routes::remove_intro_to_user),
)
.route("/auth", get(routes::auth))
.layer(
CorsLayer::new()
// TODO: move this to env variable
.allow_origin([origin.parse().unwrap()])
.allow_headers(Any)
.allow_methods([Method::GET, Method::POST, Method::DELETE]),
@ -185,7 +167,7 @@ fn spawn_api(settings: Arc<Mutex<Settings>>) {
});
}
async fn spawn_bot(settings: Arc<Mutex<Settings>>) {
async fn spawn_bot(db: Arc<tokio::sync::Mutex<db::Database>>) {
let token = env::var("DISCORD_TOKEN").expect("expected DISCORD_TOKEN env var");
let songbird = songbird::Songbird::serenity();
@ -215,12 +197,19 @@ async fn spawn_bot(settings: Arc<Mutex<Settings>>) {
match msg {
HandlerMessage::Ready(ctx) => {
info!("Got Ready message");
let settings = settings.lock().await;
let songbird = songbird::get(&ctx).await.expect("no songbird instance");
for guild_id in settings.guilds.keys() {
let handler_lock = songbird.get_or_insert(GuildId(*guild_id));
let guilds = match db.lock().await.get_guilds() {
Ok(guilds) => guilds,
Err(err) => {
error!(?err, "failed to get guild on bot ready");
continue;
}
};
for guild in guilds {
let handler_lock = songbird.get_or_insert(GuildId(guild.id));
let mut handler = handler_lock.lock().await;
@ -228,7 +217,7 @@ async fn spawn_bot(settings: Arc<Mutex<Settings>>) {
songbird::Event::Track(songbird::TrackEvent::End),
TrackEventHandler {
tx: tx.clone(),
guild_id: GuildId(*guild_id),
guild_id: GuildId(guild.id),
},
);
}
@ -251,7 +240,6 @@ async fn spawn_bot(settings: Arc<Mutex<Settings>>) {
HandlerMessage::PlaySound(ctx, member, channel_id) => {
info!("Got PlaySound message");
let settings = settings.lock().await;
let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx.cache)
else {
@ -259,60 +247,35 @@ async fn spawn_bot(settings: Arc<Mutex<Settings>>) {
continue;
};
let Some(guild_settings) = settings.guilds.get(channel.guild_id.as_u64())
else {
error!("couldn't get guild from id: {}", channel.guild_id.as_u64());
continue;
};
let Some(channel_settings) = guild_settings.channels.get(channel.name()) else {
error!(
"couldn't get channel_settings from name: {}",
channel.name()
);
continue;
};
let Some(user) = channel_settings.users.get(&member.user.name) else {
error!(
"couldn't get user settings from name: {}",
&member.user.name
);
continue;
let intros = match db.lock().await.get_user_channel_intros(
&member.user.name,
channel.guild_id.0,
channel.name(),
) {
Ok(intros) => intros,
Err(err) => {
error!(
?err,
"failed to get user channel intros when playing sound through bot"
);
continue;
}
};
// TODO: randomly choose a intro to play
let Some(intro) = user.intros.first() else {
let Some(intro) = intros.first() else {
error!("couldn't get user intro, none exist");
continue;
};
let source = match guild_settings.intros.get(&intro.index) {
Some(Intro::Online(intro)) => match songbird::ytdl(&intro.url).await {
Ok(source) => source,
Err(err) => {
error!("Error starting youtube source from {}: {err:?}", intro.url);
continue;
}
},
Some(Intro::File(intro)) => {
match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await {
Ok(source) => source,
Err(err) => {
error!(
"Error starting file source from {}: {err:?}",
intro.filename
);
continue;
}
}
}
None => {
let source = match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await
{
Ok(source) => source,
Err(err) => {
error!(
"Failed to find intro for user {} on guild {} in channel {}, IntroIndex: {}",
member.user.name,
channel.guild_id.as_u64(),
channel.name(),
intro.index
);
"Error starting file source from {}: {err:?}",
intro.filename
);
continue;
}
};
@ -346,23 +309,23 @@ async fn main() -> std::io::Result<()> {
&std::fs::read_to_string("config/settings.json").expect("no config/settings.json"),
)
.expect("error parsing settings file");
let (run_api, run_bot) = (settings.run_api, settings.run_bot);
info!("{settings:?}");
let settings = Arc::new(Mutex::new(settings));
let (run_api, run_bot) = (settings.run_api, settings.run_bot);
let db = Arc::new(tokio::sync::Mutex::new(
db::Database::new("./config/db.sqlite").expect("couldn't open sqlite db"),
));
if run_api {
spawn_api(settings.clone());
spawn_api(db.clone());
}
if run_bot {
spawn_bot(settings.clone()).await;
spawn_bot(db).await;
}
info!("spawned background tasks");
let _ = tokio::signal::ctrl_c().await;
settings.lock().await.save()?;
info!("Received Ctrl-C, shuttdown down.");
Ok(())

View File

@ -1,12 +1,14 @@
use crate::{
auth::{self, User},
auth::{self},
db::{self, User},
htmx::{Build, HtmxBuilder, Tag},
settings::{ApiState, GuildSettings, Intro, IntroFriendlyName},
settings::ApiState,
};
use axum::{
extract::{Path, State},
response::{Html, Redirect},
};
use iter_tools::Itertools;
use tracing::error;
fn page_header(title: &str) -> HtmxBuilder {
@ -27,19 +29,20 @@ pub(crate) async fn home(
user: Option<User>,
) -> Result<Html<String>, Redirect> {
if let Some(user) = user {
let settings = state.settings.lock().await;
let db = state.db.lock().await;
let guild = settings
.guilds
.iter()
.filter(|(_, guild_settings)| guild_settings.users.contains_key(&user.name));
let user_guilds = db.get_user_guilds(&user.name).map_err(|err| {
error!(?err, "failed to get user guilds");
// TODO: change this to returning a error to the client
Redirect::to(&format!("{}/login", state.origin))
})?;
Ok(Html(
page_header("MemeJoin - Home")
.builder(Tag::Div, |b| {
b.attribute("class", "container")
.builder_text(Tag::Header2, "Choose a Guild")
.push_builder(guild_list(&state.origin, guild))
.push_builder(guild_list(&state.origin, user_guilds.iter()))
})
.build(),
))
@ -48,22 +51,14 @@ pub(crate) async fn home(
}
}
fn guild_list<'a>(
origin: &str,
guilds: impl Iterator<Item = (&'a u64, &'a GuildSettings)>,
) -> HtmxBuilder {
fn guild_list<'a>(origin: &str, guilds: impl Iterator<Item = &'a db::Guild>) -> HtmxBuilder {
HtmxBuilder::new(Tag::Empty).ul(|b| {
let mut b = b;
let mut in_any_guilds = false;
for (guild_id, guild_settings) in guilds {
for guild in guilds {
in_any_guilds = true;
b = b.li(|b| {
b.link(
&guild_settings.name,
&format!("{}/guild/{}", origin, guild_id),
)
});
b = b.li(|b| b.link(&guild.name, &format!("{}/guild/{}", origin, guild.id)));
}
if !in_any_guilds {
@ -75,13 +70,14 @@ fn guild_list<'a>(
}
fn intro_list<'a>(
intros: impl Iterator<Item = (&'a String, &'a Intro)>,
intros: impl Iterator<Item = &'a db::Intro>,
label: &str,
post: &str,
) -> HtmxBuilder {
HtmxBuilder::new(Tag::Empty).form(|b| {
b.attribute("class", "container")
.hx_post(post)
.hx_target("closest #channel-intro-selector")
.attribute("hx-encoding", "multipart/form-data")
.builder(Tag::FieldSet, |b| {
let mut b = b
@ -90,9 +86,10 @@ fn intro_list<'a>(
for intro in intros {
b = b.builder(Tag::Label, |b| {
b.builder(Tag::Input, |b| {
b.attribute("type", "checkbox").attribute("name", &intro.0)
b.attribute("type", "checkbox")
.attribute("name", &intro.id.to_string())
})
.builder_text(Tag::Paragraph, intro.1.friendly_name())
.builder_text(Tag::Paragraph, &intro.name)
});
}
@ -107,19 +104,34 @@ pub(crate) async fn guild_dashboard(
user: User,
Path(guild_id): Path<u64>,
) -> Result<Html<String>, Redirect> {
let settings = state.settings.lock().await;
let db = state.db.lock().await;
let Some(guild) = settings.guilds.get(&guild_id) else {
error!(%guild_id, "no such guild");
return Err(Redirect::to(&format!("{}/", state.origin)));
};
let Some(guild_user) = guild.users.get(&user.name) else {
error!(%guild_id, %user.name, "no user in guild");
return Err(Redirect::to(&format!("{}/", state.origin)));
};
let guild_intros = db.get_guild_intros(guild_id).map_err(|err| {
error!(?err, %guild_id, "couldn't get guild intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
let guild_channels = db.get_guild_channels(guild_id).map_err(|err| {
error!(?err, %guild_id, "couldn't get guild channels");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
let all_user_intros = db.get_all_user_intros(guild_id).map_err(|err| {
error!(?err, %guild_id, "couldn't get user intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
let user_permissions = db
.get_user_permissions(&user.name, guild_id)
.unwrap_or_default();
let can_upload = guild_user.permissions.can(auth::Permission::UploadSounds);
let is_moderator = guild_user.permissions.can(auth::Permission::DeleteSounds);
let user_intros = all_user_intros
.iter()
.filter(|intro| &intro.username == &user.name)
.group_by(|intro| &intro.channel_name);
let can_upload = user_permissions.can(auth::Permission::UploadSounds);
let is_moderator = user_permissions.can(auth::Permission::DeleteSounds);
Ok(Html(
HtmxBuilder::new(Tag::Html)
@ -168,53 +180,32 @@ pub(crate) async fn guild_dashboard(
.builder(Tag::Article, |b| {
let mut b = b.builder_text(Tag::Header, "Guild Intros");
for (channel_name, channel_settings) in &guild.channels {
if let Some(channel_user) = channel_settings.users.get(&user.name) {
let current_intros =
channel_user.intros.iter().filter_map(|intro_index| {
Some((
&intro_index.index,
guild.intros.get(&intro_index.index)?,
))
});
let available_intros =
guild.intros.iter().filter_map(|intro| {
if !channel_user
.intros
.iter()
.any(|intro_index| intro.0 == &intro_index.index)
{
Some((intro.0, intro.1))
} else {
None
}
});
b = b.builder(Tag::Article, |b| {
b.builder_text(Tag::Header, channel_name).builder(
Tag::Div,
|b| {
b.builder_text(Tag::Strong, "Your Current Intros")
.push_builder(intro_list(
current_intros,
"Remove Intro",
&format!(
"{}/v2/intros/remove/{}/{}",
state.origin, guild_id, channel_name
),
))
.builder_text(Tag::Strong, "Select Intros")
.push_builder(intro_list(
available_intros,
"Add Intro",
&format!(
"{}/v2/intros/add/{}/{}",
state.origin, guild_id, channel_name
),
))
},
)
});
}
let mut user_intros = user_intros.into_iter().peekable();
for guild_channel_name in guild_channels {
// Get user intros for this channel
let intros = user_intros
.peeking_take_while(|(channel_name, _)| {
channel_name == &&guild_channel_name
})
.map(|(_, intros)| intros.map(|intro| &intro.intro))
.flatten();
b = b.builder(Tag::Article, |b| {
b.builder_text(Tag::Header, &guild_channel_name).builder(
Tag::Div,
|b| {
b.attribute("id", "channel-intro-selector")
.push_builder(channel_intro_selector(
&state.origin,
guild_id,
&guild_channel_name,
intros,
guild_intros.iter(),
))
},
)
});
}
b
@ -225,6 +216,28 @@ pub(crate) async fn guild_dashboard(
))
}
pub fn channel_intro_selector<'a>(
origin: &str,
guild_id: u64,
channel_name: &String,
intros: impl Iterator<Item = &'a db::Intro>,
guild_intros: impl Iterator<Item = &'a db::Intro>,
) -> HtmxBuilder {
HtmxBuilder::new(Tag::Empty)
.builder_text(Tag::Strong, "Your Current Intros")
.push_builder(intro_list(
intros,
"Remove Intro",
&format!("{}/v2/intros/remove/{}/{}", origin, guild_id, &channel_name),
))
.builder_text(Tag::Strong, "Select Intros")
.push_builder(intro_list(
guild_intros,
"Add Intro",
&format!("{}/v2/intros/add/{}/{}", origin, guild_id, channel_name),
))
}
fn upload_form(origin: &str, guild_id: u64) -> HtmxBuilder {
HtmxBuilder::new(Tag::Empty).form(|b| {
b.attribute("class", "container")

View File

@ -1,64 +1,25 @@
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use axum::{
body::Bytes,
extract::{Multipart, Path, Query, State},
http::{HeaderMap, HeaderValue},
response::{IntoResponse, Redirect},
Form, Json,
response::{Html, IntoResponse, Redirect},
};
use axum_extra::extract::{cookie::Cookie, CookieJar};
use reqwest::{Proxy, StatusCode, Url};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use chrono::{Duration, Utc};
use reqwest::{StatusCode, Url};
use serde::{Deserialize, Deserializer};
use tracing::{error, info};
use uuid::Uuid;
use crate::{
auth::{self, User},
settings::FileIntro,
auth::{self},
db,
htmx::Build,
page,
};
use crate::{
media,
settings::{ApiState, GuildUser, Intro, IntroIndex, UserSettings},
};
#[derive(Serialize)]
pub(crate) enum IntroResponse<'a> {
Intros(&'a HashMap<String, Intro>),
NoGuildFound,
}
#[derive(Serialize)]
pub(crate) enum MeResponse<'a> {
Me(Me<'a>),
NoUserFound,
}
#[derive(Serialize)]
pub(crate) struct Me<'a> {
pub(crate) username: String,
pub(crate) guilds: Vec<MeGuild<'a>>,
}
#[derive(Serialize)]
pub(crate) struct MeGuild<'a> {
// NOTE(pcleavelin): for some reason this doesn't serialize properly if a u64
pub(crate) id: String,
pub(crate) name: String,
pub(crate) channels: Vec<MeChannel<'a>>,
pub(crate) permissions: auth::Permissions,
}
#[derive(Serialize)]
pub(crate) struct MeChannel<'a> {
pub(crate) name: String,
pub(crate) intros: &'a Vec<IntroIndex>,
}
#[derive(Deserialize)]
pub(crate) struct DeleteIntroRequest(Vec<String>);
use crate::{media, settings::ApiState};
pub(crate) async fn health() -> &'static str {
"Hello!"
@ -71,8 +32,6 @@ pub(crate) enum Error {
#[error("{0}")]
GetUser(#[from] reqwest::Error),
#[error("User doesn't exist")]
NoUserFound,
#[error("Guild doesn't exist")]
NoGuildFound,
#[error("invalid request")]
@ -89,6 +48,9 @@ pub(crate) enum Error {
YtdlTerminated,
#[error("ffmpeg terminated unsuccessfully")]
FfmpegTerminated,
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
}
impl IntoResponse for Error {
@ -100,7 +62,6 @@ impl IntoResponse for Error {
Self::GetUser(error) => (StatusCode::UNAUTHORIZED, error.to_string()).into_response(),
Self::NoGuildFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(),
Self::NoUserFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(),
Self::InvalidRequest => (StatusCode::BAD_REQUEST, self.to_string()).into_response(),
Self::InvalidPermission => (StatusCode::UNAUTHORIZED, self.to_string()).into_response(),
@ -111,6 +72,10 @@ impl IntoResponse for Error {
Self::YtdlTerminated | Self::FfmpegTerminated => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
Self::Database(error) => {
(StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response()
}
}
}
}
@ -122,11 +87,22 @@ struct DiscordUser {
#[derive(Deserialize)]
struct DiscordUserGuild {
pub id: String,
pub name: String,
#[serde(deserialize_with = "serde_string_as_u64")]
pub id: u64,
pub owner: bool,
}
fn serde_string_as_u64<'de, D>(deserializer: D) -> Result<u64, D::Error>
where
D: Deserializer<'de>,
{
let value = <&str as Deserialize>::deserialize(deserializer)?;
value
.parse::<u64>()
.map_err(|_| serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &"u64"))
}
pub(crate) async fn v2_auth(
State(state): State<ApiState>,
Query(params): Query<HashMap<String, String>>,
@ -182,29 +158,44 @@ pub(crate) async fn v2_auth(
.await
.map_err(|err| Error::Auth(err.to_string()))?;
let mut settings = state.settings.lock().await;
let db = state.db.lock().await;
let guilds = db.get_guilds().map_err(Error::Database)?;
let mut in_a_guild = false;
for g in settings.guilds.iter_mut() {
for guild in guilds {
let Some(discord_guild) = discord_guilds
.iter()
.find(|discord_guild| discord_guild.id == g.0.to_string())
.find(|discord_guild| discord_guild.id == guild.id)
else {
continue;
};
in_a_guild = true;
if !g.1.users.contains_key(&user.username) {
g.1.users.insert(
user.username.clone(),
GuildUser {
permissions: if discord_guild.owner {
auth::Permissions(auth::Permission::all())
} else {
Default::default()
},
let now = Utc::now().naive_utc();
db.insert_user(
&user.username,
&token,
now + Duration::weeks(4),
&auth.access_token,
now + Duration::seconds(auth.expires_in as i64),
)
.map_err(Error::Database)?;
db.insert_user_guild(&user.username, guild.id)
.map_err(Error::Database)?;
if db.get_user_permissions(&user.username, guild.id).is_err() {
db.insert_user_permission(
&user.username,
guild.id,
if discord_guild.owner {
auth::Permissions(auth::Permission::all())
} else {
Default::default()
},
);
)
.map_err(Error::Database)?;
}
}
@ -212,13 +203,6 @@ pub(crate) async fn v2_auth(
return Err(Error::NoGuildFound);
}
settings.auth_users.insert(
token.clone(),
auth::User {
auth,
name: user.username.clone(),
},
);
// TODO: add permissions based on roles
let uri = Url::parse(&state.origin).expect("should be a valid url");
@ -230,404 +214,146 @@ pub(crate) async fn v2_auth(
Ok((jar.add(cookie), Redirect::to(&format!("{}/", state.origin))))
}
pub(crate) async fn auth(
State(state): State<ApiState>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<Value>, Error> {
let Some(code) = params.get("code") else {
return Err(Error::Auth("no code".to_string()));
};
info!("attempting to get access token with code {}", code);
let mut data = HashMap::new();
let redirect_uri = format!("{}/old/auth", state.origin);
data.insert("client_id", state.secrets.client_id.as_str());
data.insert("client_secret", state.secrets.client_secret.as_str());
data.insert("grant_type", "authorization_code");
data.insert("code", code);
data.insert("redirect_uri", &redirect_uri);
let client = reqwest::Client::new();
let auth: auth::Discord = client
.post("https://discord.com/api/oauth2/token")
.form(&data)
.send()
.await
.map_err(|err| Error::Auth(err.to_string()))?
.json()
.await
.map_err(|err| Error::Auth(err.to_string()))?;
let token = Uuid::new_v4().to_string();
// Get authorized username
let user: DiscordUser = client
.get("https://discord.com/api/v10/users/@me")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await?;
// TODO: get bot's guilds so we only save users who are able to use the bot
let discord_guilds: Vec<DiscordUserGuild> = client
.get("https://discord.com/api/v10/users/@me/guilds")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await
.map_err(|err| Error::Auth(err.to_string()))?;
let mut settings = state.settings.lock().await;
let mut in_a_guild = false;
for g in settings.guilds.iter_mut() {
let Some(discord_guild) = discord_guilds
.iter()
.find(|discord_guild| discord_guild.id == g.0.to_string())
else {
continue;
};
in_a_guild = true;
if !g.1.users.contains_key(&user.username) {
g.1.users.insert(
user.username.clone(),
GuildUser {
permissions: if discord_guild.owner {
auth::Permissions(auth::Permission::all())
} else {
Default::default()
},
},
);
}
}
if !in_a_guild {
return Err(Error::NoGuildFound);
}
settings.auth_users.insert(
token.clone(),
auth::User {
auth,
name: user.username.clone(),
},
);
// TODO: add permissions based on roles
Ok(Json(json!({"token": token, "username": user.username})))
}
pub(crate) async fn v2_add_intro_to_user(
State(state): State<ApiState>,
Path((guild_id, channel)): Path<(u64, String)>,
user: User,
user: db::User,
mut form_data: Multipart,
) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
let mut settings = state.settings.lock().await;
let Some(guild) = settings.guilds.get_mut(&guild_id) else {
return headers;
};
let Some(channel) = guild.channels.get_mut(&channel) else {
return headers;
};
let Some(channel_user) = channel.users.get_mut(&user.name) else {
return headers;
};
) -> Result<Html<String>, Redirect> {
let db = state.db.lock().await;
while let Ok(Some(field)) = form_data.next_field().await {
let Some(field_name) = field.name() else {
let Some(intro_id) = field.name() else {
continue;
};
if !channel_user
.intros
.iter()
.any(|intro| intro.index == field_name)
{
channel_user.intros.push(IntroIndex {
index: field_name.to_string(),
volume: 20,
});
}
let intro_id = intro_id.parse::<i32>().map_err(|err| {
error!(?err, "invalid intro id");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
db.insert_user_intro(&user.name, guild_id, &channel, intro_id)
.map_err(|err| {
error!(?err, "failed to add user intro");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
}
// TODO: don't save on every change
if let Err(err) = settings.save() {
error!("Failed to save config: {err:?}");
}
let guild_intros = db.get_guild_intros(guild_id).map_err(|err| {
error!(?err, %guild_id, "couldn't get guild intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
headers
let intros = db
.get_user_channel_intros(&user.name, guild_id, &channel)
.map_err(|err| {
error!(?err, user = %user.name, %guild_id, "couldn't get user intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
Ok(Html(
page::channel_intro_selector(
&state.origin,
guild_id,
&channel,
intros.iter(),
guild_intros.iter(),
)
.build(),
))
}
pub(crate) async fn v2_remove_intro_from_user(
State(state): State<ApiState>,
Path((guild_id, channel)): Path<(u64, String)>,
user: User,
user: db::User,
mut form_data: Multipart,
) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
let mut settings = state.settings.lock().await;
let Some(guild) = settings.guilds.get_mut(&guild_id) else {
return headers;
};
let Some(channel) = guild.channels.get_mut(&channel) else {
return headers;
};
let Some(channel_user) = channel.users.get_mut(&user.name) else {
return headers;
};
) -> Result<Html<String>, Redirect> {
let db = state.db.lock().await;
while let Ok(Some(field)) = form_data.next_field().await {
let Some(field_name) = field.name() else {
let Some(intro_id) = field.name() else {
continue;
};
if let Some(index) = channel_user
.intros
.iter()
.position(|intro| intro.index == field_name)
{
channel_user.intros.remove(index);
}
let intro_id = intro_id.parse::<i32>().map_err(|err| {
error!(?err, "invalid intro id");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
db.delete_user_intro(&user.name, guild_id, &channel, intro_id)
.map_err(|err| {
error!(?err, "failed to remove user intro");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
}
// TODO: don't save on every change
if let Err(err) = settings.save() {
error!("Failed to save config: {err:?}");
}
let guild_intros = db.get_guild_intros(guild_id).map_err(|err| {
error!(?err, %guild_id, "couldn't get guild intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
headers
}
let intros = db
.get_user_channel_intros(&user.name, guild_id, &channel)
.map_err(|err| {
error!(?err, user = %user.name, %guild_id, "couldn't get user intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
pub(crate) async fn add_intro_to_user(
State(state): State<ApiState>,
headers: HeaderMap,
Path((guild, channel, intro_index)): Path<(u64, String, String)>,
) {
let mut settings = state.settings.lock().await;
let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else {
return;
};
let user = match settings.auth_users.get(token) {
Some(user) => user.name.clone(),
None => return,
};
let Some(guild) = settings.guilds.get_mut(&guild) else {
return;
};
let Some(channel) = guild.channels.get_mut(&channel) else {
return;
};
let Some(user) = channel.users.get_mut(&user) else {
return;
};
if !user.intros.iter().any(|intro| intro.index == intro_index) {
user.intros.push(IntroIndex {
index: intro_index,
volume: 20,
});
// TODO: don't save on every change
if let Err(err) = settings.save() {
error!("Failed to save config: {err:?}");
}
}
}
pub(crate) async fn remove_intro_to_user(
State(state): State<ApiState>,
headers: HeaderMap,
Path((guild, channel, intro_index)): Path<(u64, String, String)>,
) {
let mut settings = state.settings.lock().await;
let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else {
return;
};
let user = match settings.auth_users.get(token) {
Some(user) => user.name.clone(),
None => return,
};
let Some(guild) = settings.guilds.get_mut(&guild) else {
return;
};
let Some(channel) = guild.channels.get_mut(&channel) else {
return;
};
let Some(user) = channel.users.get_mut(&user) else {
return;
};
if let Some(index) = user
.intros
.iter()
.position(|intro| intro_index == intro.index)
{
user.intros.remove(index);
}
// TODO: don't save on every change
if let Err(err) = settings.save() {
error!("Failed to save config: {err:?}");
}
}
pub(crate) async fn intros(State(state): State<ApiState>, Path(guild): Path<u64>) -> Json<Value> {
let settings = state.settings.lock().await;
let Some(guild) = settings.guilds.get(&guild) else {
return Json(json!(IntroResponse::NoGuildFound));
};
Json(json!(IntroResponse::Intros(&guild.intros)))
}
pub(crate) async fn me(
State(state): State<ApiState>,
headers: HeaderMap,
) -> Result<Json<Value>, Error> {
let mut settings = state.settings.lock().await;
let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else {
return Err(Error::NoUserFound);
};
let (username, access_token) = match settings.auth_users.get(token) {
Some(user) => (user.name.clone(), user.auth.access_token.clone()),
None => return Err(Error::NoUserFound),
};
let mut me = Me {
username: username.clone(),
guilds: Vec::new(),
};
for g in settings.guilds.iter_mut() {
// TODO: don't do this n^2 lookup
let guild_user =
g.1.users
// TODO: why must clone
.entry(username.clone())
// TODO: check if owner for permissions
.or_insert(Default::default());
let mut guild = MeGuild {
id: g.0.to_string(),
name: g.1.name.clone(),
channels: Vec::new(),
permissions: guild_user.permissions,
};
for channel in g.1.channels.iter_mut() {
let user_settings = channel
.1
.users
.entry(username.clone())
.or_insert(UserSettings { intros: Vec::new() });
guild.channels.push(MeChannel {
name: channel.0.to_owned(),
intros: &user_settings.intros,
});
}
me.guilds.push(guild);
}
if me.guilds.is_empty() {
Ok(Json(json!(MeResponse::NoUserFound)))
} else {
Ok(Json(json!(MeResponse::Me(me))))
}
}
pub(crate) async fn upload_guild_intro(
State(state): State<ApiState>,
Path(guild): Path<u64>,
Query(mut params): Query<HashMap<String, String>>,
headers: HeaderMap,
file: Bytes,
) -> Result<(), Error> {
let mut settings = state.settings.lock().await;
let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else {
return Err(Error::NoUserFound);
};
let Some(friendly_name) = params.remove("name") else {
return Err(Error::InvalidRequest);
};
{
let Some(guild) = settings.guilds.get(&guild) else {
return Err(Error::NoGuildFound);
};
let auth_user = match settings.auth_users.get(token) {
Some(user) => user,
None => return Err(Error::NoUserFound),
};
let Some(guild_user) = guild.users.get(&auth_user.name) else {
return Err(Error::NoUserFound);
};
if !guild_user.permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
}
let Some(guild) = settings.guilds.get_mut(&guild) else {
return Err(Error::NoGuildFound);
};
let uuid = Uuid::new_v4().to_string();
let temp_path = format!("./sounds/temp/{uuid}");
let dest_path = format!("./sounds/{uuid}.mp3");
// Write original file so its ready for codec conversion
std::fs::write(&temp_path, file)?;
media::normalize(&temp_path, &dest_path).await?;
std::fs::remove_file(&temp_path)?;
guild.intros.insert(
uuid.clone(),
Intro::File(FileIntro {
filename: format!("{uuid}.mp3"),
friendly_name,
}),
);
Ok(())
Ok(Html(
page::channel_intro_selector(
&state.origin,
guild_id,
&channel,
intros.iter(),
guild_intros.iter(),
)
.build(),
))
}
pub(crate) async fn v2_upload_guild_intro(
State(state): State<ApiState>,
Path(guild): Path<u64>,
user: User,
Path(guild_id): Path<u64>,
user: db::User,
mut form_data: Multipart,
) -> Result<HeaderMap, Error> {
let mut settings = state.settings.lock().await;
let mut friendly_name = None;
let db = state.db.lock().await;
let mut name = None;
let mut file = None;
if !db
.get_guilds()
.map_err(Error::Database)?
.into_iter()
.any(|guild| guild.id == guild_id)
{
return Err(Error::NoGuildFound);
}
let user_permissions = db
.get_user_permissions(&user.name, guild_id)
.map_err(Error::Database)?;
if !user_permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
while let Ok(Some(field)) = form_data.next_field().await {
let Some(field_name) = field.name() else {
continue;
};
if field_name.eq_ignore_ascii_case("name") {
friendly_name = Some(field.text().await.map_err(|_| Error::InvalidRequest)?);
name = Some(field.text().await.map_err(|_| Error::InvalidRequest)?);
continue;
}
@ -637,29 +363,13 @@ pub(crate) async fn v2_upload_guild_intro(
}
}
let Some(friendly_name) = friendly_name else {
let Some(name) = name else {
return Err(Error::InvalidRequest);
};
let Some(file) = file else {
return Err(Error::InvalidRequest);
};
{
let Some(guild) = settings.guilds.get(&guild) else {
return Err(Error::NoGuildFound);
};
let Some(guild_user) = guild.users.get(&user.name) else {
return Err(Error::NoUserFound);
};
if !guild_user.permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
}
let Some(guild) = settings.guilds.get_mut(&guild) else {
return Err(Error::NoGuildFound);
};
let uuid = Uuid::new_v4().to_string();
let temp_path = format!("./sounds/temp/{uuid}");
let dest_path = format!("./sounds/{uuid}.mp3");
@ -669,114 +379,45 @@ pub(crate) async fn v2_upload_guild_intro(
media::normalize(&temp_path, &dest_path).await?;
std::fs::remove_file(&temp_path)?;
guild.intros.insert(
uuid.clone(),
Intro::File(FileIntro {
filename: format!("{uuid}.mp3"),
friendly_name,
}),
);
db.insert_intro(&name, 0, guild_id, &format!("{uuid}.mp3"))
.map_err(Error::Database)?;
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
Ok(headers)
}
pub(crate) async fn add_guild_intro(
State(state): State<ApiState>,
Path(guild): Path<u64>,
Query(mut params): Query<HashMap<String, String>>,
headers: HeaderMap,
) -> Result<(), Error> {
let mut settings = state.settings.lock().await;
// TODO: make this an impl on HeaderMap
let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else {
return Err(Error::NoUserFound);
};
let Some(url) = params.remove("url") else {
return Err(Error::InvalidRequest);
};
let Some(friendly_name) = params.remove("name") else {
return Err(Error::InvalidRequest);
};
{
let Some(guild) = settings.guilds.get(&guild) else {
return Err(Error::NoGuildFound);
};
let auth_user = match settings.auth_users.get(token) {
Some(user) => user,
None => return Err(Error::NoUserFound),
};
let Some(guild_user) = guild.users.get(&auth_user.name) else {
return Err(Error::NoUserFound);
};
if !guild_user.permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
}
let Some(guild) = settings.guilds.get_mut(&guild) else {
return Err(Error::NoGuildFound);
};
let uuid = Uuid::new_v4().to_string();
let child = tokio::process::Command::new("yt-dlp")
.arg(&url)
.args(["-o", &format!("sounds/{uuid}")])
.args(["-x", "--audio-format", "mp3"])
.spawn()
.map_err(Error::Ytdl)?
.wait()
.await
.map_err(Error::Ytdl)?;
if !child.success() {
return Err(Error::YtdlTerminated);
}
guild.intros.insert(
uuid.clone(),
Intro::File(FileIntro {
filename: format!("{uuid}.mp3"),
friendly_name,
}),
);
Ok(())
}
pub(crate) async fn v2_add_guild_intro(
State(state): State<ApiState>,
Path(guild): Path<u64>,
Path(guild_id): Path<u64>,
Query(mut params): Query<HashMap<String, String>>,
user: User,
user: db::User,
) -> Result<HeaderMap, Error> {
let mut settings = state.settings.lock().await;
let db = state.db.lock().await;
let Some(url) = params.remove("url") else {
return Err(Error::InvalidRequest);
};
let Some(friendly_name) = params.remove("name") else {
let Some(name) = params.remove("name") else {
return Err(Error::InvalidRequest);
};
if !db
.get_guilds()
.map_err(Error::Database)?
.into_iter()
.any(|guild| guild.id == guild_id)
{
let Some(guild) = settings.guilds.get(&guild) else {
return Err(Error::NoGuildFound);
};
let Some(guild_user) = guild.users.get(&user.name) else {
return Err(Error::NoUserFound);
};
if !guild_user.permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
return Err(Error::NoGuildFound);
}
let Some(guild) = settings.guilds.get_mut(&guild) else {
return Err(Error::NoGuildFound);
};
let user_permissions = db
.get_user_permissions(&user.name, guild_id)
.map_err(Error::Database)?;
if !user_permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
let uuid = Uuid::new_v4().to_string();
let child = tokio::process::Command::new("yt-dlp")
@ -793,64 +434,11 @@ pub(crate) async fn v2_add_guild_intro(
return Err(Error::YtdlTerminated);
}
guild.intros.insert(
uuid.clone(),
Intro::File(FileIntro {
filename: format!("{uuid}.mp3"),
friendly_name,
}),
);
db.insert_intro(&name, 0, guild_id, &format!("{uuid}.mp3"))
.map_err(Error::Database)?;
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
Ok(headers)
}
pub(crate) async fn delete_guild_intro(
State(state): State<ApiState>,
Path(guild): Path<u64>,
headers: HeaderMap,
Json(body): Json<DeleteIntroRequest>,
) -> Result<(), Error> {
let mut settings = state.settings.lock().await;
// TODO: make this an impl on HeaderMap
let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else {
return Err(Error::NoUserFound);
};
{
let Some(guild) = settings.guilds.get(&guild) else {
return Err(Error::NoGuildFound);
};
let auth_user = match settings.auth_users.get(token) {
Some(user) => user,
None => return Err(Error::NoUserFound),
};
let Some(guild_user) = guild.users.get(&auth_user.name) else {
return Err(Error::NoUserFound);
};
if !guild_user.permissions.can(auth::Permission::DeleteSounds) {
return Err(Error::InvalidPermission);
}
}
let Some(guild) = settings.guilds.get_mut(&guild) else {
return Err(Error::NoGuildFound);
};
// Remove intro from any users
for channel in guild.channels.iter_mut() {
for user in channel.1.users.iter_mut() {
user.1
.intros
.retain(|user_intro| !body.0.iter().any(|intro| &user_intro.index == intro));
}
}
for intro in &body.0 {
guild.intros.remove(intro);
}
Ok(())
}

View File

@ -1,25 +1,26 @@
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use crate::auth;
use crate::{
auth,
db::{self, Database},
};
use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::Redirect};
use axum_extra::extract::CookieJar;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use serenity::prelude::TypeMapKey;
use tracing::trace;
use uuid::Uuid;
type UserToken = String;
use tracing::error;
// TODO: make this is wrapped type so cloning isn't happening
#[derive(Clone)]
pub(crate) struct ApiState {
pub settings: Arc<tokio::sync::Mutex<Settings>>,
pub db: Arc<tokio::sync::Mutex<Database>>,
pub secrets: auth::DiscordSecret,
pub origin: String,
}
#[async_trait]
impl FromRequestParts<ApiState> for crate::auth::User {
impl FromRequestParts<ApiState> for db::User {
type Rejection = Redirect;
async fn from_request_parts(
@ -29,13 +30,23 @@ impl FromRequestParts<ApiState> for crate::auth::User {
let jar = CookieJar::from_headers(&headers);
if let Some(token) = jar.get("access_token") {
match state.settings.lock().await.auth_users.get(token.value()) {
// :vomit:
Some(user) => Ok(user.clone()),
None => Err(Redirect::to("/login")),
match state.db.lock().await.get_user_from_api_key(token.value()) {
Ok(user) => {
let now = Utc::now().naive_utc();
if user.api_key_expires_at < now || user.discord_token_expires_at < now {
Err(Redirect::to(&format!("{}/login", state.origin)))
} else {
Ok(user)
}
}
Err(err) => {
error!(?err, "failed to authenticate user");
Err(Redirect::to(&format!("{}/login", state.origin)))
}
}
} else {
Err(Redirect::to("/login"))
Err(Redirect::to(&format!("{}/login", state.origin)))
}
}
}
@ -47,116 +58,7 @@ pub(crate) struct Settings {
pub(crate) run_api: bool,
#[serde(default)]
pub(crate) run_bot: bool,
pub(crate) guilds: HashMap<u64, GuildSettings>,
#[serde(default)]
pub(crate) auth_users: HashMap<UserToken, auth::User>,
}
impl TypeMapKey for Settings {
type Value = Arc<Settings>;
}
impl Settings {
pub(crate) fn save(&self) -> Result<(), std::io::Error> {
trace!("attempting to save config");
let serialized = serde_json::to_string_pretty(&self)?;
std::fs::copy(
"./config/settings.json",
format!(
"./config/{}-settings.json.old",
chrono::Utc::now().naive_utc().format("%Y-%m-%d %H:%M:%S")
),
)?;
trace!("created copy of original settings");
std::fs::write("./config/settings.json", serialized)?;
trace!("saved settings to disk");
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GuildSettings {
pub(crate) name: String,
pub(crate) sound_delay: u64,
#[serde(default)]
pub(crate) channels: HashMap<String, ChannelSettings>,
#[serde(default)]
pub(crate) intros: HashMap<String, Intro>,
#[serde(default)]
pub(crate) users: HashMap<String, GuildUser>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GuildUser {
pub(crate) permissions: auth::Permissions,
}
pub(crate) trait IntroFriendlyName {
fn friendly_name(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum Intro {
File(FileIntro),
Online(OnlineIntro),
}
impl IntroFriendlyName for Intro {
fn friendly_name(&self) -> &str {
match self {
Self::File(intro) => intro.friendly_name(),
Self::Online(intro) => intro.friendly_name(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct FileIntro {
pub(crate) filename: String,
pub(crate) friendly_name: String,
}
impl IntroFriendlyName for FileIntro {
fn friendly_name(&self) -> &str {
&self.friendly_name
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct OnlineIntro {
pub(crate) url: String,
pub(crate) friendly_name: String,
}
impl IntroFriendlyName for OnlineIntro {
fn friendly_name(&self) -> &str {
&self.friendly_name
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct ChannelSettings {
#[serde(alias = "enterUsers")]
pub(crate) users: HashMap<String, UserSettings>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct IntroIndex {
pub(crate) index: String,
pub(crate) volume: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct UserSettings {
pub(crate) intros: Vec<IntroIndex>,
}