From 1ed1e55db4346997d396542f476a2c89a78d4f98 Mon Sep 17 00:00:00 2001 From: Patrick Cleavelin Date: Fri, 3 Mar 2023 01:13:21 -0600 Subject: [PATCH] permissions (still need to get discord roles) intros can now be added to a guild (given proper permissions) --- Cargo.toml | 11 ++++- src/auth.rs | 49 ++++++++++++++++++++ src/main.rs | 12 +++-- src/routes.rs | 119 +++++++++++++++++++++++++++++++++++++++--------- src/settings.rs | 31 +++---------- 5 files changed, 169 insertions(+), 53 deletions(-) create mode 100644 src/auth.rs diff --git a/Cargo.toml b/Cargo.toml index 112f26b..85e8573 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,11 +12,18 @@ futures = "0.3.26" reqwest = "0.11.14" serde = "1.0.152" serde_json = "1.0.93" -serenity = { version = "0.11.5", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "cache", "voice"] } -songbird = { version = "0.3.0", features = [ "builtin-queue" ] } thiserror = "1.0.38" tokio = { version = "1.25.0", features = ["rt-multi-thread", "macros", "signal"] } tower-http = { version = "0.4.0", features = ["cors"] } tracing = "0.1.37" tracing-subscriber = "0.3.16" uuid = { version = "1.3.0", features = ["v4"] } + +[dependencies.serenity] +version = "0.11.5" +default-features = false +features = ["client", "gateway", "rustls_backend", "model", "cache", "voice"] + +[dependencies.songbird] +version = "0.3.0" +features = [ "builtin-queue", "yt-dlp" ] diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..e54cf91 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,49 @@ +use std::{collections::HashMap, sync::Arc}; + +use serde::{Deserialize, Serialize}; +use serenity::prelude::TypeMapKey; +use tracing::trace; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Discord { + pub(crate) access_token: String, + pub(crate) token_type: String, + pub(crate) expires_in: usize, + pub(crate) refresh_token: String, + pub(crate) scope: String, +} + +#[derive(Clone)] +pub(crate) struct DiscordSecret { + pub(crate) client_id: String, + pub(crate) client_secret: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct User { + pub(crate) auth: Discord, + #[serde(default)] + pub(crate) permissions: Permissions, + pub(crate) name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Permissions(u8); +impl Default for Permissions { + fn default() -> Permissions { + Permissions(0) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[repr(u8)] +pub enum Permission { + None, + DownloadSounds, +} + +impl Permissions { + pub(crate) fn can(&self, perm: Permission) -> bool { + self.0 & (perm as u8) > 0 + } +} diff --git a/src/main.rs b/src/main.rs index 601e885..981ed13 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ #![feature(proc_macro_hygiene)] #![feature(async_closure)] +mod auth; mod routes; pub mod settings; @@ -9,7 +10,7 @@ use axum::http::{HeaderValue, Method}; use axum::routing::{get, post}; use axum::Router; use futures::StreamExt; -use settings::{ApiState, DiscordSecret}; +use settings::ApiState; use songbird::tracks::TrackQueue; use std::collections::HashMap; use std::env; @@ -116,7 +117,7 @@ impl EventHandler for Handler { } fn spawn_api(settings: Arc>) { - let secrets = DiscordSecret { + let secrets = auth::DiscordSecret { client_id: env::var("DISCORD_CLIENT_ID").expect("expected DISCORD_CLIENT_ID env var"), client_secret: env::var("DISCORD_CLIENT_SECRET") .expect("expected DISCORD_CLIENT_SECRET env var"), @@ -128,13 +129,14 @@ fn spawn_api(settings: Arc>) { let api = Router::new() .route("/health", get(routes::health)) .route("/me", get(routes::me)) + .route("/intros/:guild/add/:url", get(routes::add_guild_intro)) .route("/intros/:guild", get(routes::intros)) .route( - "/intros/:guild/:channel/:user/:intro", + "/intros/:guild/:channel/:intro", post(routes::add_intro_to_user), ) .route( - "/intros/:guild/:channel/:user/:intro/remove", + "/intros/:guild/:channel/:intro/remove", post(routes::remove_intro_to_user), ) .route("/auth", get(routes::auth)) @@ -246,7 +248,7 @@ async fn spawn_bot(settings: Arc>) { continue; }; - let source = match guild_settings.intros.get(intro.index) { + let source = match guild_settings.intros.get(&intro.index) { Some(Intro::Online(intro)) => match songbird::ytdl(&intro.url).await { Ok(source) => source, Err(err) => { diff --git a/src/routes.rs b/src/routes.rs index 551b26a..e3a9baf 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -13,11 +13,12 @@ use serde_json::{json, Value}; use tracing::{error, info}; use uuid::Uuid; -use crate::settings::{ApiState, Auth, AuthUser, Intro, IntroIndex}; +use crate::settings::{ApiState, Intro, IntroIndex}; +use crate::{auth, settings::FileIntro}; #[derive(Serialize)] pub(crate) enum IntroResponse<'a> { - Intros(&'a Vec), + Intros(&'a HashMap), NoGuildFound, } @@ -52,19 +53,44 @@ pub(crate) async fn health() -> &'static str { #[derive(Debug, thiserror::Error)] pub(crate) enum Error { #[error("{0}")] - AuthError(String), + Auth(String), #[error("{0}")] GetUser(#[from] reqwest::Error), + + #[error("User doesn't exist")] + NoUserFound, + #[error("Guild doesn't exist")] + NoGuildFound, + #[error("invalid request")] + InvalidRequest, + + #[error("Invalid permissions for request")] + InvalidPermission, + #[error("{0}")] + Ytdl(#[from] std::io::Error), + + #[error("ytdl terminated unsuccessfully")] + YtdlTerminated, } impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { - let body = match self { - Self::AuthError(msg) => msg, - Self::GetUser(error) => error.to_string(), - }; + match self { + Self::Auth(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response(), + Self::GetUser(error) => (StatusCode::UNAUTHORIZED, error.to_string()).into_response(), - (StatusCode::INTERNAL_SERVER_ERROR, body).into_response() + Self::NoGuildFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(), + Self::NoUserFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(), + Self::InvalidRequest => (StatusCode::BAD_REQUEST, self.to_string()).into_response(), + + Self::InvalidPermission => (StatusCode::UNAUTHORIZED, self.to_string()).into_response(), + Self::Ytdl(error) => { + (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response() + } + Self::YtdlTerminated => { + (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() + } + } } } @@ -78,7 +104,7 @@ pub(crate) async fn auth( Query(params): Query>, ) -> Result, Error> { let Some(code) = params.get("code") else { - return Err(Error::AuthError("no code".to_string())); + return Err(Error::Auth("no code".to_string())); }; info!("attempting to get access token with code {}", code); @@ -92,15 +118,15 @@ pub(crate) async fn auth( let client = reqwest::Client::new(); - let auth: Auth = client + let auth: auth::Discord = client .post("https://discord.com/api/oauth2/token") .form(&data) .send() .await - .map_err(|err| Error::AuthError(err.to_string()))? + .map_err(|err| Error::Auth(err.to_string()))? .json() .await - .map_err(|err| Error::AuthError(err.to_string()))?; + .map_err(|err| Error::Auth(err.to_string()))?; let token = Uuid::new_v4().to_string(); // Get authorized username @@ -115,8 +141,10 @@ pub(crate) async fn auth( let mut settings = state.settings.lock().await; settings.auth_users.insert( token.clone(), - AuthUser { + auth::User { auth, + // TODO: replace with roles + permissions: auth::Permissions::default(), name: user.username.clone(), }, ); @@ -127,7 +155,7 @@ pub(crate) async fn auth( pub(crate) async fn add_intro_to_user( State(state): State>, headers: HeaderMap, - Path((guild, channel, intro_index)): Path<(u64, String, usize)>, + Path((guild, channel, intro_index)): Path<(u64, String, String)>, ) { let mut settings = state.settings.lock().await; let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { return; }; @@ -140,20 +168,22 @@ pub(crate) async fn add_intro_to_user( let Some(channel) = guild.channels.get_mut(&channel) else { return; }; let Some(user) = channel.users.get_mut(&user) else { return; }; - user.intros.push(IntroIndex { - index: intro_index, - volume: 20, - }); + if !user.intros.iter().any(|intro| intro.index == intro_index) { + user.intros.push(IntroIndex { + index: intro_index, + volume: 20, + }); - if let Err(err) = settings.save() { - error!("Failed to save config: {err:?}"); + if let Err(err) = settings.save() { + error!("Failed to save config: {err:?}"); + } } } pub(crate) async fn remove_intro_to_user( State(state): State>, headers: HeaderMap, - Path((guild, channel, intro_index)): Path<(u64, String, usize)>, + Path((guild, channel, intro_index)): Path<(u64, String, String)>, ) { let mut settings = state.settings.lock().await; let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { return; }; @@ -229,3 +259,50 @@ pub(crate) async fn me(State(state): State>, headers: HeaderMap) - Json(json!(MeResponse::Me(me))) } } + +pub(crate) async fn add_guild_intro( + State(state): State>, + Path((guild, url)): Path<(u64, String)>, + Query(mut params): Query>, + headers: HeaderMap, +) -> Result<(), Error> { + let mut settings = state.settings.lock().await; + // TODO: make this an impl on HeaderMap + let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { return Err(Error::NoUserFound); }; + let Some(friendly_name) = params.remove("name") else { return Err(Error::InvalidRequest); }; + let user = match settings.auth_users.get(token) { + Some(user) => user, + None => return Err(Error::NoUserFound), + }; + + if !user.permissions.can(auth::Permission::DownloadSounds) { + return Err(Error::InvalidPermission); + } + + let Some(guild) = settings.guilds.get_mut(&guild) else { return Err(Error::NoGuildFound); }; + + let uuid = Uuid::new_v4().to_string(); + let child = tokio::process::Command::new("yt-dlp") + .arg(&url) + .args(["-o", &format!("./sounds/{uuid}")]) + .args(["-x", "--audio-format", "mp3"]) + .spawn() + .map_err(Error::Ytdl)? + .wait() + .await + .map_err(Error::Ytdl)?; + + if !child.success() { + return Err(Error::YtdlTerminated); + } + + guild.intros.insert( + uuid.clone(), + Intro::File(FileIntro { + filename: format!("{uuid}.mp3"), + friendly_name, + }), + ); + + Ok(()) +} diff --git a/src/settings.rs b/src/settings.rs index c598041..9ef21b3 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,33 +1,14 @@ use std::{collections::HashMap, sync::Arc}; +use crate::auth; use serde::{Deserialize, Serialize}; use serenity::prelude::TypeMapKey; use tracing::trace; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct Auth { - pub(crate) access_token: String, - pub(crate) token_type: String, - pub(crate) expires_in: usize, - pub(crate) refresh_token: String, - pub(crate) scope: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct AuthUser { - pub auth: Auth, - pub name: String, -} +use uuid::Uuid; pub(crate) struct ApiState { pub settings: Arc>, - pub secrets: DiscordSecret, -} - -#[derive(Clone)] -pub(crate) struct DiscordSecret { - pub client_id: String, - pub client_secret: String, + pub secrets: auth::DiscordSecret, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -40,7 +21,7 @@ pub(crate) struct Settings { pub(crate) guilds: HashMap, #[serde(default)] - pub(crate) auth_users: HashMap, + pub(crate) auth_users: HashMap, } impl TypeMapKey for Settings { type Value = Arc; @@ -73,7 +54,7 @@ pub(crate) struct GuildSettings { #[serde(alias = "userEnteredSoundDelay")] pub(crate) sound_delay: u64, pub(crate) channels: HashMap, - pub(crate) intros: Vec, + pub(crate) intros: HashMap, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -106,7 +87,7 @@ pub(crate) struct ChannelSettings { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub(crate) struct IntroIndex { - pub(crate) index: usize, + pub(crate) index: String, pub(crate) volume: i32, }