From b984f048f6a8d71f58e8726a859f54e8c53b5875 Mon Sep 17 00:00:00 2001 From: Patrick Cleavelin Date: Wed, 1 Mar 2023 22:01:36 -0600 Subject: [PATCH] save settings on exit, use env for discord secrets --- src/main.rs | 121 ++++++++++++++++++++++++------------------------ src/routes.rs | 39 +++++++--------- src/settings.rs | 18 ++++++- 3 files changed, 94 insertions(+), 84 deletions(-) diff --git a/src/main.rs b/src/main.rs index 21d2ecc..601e885 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ use axum::http::{HeaderValue, Method}; use axum::routing::{get, post}; use axum::Router; use futures::StreamExt; +use settings::{ApiState, DiscordSecret}; use songbird::tracks::TrackQueue; use std::collections::HashMap; use std::env; @@ -114,7 +115,15 @@ impl EventHandler for Handler { } } -fn spawn_api(settings: Arc>) -> tokio::task::JoinHandle<()> { +fn spawn_api(settings: Arc>) { + let secrets = DiscordSecret { + client_id: env::var("DISCORD_CLIENT_ID").expect("expected DISCORD_CLIENT_ID env var"), + client_secret: env::var("DISCORD_CLIENT_SECRET") + .expect("expected DISCORD_CLIENT_SECRET env var"), + }; + + let state = ApiState { settings, secrets }; + tokio::spawn(async move { let api = Router::new() .route("/health", get(routes::health)) @@ -135,19 +144,17 @@ fn spawn_api(settings: Arc>) -> tokio::task::JoinHandle<()> { .allow_headers(Any) .allow_methods([Method::GET, Method::POST]), ) - .with_state(settings); + .with_state(Arc::new(state)); let addr = SocketAddr::from(([0, 0, 0, 0], 7756)); info!("socket listening on {addr}"); axum::Server::bind(&addr) .serve(api.into_make_service()) .await .unwrap(); - }) + }); } -async fn spawn_bot(settings: Arc>) -> Vec> { - let mut tasks = vec![]; - +async fn spawn_bot(settings: Arc>) { let token = env::var("DISCORD_TOKEN").expect("expected DISCORD_TOKEN env var"); let songbird = songbird::Songbird::serenity(); @@ -166,13 +173,13 @@ async fn spawn_bot(settings: Arc>) -> Vec { @@ -192,7 +199,7 @@ async fn spawn_bot(settings: Arc>) -> Vec { @@ -240,63 +247,58 @@ async fn spawn_bot(settings: Arc>) -> Vec match songbird::ytdl(&intro.url).await { + Some(Intro::Online(intro)) => match songbird::ytdl(&intro.url).await { + Ok(source) => source, + Err(err) => { + error!("Error starting youtube source from {}: {err:?}", intro.url); + continue; + } + }, + Some(Intro::File(intro)) => { + match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await { Ok(source) => source, Err(err) => { error!( - "Error starting youtube source from {}: {err:?}", - intro.url - ); + "Error starting file source from {}: {err:?}", + intro.filename + ); continue; } - }, - Some(Intro::File(intro)) => { - match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await { - Ok(source) => source, - Err(err) => { - error!( - "Error starting file source from {}: {err:?}", - intro.filename - ); - continue; - } - } - }, - None => { - error!( - "Failed to find intro for user {} on guild {} in channel {}, IntroIndex: {}", - member.user.name, - channel.guild_id.as_u64(), - channel.name(), - intro.index - ); - continue; - } - }; - - match songbird.join(member.guild_id, channel_id).await { - (handler_lock, Ok(())) => { - let mut handler = handler_lock.lock().await; - - let _track_handler = handler.enqueue_source(source); - // TODO: set volume - } - - (_, Err(err)) => { - error!("Failed to join voice channel {}: {err:?}", channel.name()); } } + None => { + error!( + "Failed to find intro for user {} on guild {} in channel {}, IntroIndex: {}", + member.user.name, + channel.guild_id.as_u64(), + channel.name(), + intro.index + ); + continue; + } + }; + + match songbird.join(member.guild_id, channel_id).await { + (handler_lock, Ok(())) => { + let mut handler = handler_lock.lock().await; + + let _track_handler = handler.enqueue_source(source); + // TODO: set volume + } + + (_, Err(err)) => { + error!("Failed to join voice channel {}: {err:?}", channel.name()); + } + } } } } - })); - - tasks + }); } #[tokio::main] #[instrument] -async fn main() { +async fn main() -> std::io::Result<()> { tracing_subscriber::fmt::init(); let settings = serde_json::from_str::( @@ -308,20 +310,19 @@ async fn main() { info!("{settings:?}"); - let mut tasks = vec![]; - let settings = Arc::new(Mutex::new(settings)); if run_api { - tasks.push(spawn_api(settings.clone())); + spawn_api(settings.clone()); } if run_bot { - tasks.append(&mut spawn_bot(settings.clone()).await); + spawn_bot(settings.clone()).await; } - let tasks = futures::stream::iter(tasks); - let mut buffered = tasks.buffer_unordered(5); - while buffered.next().await.is_some() {} + info!("spawned background tasks"); let _ = tokio::signal::ctrl_c().await; + settings.lock().await.save()?; info!("Received Ctrl-C, shuttdown down."); + + Ok(()) } diff --git a/src/routes.rs b/src/routes.rs index 4d424d2..551b26a 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,21 +1,19 @@ use std::{collections::HashMap, sync::Arc}; use axum::{ - body::StreamBody, extract::{Path, Query, State}, http::HeaderMap, response::IntoResponse, Json, }; -use futures::Stream; + use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use tokio::sync::Mutex; use tracing::{error, info}; use uuid::Uuid; -use crate::settings::{Auth, AuthUser, GuildSettings, Intro, IntroIndex, Settings, UserSettings}; +use crate::settings::{ApiState, Auth, AuthUser, Intro, IntroIndex}; #[derive(Serialize)] pub(crate) enum IntroResponse<'a> { @@ -47,10 +45,8 @@ pub(crate) struct MeChannel<'a> { pub(crate) intros: &'a Vec, } -pub(crate) async fn health(State(state): State>>) -> Json { - let settings = state.lock().await; - - Json(json!(*settings)) +pub(crate) async fn health() -> &'static str { + "Hello!" } #[derive(Debug, thiserror::Error)] @@ -78,7 +74,7 @@ struct DiscordUser { } pub(crate) async fn auth( - State(settings): State>>, + State(state): State>, Query(params): Query>, ) -> Result, Error> { let Some(code) = params.get("code") else { @@ -88,8 +84,8 @@ pub(crate) async fn auth( info!("attempting to get access token with code {}", code); let mut data = HashMap::new(); - data.insert("client_id", "577634620728934400"); - data.insert("client_secret", "CLIENT_SECRET_HERE"); + data.insert("client_id", state.secrets.client_id.as_str()); + data.insert("client_secret", state.secrets.client_secret.as_str()); data.insert("grant_type", "authorization_code"); data.insert("code", code); data.insert("redirect_uri", "http://localhost:5173/auth"); @@ -116,7 +112,7 @@ pub(crate) async fn auth( .json() .await?; - let mut settings = settings.lock().await; + let mut settings = state.settings.lock().await; settings.auth_users.insert( token.clone(), AuthUser { @@ -129,11 +125,11 @@ pub(crate) async fn auth( } pub(crate) async fn add_intro_to_user( - State(state): State>>, + State(state): State>, headers: HeaderMap, Path((guild, channel, intro_index)): Path<(u64, String, usize)>, ) { - let mut settings = state.lock().await; + let mut settings = state.settings.lock().await; let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { return; }; let user = match settings.auth_users.get(token) { Some(user) => user.name.clone(), @@ -155,11 +151,11 @@ pub(crate) async fn add_intro_to_user( } pub(crate) async fn remove_intro_to_user( - State(state): State>>, + State(state): State>, headers: HeaderMap, Path((guild, channel, intro_index)): Path<(u64, String, usize)>, ) { - let mut settings = state.lock().await; + let mut settings = state.settings.lock().await; let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { return; }; let user = match settings.auth_users.get(token) { Some(user) => user.name.clone(), @@ -184,20 +180,17 @@ pub(crate) async fn remove_intro_to_user( } pub(crate) async fn intros( - State(state): State>>, + State(state): State>, Path(guild): Path, ) -> Json { - let settings = state.lock().await; + let settings = state.settings.lock().await; let Some(guild) = settings.guilds.get(&guild) else { return Json(json!(IntroResponse::NoGuildFound)); }; Json(json!(IntroResponse::Intros(&guild.intros))) } -pub(crate) async fn me( - State(state): State>>, - headers: HeaderMap, -) -> Json { - let settings = state.lock().await; +pub(crate) async fn me(State(state): State>, headers: HeaderMap) -> Json { + let settings = state.settings.lock().await; let Some(token) = headers.get("token").and_then(|v| v.to_str().ok()) else { return Json(json!(MeResponse::NoUserFound)); }; let user = match settings.auth_users.get(token) { diff --git a/src/settings.rs b/src/settings.rs index 98c46a9..c598041 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use serde::{Deserialize, Serialize}; use serenity::prelude::TypeMapKey; +use tracing::trace; #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct Auth { @@ -18,6 +19,17 @@ pub(crate) struct AuthUser { pub name: String, } +pub(crate) struct ApiState { + pub settings: Arc>, + pub secrets: DiscordSecret, +} + +#[derive(Clone)] +pub(crate) struct DiscordSecret { + pub client_id: String, + pub client_secret: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub(crate) struct Settings { @@ -27,7 +39,7 @@ pub(crate) struct Settings { pub(crate) run_bot: bool, pub(crate) guilds: HashMap, - #[serde(skip)] + #[serde(default)] pub(crate) auth_users: HashMap, } impl TypeMapKey for Settings { @@ -36,6 +48,7 @@ impl TypeMapKey for Settings { impl Settings { pub(crate) fn save(&self) -> Result<(), std::io::Error> { + trace!("attempting to save config"); let serialized = serde_json::to_string_pretty(&self)?; std::fs::copy( @@ -45,8 +58,11 @@ impl Settings { chrono::Utc::now().naive_utc().format("%Y-%m-%d %H:%M:%S") ), )?; + trace!("created copy of original settings"); + std::fs::write("./config/settings.json", serialized)?; + trace!("saved settings to disk"); Ok(()) } }