diff --git a/Cargo.toml b/Cargo.toml index 984f717..88e29c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ dotenv = "0.15.0" futures = "0.3.26" iter_tools = "0.1.4" reqwest = "0.11.14" -rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] } serde = "1.0.152" serde_json = "1.0.93" thiserror = "1.0.38" @@ -32,3 +31,9 @@ features = ["client", "gateway", "rustls_backend", "model", "cache", "voice"] [dependencies.songbird] version = "0.3.2" features = [ "builtin-queue", "yt-dlp" ] + +[target.'cfg(unix)'.dependencies] +rusqlite = { version = "0.29.0", features = ["chrono"] } + +[target.'cfg(windows)'.dependencies] +rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] } diff --git a/flake.nix b/flake.nix index a962ccb..c0d424f 100644 --- a/flake.nix +++ b/flake.nix @@ -35,6 +35,7 @@ pkg-config gcc openssl + sqlite pkg-config python3 ffmpeg @@ -50,7 +51,7 @@ name = "memejoin-rs"; src = self; buildInputs = [ openssl.dev ]; - nativeBuildInputs = [ local-rust pkg-config openssl openssl.dev cmake gcc libopus ]; + nativeBuildInputs = [ local-rust pkg-config openssl openssl.dev cmake gcc libopus sqlite ]; cargoLock = { lockFile = ./Cargo.lock; @@ -62,7 +63,7 @@ name = "memejoin-rs"; copyToRoot = buildEnv { name = "image-root"; - paths = [ default cacert openssl openssl.dev ffmpeg libopus youtube-dl yt-dlp ]; + paths = [ default cacert openssl openssl.dev ffmpeg libopus youtube-dl yt-dlp sqlite ]; }; runAsRoot = '' #!${runtimeShell} diff --git a/src/db.rs b/src/db/mod.rs similarity index 91% rename from src/db.rs rename to src/db/mod.rs index 9a60249..3d96aff 100644 --- a/src/db.rs +++ b/src/db/mod.rs @@ -1,10 +1,9 @@ use std::path::Path; use chrono::NaiveDateTime; -use iter_tools::Itertools; use rusqlite::{Connection, OptionalExtension, Result}; use serde::{Deserialize, Serialize}; -use tracing::{error, warn}; +use tracing::warn; use crate::auth; @@ -23,7 +22,7 @@ impl Database { let mut query = self.conn.prepare( " SELECT - id, name, soundDelay + id, name, sound_delay FROM Guild ", )?; @@ -92,7 +91,7 @@ impl Database { let mut query = self.conn.prepare( " SELECT - id, name, soundDelay + id, name, sound_delay FROM Guild LEFT JOIN UserGuild ON UserGuild.guild_id = Guild.id WHERE UserGuild.username = :username @@ -120,7 +119,8 @@ impl Database { " SELECT Intro.id, - Intro.name + Intro.name, + Intro.filename FROM Intro WHERE Intro.guild_id = :guild_id @@ -139,6 +139,7 @@ impl Database { Ok(Intro { id: row.get(0)?, name: row.get(1)?, + filename: row.get(2)?, }) }, )? @@ -154,6 +155,7 @@ impl Database { SELECT Intro.id, Intro.name, + Intro.filename, UI.channel_name, UI.username FROM Intro @@ -177,9 +179,10 @@ impl Database { intro: Intro { id: row.get(0)?, name: row.get(1)?, + filename: row.get(2)?, }, - channel_name: row.get(2)?, - username: row.get(3)?, + channel_name: row.get(3)?, + username: row.get(4)?, }) }, )? @@ -252,7 +255,7 @@ impl Database { Ok(intros) } - pub fn add_user( + pub fn insert_user( &self, username: &str, api_key: &str, @@ -281,6 +284,27 @@ impl Database { Ok(()) } + pub fn insert_intro( + &self, + name: &str, + volume: i32, + guild_id: u64, + filename: &str, + ) -> Result<()> { + let affected = self.conn.execute( + "INSERT INTO + Intro (name, volume, guild_id, filename) + VALUES (?1, ?2, ?3, ?4)", + &[name, &volume.to_string(), &guild_id.to_string(), filename], + )?; + + if affected < 1 { + warn!("no rows affected when attempting to insert intro"); + } + + Ok(()) + } + pub fn insert_user_guild(&self, username: &str, guild_id: u64) -> Result<()> { let affected = self.conn.execute( "INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)", @@ -340,7 +364,7 @@ impl Database { Ok(()) } - pub fn remove_user_intro( + pub fn delete_user_intro( &self, username: &str, guild_id: u64, @@ -372,7 +396,7 @@ impl Database { } pub struct Guild { - pub id: String, + pub id: u64, pub name: String, pub sound_delay: u32, } @@ -389,6 +413,7 @@ pub struct User { pub struct Intro { pub id: i32, pub name: String, + pub filename: String, } pub struct UserIntro { diff --git a/src/db/schema.sql b/src/db/schema.sql new file mode 100644 index 0000000..d38ef1e --- /dev/null +++ b/src/db/schema.sql @@ -0,0 +1,84 @@ +BEGIN; + +create table User +( + username TEXT not null + constraint User_pk + primary key, + api_key TEXT not null, + api_key_expires_at DATETIME not null, + discord_token TEXT not null, + discord_token_expires_at DATETIME not null +); + +create table Intro +( + id integer not null + constraint Intro_pk + primary key autoincrement, + name TEXT not null, + volume integer not null, + guild_id integer not null + constraint Intro_Guild_guild_id_fk + references Guild ("id"), + filename TEXT not null +); + +create table Guild +( + id integer not null + primary key, + name TEXT not null, + sound_delay integer not null +); + +create table Channel +( + name TEXT + primary key, + guild_id integer + constraint Channel_Guild_id_fk + references Guild (id) +); + +create table UserGuild +( + username TEXT not null + constraint UserGuild_User_username_fk + references User, + guild_id integer not null + constraint UserGuild_Guild_id_fk + references Guild (id), + primary key ("username", "guild_id") +); + +create table UserIntro +( + username text not null + constraint UserIntro_User_username_fk + references User, + intro_id integer not null + constraint UserIntro_Intro_id_fk + references Intro, + guild_id integer not null + constraint UserIntro_Guild_guild_id_fk + references Guild ("id"), + channel_name text not null + constraint UserIntro_Channel_channel_name_fk + references Channel ("name"), + primary key ("username", "intro_id", "guild_id", "channel_name") +); + +create table UserPermission +( + username TEXT not null + constraint UserPermission_User_username_fk + references User, + guild_id integer not null + constraint User_Guild_guild_id_fk + references Guild ("id"), + permissions integer not null, + primary key ("username", "guild_id") +); + +COMMIT; diff --git a/src/main.rs b/src/main.rs index bc3c601..6730a07 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,20 +10,16 @@ mod page; mod routes; pub mod settings; -use axum::http::{HeaderValue, Method}; -use axum::routing::{delete, get, post}; +use axum::http::Method; +use axum::routing::{get, post}; use axum::Router; -use futures::StreamExt; use settings::ApiState; -use songbird::tracks::TrackQueue; -use std::collections::HashMap; use std::env; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::mpsc; use tower_http::cors::{Any, CorsLayer}; -use serde::Deserialize; use serenity::async_trait; use serenity::model::prelude::{Channel, ChannelId, GuildId, Member, Ready}; use serenity::model::voice::VoiceState; @@ -32,7 +28,7 @@ use serenity::prelude::*; use songbird::SerenityInit; use tracing::*; -use crate::settings::{Intro, Settings}; +use crate::settings::Settings; enum HandlerMessage { Ready(Context), @@ -120,7 +116,7 @@ impl EventHandler for Handler { } } -fn spawn_api(settings: Arc>) { +fn spawn_api(db: Arc>) { let secrets = auth::DiscordSecret { client_id: env::var("DISCORD_CLIENT_ID").expect("expected DISCORD_CLIENT_ID env var"), client_secret: env::var("DISCORD_CLIENT_SECRET") @@ -129,10 +125,7 @@ fn spawn_api(settings: Arc>) { let origin = env::var("APP_ORIGIN").expect("expected APP_ORIGIN"); let state = ApiState { - db: Arc::new(tokio::sync::Mutex::new( - db::Database::new("db.sqlite").expect("couldn't open sqlite db"), - )), - settings, + db, secrets, origin: origin.clone(), }; @@ -158,23 +151,8 @@ fn spawn_api(settings: Arc>) { post(routes::v2_upload_guild_intro), ) .route("/health", get(routes::health)) - .route("/me", get(routes::me)) - .route("/intros/:guild", get(routes::intros)) - .route("/intros/:guild/add", get(routes::add_guild_intro)) - .route("/intros/:guild/upload", post(routes::upload_guild_intro)) - .route("/intros/:guild/delete", delete(routes::delete_guild_intro)) - .route( - "/intros/:guild/:channel/:intro", - post(routes::add_intro_to_user), - ) - .route( - "/intros/:guild/:channel/:intro/remove", - post(routes::remove_intro_to_user), - ) - .route("/auth", get(routes::auth)) .layer( CorsLayer::new() - // TODO: move this to env variable .allow_origin([origin.parse().unwrap()]) .allow_headers(Any) .allow_methods([Method::GET, Method::POST, Method::DELETE]), @@ -189,7 +167,7 @@ fn spawn_api(settings: Arc>) { }); } -async fn spawn_bot(settings: Arc>) { +async fn spawn_bot(db: Arc>) { let token = env::var("DISCORD_TOKEN").expect("expected DISCORD_TOKEN env var"); let songbird = songbird::Songbird::serenity(); @@ -219,12 +197,19 @@ async fn spawn_bot(settings: Arc>) { match msg { HandlerMessage::Ready(ctx) => { info!("Got Ready message"); - let settings = settings.lock().await; let songbird = songbird::get(&ctx).await.expect("no songbird instance"); - for guild_id in settings.guilds.keys() { - let handler_lock = songbird.get_or_insert(GuildId(*guild_id)); + let guilds = match db.lock().await.get_guilds() { + Ok(guilds) => guilds, + Err(err) => { + error!(?err, "failed to get guild on bot ready"); + continue; + } + }; + + for guild in guilds { + let handler_lock = songbird.get_or_insert(GuildId(guild.id)); let mut handler = handler_lock.lock().await; @@ -232,7 +217,7 @@ async fn spawn_bot(settings: Arc>) { songbird::Event::Track(songbird::TrackEvent::End), TrackEventHandler { tx: tx.clone(), - guild_id: GuildId(*guild_id), + guild_id: GuildId(guild.id), }, ); } @@ -255,7 +240,6 @@ async fn spawn_bot(settings: Arc>) { HandlerMessage::PlaySound(ctx, member, channel_id) => { info!("Got PlaySound message"); - let settings = settings.lock().await; let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx.cache) else { @@ -263,60 +247,35 @@ async fn spawn_bot(settings: Arc>) { continue; }; - let Some(guild_settings) = settings.guilds.get(channel.guild_id.as_u64()) - else { - error!("couldn't get guild from id: {}", channel.guild_id.as_u64()); - continue; - }; - let Some(channel_settings) = guild_settings.channels.get(channel.name()) else { - error!( - "couldn't get channel_settings from name: {}", - channel.name() - ); - continue; - }; - let Some(user) = channel_settings.users.get(&member.user.name) else { - error!( - "couldn't get user settings from name: {}", - &member.user.name - ); - continue; + let intros = match db.lock().await.get_user_channel_intros( + &member.user.name, + channel.guild_id.0, + channel.name(), + ) { + Ok(intros) => intros, + Err(err) => { + error!( + ?err, + "failed to get user channel intros when playing sound through bot" + ); + continue; + } }; // TODO: randomly choose a intro to play - let Some(intro) = user.intros.first() else { + let Some(intro) = intros.first() else { error!("couldn't get user intro, none exist"); continue; }; - let source = match guild_settings.intros.get(&intro.index) { - Some(Intro::Online(intro)) => match songbird::ytdl(&intro.url).await { - Ok(source) => source, - Err(err) => { - error!("Error starting youtube source from {}: {err:?}", intro.url); - continue; - } - }, - Some(Intro::File(intro)) => { - match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await { - Ok(source) => source, - Err(err) => { - error!( - "Error starting file source from {}: {err:?}", - intro.filename - ); - continue; - } - } - } - None => { + let source = match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await + { + Ok(source) => source, + Err(err) => { error!( - "Failed to find intro for user {} on guild {} in channel {}, IntroIndex: {}", - member.user.name, - channel.guild_id.as_u64(), - channel.name(), - intro.index - ); + "Error starting file source from {}: {err:?}", + intro.filename + ); continue; } }; @@ -350,23 +309,23 @@ async fn main() -> std::io::Result<()> { &std::fs::read_to_string("config/settings.json").expect("no config/settings.json"), ) .expect("error parsing settings file"); - - let (run_api, run_bot) = (settings.run_api, settings.run_bot); - info!("{settings:?}"); - let settings = Arc::new(Mutex::new(settings)); + let (run_api, run_bot) = (settings.run_api, settings.run_bot); + let db = Arc::new(tokio::sync::Mutex::new( + db::Database::new("db.sqlite").expect("couldn't open sqlite db"), + )); + if run_api { - spawn_api(settings.clone()); + spawn_api(db.clone()); } if run_bot { - spawn_bot(settings.clone()).await; + spawn_bot(db).await; } info!("spawned background tasks"); let _ = tokio::signal::ctrl_c().await; - settings.lock().await.save()?; info!("Received Ctrl-C, shuttdown down."); Ok(()) diff --git a/src/page.rs b/src/page.rs index 88315e3..e2785f2 100644 --- a/src/page.rs +++ b/src/page.rs @@ -2,7 +2,7 @@ use crate::{ auth::{self}, db::{self, User}, htmx::{Build, HtmxBuilder, Tag}, - settings::{ApiState, GuildSettings, Intro, IntroFriendlyName}, + settings::ApiState, }; use axum::{ extract::{Path, State}, @@ -34,7 +34,7 @@ pub(crate) async fn home( let user_guilds = db.get_user_guilds(&user.name).map_err(|err| { error!(?err, "failed to get user guilds"); // TODO: change this to returning a error to the client - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; Ok(Html( @@ -109,17 +109,17 @@ pub(crate) async fn guild_dashboard( let guild_intros = db.get_guild_intros(guild_id).map_err(|err| { error!(?err, %guild_id, "couldn't get guild intros"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; let guild_channels = db.get_guild_channels(guild_id).map_err(|err| { error!(?err, %guild_id, "couldn't get guild channels"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; let all_user_intros = db.get_all_user_intros(guild_id).map_err(|err| { error!(?err, %guild_id, "couldn't get user intros"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; let user_permissions = db .get_user_permissions(&user.name, guild_id) diff --git a/src/routes.rs b/src/routes.rs index dd1dfdf..338822b 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,69 +1,25 @@ -use std::{collections::HashMap, ops::Add, sync::Arc}; +use std::collections::HashMap; use axum::{ - body::Bytes, extract::{Multipart, Path, Query, State}, http::{HeaderMap, HeaderValue}, response::{Html, IntoResponse, Redirect}, - Form, Json, }; use axum_extra::extract::{cookie::Cookie, CookieJar}; -use chrono::{Duration, NaiveDate, Utc}; -use iter_tools::Itertools; -use reqwest::{Proxy, StatusCode, Url}; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; -use tracing::{error, info, log::trace}; +use chrono::{Duration, Utc}; +use reqwest::{StatusCode, Url}; +use serde::{Deserialize, Deserializer}; +use tracing::{error, info}; use uuid::Uuid; use crate::{ - auth::{self, User}, + auth::{self}, db, htmx::Build, page, - settings::FileIntro, }; -use crate::{ - media, - settings::{ApiState, GuildUser, Intro, IntroIndex, UserSettings}, -}; - -#[derive(Serialize)] -pub(crate) enum IntroResponse<'a> { - Intros(&'a HashMap), - NoGuildFound, -} - -#[derive(Serialize)] -pub(crate) enum MeResponse<'a> { - Me(Me<'a>), - NoUserFound, -} - -#[derive(Serialize)] -pub(crate) struct Me<'a> { - pub(crate) username: String, - pub(crate) guilds: Vec>, -} - -#[derive(Serialize)] -pub(crate) struct MeGuild<'a> { - // NOTE(pcleavelin): for some reason this doesn't serialize properly if a u64 - pub(crate) id: String, - pub(crate) name: String, - pub(crate) channels: Vec>, - pub(crate) permissions: auth::Permissions, -} - -#[derive(Serialize)] -pub(crate) struct MeChannel<'a> { - pub(crate) name: String, - pub(crate) intros: &'a Vec, -} - -#[derive(Deserialize)] -pub(crate) struct DeleteIntroRequest(Vec); +use crate::{media, settings::ApiState}; pub(crate) async fn health() -> &'static str { "Hello!" @@ -76,8 +32,6 @@ pub(crate) enum Error { #[error("{0}")] GetUser(#[from] reqwest::Error), - #[error("User doesn't exist")] - NoUserFound, #[error("Guild doesn't exist")] NoGuildFound, #[error("invalid request")] @@ -108,7 +62,6 @@ impl IntoResponse for Error { Self::GetUser(error) => (StatusCode::UNAUTHORIZED, error.to_string()).into_response(), Self::NoGuildFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(), - Self::NoUserFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(), Self::InvalidRequest => (StatusCode::BAD_REQUEST, self.to_string()).into_response(), Self::InvalidPermission => (StatusCode::UNAUTHORIZED, self.to_string()).into_response(), @@ -134,11 +87,22 @@ struct DiscordUser { #[derive(Deserialize)] struct DiscordUserGuild { - pub id: String, - pub name: String, + #[serde(deserialize_with = "serde_string_as_u64")] + pub id: u64, pub owner: bool, } +fn serde_string_as_u64<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let value = <&str as Deserialize>::deserialize(deserializer)?; + + value + .parse::() + .map_err(|_| serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &"u64")) +} + pub(crate) async fn v2_auth( State(state): State, Query(params): Query>, @@ -208,11 +172,8 @@ pub(crate) async fn v2_auth( in_a_guild = true; - // TODO: change this - let guild_id = guild.id.parse::().expect("guild id should be u64"); - let now = Utc::now().naive_utc(); - db.add_user( + db.insert_user( &user.username, &token, now + Duration::weeks(4), @@ -221,20 +182,21 @@ pub(crate) async fn v2_auth( ) .map_err(Error::Database)?; - db.insert_user_guild(&user.username, guild_id) + db.insert_user_guild(&user.username, guild.id) .map_err(Error::Database)?; - // TODO: Don't reset permissions - db.insert_user_permission( - &user.username, - guild_id, - if discord_guild.owner { - auth::Permissions(auth::Permission::all()) - } else { - Default::default() - }, - ) - .map_err(Error::Database)?; + if db.get_user_permissions(&user.username, guild.id).is_err() { + db.insert_user_permission( + &user.username, + guild.id, + if discord_guild.owner { + auth::Permissions(auth::Permission::all()) + } else { + Default::default() + }, + ) + .map_err(Error::Database)?; + } } if !in_a_guild { @@ -252,99 +214,6 @@ pub(crate) async fn v2_auth( Ok((jar.add(cookie), Redirect::to(&format!("{}/", state.origin)))) } -pub(crate) async fn auth( - State(state): State, - Query(params): Query>, -) -> Result, Error> { - let Some(code) = params.get("code") else { - return Err(Error::Auth("no code".to_string())); - }; - - info!("attempting to get access token with code {}", code); - - let mut data = HashMap::new(); - - let redirect_uri = format!("{}/old/auth", state.origin); - data.insert("client_id", state.secrets.client_id.as_str()); - data.insert("client_secret", state.secrets.client_secret.as_str()); - data.insert("grant_type", "authorization_code"); - data.insert("code", code); - data.insert("redirect_uri", &redirect_uri); - - let client = reqwest::Client::new(); - - let auth: auth::Discord = client - .post("https://discord.com/api/oauth2/token") - .form(&data) - .send() - .await - .map_err(|err| Error::Auth(err.to_string()))? - .json() - .await - .map_err(|err| Error::Auth(err.to_string()))?; - let token = Uuid::new_v4().to_string(); - - // Get authorized username - let user: DiscordUser = client - .get("https://discord.com/api/v10/users/@me") - .bearer_auth(&auth.access_token) - .send() - .await? - .json() - .await?; - - // TODO: get bot's guilds so we only save users who are able to use the bot - let discord_guilds: Vec = client - .get("https://discord.com/api/v10/users/@me/guilds") - .bearer_auth(&auth.access_token) - .send() - .await? - .json() - .await - .map_err(|err| Error::Auth(err.to_string()))?; - - let mut settings = state.settings.lock().await; - let mut in_a_guild = false; - for g in settings.guilds.iter_mut() { - let Some(discord_guild) = discord_guilds - .iter() - .find(|discord_guild| discord_guild.id == g.0.to_string()) - else { - continue; - }; - - in_a_guild = true; - - if !g.1.users.contains_key(&user.username) { - g.1.users.insert( - user.username.clone(), - GuildUser { - permissions: if discord_guild.owner { - auth::Permissions(auth::Permission::all()) - } else { - Default::default() - }, - }, - ); - } - } - - if !in_a_guild { - return Err(Error::NoGuildFound); - } - - settings.auth_users.insert( - token.clone(), - auth::User { - auth, - name: user.username.clone(), - }, - ); - // TODO: add permissions based on roles - - Ok(Json(json!({"token": token, "username": user.username}))) -} - pub(crate) async fn v2_add_intro_to_user( State(state): State, Path((guild_id, channel)): Path<(u64, String)>, @@ -361,21 +230,21 @@ pub(crate) async fn v2_add_intro_to_user( let intro_id = intro_id.parse::().map_err(|err| { error!(?err, "invalid intro id"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; db.insert_user_intro(&user.name, guild_id, &channel, intro_id) .map_err(|err| { error!(?err, "failed to add user intro"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; } let guild_intros = db.get_guild_intros(guild_id).map_err(|err| { error!(?err, %guild_id, "couldn't get guild intros"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; let intros = db @@ -383,7 +252,7 @@ pub(crate) async fn v2_add_intro_to_user( .map_err(|err| { error!(?err, user = %user.name, %guild_id, "couldn't get user intros"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; Ok(Html( @@ -414,21 +283,21 @@ pub(crate) async fn v2_remove_intro_from_user( let intro_id = intro_id.parse::().map_err(|err| { error!(?err, "invalid intro id"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; - db.remove_user_intro(&user.name, guild_id, &channel, intro_id) + db.delete_user_intro(&user.name, guild_id, &channel, intro_id) .map_err(|err| { error!(?err, "failed to remove user intro"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; } let guild_intros = db.get_guild_intros(guild_id).map_err(|err| { error!(?err, %guild_id, "couldn't get guild intros"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; let intros = db @@ -436,7 +305,7 @@ pub(crate) async fn v2_remove_intro_from_user( .map_err(|err| { error!(?err, user = %user.name, %guild_id, "couldn't get user intros"); // TODO: change to actual error - Redirect::to("/login") + Redirect::to(&format!("{}/login", state.origin)) })?; Ok(Html( @@ -451,222 +320,40 @@ pub(crate) async fn v2_remove_intro_from_user( )) } -pub(crate) async fn add_intro_to_user( - State(state): State, - headers: HeaderMap, - Path((guild, channel, intro_index)): Path<(u64, String, String)>, -) { - let mut settings = state.settings.lock().await; - let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { - return; - }; - let user = match settings.auth_users.get(token) { - Some(user) => user.name.clone(), - None => return, - }; - - let Some(guild) = settings.guilds.get_mut(&guild) else { - return; - }; - let Some(channel) = guild.channels.get_mut(&channel) else { - return; - }; - let Some(user) = channel.users.get_mut(&user) else { - return; - }; - - if !user.intros.iter().any(|intro| intro.index == intro_index) { - user.intros.push(IntroIndex { - index: intro_index, - volume: 20, - }); - - // TODO: don't save on every change - if let Err(err) = settings.save() { - error!("Failed to save config: {err:?}"); - } - } -} - -pub(crate) async fn remove_intro_to_user( - State(state): State, - headers: HeaderMap, - Path((guild, channel, intro_index)): Path<(u64, String, String)>, -) { - let mut settings = state.settings.lock().await; - let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { - return; - }; - let user = match settings.auth_users.get(token) { - Some(user) => user.name.clone(), - None => return, - }; - - let Some(guild) = settings.guilds.get_mut(&guild) else { - return; - }; - let Some(channel) = guild.channels.get_mut(&channel) else { - return; - }; - let Some(user) = channel.users.get_mut(&user) else { - return; - }; - - if let Some(index) = user - .intros - .iter() - .position(|intro| intro_index == intro.index) - { - user.intros.remove(index); - } - - // TODO: don't save on every change - if let Err(err) = settings.save() { - error!("Failed to save config: {err:?}"); - } -} - -pub(crate) async fn intros(State(state): State, Path(guild): Path) -> Json { - let settings = state.settings.lock().await; - let Some(guild) = settings.guilds.get(&guild) else { - return Json(json!(IntroResponse::NoGuildFound)); - }; - - Json(json!(IntroResponse::Intros(&guild.intros))) -} - -pub(crate) async fn me( - State(state): State, - headers: HeaderMap, -) -> Result, Error> { - let mut settings = state.settings.lock().await; - let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { - return Err(Error::NoUserFound); - }; - - let (username, access_token) = match settings.auth_users.get(token) { - Some(user) => (user.name.clone(), user.auth.access_token.clone()), - None => return Err(Error::NoUserFound), - }; - - let mut me = Me { - username: username.clone(), - guilds: Vec::new(), - }; - - for g in settings.guilds.iter_mut() { - // TODO: don't do this n^2 lookup - - let guild_user = - g.1.users - // TODO: why must clone - .entry(username.clone()) - // TODO: check if owner for permissions - .or_insert(Default::default()); - - let mut guild = MeGuild { - id: g.0.to_string(), - name: g.1.name.clone(), - channels: Vec::new(), - permissions: guild_user.permissions, - }; - - for channel in g.1.channels.iter_mut() { - let user_settings = channel - .1 - .users - .entry(username.clone()) - .or_insert(UserSettings { intros: Vec::new() }); - - guild.channels.push(MeChannel { - name: channel.0.to_owned(), - intros: &user_settings.intros, - }); - } - - me.guilds.push(guild); - } - - if me.guilds.is_empty() { - Ok(Json(json!(MeResponse::NoUserFound))) - } else { - Ok(Json(json!(MeResponse::Me(me)))) - } -} - -pub(crate) async fn upload_guild_intro( - State(state): State, - Path(guild): Path, - Query(mut params): Query>, - headers: HeaderMap, - file: Bytes, -) -> Result<(), Error> { - let mut settings = state.settings.lock().await; - - let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { - return Err(Error::NoUserFound); - }; - let Some(friendly_name) = params.remove("name") else { - return Err(Error::InvalidRequest); - }; - - { - let Some(guild) = settings.guilds.get(&guild) else { - return Err(Error::NoGuildFound); - }; - let auth_user = match settings.auth_users.get(token) { - Some(user) => user, - None => return Err(Error::NoUserFound), - }; - let Some(guild_user) = guild.users.get(&auth_user.name) else { - return Err(Error::NoUserFound); - }; - - if !guild_user.permissions.can(auth::Permission::UploadSounds) { - return Err(Error::InvalidPermission); - } - } - - let Some(guild) = settings.guilds.get_mut(&guild) else { - return Err(Error::NoGuildFound); - }; - let uuid = Uuid::new_v4().to_string(); - let temp_path = format!("./sounds/temp/{uuid}"); - let dest_path = format!("./sounds/{uuid}.mp3"); - - // Write original file so its ready for codec conversion - std::fs::write(&temp_path, file)?; - media::normalize(&temp_path, &dest_path).await?; - std::fs::remove_file(&temp_path)?; - - guild.intros.insert( - uuid.clone(), - Intro::File(FileIntro { - filename: format!("{uuid}.mp3"), - friendly_name, - }), - ); - - Ok(()) -} - pub(crate) async fn v2_upload_guild_intro( State(state): State, - Path(guild): Path, + Path(guild_id): Path, user: db::User, mut form_data: Multipart, ) -> Result { - let mut settings = state.settings.lock().await; - let mut friendly_name = None; + let db = state.db.lock().await; + let mut name = None; let mut file = None; + if !db + .get_guilds() + .map_err(Error::Database)? + .into_iter() + .any(|guild| guild.id == guild_id) + { + return Err(Error::NoGuildFound); + } + + let user_permissions = db + .get_user_permissions(&user.name, guild_id) + .map_err(Error::Database)?; + + if !user_permissions.can(auth::Permission::UploadSounds) { + return Err(Error::InvalidPermission); + } + while let Ok(Some(field)) = form_data.next_field().await { let Some(field_name) = field.name() else { continue; }; if field_name.eq_ignore_ascii_case("name") { - friendly_name = Some(field.text().await.map_err(|_| Error::InvalidRequest)?); + name = Some(field.text().await.map_err(|_| Error::InvalidRequest)?); continue; } @@ -676,29 +363,13 @@ pub(crate) async fn v2_upload_guild_intro( } } - let Some(friendly_name) = friendly_name else { + let Some(name) = name else { return Err(Error::InvalidRequest); }; let Some(file) = file else { return Err(Error::InvalidRequest); }; - { - let Some(guild) = settings.guilds.get(&guild) else { - return Err(Error::NoGuildFound); - }; - let Some(guild_user) = guild.users.get(&user.name) else { - return Err(Error::NoUserFound); - }; - - if !guild_user.permissions.can(auth::Permission::UploadSounds) { - return Err(Error::InvalidPermission); - } - } - - let Some(guild) = settings.guilds.get_mut(&guild) else { - return Err(Error::NoGuildFound); - }; let uuid = Uuid::new_v4().to_string(); let temp_path = format!("./sounds/temp/{uuid}"); let dest_path = format!("./sounds/{uuid}.mp3"); @@ -708,114 +379,45 @@ pub(crate) async fn v2_upload_guild_intro( media::normalize(&temp_path, &dest_path).await?; std::fs::remove_file(&temp_path)?; - guild.intros.insert( - uuid.clone(), - Intro::File(FileIntro { - filename: format!("{uuid}.mp3"), - friendly_name, - }), - ); + db.insert_intro(&name, 0, guild_id, &format!("{uuid}.mp3")) + .map_err(Error::Database)?; let mut headers = HeaderMap::new(); headers.insert("HX-Refresh", HeaderValue::from_static("true")); + Ok(headers) } -pub(crate) async fn add_guild_intro( - State(state): State, - Path(guild): Path, - Query(mut params): Query>, - headers: HeaderMap, -) -> Result<(), Error> { - let mut settings = state.settings.lock().await; - // TODO: make this an impl on HeaderMap - let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { - return Err(Error::NoUserFound); - }; - let Some(url) = params.remove("url") else { - return Err(Error::InvalidRequest); - }; - let Some(friendly_name) = params.remove("name") else { - return Err(Error::InvalidRequest); - }; - - { - let Some(guild) = settings.guilds.get(&guild) else { - return Err(Error::NoGuildFound); - }; - let auth_user = match settings.auth_users.get(token) { - Some(user) => user, - None => return Err(Error::NoUserFound), - }; - let Some(guild_user) = guild.users.get(&auth_user.name) else { - return Err(Error::NoUserFound); - }; - - if !guild_user.permissions.can(auth::Permission::UploadSounds) { - return Err(Error::InvalidPermission); - } - } - - let Some(guild) = settings.guilds.get_mut(&guild) else { - return Err(Error::NoGuildFound); - }; - - let uuid = Uuid::new_v4().to_string(); - let child = tokio::process::Command::new("yt-dlp") - .arg(&url) - .args(["-o", &format!("sounds/{uuid}")]) - .args(["-x", "--audio-format", "mp3"]) - .spawn() - .map_err(Error::Ytdl)? - .wait() - .await - .map_err(Error::Ytdl)?; - - if !child.success() { - return Err(Error::YtdlTerminated); - } - - guild.intros.insert( - uuid.clone(), - Intro::File(FileIntro { - filename: format!("{uuid}.mp3"), - friendly_name, - }), - ); - - Ok(()) -} - pub(crate) async fn v2_add_guild_intro( State(state): State, - Path(guild): Path, + Path(guild_id): Path, Query(mut params): Query>, user: db::User, ) -> Result { - let mut settings = state.settings.lock().await; + let db = state.db.lock().await; let Some(url) = params.remove("url") else { return Err(Error::InvalidRequest); }; - let Some(friendly_name) = params.remove("name") else { + let Some(name) = params.remove("name") else { return Err(Error::InvalidRequest); }; + if !db + .get_guilds() + .map_err(Error::Database)? + .into_iter() + .any(|guild| guild.id == guild_id) { - let Some(guild) = settings.guilds.get(&guild) else { - return Err(Error::NoGuildFound); - }; - let Some(guild_user) = guild.users.get(&user.name) else { - return Err(Error::NoUserFound); - }; - - if !guild_user.permissions.can(auth::Permission::UploadSounds) { - return Err(Error::InvalidPermission); - } + return Err(Error::NoGuildFound); } - let Some(guild) = settings.guilds.get_mut(&guild) else { - return Err(Error::NoGuildFound); - }; + let user_permissions = db + .get_user_permissions(&user.name, guild_id) + .map_err(Error::Database)?; + + if !user_permissions.can(auth::Permission::UploadSounds) { + return Err(Error::InvalidPermission); + } let uuid = Uuid::new_v4().to_string(); let child = tokio::process::Command::new("yt-dlp") @@ -832,64 +434,11 @@ pub(crate) async fn v2_add_guild_intro( return Err(Error::YtdlTerminated); } - guild.intros.insert( - uuid.clone(), - Intro::File(FileIntro { - filename: format!("{uuid}.mp3"), - friendly_name, - }), - ); + db.insert_intro(&name, 0, guild_id, &format!("{uuid}.mp3")) + .map_err(Error::Database)?; let mut headers = HeaderMap::new(); headers.insert("HX-Refresh", HeaderValue::from_static("true")); + Ok(headers) } - -pub(crate) async fn delete_guild_intro( - State(state): State, - Path(guild): Path, - headers: HeaderMap, - Json(body): Json, -) -> Result<(), Error> { - let mut settings = state.settings.lock().await; - // TODO: make this an impl on HeaderMap - let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { - return Err(Error::NoUserFound); - }; - - { - let Some(guild) = settings.guilds.get(&guild) else { - return Err(Error::NoGuildFound); - }; - let auth_user = match settings.auth_users.get(token) { - Some(user) => user, - None => return Err(Error::NoUserFound), - }; - let Some(guild_user) = guild.users.get(&auth_user.name) else { - return Err(Error::NoUserFound); - }; - - if !guild_user.permissions.can(auth::Permission::DeleteSounds) { - return Err(Error::InvalidPermission); - } - } - - let Some(guild) = settings.guilds.get_mut(&guild) else { - return Err(Error::NoGuildFound); - }; - - // Remove intro from any users - for channel in guild.channels.iter_mut() { - for user in channel.1.users.iter_mut() { - user.1 - .intros - .retain(|user_intro| !body.0.iter().any(|intro| &user_intro.index == intro)); - } - } - - for intro in &body.0 { - guild.intros.remove(intro); - } - - Ok(()) -} diff --git a/src/settings.rs b/src/settings.rs index ef58c30..bbb040a 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; use crate::{ auth, @@ -6,18 +6,15 @@ use crate::{ }; use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::Redirect}; use axum_extra::extract::CookieJar; +use chrono::Utc; use serde::{Deserialize, Serialize}; use serenity::prelude::TypeMapKey; -use tracing::{error, trace}; -use uuid::Uuid; - -type UserToken = String; +use tracing::error; // TODO: make this is wrapped type so cloning isn't happening #[derive(Clone)] pub(crate) struct ApiState { pub db: Arc>, - pub settings: Arc>, pub secrets: auth::DiscordSecret, pub origin: String, } @@ -34,16 +31,22 @@ impl FromRequestParts for db::User { if let Some(token) = jar.get("access_token") { match state.db.lock().await.get_user_from_api_key(token.value()) { - // :vomit: - Ok(user) => Ok(user), + Ok(user) => { + let now = Utc::now().naive_utc(); + if user.api_key_expires_at < now || user.discord_token_expires_at < now { + Err(Redirect::to(&format!("{}/login", state.origin))) + } else { + Ok(user) + } + } Err(err) => { error!(?err, "failed to authenticate user"); - Err(Redirect::to("/login")) + Err(Redirect::to(&format!("{}/login", state.origin))) } } } else { - Err(Redirect::to("/login")) + Err(Redirect::to(&format!("{}/login", state.origin))) } } } @@ -55,116 +58,7 @@ pub(crate) struct Settings { pub(crate) run_api: bool, #[serde(default)] pub(crate) run_bot: bool, - pub(crate) guilds: HashMap, - - #[serde(default)] - pub(crate) auth_users: HashMap, } impl TypeMapKey for Settings { type Value = Arc; } - -impl Settings { - pub(crate) fn save(&self) -> Result<(), std::io::Error> { - trace!("attempting to save config"); - let serialized = serde_json::to_string_pretty(&self)?; - - std::fs::copy( - "./config/settings.json", - format!( - "./config/{}-settings.json.old", - chrono::Utc::now().naive_utc().format("%Y-%m-%d %H:%M:%S") - ), - )?; - trace!("created copy of original settings"); - - std::fs::write("./config/settings.json", serialized)?; - - trace!("saved settings to disk"); - Ok(()) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct GuildSettings { - pub(crate) name: String, - pub(crate) sound_delay: u64, - #[serde(default)] - pub(crate) channels: HashMap, - #[serde(default)] - pub(crate) intros: HashMap, - #[serde(default)] - pub(crate) users: HashMap, -} - -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct GuildUser { - pub(crate) permissions: auth::Permissions, -} - -pub(crate) trait IntroFriendlyName { - fn friendly_name(&self) -> &str; -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum Intro { - File(FileIntro), - Online(OnlineIntro), -} - -impl IntroFriendlyName for Intro { - fn friendly_name(&self) -> &str { - match self { - Self::File(intro) => intro.friendly_name(), - Self::Online(intro) => intro.friendly_name(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct FileIntro { - pub(crate) filename: String, - pub(crate) friendly_name: String, -} - -impl IntroFriendlyName for FileIntro { - fn friendly_name(&self) -> &str { - &self.friendly_name - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct OnlineIntro { - pub(crate) url: String, - pub(crate) friendly_name: String, -} - -impl IntroFriendlyName for OnlineIntro { - fn friendly_name(&self) -> &str { - &self.friendly_name - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct ChannelSettings { - #[serde(alias = "enterUsers")] - pub(crate) users: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct IntroIndex { - pub(crate) index: String, - pub(crate) volume: i32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct UserSettings { - pub(crate) intros: Vec, -}