refactor settings to allow for a pool of intros, setup api

pull/5/head
Patrick Cleavelin 2023-02-28 02:28:41 -06:00
parent 541d0fd70e
commit 7f7a6472be
5 changed files with 367 additions and 96 deletions

125
Cargo.lock generated
View File

@ -86,6 +86,56 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6137c6234afb339e75e764c866e3594900f0211e1315d33779f269bbe2ec6967"
dependencies = [
"async-trait",
"axum-core",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]]
name = "base64"
version = "0.13.1"
@ -466,9 +516,9 @@ dependencies = [
[[package]]
name = "http"
version = "0.2.8"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399"
checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482"
dependencies = [
"bytes",
"fnv",
@ -486,6 +536,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "httparse"
version = "1.8.0"
@ -631,6 +687,12 @@ dependencies = [
"regex-automata",
]
[[package]]
name = "matchit"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]]
name = "memchr"
version = "2.5.0"
@ -641,11 +703,14 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
name = "memejoin-rs"
version = "0.1.1-alpha"
dependencies = [
"axum",
"futures",
"serde",
"serde_json",
"serenity",
"songbird",
"tokio",
"tower-http",
"tracing",
"tracing-subscriber",
]
@ -1118,6 +1183,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
dependencies = [
"serde",
]
[[package]]
name = "serde_repr"
version = "0.1.10"
@ -1327,6 +1401,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "thiserror"
version = "1.0.38"
@ -1453,6 +1533,47 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]]
name = "tower-service"
version = "0.3.2"

View File

@ -6,10 +6,13 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = "0.6.9"
futures = "0.3.26"
serde = "1.0.152"
serde_json = "1.0.93"
serenity = { version = "0.11.5", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "cache", "voice"] }
songbird = { version = "0.3.0", features = [ "builtin-queue" ] }
tokio = { version = "1.25.0", features = ["rt-multi-thread", "macros", "signal"] }
tower-http = { version = "0.4.0", features = ["cors"] }
tracing = "0.1.37"
tracing-subscriber = "0.3.16"

View File

@ -2,11 +2,20 @@
#![feature(proc_macro_hygiene)]
#![feature(async_closure)]
mod routes;
pub mod settings;
use axum::http::{HeaderValue, Method};
use axum::routing::get;
use axum::Router;
use futures::StreamExt;
use songbird::tracks::TrackQueue;
use std::collections::HashMap;
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use tower_http::cors::CorsLayer;
use serde::Deserialize;
use serenity::async_trait;
@ -17,6 +26,8 @@ use serenity::prelude::*;
use songbird::SerenityInit;
use tracing::*;
use crate::settings::{Intro, Settings};
enum HandlerMessage {
Ready(Context),
PlaySound(Context, Member, ChannelId),
@ -52,46 +63,6 @@ impl songbird::EventHandler for TrackEventHandler {
}
}
#[derive(Debug, Clone, Deserialize)]
struct Settings {
guilds: HashMap<u64, GuildSettings>,
}
impl TypeMapKey for Settings {
type Value = Arc<Settings>;
}
#[derive(Debug, Clone, Deserialize)]
struct GuildSettings {
#[serde(alias = "userEnteredSoundDelay")]
_sound_delay: u64,
channels: HashMap<String, ChannelSettings>,
}
#[derive(Debug, Clone, Deserialize)]
struct ChannelSettings {
#[serde(alias = "enterUsers")]
users: HashMap<String, UserSettings>,
}
#[derive(Debug, Clone, Deserialize)]
struct UserSettings {
#[serde(rename = "type")]
ty: SoundType,
#[serde(alias = "enterSound")]
sound: String,
#[serde(alias = "youtubeVolume")]
_volume: i32,
}
#[derive(Debug, Clone, Deserialize)]
enum SoundType {
#[serde(alias = "file")]
File,
#[serde(alias = "youtube")]
Youtube,
}
#[async_trait]
impl EventHandler for Handler {
async fn ready(&self, ctx: Context, ready: Ready) {
@ -143,20 +114,31 @@ impl EventHandler for Handler {
}
}
#[tokio::main]
#[instrument]
async fn main() {
tracing_subscriber::fmt::init();
fn spawn_api(settings: Arc<Mutex<Settings>>) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let api = Router::new()
.route("/health", get(routes::health))
.route("/me/:user", get(routes::me))
.route("/intros/:guild", get(routes::intros))
.layer(
CorsLayer::new()
.allow_origin("*".parse::<HeaderValue>().unwrap())
.allow_methods([Method::GET]),
)
.with_state(settings);
let addr = SocketAddr::from(([0, 0, 0, 0], 7756));
info!("socket listening on {addr}");
axum::Server::bind(&addr)
.serve(api.into_make_service())
.await
.unwrap();
})
}
async fn spawn_bot(settings: Arc<Mutex<Settings>>) -> Vec<tokio::task::JoinHandle<()>> {
let mut tasks = vec![];
let token = env::var("DISCORD_TOKEN").expect("expected DISCORD_TOKEN env var");
let settings = serde_json::from_str::<Settings>(
&std::fs::read_to_string("config/settings.json").expect("no config/settings.json"),
)
.expect("error parsing settings file");
info!("{settings:?}");
let songbird = songbird::Songbird::serenity();
let (tx, mut rx) = mpsc::channel(10);
@ -174,17 +156,18 @@ async fn main() {
.expect("Error creating client");
info!("Starting bot with token '{token}'");
tokio::spawn(async move {
tasks.push(tokio::spawn(async move {
if let Err(err) = client.start().await {
error!("An error occurred while running the client: {err:?}");
}
});
}));
tokio::spawn(async move {
tasks.push(tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
match msg {
HandlerMessage::Ready(ctx) => {
info!("Got Ready message");
let settings = settings.lock().await;
let songbird = songbird::get(&ctx).await.expect("no songbird instance");
@ -199,7 +182,7 @@ async fn main() {
tx: tx.clone(),
guild_id: GuildId(*guild_id),
},
);
);
}
}
HandlerMessage::TrackEnded(guild_id) => {
@ -220,61 +203,102 @@ async fn main() {
HandlerMessage::PlaySound(ctx, member, channel_id) => {
info!("Got PlaySound message");
let settings = settings.lock().await;
let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx.cache) else {
error!("Failed to get cached channel from member!");
continue;
};
let Some(user) = settings.guilds.get(channel.guild_id.as_u64())
.and_then(|guild| guild.channels.get(channel.name()))
.and_then(|c| c.users.get(&member.user.name))
else {
info!("No sound associated for {} in channel {}", member.user.name, channel.name());
continue;
};
let Some(guild_settings) = settings.guilds.get(channel.guild_id.as_u64()) else { continue; };
let Some(channel_settings) = guild_settings.channels.get(channel.name()) else { continue; };
let Some(user) = channel_settings.users.get(&member.user.name) else { continue; };
let source = match user.ty {
SoundType::Youtube => match songbird::ytdl(&user.sound).await {
Ok(source) => source,
Err(err) => {
error!(
"Error starting youtube source from {}: {err:?}",
user.sound
);
continue;
}
},
SoundType::File => {
match songbird::ffmpeg(format!("sounds/{}", &user.sound)).await {
// TODO: randomly choose a intro to play
let Some(intro) = user.intros.first() else { continue; };
let source = match guild_settings.intros.get(intro.0) {
Some(Intro::Online(intro)) => match songbird::ytdl(&intro.url).await {
Ok(source) => source,
Err(err) => {
error!(
"Error starting file source from {}: {err:?}",
user.sound
);
"Error starting youtube source from {}: {err:?}",
intro.url
);
continue;
}
},
Some(Intro::File(intro)) => {
match songbird::ffmpeg(format!("sounds/{}", &intro.filename)).await {
Ok(source) => source,
Err(err) => {
error!(
"Error starting file source from {}: {err:?}",
intro.filename
);
continue;
}
}
},
None => {
error!(
"Failed to find intro for user {} on guild {} in channel {}, IntroIndex: {}",
member.user.name,
channel.guild_id.as_u64(),
channel.name(),
intro.0
);
continue;
}
};
match songbird.join(member.guild_id, channel_id).await {
(handler_lock, Ok(())) => {
let mut handler = handler_lock.lock().await;
let _track_handler = handler.enqueue_source(source);
// TODO: set volume
}
(_, Err(err)) => {
error!("Failed to join voice channel {}: {err:?}", channel.name());
}
}
};
match songbird.join(member.guild_id, channel_id).await {
(handler_lock, Ok(())) => {
let mut handler = handler_lock.lock().await;
let _track_handler = handler.enqueue_source(source);
// TODO: set volume
}
(_, Err(err)) => {
error!("Failed to join voice channel {}: {err:?}", channel.name());
}
}
}
}
}
});
}));
tasks
}
#[tokio::main]
#[instrument]
async fn main() {
tracing_subscriber::fmt::init();
let settings = serde_json::from_str::<Settings>(
&std::fs::read_to_string("config/settings.json").expect("no config/settings.json"),
)
.expect("error parsing settings file");
let (run_api, run_bot) = (settings.run_api, settings.run_bot);
info!("{settings:?}");
let mut tasks = vec![];
let settings = Arc::new(Mutex::new(settings));
if run_api {
tasks.push(spawn_api(settings.clone()));
}
if run_bot {
tasks.append(&mut spawn_bot(settings.clone()).await);
}
let tasks = futures::stream::iter(tasks);
let mut buffered = tasks.buffer_unordered(5);
while buffered.next().await.is_some() {}
let _ = tokio::signal::ctrl_c().await;
info!("Received Ctrl-C, shuttdown down.");

60
src/routes.rs Normal file
View File

@ -0,0 +1,60 @@
use std::sync::Arc;
use axum::{
extract::{Path, State},
Json,
};
use serde::Serialize;
use serde_json::{json, Value};
use tokio::sync::Mutex;
use crate::settings::{Intro, Settings, UserSettings};
#[derive(Serialize)]
pub(crate) enum MeResponse<'a> {
Settings(Vec<&'a UserSettings>),
NoUserFound,
}
#[derive(Serialize)]
pub(crate) enum IntroResponse<'a> {
Intros(&'a Vec<Intro>),
NoGuildFound,
}
pub(crate) async fn health(State(state): State<Arc<Mutex<Settings>>>) -> Json<Value> {
let settings = state.lock().await;
Json(json!(*settings))
}
pub(crate) async fn intros(
State(state): State<Arc<Mutex<Settings>>>,
Path(guild): Path<u64>,
) -> Json<Value> {
let settings = state.lock().await;
let Some(guild) = settings.guilds.get(&guild) else { return Json(json!(IntroResponse::NoGuildFound)); };
Json(json!(IntroResponse::Intros(&guild.intros)))
}
pub(crate) async fn me(
State(state): State<Arc<Mutex<Settings>>>,
Path(user): Path<String>,
) -> Json<Value> {
let settings = state.lock().await;
let user_settings = settings
.guilds
.values()
.flat_map(|guild| guild.channels.values().flat_map(|channel| &channel.users))
.filter(|(name, _)| **name == user)
.map(|(_, settings)| settings)
.collect::<Vec<_>>();
if user_settings.is_empty() {
Json(json!(MeResponse::NoUserFound))
} else {
Json(json!(MeResponse::Settings(user_settings)))
}
}

63
src/settings.rs Normal file
View File

@ -0,0 +1,63 @@
use std::{collections::HashMap, sync::Arc};
use serde::{Deserialize, Serialize};
use serenity::prelude::TypeMapKey;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct Settings {
#[serde(default)]
pub(crate) run_api: bool,
#[serde(default)]
pub(crate) run_bot: bool,
pub(crate) guilds: HashMap<u64, GuildSettings>,
}
impl TypeMapKey for Settings {
type Value = Arc<Settings>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GuildSettings {
#[serde(alias = "userEnteredSoundDelay")]
pub(crate) sound_delay: u64,
pub(crate) channels: HashMap<String, ChannelSettings>,
pub(crate) intros: Vec<Intro>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum Intro {
File(FileIntro),
Online(OnlineIntro),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct FileIntro {
pub(crate) filename: String,
pub(crate) friendly_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct OnlineIntro {
pub(crate) url: String,
pub(crate) friendly_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct ChannelSettings {
#[serde(alias = "enterUsers")]
pub(crate) users: HashMap<String, UserSettings>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct IntroIndex(pub usize, pub i32);
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct UserSettings {
pub(crate) intros: Vec<IntroIndex>,
}