diff --git a/Cargo.lock b/Cargo.lock index abd6368..ab9526b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,6 +86,56 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6137c6234afb339e75e764c866e3594900f0211e1315d33779f269bbe2ec6967" +dependencies = [ + "async-trait", + "axum-core", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "base64" version = "0.13.1" @@ -466,9 +516,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" dependencies = [ "bytes", "fnv", @@ -486,6 +536,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -631,6 +687,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" + [[package]] name = "memchr" version = "2.5.0" @@ -641,11 +703,14 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" name = "memejoin-rs" version = "0.1.1-alpha" dependencies = [ + "axum", + "futures", "serde", "serde_json", "serenity", "songbird", "tokio", + "tower-http", "tracing", "tracing-subscriber", ] @@ -1118,6 +1183,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341" +dependencies = [ + "serde", +] + [[package]] name = "serde_repr" version = "0.1.10" @@ -1327,6 +1401,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "thiserror" version = "1.0.38" @@ -1453,6 +1533,47 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + [[package]] name = "tower-service" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index d383237..32e6ff3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,10 +6,13 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +axum = "0.6.9" +futures = "0.3.26" serde = "1.0.152" serde_json = "1.0.93" serenity = { version = "0.11.5", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "cache", "voice"] } songbird = { version = "0.3.0", features = [ "builtin-queue" ] } tokio = { version = "1.25.0", features = ["rt-multi-thread", "macros", "signal"] } +tower-http = { version = "0.4.0", features = ["cors"] } tracing = "0.1.37" tracing-subscriber = "0.3.16" diff --git a/src/main.rs b/src/main.rs index c59db1d..086781e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,11 +2,20 @@ #![feature(proc_macro_hygiene)] #![feature(async_closure)] +mod routes; +pub mod settings; + +use axum::http::{HeaderValue, Method}; +use axum::routing::get; +use axum::Router; +use futures::StreamExt; use songbird::tracks::TrackQueue; use std::collections::HashMap; use std::env; +use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::mpsc; +use tower_http::cors::CorsLayer; use serde::Deserialize; use serenity::async_trait; @@ -17,6 +26,8 @@ use serenity::prelude::*; use songbird::SerenityInit; use tracing::*; +use crate::settings::{Intro, Settings}; + enum HandlerMessage { Ready(Context), PlaySound(Context, Member, ChannelId), @@ -52,46 +63,6 @@ impl songbird::EventHandler for TrackEventHandler { } } -#[derive(Debug, Clone, Deserialize)] -struct Settings { - guilds: HashMap, -} -impl TypeMapKey for Settings { - type Value = Arc; -} - -#[derive(Debug, Clone, Deserialize)] -struct GuildSettings { - #[serde(alias = "userEnteredSoundDelay")] - _sound_delay: u64, - channels: HashMap, -} - -#[derive(Debug, Clone, Deserialize)] -struct ChannelSettings { - #[serde(alias = "enterUsers")] - users: HashMap, -} - -#[derive(Debug, Clone, Deserialize)] -struct UserSettings { - #[serde(rename = "type")] - ty: SoundType, - - #[serde(alias = "enterSound")] - sound: String, - #[serde(alias = "youtubeVolume")] - _volume: i32, -} - -#[derive(Debug, Clone, Deserialize)] -enum SoundType { - #[serde(alias = "file")] - File, - #[serde(alias = "youtube")] - Youtube, -} - #[async_trait] impl EventHandler for Handler { async fn ready(&self, ctx: Context, ready: Ready) { @@ -143,20 +114,31 @@ impl EventHandler for Handler { } } -#[tokio::main] -#[instrument] -async fn main() { - tracing_subscriber::fmt::init(); +fn spawn_api(settings: Arc>) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let api = Router::new() + .route("/health", get(routes::health)) + .route("/me/:user", get(routes::me)) + .route("/intros/:guild", get(routes::intros)) + .layer( + CorsLayer::new() + .allow_origin("*".parse::().unwrap()) + .allow_methods([Method::GET]), + ) + .with_state(settings); + let addr = SocketAddr::from(([0, 0, 0, 0], 7756)); + info!("socket listening on {addr}"); + axum::Server::bind(&addr) + .serve(api.into_make_service()) + .await + .unwrap(); + }) +} + +async fn spawn_bot(settings: Arc>) -> Vec> { + let mut tasks = vec![]; let token = env::var("DISCORD_TOKEN").expect("expected DISCORD_TOKEN env var"); - - let settings = serde_json::from_str::( - &std::fs::read_to_string("config/settings.json").expect("no config/settings.json"), - ) - .expect("error parsing settings file"); - - info!("{settings:?}"); - let songbird = songbird::Songbird::serenity(); let (tx, mut rx) = mpsc::channel(10); @@ -174,17 +156,18 @@ async fn main() { .expect("Error creating client"); info!("Starting bot with token '{token}'"); - tokio::spawn(async move { + tasks.push(tokio::spawn(async move { if let Err(err) = client.start().await { error!("An error occurred while running the client: {err:?}"); } - }); + })); - tokio::spawn(async move { + tasks.push(tokio::spawn(async move { while let Some(msg) = rx.recv().await { match msg { HandlerMessage::Ready(ctx) => { info!("Got Ready message"); + let settings = settings.lock().await; let songbird = songbird::get(&ctx).await.expect("no songbird instance"); @@ -199,7 +182,7 @@ async fn main() { tx: tx.clone(), guild_id: GuildId(*guild_id), }, - ); + ); } } HandlerMessage::TrackEnded(guild_id) => { @@ -220,61 +203,102 @@ async fn main() { HandlerMessage::PlaySound(ctx, member, channel_id) => { info!("Got PlaySound message"); + let settings = settings.lock().await; let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx.cache) else { error!("Failed to get cached channel from member!"); continue; }; - let Some(user) = settings.guilds.get(channel.guild_id.as_u64()) - .and_then(|guild| guild.channels.get(channel.name())) - .and_then(|c| c.users.get(&member.user.name)) - else { - info!("No sound associated for {} in channel {}", member.user.name, channel.name()); - continue; - }; + let Some(guild_settings) = settings.guilds.get(channel.guild_id.as_u64()) else { continue; }; + let Some(channel_settings) = guild_settings.channels.get(channel.name()) else { continue; }; + let Some(user) = channel_settings.users.get(&member.user.name) else { continue; }; - let source = match user.ty { - SoundType::Youtube => match songbird::ytdl(&user.sound).await { - Ok(source) => source, - Err(err) => { - error!( - "Error starting youtube source from {}: {err:?}", - user.sound - ); - continue; - } - }, - SoundType::File => { - match songbird::ffmpeg(format!("sounds/{}", &user.sound)).await { + // TODO: randomly choose a intro to play + let Some(intro) = user.intros.first() else { continue; }; + + let source = match guild_settings.intros.get(intro.0) { + Some(Intro::Online(intro)) => match songbird::ytdl(&intro.url).await { Ok(source) => source, Err(err) => { error!( - "Error starting file source from {}: {err:?}", - user.sound - ); + "Error starting youtube source from {}: {err:?}", + intro.url + ); continue; } + }, + Some(Intro::File(intro)) => { + match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await { + Ok(source) => source, + Err(err) => { + error!( + "Error starting file source from {}: {err:?}", + intro.filename + ); + continue; + } + } + }, + None => { + error!( + "Failed to find intro for user {} on guild {} in channel {}, IntroIndex: {}", + member.user.name, + channel.guild_id.as_u64(), + channel.name(), + intro.0 + ); + continue; + } + }; + + match songbird.join(member.guild_id, channel_id).await { + (handler_lock, Ok(())) => { + let mut handler = handler_lock.lock().await; + + let _track_handler = handler.enqueue_source(source); + // TODO: set volume + } + + (_, Err(err)) => { + error!("Failed to join voice channel {}: {err:?}", channel.name()); } } - }; - - match songbird.join(member.guild_id, channel_id).await { - (handler_lock, Ok(())) => { - let mut handler = handler_lock.lock().await; - - let _track_handler = handler.enqueue_source(source); - // TODO: set volume - } - - (_, Err(err)) => { - error!("Failed to join voice channel {}: {err:?}", channel.name()); - } - } } } } - }); + })); + + tasks +} + +#[tokio::main] +#[instrument] +async fn main() { + tracing_subscriber::fmt::init(); + + let settings = serde_json::from_str::( + &std::fs::read_to_string("config/settings.json").expect("no config/settings.json"), + ) + .expect("error parsing settings file"); + + let (run_api, run_bot) = (settings.run_api, settings.run_bot); + + info!("{settings:?}"); + + let mut tasks = vec![]; + + let settings = Arc::new(Mutex::new(settings)); + if run_api { + tasks.push(spawn_api(settings.clone())); + } + if run_bot { + tasks.append(&mut spawn_bot(settings.clone()).await); + } + + let tasks = futures::stream::iter(tasks); + let mut buffered = tasks.buffer_unordered(5); + while buffered.next().await.is_some() {} let _ = tokio::signal::ctrl_c().await; info!("Received Ctrl-C, shuttdown down."); diff --git a/src/routes.rs b/src/routes.rs new file mode 100644 index 0000000..b9dea35 --- /dev/null +++ b/src/routes.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use axum::{ + extract::{Path, State}, + Json, +}; +use serde::Serialize; +use serde_json::{json, Value}; +use tokio::sync::Mutex; + +use crate::settings::{Intro, Settings, UserSettings}; + +#[derive(Serialize)] +pub(crate) enum MeResponse<'a> { + Settings(Vec<&'a UserSettings>), + NoUserFound, +} + +#[derive(Serialize)] +pub(crate) enum IntroResponse<'a> { + Intros(&'a Vec), + NoGuildFound, +} + +pub(crate) async fn health(State(state): State>>) -> Json { + let settings = state.lock().await; + + Json(json!(*settings)) +} + +pub(crate) async fn intros( + State(state): State>>, + Path(guild): Path, +) -> Json { + let settings = state.lock().await; + let Some(guild) = settings.guilds.get(&guild) else { return Json(json!(IntroResponse::NoGuildFound)); }; + + Json(json!(IntroResponse::Intros(&guild.intros))) +} + +pub(crate) async fn me( + State(state): State>>, + Path(user): Path, +) -> Json { + let settings = state.lock().await; + + let user_settings = settings + .guilds + .values() + .flat_map(|guild| guild.channels.values().flat_map(|channel| &channel.users)) + .filter(|(name, _)| **name == user) + .map(|(_, settings)| settings) + .collect::>(); + + if user_settings.is_empty() { + Json(json!(MeResponse::NoUserFound)) + } else { + Json(json!(MeResponse::Settings(user_settings))) + } +} diff --git a/src/settings.rs b/src/settings.rs new file mode 100644 index 0000000..862941e --- /dev/null +++ b/src/settings.rs @@ -0,0 +1,63 @@ +use std::{collections::HashMap, sync::Arc}; + +use serde::{Deserialize, Serialize}; +use serenity::prelude::TypeMapKey; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct Settings { + #[serde(default)] + pub(crate) run_api: bool, + #[serde(default)] + pub(crate) run_bot: bool, + pub(crate) guilds: HashMap, +} +impl TypeMapKey for Settings { + type Value = Arc; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct GuildSettings { + #[serde(alias = "userEnteredSoundDelay")] + pub(crate) sound_delay: u64, + pub(crate) channels: HashMap, + pub(crate) intros: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) enum Intro { + File(FileIntro), + Online(OnlineIntro), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct FileIntro { + pub(crate) filename: String, + pub(crate) friendly_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct OnlineIntro { + pub(crate) url: String, + pub(crate) friendly_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ChannelSettings { + #[serde(alias = "enterUsers")] + pub(crate) users: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct IntroIndex(pub usize, pub i32); + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct UserSettings { + pub(crate) intros: Vec, +}