memejoin-rs/src/routes.rs

610 lines
17 KiB
Rust

use std::collections::HashMap;
use axum::{
extract::{Multipart, Path, Query, State},
http::{HeaderMap, HeaderValue},
response::{Html, IntoResponse, Redirect},
};
use axum_extra::extract::{cookie::Cookie, CookieJar};
use chrono::{Duration, Utc};
use reqwest::{StatusCode, Url};
use serde::{Deserialize, Deserializer};
use std::str::FromStr;
use tracing::{error, info};
use uuid::Uuid;
use crate::{
auth::{self},
db,
htmx::Build,
page,
};
use crate::{media, settings::ApiState};
pub(crate) async fn health() -> &'static str {
"Hello!"
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Error {
#[error("{0}")]
Auth(String),
#[error("{0}")]
GetUser(#[from] reqwest::Error),
#[error("Guild doesn't exist")]
NoGuildFound,
#[error("invalid request")]
InvalidRequest,
#[error("Invalid permissions for request")]
InvalidPermission,
#[error("{0}")]
Ytdl(#[from] std::io::Error),
#[error("{0}")]
Ffmpeg(String),
#[error("ytdl terminated unsuccessfully")]
YtdlTerminated,
#[error("ffmpeg terminated unsuccessfully")]
FfmpegTerminated,
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
error!("{self}");
match self {
Self::Auth(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response(),
Self::GetUser(error) => (StatusCode::UNAUTHORIZED, error.to_string()).into_response(),
Self::NoGuildFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(),
Self::InvalidRequest => (StatusCode::BAD_REQUEST, self.to_string()).into_response(),
Self::InvalidPermission => (StatusCode::UNAUTHORIZED, self.to_string()).into_response(),
Self::Ytdl(error) => {
(StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response()
}
Self::Ffmpeg(error) => (StatusCode::INTERNAL_SERVER_ERROR, error).into_response(),
Self::YtdlTerminated | Self::FfmpegTerminated => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
Self::Database(error) => {
(StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response()
}
}
}
}
#[derive(Deserialize)]
struct DiscordUser {
pub username: String,
}
#[derive(Deserialize)]
pub(crate) struct DiscordUserGuild {
#[serde(deserialize_with = "serde_string_as_u64")]
pub id: u64,
pub name: String,
pub owner: bool,
}
#[derive(Deserialize)]
pub(crate) struct DiscordChannel {
pub name: Option<String>,
#[serde(rename = "type")]
pub ty: u32,
}
#[derive(Deserialize, PartialEq, Eq)]
#[repr(u32)]
pub(crate) enum ChannelType {
GuildText = 0,
GuildVoice = 2,
}
fn serde_string_as_u64<'de, D>(deserializer: D) -> Result<u64, D::Error>
where
D: Deserializer<'de>,
{
let value = <&str as Deserialize>::deserialize(deserializer)?;
value
.parse::<u64>()
.map_err(|_| serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &"u64"))
}
pub(crate) async fn v2_auth(
State(state): State<ApiState>,
Query(params): Query<HashMap<String, String>>,
jar: CookieJar,
) -> Result<(CookieJar, Redirect), Error> {
let Some(code) = params.get("code") else {
return Err(Error::Auth("no code".to_string()));
};
info!("attempting to get access token with code {}", code);
let mut data = HashMap::new();
let redirect_uri = format!("{}/v2/auth", state.origin);
data.insert("client_id", state.secrets.client_id.as_str());
data.insert("client_secret", state.secrets.client_secret.as_str());
data.insert("grant_type", "authorization_code");
data.insert("code", code);
data.insert("redirect_uri", &redirect_uri);
let client = reqwest::Client::new();
let auth: auth::Discord = client
.post("https://discord.com/api/oauth2/token")
.form(&data)
.send()
.await
.map_err(|err| Error::Auth(err.to_string()))?
.json()
.await
.map_err(|err| {
error!(?err, "auth error");
Error::Auth(err.to_string())
})?;
// Get authorized username
let user: DiscordUser = client
.get("https://discord.com/api/v10/users/@me")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await?;
// TODO: get bot's guilds so we only save users who are able to use the bot
let discord_guilds: Vec<DiscordUserGuild> = client
.get("https://discord.com/api/v10/users/@me/guilds")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await
.map_err(|err| Error::Auth(err.to_string()))?;
let db = state.db.lock().await;
let needs_setup = db.get_user_count().map_err(Error::Database)? == 0;
let token = if let Some(user) = db
.get_user(&user.username)
.map_err(Error::Database)?
.filter(|user| user.api_key_expires_at >= Utc::now().naive_utc())
{
user.api_key
} else {
Uuid::new_v4().to_string()
};
if needs_setup {
let now = Utc::now().naive_utc();
db.insert_user(
&user.username,
&token,
now + Duration::weeks(4),
&auth.access_token,
now + Duration::seconds(auth.expires_in as i64),
)
.map_err(Error::Database)?;
db.insert_user_app_permission(
&user.username,
auth::AppPermissions(auth::AppPermission::all()),
)
.map_err(Error::Database)?;
}
let guilds = db.get_guilds().map_err(Error::Database)?;
let mut in_a_guild = false;
for guild in guilds {
let Some(discord_guild) = discord_guilds
.iter()
.find(|discord_guild| discord_guild.id == guild.id)
else {
continue;
};
in_a_guild = true;
if !needs_setup {
let now = Utc::now().naive_utc();
db.insert_user(
&user.username,
&token,
now + Duration::weeks(4),
&auth.access_token,
now + Duration::seconds(auth.expires_in as i64),
)
.map_err(Error::Database)?;
}
db.insert_user_guild(&user.username, guild.id)
.map_err(Error::Database)?;
if db.get_user_permissions(&user.username, guild.id).is_err() {
db.insert_user_permission(
&user.username,
guild.id,
if discord_guild.owner {
auth::Permissions(auth::Permission::all())
} else {
Default::default()
},
)
.map_err(Error::Database)?;
}
}
if !in_a_guild {
return Err(Error::NoGuildFound);
}
// TODO: add permissions based on roles
let uri = Url::parse(&state.origin).expect("should be a valid url");
let mut cookie = Cookie::new("access_token", token);
cookie.set_path(uri.path().to_string());
cookie.set_secure(true);
Ok((jar.add(cookie), Redirect::to(&format!("{}/", state.origin))))
}
pub(crate) async fn v2_add_intro_to_user(
State(state): State<ApiState>,
Path((guild_id, channel)): Path<(u64, String)>,
user: db::User,
mut form_data: Multipart,
) -> Result<Html<String>, Redirect> {
let db = state.db.lock().await;
while let Ok(Some(field)) = form_data.next_field().await {
let Some(intro_id) = field.name() else {
continue;
};
let intro_id = intro_id.parse::<i32>().map_err(|err| {
error!(?err, "invalid intro id");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
db.insert_user_intro(&user.name, guild_id, &channel, intro_id)
.map_err(|err| {
error!(?err, "failed to add user intro");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
}
let guild_intros = db.get_guild_intros(guild_id).map_err(|err| {
error!(?err, %guild_id, "couldn't get guild intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
let intros = db
.get_user_channel_intros(&user.name, guild_id, &channel)
.map_err(|err| {
error!(?err, user = %user.name, %guild_id, "couldn't get user intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
Ok(Html(
page::channel_intro_selector(
&state.origin,
guild_id,
&channel,
intros.iter(),
guild_intros.iter(),
)
.build(),
))
}
pub(crate) async fn v2_remove_intro_from_user(
State(state): State<ApiState>,
Path((guild_id, channel)): Path<(u64, String)>,
user: db::User,
mut form_data: Multipart,
) -> Result<Html<String>, Redirect> {
let db = state.db.lock().await;
while let Ok(Some(field)) = form_data.next_field().await {
let Some(intro_id) = field.name() else {
continue;
};
let intro_id = intro_id.parse::<i32>().map_err(|err| {
error!(?err, "invalid intro id");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
db.delete_user_intro(&user.name, guild_id, &channel, intro_id)
.map_err(|err| {
error!(?err, "failed to remove user intro");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
}
let guild_intros = db.get_guild_intros(guild_id).map_err(|err| {
error!(?err, %guild_id, "couldn't get guild intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
let intros = db
.get_user_channel_intros(&user.name, guild_id, &channel)
.map_err(|err| {
error!(?err, user = %user.name, %guild_id, "couldn't get user intros");
// TODO: change to actual error
Redirect::to(&format!("{}/login", state.origin))
})?;
Ok(Html(
page::channel_intro_selector(
&state.origin,
guild_id,
&channel,
intros.iter(),
guild_intros.iter(),
)
.build(),
))
}
pub(crate) async fn v2_upload_guild_intro(
State(state): State<ApiState>,
Path(guild_id): Path<u64>,
user: db::User,
mut form_data: Multipart,
) -> Result<HeaderMap, Error> {
let db = state.db.lock().await;
let mut name = None;
let mut file = None;
if !db
.get_guilds()
.map_err(Error::Database)?
.into_iter()
.any(|guild| guild.id == guild_id)
{
return Err(Error::NoGuildFound);
}
let user_permissions = db
.get_user_permissions(&user.name, guild_id)
.map_err(Error::Database)?;
if !user_permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
while let Ok(Some(field)) = form_data.next_field().await {
let Some(field_name) = field.name() else {
continue;
};
if field_name.eq_ignore_ascii_case("name") {
name = Some(field.text().await.map_err(|_| Error::InvalidRequest)?);
continue;
}
if field_name.eq_ignore_ascii_case("file") {
file = Some(field.bytes().await.map_err(|_| Error::InvalidRequest)?);
continue;
}
}
let Some(name) = name else {
return Err(Error::InvalidRequest);
};
let Some(file) = file else {
return Err(Error::InvalidRequest);
};
let uuid = Uuid::new_v4().to_string();
let temp_path = format!("./sounds/temp/{uuid}");
let dest_path = format!("./sounds/{uuid}.mp3");
// Write original file so its ready for codec conversion
std::fs::write(&temp_path, file)?;
media::normalize(&temp_path, &dest_path).await?;
std::fs::remove_file(&temp_path)?;
db.insert_intro(&name, 0, guild_id, &format!("{uuid}.mp3"))
.map_err(Error::Database)?;
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
Ok(headers)
}
pub(crate) async fn v2_add_guild_intro(
State(state): State<ApiState>,
Path(guild_id): Path<u64>,
Query(mut params): Query<HashMap<String, String>>,
user: db::User,
) -> Result<HeaderMap, Error> {
let db = state.db.lock().await;
let Some(url) = params.remove("url") else {
return Err(Error::InvalidRequest);
};
let Some(name) = params.remove("name") else {
return Err(Error::InvalidRequest);
};
if !db
.get_guilds()
.map_err(Error::Database)?
.into_iter()
.any(|guild| guild.id == guild_id)
{
return Err(Error::NoGuildFound);
}
let user_permissions = db
.get_user_permissions(&user.name, guild_id)
.map_err(Error::Database)?;
if !user_permissions.can(auth::Permission::UploadSounds) {
return Err(Error::InvalidPermission);
}
let uuid = Uuid::new_v4().to_string();
let child = tokio::process::Command::new("yt-dlp")
.arg(&url)
.args(["-o", &format!("sounds/{uuid}")])
.args(["-x", "--audio-format", "mp3"])
.spawn()
.map_err(Error::Ytdl)?
.wait()
.await
.map_err(Error::Ytdl)?;
if !child.success() {
return Err(Error::YtdlTerminated);
}
db.insert_intro(&name, 0, guild_id, &format!("{uuid}.mp3"))
.map_err(Error::Database)?;
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
Ok(headers)
}
#[derive(Debug, Deserialize)]
pub(crate) struct GuildSetupParams {
name: String,
}
pub(crate) async fn guild_setup(
State(state): State<ApiState>,
user: db::User,
Path(guild_id): Path<u64>,
Query(GuildSetupParams { name }): Query<GuildSetupParams>,
) -> Result<Redirect, Error> {
let db = state.db.lock().await;
let user_permissions = db.get_user_app_permissions(&user.name).unwrap_or_default();
if !user_permissions.can(auth::AppPermission::AddGuild) {
return Err(Error::InvalidPermission);
}
db.insert_guild(&guild_id, &name, 0)?;
db.insert_user_guild(&user.name, guild_id)?;
db.insert_user_permission(
&user.name,
guild_id,
auth::Permissions(auth::Permission::all()),
)?;
Ok(Redirect::to(&format!(
"{}/guild/{}",
state.origin, guild_id
)))
}
pub(crate) async fn guild_add_channel(
State(state): State<ApiState>,
user: db::User,
Path(guild_id): Path<u64>,
mut form_data: Multipart,
) -> Result<HeaderMap, Error> {
let db = state.db.lock().await;
let user_permissions = db
.get_user_permissions(&user.name, guild_id)
.unwrap_or_default();
if !user_permissions.can(auth::Permission::AddChannel) {
return Err(Error::InvalidPermission);
}
while let Ok(Some(field)) = form_data.next_field().await {
let Some(channel_name) = field.name() else {
continue;
};
db.insert_guild_channel(&guild_id, channel_name)?;
}
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
Ok(headers)
}
pub(crate) async fn update_guild_permissions(
State(state): State<ApiState>,
Path(guild_id): Path<u64>,
user: db::User,
mut form_data: Multipart,
) -> Result<HeaderMap, Error> {
let db = state.db.lock().await;
let this_user_permissions = db
.get_user_permissions(&user.name, guild_id)
.unwrap_or_default();
if !this_user_permissions.can(auth::Permission::Moderator) {
return Err(Error::InvalidPermission);
}
let mut users_to_update: HashMap<String, auth::Permissions> = db
.get_guild_users(guild_id)?
.into_iter()
.map(|user| (user, Default::default()))
.collect();
while let Ok(Some(field)) = form_data.next_field().await {
let Some(field_name) = field.name() else {
continue;
};
if let Some((username, permission)) = field_name.split_once('#') {
let permission = auth::Permission::from_str(permission)?;
let username = username.to_string();
if field.text().await.map_err(|_| Error::InvalidRequest)? == "on" {
users_to_update
.entry(username)
.and_modify(|value| {
value.add(permission);
})
.or_insert_with(|| {
let mut perm = auth::Permissions::default();
perm.add(permission);
perm
});
}
}
}
for (user, permissions) in users_to_update {
let user_permissions = db.get_user_permissions(&user, guild_id).unwrap_or_default();
if !user_permissions.can(auth::Permission::Moderator) {
db.insert_user_permission(&user, guild_id, permissions)?;
}
}
let mut headers = HeaderMap::new();
headers.insert("HX-Refresh", HeaderValue::from_static("true"));
Ok(headers)
}