diff --git a/Cargo.toml b/Cargo.toml index 296ce1e..886db56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "memejoin-rs" version = "0.2.2-alpha" -edition = "2021" +edition = "2024" [[bin]] name = "memejoin-rs" diff --git a/src/lib/auth.rs b/src/lib/auth.rs index 52051e6..5049a24 100644 --- a/src/lib/auth.rs +++ b/src/lib/auth.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; + #[derive(Clone)] pub struct DiscordSecret { pub client_id: String, @@ -5,6 +7,15 @@ pub struct DiscordSecret { pub bot_token: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct Discord { + pub(crate) access_token: String, + pub(crate) token_type: String, + pub(crate) expires_in: usize, + pub(crate) refresh_token: String, + pub(crate) scope: String, +} + /* use std::str::FromStr; diff --git a/src/lib/domain/intro_tool/debug_service.rs b/src/lib/domain/intro_tool/debug_service.rs index 797e627..354a026 100644 --- a/src/lib/domain/intro_tool/debug_service.rs +++ b/src/lib/domain/intro_tool/debug_service.rs @@ -5,6 +5,8 @@ use crate::domain::intro_tool::{ ports::{IntroToolRepository, IntroToolService}, }; +use super::ports::AuthService; + #[derive(Clone)] pub struct DebugService where @@ -34,6 +36,13 @@ where self.wrapped_service.needs_setup().await } + async fn authenticate_user( + &self, + params: A::Params, + ) -> Result> { + self.wrapped_service.authenticate_user(params).await + } + async fn get_guild( &self, guild_id: impl Into + Send, @@ -41,6 +50,12 @@ where self.wrapped_service.get_guild(guild_id).await } + async fn get_guilds( + &self, + ) -> Result, models::guild::GetGuildError> { + self.wrapped_service.get_guilds().await + } + async fn get_guild_users( &self, guild_id: models::guild::GuildId, @@ -88,6 +103,20 @@ where .with_channel_intros(user.intros().clone())) } + async fn set_user_intro( + &self, + req: models::guild::AddIntroToUserRequest, + ) -> Result<(), models::guild::AddIntroToUserError> { + self.wrapped_service.set_user_intro(req).await + } + + async fn refresh_user_token( + &self, + username: &str, + ) -> Result { + self.wrapped_service.refresh_user_token(username).await + } + async fn create_guild( &self, req: models::guild::CreateGuildRequest, @@ -102,6 +131,16 @@ where self.wrapped_service.create_user(req).await } + async fn add_user_to_guild( + &self, + guild_id: models::guild::GuildId, + username: &str, + ) -> Result<(), models::guild::AddUserToGuildError> { + self.wrapped_service + .add_user_to_guild(guild_id, username) + .await + } + async fn create_channel( &self, req: models::guild::CreateChannelRequest, @@ -115,11 +154,4 @@ where ) -> Result { self.wrapped_service.add_intro_to_guild(req).await } - - async fn set_user_intro( - &self, - req: models::guild::AddIntroToUserRequest, - ) -> Result<(), models::guild::AddIntroToUserError> { - self.wrapped_service.set_user_intro(req).await - } } diff --git a/src/lib/domain/intro_tool/models/guild.rs b/src/lib/domain/intro_tool/models/guild.rs index 4278dcd..04d6390 100644 --- a/src/lib/domain/intro_tool/models/guild.rs +++ b/src/lib/domain/intro_tool/models/guild.rs @@ -1,13 +1,18 @@ -use std::collections::HashMap; +use std::{borrow::Cow, collections::HashMap}; use chrono::NaiveDateTime; use thiserror::Error; +use crate::domain::intro_tool::ports::AuthService; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ApiToken(String); + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct GuildId(u64); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ExternalGuildId(u64); +pub struct ExternalGuildId(pub u64); #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct UserName(String); @@ -18,6 +23,18 @@ pub struct ChannelName(String); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct IntroId(i32); +impl From for Cow<'_, str> { + fn from(value: ApiToken) -> Self { + Cow::Owned(value.0) + } +} + +impl From for ApiToken { + fn from(value: String) -> Self { + Self(value) + } +} + impl From for GuildId { fn from(id: u64) -> Self { Self(id) @@ -108,6 +125,10 @@ impl GuildRef { pub fn name(&self) -> &str { &self.name } + + pub fn external_id(&self) -> ExternalGuildId { + self.external_id + } } impl GuildRef { @@ -198,6 +219,10 @@ impl User { &self.channel_intros } + pub fn api_key(&self) -> &str { + &self.api_key + } + pub fn api_key_expires_at(&self) -> NaiveDateTime { self.api_key_expires_at } @@ -255,18 +280,18 @@ impl Intro { } pub struct CreateGuildRequest { - name: String, - sound_delay: u32, - external_id: ExternalGuildId, + pub name: String, + pub sound_delay: u32, + pub external_id: ExternalGuildId, } pub struct CreateUserRequest { - user: UserName, + pub user: UserName, } pub struct CreateChannelRequest { - guild_id: GuildId, - channel_name: ChannelName, + pub guild_id: GuildId, + pub channel_name: ChannelName, } pub struct AddIntroToGuildRequest { @@ -297,6 +322,9 @@ pub enum CreateGuildError { #[derive(Debug, Error)] pub enum CreateUserError { + #[error("Could not get user")] + CouldNotGetUser(#[from] GetUserError), + #[error(transparent)] Unknown(#[from] anyhow::Error), } @@ -307,6 +335,12 @@ pub enum CreateChannelError { Unknown(#[from] anyhow::Error), } +#[derive(Debug, Error)] +pub enum AddUserToGuildError { + #[error(transparent)] + Unknown(#[from] anyhow::Error), +} + #[derive(Debug, Error)] pub enum AddIntroToGuildError { #[error(transparent)] @@ -366,3 +400,27 @@ pub enum GetIntroError { #[error(transparent)] Unknown(#[from] anyhow::Error), } + +#[derive(Debug, Error)] +pub enum AutheticateUserError { + #[error("Could not fetch guild")] + CouldNotFetchGuild(#[from] GetGuildError), + + #[error("Could not create user")] + CouldNotCreateUser(#[from] CreateUserError), + + #[error("Could not fetch guild user")] + CouldNotFetchUser(#[from] GetUserError), + + #[error("Could not add user to guild")] + CouldNotAddUserToGuild(#[from] AddUserToGuildError), + + #[error("User not part of instance's guilds")] + UserNotPartOfInstanceGuilds, + + #[error("Error authenticating user")] + ExternalError(A::Error), + + #[error(transparent)] + Unknown(#[from] anyhow::Error), +} diff --git a/src/lib/domain/intro_tool/ports.rs b/src/lib/domain/intro_tool/ports.rs index 5879782..74eaeef 100644 --- a/src/lib/domain/intro_tool/ports.rs +++ b/src/lib/domain/intro_tool/ports.rs @@ -1,21 +1,31 @@ use std::{collections::HashMap, future::Future}; -use crate::domain::intro_tool::models::guild::{ChannelName, IntroId}; +use chrono::NaiveDateTime; + +use crate::domain::intro_tool::models::guild::{ + AddUserToGuildError, AutheticateUserError, ExternalGuildId, UserName, +}; use super::models::guild::{ AddIntroToGuildError, AddIntroToGuildRequest, AddIntroToUserError, AddIntroToUserRequest, - Channel, CreateChannelError, CreateChannelRequest, CreateGuildError, CreateGuildRequest, - CreateUserError, CreateUserRequest, GetChannelError, GetGuildError, GetIntroError, - GetUserError, Guild, GuildId, GuildRef, Intro, User, + ApiToken, Channel, ChannelName, CreateChannelError, CreateChannelRequest, CreateGuildError, + CreateGuildRequest, CreateUserError, CreateUserRequest, GetChannelError, GetGuildError, + GetIntroError, GetUserError, Guild, GuildId, GuildRef, Intro, IntroId, User, }; pub trait IntroToolService: Send + Sync + Clone + 'static { fn needs_setup(&self) -> impl Future + Send; + fn authenticate_user( + &self, + params: A::Params, + ) -> impl Future>> + Send; + fn get_guild( &self, guild_id: impl Into + Send, ) -> impl Future> + Send; + fn get_guilds(&self) -> impl Future, GetGuildError>> + Send; fn get_guild_users( &self, guild_id: GuildId, @@ -42,8 +52,24 @@ pub trait IntroToolService: Send + Sync + Clone + 'static { req: AddIntroToUserRequest, ) -> impl Future> + Send; + fn refresh_user_token( + &self, + username: &str, + ) -> impl Future> + Send; + async fn create_guild(&self, req: CreateGuildRequest) -> Result; - async fn create_user(&self, req: CreateUserRequest) -> Result; + + fn create_user( + &self, + req: CreateUserRequest, + ) -> impl Future> + Send; + + fn add_user_to_guild( + &self, + guild_id: GuildId, + username: &str, + ) -> impl Future> + Send; + async fn create_channel( &self, req: CreateChannelRequest, @@ -60,6 +86,7 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static { &self, guild_id: GuildId, ) -> impl Future> + Send; + fn get_guilds(&self) -> impl Future, GetGuildError>> + Send; fn get_guild_count(&self) -> impl Future> + Send; fn get_guild_users( @@ -97,13 +124,31 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static { api_key: &str, ) -> impl Future> + Send; + fn set_user_api_key( + &self, + username: &str, + api_key: &str, + expires_at: NaiveDateTime, + ) -> impl Future> + Send; + fn set_user_intro( &self, req: AddIntroToUserRequest, ) -> impl Future> + Send; async fn create_guild(&self, req: CreateGuildRequest) -> Result; - async fn create_user(&self, req: CreateUserRequest) -> Result; + + fn create_user( + &self, + req: CreateUserRequest, + ) -> impl Future> + Send; + + fn add_user_to_guild( + &self, + guild_id: GuildId, + username: &str, + ) -> impl Future> + Send; + async fn create_channel( &self, req: CreateChannelRequest, @@ -117,6 +162,21 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static { ) -> impl Future> + Send; } +pub trait ExternalUser: Send + Sync + Clone + 'static { + fn external_token(&self) -> &str; + fn username(&self) -> UserName; + fn guilds(&self) -> impl Iterator; +} +pub trait AuthService: Send + Sync + Clone + 'static { + type Params: Send; + type User: ExternalUser + Send; + type Error: std::error::Error + Send; + + fn authenticate_user( + params: Self::Params, + ) -> impl Future> + Send; +} + pub trait RemoteAudioFetcher: Send + Sync + Clone + 'static { fn fetch_remote_audio( &self, diff --git a/src/lib/domain/intro_tool/service.rs b/src/lib/domain/intro_tool/service.rs index 98e3f01..906fb05 100644 --- a/src/lib/domain/intro_tool/service.rs +++ b/src/lib/domain/intro_tool/service.rs @@ -1,10 +1,19 @@ +use chrono::{Duration, Utc}; +use iter_tools::Itertools; use uuid::Uuid; use crate::domain::intro_tool::{ - models::guild::{self, GetUserError, GuildId, IntroId, User}, - ports::{IntroToolRepository, IntroToolService, LocalAudioFetcher, RemoteAudioFetcher}, + models::guild::{ + self, ApiToken, AutheticateUserError, CreateUserRequest, GetUserError, GuildId, IntroId, + User, + }, + ports::{ + ExternalUser, IntroToolRepository, IntroToolService, LocalAudioFetcher, RemoteAudioFetcher, + }, }; +use super::ports::AuthService; + #[derive(Clone)] pub struct Service where @@ -46,6 +55,64 @@ where guild_count == 0 } + async fn authenticate_user( + &self, + params: A::Params, + ) -> Result> { + let external_user = A::authenticate_user(params) + .await + .map_err(AutheticateUserError::ExternalError)?; + + let guilds = self.get_guilds().await?; + let external_user_guilds = guilds + .iter() + .filter(|guild| external_user.guilds().contains(&guild.external_id())) + .collect::>(); + + if external_user_guilds.is_empty() { + return Err(AutheticateUserError::UserNotPartOfInstanceGuilds); + } + + let user = match self.get_user(external_user.username()).await { + Ok(user) => Some(user), + Err(GetUserError::NotFound) => None, + + Err(err) => return Err(AutheticateUserError::CouldNotFetchUser(err)), + }; + + match user { + Some(user) => { + self.refresh_user_token(user.name()).await?; + } + None => { + self.create_user(CreateUserRequest { + user: external_user.username().clone(), + }) + .await?; + } + } + + let user = self.get_user(external_user.username()).await?; + let user_guilds = self.get_user_guilds(user.name()).await?; + + let guilds_to_add_user = + user_guilds + .iter() + .map(|guild| guild.id()) + .filter(|user_guild_id| { + external_user_guilds + .iter() + .map(|external_guild| external_guild.id()) + .contains(user_guild_id) + }); + + for guild in guilds_to_add_user { + self.add_user_to_guild(guild, user.name()).await?; + } + + Ok(user.api_key().to_string().into()) + } + async fn get_guild( &self, guild_id: impl Into, @@ -53,9 +120,14 @@ where self.repo.get_guild(guild_id.into()).await } + async fn get_guilds(&self) -> Result, guild::GetGuildError> { + self.repo.get_guilds().await + } + async fn get_guild_users(&self, guild_id: GuildId) -> Result, GetUserError> { self.repo.get_guild_users(guild_id).await } + async fn get_guild_intros( &self, guild_id: GuildId, @@ -81,6 +153,31 @@ where self.repo.get_user_from_api_key(api_key).await } + async fn set_user_intro( + &self, + req: guild::AddIntroToUserRequest, + ) -> Result<(), guild::AddIntroToUserError> { + self.repo.set_user_intro(req).await + } + + async fn refresh_user_token(&self, username: &str) -> Result { + let user = self.get_user(username).await?; + + let user_token = if user.api_key_expires_at() >= Utc::now().naive_utc() { + user.api_key().to_string() + } else { + Uuid::new_v4().to_string() + }; + + let expires_at = Utc::now().naive_utc() + Duration::weeks(4); + + self.repo + .set_user_api_key(username, &user_token, expires_at) + .await?; + + Ok(user_token) + } + async fn create_guild( &self, req: guild::CreateGuildRequest, @@ -92,7 +189,19 @@ where &self, req: guild::CreateUserRequest, ) -> Result { - self.repo.create_user(req).await + let username = req.user.clone(); + + self.repo.create_user(req).await?; + + Ok(self.get_user(username.as_ref()).await?) + } + + async fn add_user_to_guild( + &self, + guild_id: GuildId, + username: &str, + ) -> Result<(), guild::AddUserToGuildError> { + self.repo.add_user_to_guild(guild_id, username).await } async fn create_channel( @@ -123,11 +232,4 @@ where .add_intro_to_guild(&req.name, req.guild_id, file_name) .await } - - async fn set_user_intro( - &self, - req: guild::AddIntroToUserRequest, - ) -> Result<(), guild::AddIntroToUserError> { - self.repo.set_user_intro(req).await - } } diff --git a/src/lib/htmx.rs b/src/lib/htmx.rs index 39a67c1..22c3ce2 100644 --- a/src/lib/htmx.rs +++ b/src/lib/htmx.rs @@ -63,6 +63,7 @@ pub enum Tag { Header6, Strong, Paragraph, + Blockquote, JustText, } @@ -125,6 +126,7 @@ impl Tag { Self::Header6 => "h6", Self::Strong => "strong", Self::Paragraph => "paragraph", + Self::Blockquote => "blockquote", } } diff --git a/src/lib/inbound/http.rs b/src/lib/inbound/http.rs index 1c63c37..3c71c39 100644 --- a/src/lib/inbound/http.rs +++ b/src/lib/inbound/http.rs @@ -1,5 +1,5 @@ mod handlers; -mod page; +pub(super) mod page; use std::{net::SocketAddr, sync::Arc}; @@ -49,7 +49,9 @@ impl FromRequestParts> for User { { Ok(user) => { let now = Utc::now().naive_utc(); - if user.api_key_expires_at() < now || user.discord_token_expires_at() < now { + if user.api_key_expires_at() < now { + //|| user.discord_token_expires_at() < now { + tracing::error!("user token expired at: {}", user.api_key_expires_at()); Err(Redirect::to(&format!("{}/login", state.origin))) } else { Ok(user) @@ -125,6 +127,7 @@ where "/v2/intros/add/:guild_id/:channel", post(handlers::set_user_intro), ) + .route("/v2/auth", get(page::auth)) // .route("/guild/:guild_id/setup", get(routes::guild_setup)) // .route( @@ -135,7 +138,6 @@ where // "/guild/:guild_id/permissions/update", // post(routes::update_guild_permissions), // ) - // .route("/v2/auth", get(routes::v2_auth)) // .route( // "/v2/intros/remove/:guild_id/:channel", // post(routes::v2_remove_intro_from_user), diff --git a/src/lib/inbound/http/page.rs b/src/lib/inbound/http/page.rs index f74c6dc..7ffd7c2 100644 --- a/src/lib/inbound/http/page.rs +++ b/src/lib/inbound/http/page.rs @@ -1,15 +1,25 @@ +use std::collections::HashMap; + use axum::{ - extract::{Path, State}, + extract::{Path, Query, State}, response::{Html, Redirect}, }; +use axum_extra::extract::{CookieJar, cookie::Cookie}; +use reqwest::Url; +use serde::{Deserialize, Deserializer}; use crate::{ + auth, domain::intro_tool::{ models::guild::{ChannelName, GuildRef, Intro, User}, ports::IntroToolService, }, htmx::{Build, HtmxBuilder, Tag}, - inbound::{http::ApiState, response::ErrorAsRedirect}, + inbound::{ + http::ApiState, + response::{ApiError, ErrorAsRedirect, PageError}, + }, + outbound::discord::{DiscordAuthParams, DiscordService}, }; pub async fn home( @@ -22,7 +32,7 @@ pub async fn home( .intro_tool_service .get_user_guilds(user.name()) .await - .as_redirect(&state.origin, "/login")?; + .as_redirect(&state.origin, "login")?; // TODO: get user app permissions // TODO: check if user can add guilds @@ -78,7 +88,10 @@ pub async fn login( if user.is_some() { Err(Redirect::to(&format!("{}/", state.origin))) } else { - let authorize_uri = format!("https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}/v2/auth&response_type=code&scope=guilds.members.read+guilds+identify", state.secrets.client_id, state.origin); + let authorize_uri = format!( + "https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}/v2/auth&response_type=code&scope=guilds.members.read+guilds+identify", + state.secrets.client_id, state.origin + ); Ok(Html( HtmxBuilder::new(Tag::Html) @@ -111,17 +124,17 @@ pub async fn guild_dashboard( .intro_tool_service .get_guild(guild_id) .await - .as_redirect(&state.origin, "/login")?; + .as_redirect(&state.origin, "login")?; let user_guilds = state .intro_tool_service .get_user_guilds(user.name()) .await - .as_redirect(&state.origin, "/login")?; + .as_redirect(&state.origin, "login")?; let guild_intros = state .intro_tool_service .get_guild_intros(guild_id.into()) .await - .as_redirect(&state.origin, "/login")?; + .as_redirect(&state.origin, "login")?; // does user have access to this guild if !user_guilds @@ -213,7 +226,38 @@ pub async fn guild_dashboard( )) } -fn page_header(title: &str) -> HtmxBuilder { +pub async fn auth( + State(state): State>, + Query(params): Query>, + jar: CookieJar, +) -> Result<(CookieJar, Redirect), PageError> { + let Some(code) = params.get("code") else { + return Err(ApiError::bad_request("no code").into()); + }; + + tracing::info!("attempting to get access token with code {}", code); + + let token = state + .intro_tool_service + // TODO: decoulple discord from HTTP server + .authenticate_user::(DiscordAuthParams { + origin: state.origin.clone(), + code: code.clone(), + client_id: state.secrets.client_id.clone(), + client_secret: state.secrets.client_secret.clone(), + }) + .await + .map_err(ApiError::from)?; + let uri = Url::parse(&state.origin).expect("should be a valid url"); + + let mut cookie = Cookie::new("access_token", token); + cookie.set_path(uri.path().to_string()); + cookie.set_secure(true); + + Ok((jar.add(cookie), Redirect::to(&format!("{}/", state.origin)))) +} + +pub fn page_header(title: &str) -> HtmxBuilder { HtmxBuilder::new(Tag::Html).head(|b| { b.title(title) .script( diff --git a/src/lib/inbound/response.rs b/src/lib/inbound/response.rs index 2f06088..753b182 100644 --- a/src/lib/inbound/response.rs +++ b/src/lib/inbound/response.rs @@ -1,15 +1,23 @@ use std::fmt::Debug; use axum::{ - response::{IntoResponse, Redirect}, Json, + response::{Html, IntoResponse, Redirect}, }; use reqwest::StatusCode; use serde::Serialize; -use crate::domain::intro_tool::models::guild::{ - AddIntroToGuildError, AddIntroToUserError, GetChannelError, GetGuildError, GetIntroError, - GetUserError, +use crate::{ + domain::intro_tool::{ + models::guild::{ + AddIntroToGuildError, AddIntroToUserError, AddUserToGuildError, AutheticateUserError, + CreateUserError, GetChannelError, GetGuildError, GetIntroError, GetUserError, + }, + ports::AuthService, + }, + htmx::{Build, HtmxBuilder, Tag}, + inbound::http::page::page_header, + outbound::discord::DiscordError, }; pub(super) trait ErrorAsRedirect: Sized { @@ -70,6 +78,32 @@ impl ErrorAsRedirect for Result { } } +pub(super) struct PageError(pub ApiError); + +impl IntoResponse for PageError { + fn into_response(self) -> axum::response::Response { + Html( + page_header("MemeJoin - Error") + .builder(Tag::Div, |b| { + b.attribute("class", "container") + .builder_text( + Tag::Header2, + &format!("Uh oh! - Status Code {}", self.0.status_code()), + ) + .builder(Tag::Blockquote, |b| b.text(self.0.message())) + .builder(Tag::Empty, |b| b.text("
")) + .builder(Tag::Anchor, |b| { + b.attribute("role", "button") + .text("Go Back") + .attribute("href", "/") + }) + }) + .build(), + ) + .into_response() + } +} + pub(super) struct ApiResponse(StatusCode, Json); #[derive(Serialize, Debug)] @@ -100,6 +134,15 @@ impl ApiError { } } + fn message(&self) -> &str { + match self { + ApiError::NotFound { message } => message, + ApiError::BadRequest { message } => message, + ApiError::Forbidden { message } => message, + ApiError::InternalServerError { message } => message, + } + } + pub(super) fn not_found(message: impl ToString) -> Self { Self::NotFound { message: message.to_string(), @@ -133,6 +176,12 @@ impl IntoResponse for ApiError { } } +impl From for PageError { + fn from(value: ApiError) -> Self { + Self(value) + } +} + impl From for ApiError { fn from(value: GetGuildError) -> Self { match value { @@ -215,3 +264,75 @@ impl From for ApiError { } } } + +impl From for ApiError { + fn from(value: CreateUserError) -> Self { + match value { + CreateUserError::CouldNotGetUser(err) => err.into(), + CreateUserError::Unknown(error) => { + tracing::error!(err = ?error, "unknown error"); + + Self::internal(error.to_string()) + } + } + } +} + +impl From for ApiError { + fn from(value: AddUserToGuildError) -> Self { + match value { + AddUserToGuildError::Unknown(error) => { + tracing::error!(err = ?error, "unknown error"); + + Self::internal(error.to_string()) + } + } + } +} + +impl From> for ApiError +where + ::Error: Into, +{ + fn from(value: AutheticateUserError) -> Self { + match value { + AutheticateUserError::CouldNotFetchGuild(err) => err.into(), + AutheticateUserError::CouldNotCreateUser(err) => err.into(), + AutheticateUserError::CouldNotFetchUser(err) => err.into(), + AutheticateUserError::CouldNotAddUserToGuild(err) => err.into(), + AutheticateUserError::UserNotPartOfInstanceGuilds => { + Self::internal("User not part of instance guilds") + } + AutheticateUserError::ExternalError(err) => err.into(), + AutheticateUserError::Unknown(err) => { + tracing::error!(err = ?err, "unknown error"); + + Self::internal(err.to_string()) + } + } + } +} + +impl From for ApiError { + fn from(value: DiscordError) -> Self { + match value { + DiscordError::ApiRequest(error) => { + tracing::error!(err = ?error, "api request error"); + + Self::internal(error.to_string()) + } + } + } +} + +impl From for ApiError { + fn from(value: reqwest::Error) -> Self { + Self::internal(format!("error making request to external service: {value}",)) + } +} + +impl From for ApiError { + fn from(value: anyhow::Error) -> Self { + Self::internal(format!("unknown error: {value}",)) + } +} diff --git a/src/lib/outbound/discord.rs b/src/lib/outbound/discord.rs new file mode 100644 index 0000000..9e03fc7 --- /dev/null +++ b/src/lib/outbound/discord.rs @@ -0,0 +1,130 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Deserializer}; +use thiserror::Error; + +use crate::domain::intro_tool::{ + models::guild::{ExternalGuildId, UserName}, + ports::{AuthService, ExternalUser}, +}; + +#[derive(Clone)] +pub struct DiscordService; + +#[derive(Clone)] +pub struct DiscordUser { + token: String, + username: String, + guilds: Vec, +} + +#[derive(Debug, Error)] +pub enum DiscordError { + #[error(transparent)] + ApiRequest(#[from] reqwest::Error), +} + +pub struct DiscordAuthParams { + pub origin: String, + pub code: String, + pub client_id: String, + pub client_secret: String, +} + +#[derive(Deserialize)] +struct DiscordApiAuth { + access_token: String, + token_type: String, + expires_in: usize, + refresh_token: String, + scope: String, +} + +#[derive(Deserialize)] +struct DiscordApiUser { + pub username: String, +} + +#[derive(Deserialize)] +struct DiscordUserGuild { + #[serde(deserialize_with = "serde_string_as_u64")] + id: u64, + name: String, + owner: bool, +} + +impl AuthService for DiscordService { + type Params = DiscordAuthParams; + type User = DiscordUser; + type Error = DiscordError; + + async fn authenticate_user(params: Self::Params) -> Result { + let mut data = HashMap::new(); + + let redirect_uri = format!("{}/v2/auth", params.origin); + data.insert("client_id", params.client_id.as_str()); + data.insert("client_secret", params.client_secret.as_str()); + data.insert("grant_type", "authorization_code"); + data.insert("code", ¶ms.code); + data.insert("redirect_uri", &redirect_uri); + + let client = reqwest::Client::new(); + + let auth: DiscordApiAuth = client + .post("https://discord.com/api/oauth2/token") + .form(&data) + .send() + .await? + .json() + .await?; + + // Get authorized username + let user: DiscordApiUser = client + .get("https://discord.com/api/v10/users/@me") + .bearer_auth(&auth.access_token) + .send() + .await? + .json() + .await?; + + // TODO: get bot's guilds so we only save users who are able to use the bot + let discord_guilds: Vec = client + .get("https://discord.com/api/v10/users/@me/guilds") + .bearer_auth(&auth.access_token) + .send() + .await? + .json() + .await?; + + Ok(Self::User { + token: auth.access_token, + username: user.username, + guilds: discord_guilds.into_iter().map(|guild| guild.id).collect(), + }) + } +} + +impl ExternalUser for DiscordUser { + fn external_token(&self) -> &str { + &self.token + } + + fn username(&self) -> UserName { + self.username.clone().into() + } + + fn guilds(&self) -> impl Iterator { + self.guilds.iter().map(|id| ExternalGuildId(*id)) + } +} + +fn serde_string_as_u64<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let value = <&str as Deserialize>::deserialize(deserializer)?; + + value + .parse::() + .map_err(|_| serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &"u64")) +} diff --git a/src/lib/outbound/mod.rs b/src/lib/outbound/mod.rs index 659980b..b6769fc 100644 --- a/src/lib/outbound/mod.rs +++ b/src/lib/outbound/mod.rs @@ -1,3 +1,4 @@ +pub mod discord; pub mod ffmpeg; pub mod sqlite; pub mod ytdlp; diff --git a/src/lib/outbound/sqlite.rs b/src/lib/outbound/sqlite.rs index 4c0bc75..75d603f 100644 --- a/src/lib/outbound/sqlite.rs +++ b/src/lib/outbound/sqlite.rs @@ -1,3 +1,4 @@ +use chrono::NaiveDateTime; use iter_tools::Itertools; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; @@ -63,6 +64,39 @@ impl IntroToolRepository for Sqlite { .with_channels(self.get_guild_channels(guild_id).await?)) } + async fn get_guilds(&self) -> Result, GetGuildError> { + let conn = self.conn.lock().await; + + let mut query = conn + .prepare( + " + SELECT + Guild.id, + Guild.name, + Guild.sound_delay + FROM Guild + LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id + LEFT JOIN User ON User.username = UserGuild.username + ", + ) + .context("failed to prepare query")?; + + let guilds = query + .query_map([], |row| { + Ok(GuildRef::new( + row.get::<_, u64>(0)?.into(), + row.get(1)?, + row.get(2)?, + row.get::<_, u64>(0)?.into(), + )) + }) + .context("failed to map prepared query")? + .collect::>() + .context("failed to fetch guild user rows")?; + + Ok(guilds) + } + async fn get_guild_count(&self) -> Result { let conn = self.conn.lock().await; @@ -117,43 +151,6 @@ impl IntroToolRepository for Sqlite { Ok(users) } - async fn get_user_guilds( - &self, - username: impl AsRef, - ) -> Result, GetGuildError> { - let conn = self.conn.lock().await; - - let mut query = conn - .prepare( - " - SELECT - Guild.id, - Guild.name, - Guild.sound_delay - FROM Guild - LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id - LEFT JOIN User ON User.username = UserGuild.username - WHERE User.username = :username - ", - ) - .context("failed to prepare query")?; - - let guilds = query - .query_map(&[(":username", username.as_ref())], |row| { - Ok(GuildRef::new( - row.get::<_, u64>(0)?.into(), - row.get(1)?, - row.get(2)?, - row.get::<_, u64>(0)?.into(), - )) - }) - .context("failed to map prepared query")? - .collect::>() - .context("failed to fetch guild user rows")?; - - Ok(guilds) - } - async fn get_guild_channels(&self, guild_id: GuildId) -> Result, GetChannelError> { let conn = self.conn.lock().await; @@ -305,6 +302,43 @@ impl IntroToolRepository for Sqlite { Ok(intros) } + async fn get_user_guilds( + &self, + username: impl AsRef, + ) -> Result, GetGuildError> { + let conn = self.conn.lock().await; + + let mut query = conn + .prepare( + " + SELECT + Guild.id, + Guild.name, + Guild.sound_delay + FROM Guild + LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id + LEFT JOIN User ON User.username = UserGuild.username + WHERE User.username = :username + ", + ) + .context("failed to prepare query")?; + + let guilds = query + .query_map(&[(":username", username.as_ref())], |row| { + Ok(GuildRef::new( + row.get::<_, u64>(0)?.into(), + row.get(1)?, + row.get(2)?, + row.get::<_, u64>(0)?.into(), + )) + }) + .context("failed to map prepared query")? + .collect::>() + .context("failed to fetch guild user rows")?; + + Ok(guilds) + } + async fn get_user_from_api_key(&self, api_key: &str) -> Result { let username = { let conn = self.conn.lock().await; @@ -330,12 +364,124 @@ impl IntroToolRepository for Sqlite { self.get_user(username).await } - async fn create_guild(&self, req: CreateGuildRequest) -> Result { - todo!() + async fn set_user_api_key( + &self, + username: &str, + api_key: &str, + expires_at: NaiveDateTime, + ) -> Result<(), GetUserError> { + let conn = self.conn.lock().await; + + conn.execute( + " + UPDATE User + SET api_key = ?1, api_key_expires_at = ?2 + WHERE username = ?3 + ", + [api_key, &expires_at.to_string(), username], + ) + .context("failed to update user api key")?; + + Ok(()) } - async fn create_user(&self, req: CreateUserRequest) -> Result { - todo!() + async fn set_user_intro( + &self, + req: AddIntroToUserRequest, + ) -> Result<(), guild::AddIntroToUserError> { + let conn = self.conn.lock().await; + + conn.execute( + " + DELETE FROM UserIntro + WHERE username = ?1 + AND guild_id = ?2 + AND channel_name = ?3 + ", + [ + &req.user.to_string(), + &req.guild_id.to_string(), + &req.channel_name.to_string(), + ], + ) + .context("failed to delete user intros")?; + + conn.execute( + " + INSERT INTO + UserIntro (username, guild_id, channel_name, intro_id) + VALUES (?1, ?2, ?3, ?4)", + [ + &req.user.to_string(), + &req.guild_id.to_string(), + &req.channel_name.to_string(), + &req.intro_id.to_string(), + ], + ) + .context("failed to insert user intro")?; + + Ok(()) + } + + async fn create_guild(&self, req: CreateGuildRequest) -> Result { + let conn = self.conn.lock().await; + + let guild_id: GuildId = req.external_id.0.into(); + + conn.execute( + " + INSERT INTO + Guild (id, name, sound_delay) + VALUES (?1, ?2, ?3) + ", + [ + &guild_id.to_string(), + &req.name, + &req.sound_delay.to_string(), + ], + ) + .context("failed to insert guild")?; + + Ok(Guild::new( + guild_id, + req.name, + req.sound_delay, + req.external_id, + )) + } + + async fn create_user(&self, req: CreateUserRequest) -> Result<(), CreateUserError> { + let conn = self.conn.lock().await; + + conn.execute( + " + INSERT INTO + User (username) + VALUES (?1) + ", + [req.user.as_ref()], + ) + .context("failed to insert user")?; + + Ok(()) + } + + async fn add_user_to_guild( + &self, + guild_id: GuildId, + username: &str, + ) -> Result<(), guild::AddUserToGuildError> { + let conn = self.conn.lock().await; + + conn.execute( + " + INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2) + ", + [username, &guild_id.to_string()], + ) + .context("failed to insert user guild")?; + + Ok(()) } async fn create_channel( @@ -389,42 +535,4 @@ impl IntroToolRepository for Sqlite { Ok(intro_id) } - - async fn set_user_intro( - &self, - req: AddIntroToUserRequest, - ) -> Result<(), guild::AddIntroToUserError> { - let conn = self.conn.lock().await; - - conn.execute( - " - DELETE FROM UserIntro - WHERE username = ?1 - AND guild_id = ?2 - AND channel_name = ?3 - ", - [ - &req.user.to_string(), - &req.guild_id.to_string(), - &req.channel_name.to_string(), - ], - ) - .context("failed to delete user intros")?; - - conn.execute( - " - INSERT INTO - UserIntro (username, guild_id, channel_name, intro_id) - VALUES (?1, ?2, ?3, ?4)", - [ - &req.user.to_string(), - &req.guild_id.to_string(), - &req.channel_name.to_string(), - &req.intro_id.to_string(), - ], - ) - .context("failed to insert user intro")?; - - Ok(()) - } }