discord auth

hexagon
Patrick Cleavelin 2025-10-18 16:16:20 -05:00
parent c9a91d3d36
commit 569d87aec4
13 changed files with 797 additions and 126 deletions

View File

@ -1,7 +1,7 @@
[package] [package]
name = "memejoin-rs" name = "memejoin-rs"
version = "0.2.2-alpha" version = "0.2.2-alpha"
edition = "2021" edition = "2024"
[[bin]] [[bin]]
name = "memejoin-rs" name = "memejoin-rs"

View File

@ -1,3 +1,5 @@
use serde::{Deserialize, Serialize};
#[derive(Clone)] #[derive(Clone)]
pub struct DiscordSecret { pub struct DiscordSecret {
pub client_id: String, pub client_id: String,
@ -5,6 +7,15 @@ pub struct DiscordSecret {
pub bot_token: String, pub bot_token: String,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Discord {
pub(crate) access_token: String,
pub(crate) token_type: String,
pub(crate) expires_in: usize,
pub(crate) refresh_token: String,
pub(crate) scope: String,
}
/* /*
use std::str::FromStr; use std::str::FromStr;

View File

@ -5,6 +5,8 @@ use crate::domain::intro_tool::{
ports::{IntroToolRepository, IntroToolService}, ports::{IntroToolRepository, IntroToolService},
}; };
use super::ports::AuthService;
#[derive(Clone)] #[derive(Clone)]
pub struct DebugService<S> pub struct DebugService<S>
where where
@ -34,6 +36,13 @@ where
self.wrapped_service.needs_setup().await self.wrapped_service.needs_setup().await
} }
async fn authenticate_user<A: AuthService>(
&self,
params: A::Params,
) -> Result<models::guild::ApiToken, models::guild::AutheticateUserError<A>> {
self.wrapped_service.authenticate_user(params).await
}
async fn get_guild( async fn get_guild(
&self, &self,
guild_id: impl Into<models::guild::GuildId> + Send, guild_id: impl Into<models::guild::GuildId> + Send,
@ -41,6 +50,12 @@ where
self.wrapped_service.get_guild(guild_id).await self.wrapped_service.get_guild(guild_id).await
} }
async fn get_guilds(
&self,
) -> Result<Vec<models::guild::GuildRef>, models::guild::GetGuildError> {
self.wrapped_service.get_guilds().await
}
async fn get_guild_users( async fn get_guild_users(
&self, &self,
guild_id: models::guild::GuildId, guild_id: models::guild::GuildId,
@ -88,6 +103,20 @@ where
.with_channel_intros(user.intros().clone())) .with_channel_intros(user.intros().clone()))
} }
async fn set_user_intro(
&self,
req: models::guild::AddIntroToUserRequest,
) -> Result<(), models::guild::AddIntroToUserError> {
self.wrapped_service.set_user_intro(req).await
}
async fn refresh_user_token(
&self,
username: &str,
) -> Result<String, models::guild::GetUserError> {
self.wrapped_service.refresh_user_token(username).await
}
async fn create_guild( async fn create_guild(
&self, &self,
req: models::guild::CreateGuildRequest, req: models::guild::CreateGuildRequest,
@ -102,6 +131,16 @@ where
self.wrapped_service.create_user(req).await self.wrapped_service.create_user(req).await
} }
async fn add_user_to_guild(
&self,
guild_id: models::guild::GuildId,
username: &str,
) -> Result<(), models::guild::AddUserToGuildError> {
self.wrapped_service
.add_user_to_guild(guild_id, username)
.await
}
async fn create_channel( async fn create_channel(
&self, &self,
req: models::guild::CreateChannelRequest, req: models::guild::CreateChannelRequest,
@ -115,11 +154,4 @@ where
) -> Result<IntroId, models::guild::AddIntroToGuildError> { ) -> Result<IntroId, models::guild::AddIntroToGuildError> {
self.wrapped_service.add_intro_to_guild(req).await self.wrapped_service.add_intro_to_guild(req).await
} }
async fn set_user_intro(
&self,
req: models::guild::AddIntroToUserRequest,
) -> Result<(), models::guild::AddIntroToUserError> {
self.wrapped_service.set_user_intro(req).await
}
} }

View File

@ -1,13 +1,18 @@
use std::collections::HashMap; use std::{borrow::Cow, collections::HashMap};
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use thiserror::Error; use thiserror::Error;
use crate::domain::intro_tool::ports::AuthService;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ApiToken(String);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GuildId(u64); pub struct GuildId(u64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ExternalGuildId(u64); pub struct ExternalGuildId(pub u64);
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UserName(String); pub struct UserName(String);
@ -18,6 +23,18 @@ pub struct ChannelName(String);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IntroId(i32); pub struct IntroId(i32);
impl From<ApiToken> for Cow<'_, str> {
fn from(value: ApiToken) -> Self {
Cow::Owned(value.0)
}
}
impl From<String> for ApiToken {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<u64> for GuildId { impl From<u64> for GuildId {
fn from(id: u64) -> Self { fn from(id: u64) -> Self {
Self(id) Self(id)
@ -108,6 +125,10 @@ impl GuildRef {
pub fn name(&self) -> &str { pub fn name(&self) -> &str {
&self.name &self.name
} }
pub fn external_id(&self) -> ExternalGuildId {
self.external_id
}
} }
impl GuildRef { impl GuildRef {
@ -198,6 +219,10 @@ impl User {
&self.channel_intros &self.channel_intros
} }
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn api_key_expires_at(&self) -> NaiveDateTime { pub fn api_key_expires_at(&self) -> NaiveDateTime {
self.api_key_expires_at self.api_key_expires_at
} }
@ -255,18 +280,18 @@ impl Intro {
} }
pub struct CreateGuildRequest { pub struct CreateGuildRequest {
name: String, pub name: String,
sound_delay: u32, pub sound_delay: u32,
external_id: ExternalGuildId, pub external_id: ExternalGuildId,
} }
pub struct CreateUserRequest { pub struct CreateUserRequest {
user: UserName, pub user: UserName,
} }
pub struct CreateChannelRequest { pub struct CreateChannelRequest {
guild_id: GuildId, pub guild_id: GuildId,
channel_name: ChannelName, pub channel_name: ChannelName,
} }
pub struct AddIntroToGuildRequest { pub struct AddIntroToGuildRequest {
@ -297,6 +322,9 @@ pub enum CreateGuildError {
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum CreateUserError { pub enum CreateUserError {
#[error("Could not get user")]
CouldNotGetUser(#[from] GetUserError),
#[error(transparent)] #[error(transparent)]
Unknown(#[from] anyhow::Error), Unknown(#[from] anyhow::Error),
} }
@ -307,6 +335,12 @@ pub enum CreateChannelError {
Unknown(#[from] anyhow::Error), Unknown(#[from] anyhow::Error),
} }
#[derive(Debug, Error)]
pub enum AddUserToGuildError {
#[error(transparent)]
Unknown(#[from] anyhow::Error),
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum AddIntroToGuildError { pub enum AddIntroToGuildError {
#[error(transparent)] #[error(transparent)]
@ -366,3 +400,27 @@ pub enum GetIntroError {
#[error(transparent)] #[error(transparent)]
Unknown(#[from] anyhow::Error), Unknown(#[from] anyhow::Error),
} }
#[derive(Debug, Error)]
pub enum AutheticateUserError<A: AuthService> {
#[error("Could not fetch guild")]
CouldNotFetchGuild(#[from] GetGuildError),
#[error("Could not create user")]
CouldNotCreateUser(#[from] CreateUserError),
#[error("Could not fetch guild user")]
CouldNotFetchUser(#[from] GetUserError),
#[error("Could not add user to guild")]
CouldNotAddUserToGuild(#[from] AddUserToGuildError),
#[error("User not part of instance's guilds")]
UserNotPartOfInstanceGuilds,
#[error("Error authenticating user")]
ExternalError(A::Error),
#[error(transparent)]
Unknown(#[from] anyhow::Error),
}

View File

@ -1,21 +1,31 @@
use std::{collections::HashMap, future::Future}; use std::{collections::HashMap, future::Future};
use crate::domain::intro_tool::models::guild::{ChannelName, IntroId}; use chrono::NaiveDateTime;
use crate::domain::intro_tool::models::guild::{
AddUserToGuildError, AutheticateUserError, ExternalGuildId, UserName,
};
use super::models::guild::{ use super::models::guild::{
AddIntroToGuildError, AddIntroToGuildRequest, AddIntroToUserError, AddIntroToUserRequest, AddIntroToGuildError, AddIntroToGuildRequest, AddIntroToUserError, AddIntroToUserRequest,
Channel, CreateChannelError, CreateChannelRequest, CreateGuildError, CreateGuildRequest, ApiToken, Channel, ChannelName, CreateChannelError, CreateChannelRequest, CreateGuildError,
CreateUserError, CreateUserRequest, GetChannelError, GetGuildError, GetIntroError, CreateGuildRequest, CreateUserError, CreateUserRequest, GetChannelError, GetGuildError,
GetUserError, Guild, GuildId, GuildRef, Intro, User, GetIntroError, GetUserError, Guild, GuildId, GuildRef, Intro, IntroId, User,
}; };
pub trait IntroToolService: Send + Sync + Clone + 'static { pub trait IntroToolService: Send + Sync + Clone + 'static {
fn needs_setup(&self) -> impl Future<Output = bool> + Send; fn needs_setup(&self) -> impl Future<Output = bool> + Send;
fn authenticate_user<A: AuthService>(
&self,
params: A::Params,
) -> impl Future<Output = Result<ApiToken, AutheticateUserError<A>>> + Send;
fn get_guild( fn get_guild(
&self, &self,
guild_id: impl Into<GuildId> + Send, guild_id: impl Into<GuildId> + Send,
) -> impl Future<Output = Result<Guild, GetGuildError>> + Send; ) -> impl Future<Output = Result<Guild, GetGuildError>> + Send;
fn get_guilds(&self) -> impl Future<Output = Result<Vec<GuildRef>, GetGuildError>> + Send;
fn get_guild_users( fn get_guild_users(
&self, &self,
guild_id: GuildId, guild_id: GuildId,
@ -42,8 +52,24 @@ pub trait IntroToolService: Send + Sync + Clone + 'static {
req: AddIntroToUserRequest, req: AddIntroToUserRequest,
) -> impl Future<Output = Result<(), AddIntroToUserError>> + Send; ) -> impl Future<Output = Result<(), AddIntroToUserError>> + Send;
fn refresh_user_token(
&self,
username: &str,
) -> impl Future<Output = Result<String, GetUserError>> + Send;
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError>; async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError>;
async fn create_user(&self, req: CreateUserRequest) -> Result<User, CreateUserError>;
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<User, CreateUserError>> + Send;
fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> impl Future<Output = Result<(), AddUserToGuildError>> + Send;
async fn create_channel( async fn create_channel(
&self, &self,
req: CreateChannelRequest, req: CreateChannelRequest,
@ -60,6 +86,7 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static {
&self, &self,
guild_id: GuildId, guild_id: GuildId,
) -> impl Future<Output = Result<Guild, GetGuildError>> + Send; ) -> impl Future<Output = Result<Guild, GetGuildError>> + Send;
fn get_guilds(&self) -> impl Future<Output = Result<Vec<GuildRef>, GetGuildError>> + Send;
fn get_guild_count(&self) -> impl Future<Output = Result<usize, GetGuildError>> + Send; fn get_guild_count(&self) -> impl Future<Output = Result<usize, GetGuildError>> + Send;
fn get_guild_users( fn get_guild_users(
@ -97,13 +124,31 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static {
api_key: &str, api_key: &str,
) -> impl Future<Output = Result<User, GetUserError>> + Send; ) -> impl Future<Output = Result<User, GetUserError>> + Send;
fn set_user_api_key(
&self,
username: &str,
api_key: &str,
expires_at: NaiveDateTime,
) -> impl Future<Output = Result<(), GetUserError>> + Send;
fn set_user_intro( fn set_user_intro(
&self, &self,
req: AddIntroToUserRequest, req: AddIntroToUserRequest,
) -> impl Future<Output = Result<(), AddIntroToUserError>> + Send; ) -> impl Future<Output = Result<(), AddIntroToUserError>> + Send;
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError>; async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError>;
async fn create_user(&self, req: CreateUserRequest) -> Result<User, CreateUserError>;
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<(), CreateUserError>> + Send;
fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> impl Future<Output = Result<(), AddUserToGuildError>> + Send;
async fn create_channel( async fn create_channel(
&self, &self,
req: CreateChannelRequest, req: CreateChannelRequest,
@ -117,6 +162,21 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static {
) -> impl Future<Output = Result<IntroId, AddIntroToGuildError>> + Send; ) -> impl Future<Output = Result<IntroId, AddIntroToGuildError>> + Send;
} }
pub trait ExternalUser: Send + Sync + Clone + 'static {
fn external_token(&self) -> &str;
fn username(&self) -> UserName;
fn guilds(&self) -> impl Iterator<Item = ExternalGuildId>;
}
pub trait AuthService: Send + Sync + Clone + 'static {
type Params: Send;
type User: ExternalUser + Send;
type Error: std::error::Error + Send;
fn authenticate_user(
params: Self::Params,
) -> impl Future<Output = Result<Self::User, Self::Error>> + Send;
}
pub trait RemoteAudioFetcher: Send + Sync + Clone + 'static { pub trait RemoteAudioFetcher: Send + Sync + Clone + 'static {
fn fetch_remote_audio( fn fetch_remote_audio(
&self, &self,

View File

@ -1,10 +1,19 @@
use chrono::{Duration, Utc};
use iter_tools::Itertools;
use uuid::Uuid; use uuid::Uuid;
use crate::domain::intro_tool::{ use crate::domain::intro_tool::{
models::guild::{self, GetUserError, GuildId, IntroId, User}, models::guild::{
ports::{IntroToolRepository, IntroToolService, LocalAudioFetcher, RemoteAudioFetcher}, self, ApiToken, AutheticateUserError, CreateUserRequest, GetUserError, GuildId, IntroId,
User,
},
ports::{
ExternalUser, IntroToolRepository, IntroToolService, LocalAudioFetcher, RemoteAudioFetcher,
},
}; };
use super::ports::AuthService;
#[derive(Clone)] #[derive(Clone)]
pub struct Service<R, RA, LA> pub struct Service<R, RA, LA>
where where
@ -46,6 +55,64 @@ where
guild_count == 0 guild_count == 0
} }
async fn authenticate_user<A: AuthService>(
&self,
params: A::Params,
) -> Result<ApiToken, AutheticateUserError<A>> {
let external_user = A::authenticate_user(params)
.await
.map_err(AutheticateUserError::ExternalError)?;
let guilds = self.get_guilds().await?;
let external_user_guilds = guilds
.iter()
.filter(|guild| external_user.guilds().contains(&guild.external_id()))
.collect::<Vec<_>>();
if external_user_guilds.is_empty() {
return Err(AutheticateUserError::UserNotPartOfInstanceGuilds);
}
let user = match self.get_user(external_user.username()).await {
Ok(user) => Some(user),
Err(GetUserError::NotFound) => None,
Err(err) => return Err(AutheticateUserError::CouldNotFetchUser(err)),
};
match user {
Some(user) => {
self.refresh_user_token(user.name()).await?;
}
None => {
self.create_user(CreateUserRequest {
user: external_user.username().clone(),
})
.await?;
}
}
let user = self.get_user(external_user.username()).await?;
let user_guilds = self.get_user_guilds(user.name()).await?;
let guilds_to_add_user =
user_guilds
.iter()
.map(|guild| guild.id())
.filter(|user_guild_id| {
external_user_guilds
.iter()
.map(|external_guild| external_guild.id())
.contains(user_guild_id)
});
for guild in guilds_to_add_user {
self.add_user_to_guild(guild, user.name()).await?;
}
Ok(user.api_key().to_string().into())
}
async fn get_guild( async fn get_guild(
&self, &self,
guild_id: impl Into<GuildId>, guild_id: impl Into<GuildId>,
@ -53,9 +120,14 @@ where
self.repo.get_guild(guild_id.into()).await self.repo.get_guild(guild_id.into()).await
} }
async fn get_guilds(&self) -> Result<Vec<guild::GuildRef>, guild::GetGuildError> {
self.repo.get_guilds().await
}
async fn get_guild_users(&self, guild_id: GuildId) -> Result<Vec<User>, GetUserError> { async fn get_guild_users(&self, guild_id: GuildId) -> Result<Vec<User>, GetUserError> {
self.repo.get_guild_users(guild_id).await self.repo.get_guild_users(guild_id).await
} }
async fn get_guild_intros( async fn get_guild_intros(
&self, &self,
guild_id: GuildId, guild_id: GuildId,
@ -81,6 +153,31 @@ where
self.repo.get_user_from_api_key(api_key).await self.repo.get_user_from_api_key(api_key).await
} }
async fn set_user_intro(
&self,
req: guild::AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
self.repo.set_user_intro(req).await
}
async fn refresh_user_token(&self, username: &str) -> Result<String, GetUserError> {
let user = self.get_user(username).await?;
let user_token = if user.api_key_expires_at() >= Utc::now().naive_utc() {
user.api_key().to_string()
} else {
Uuid::new_v4().to_string()
};
let expires_at = Utc::now().naive_utc() + Duration::weeks(4);
self.repo
.set_user_api_key(username, &user_token, expires_at)
.await?;
Ok(user_token)
}
async fn create_guild( async fn create_guild(
&self, &self,
req: guild::CreateGuildRequest, req: guild::CreateGuildRequest,
@ -92,7 +189,19 @@ where
&self, &self,
req: guild::CreateUserRequest, req: guild::CreateUserRequest,
) -> Result<guild::User, guild::CreateUserError> { ) -> Result<guild::User, guild::CreateUserError> {
self.repo.create_user(req).await let username = req.user.clone();
self.repo.create_user(req).await?;
Ok(self.get_user(username.as_ref()).await?)
}
async fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> Result<(), guild::AddUserToGuildError> {
self.repo.add_user_to_guild(guild_id, username).await
} }
async fn create_channel( async fn create_channel(
@ -123,11 +232,4 @@ where
.add_intro_to_guild(&req.name, req.guild_id, file_name) .add_intro_to_guild(&req.name, req.guild_id, file_name)
.await .await
} }
async fn set_user_intro(
&self,
req: guild::AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
self.repo.set_user_intro(req).await
}
} }

View File

@ -63,6 +63,7 @@ pub enum Tag {
Header6, Header6,
Strong, Strong,
Paragraph, Paragraph,
Blockquote,
JustText, JustText,
} }
@ -125,6 +126,7 @@ impl Tag {
Self::Header6 => "h6", Self::Header6 => "h6",
Self::Strong => "strong", Self::Strong => "strong",
Self::Paragraph => "paragraph", Self::Paragraph => "paragraph",
Self::Blockquote => "blockquote",
} }
} }

View File

@ -1,5 +1,5 @@
mod handlers; mod handlers;
mod page; pub(super) mod page;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
@ -49,7 +49,9 @@ impl<S: IntroToolService> FromRequestParts<ApiState<S>> for User {
{ {
Ok(user) => { Ok(user) => {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
if user.api_key_expires_at() < now || user.discord_token_expires_at() < now { if user.api_key_expires_at() < now {
//|| user.discord_token_expires_at() < now {
tracing::error!("user token expired at: {}", user.api_key_expires_at());
Err(Redirect::to(&format!("{}/login", state.origin))) Err(Redirect::to(&format!("{}/login", state.origin)))
} else { } else {
Ok(user) Ok(user)
@ -125,6 +127,7 @@ where
"/v2/intros/add/:guild_id/:channel", "/v2/intros/add/:guild_id/:channel",
post(handlers::set_user_intro), post(handlers::set_user_intro),
) )
.route("/v2/auth", get(page::auth))
// .route("/guild/:guild_id/setup", get(routes::guild_setup)) // .route("/guild/:guild_id/setup", get(routes::guild_setup))
// .route( // .route(
@ -135,7 +138,6 @@ where
// "/guild/:guild_id/permissions/update", // "/guild/:guild_id/permissions/update",
// post(routes::update_guild_permissions), // post(routes::update_guild_permissions),
// ) // )
// .route("/v2/auth", get(routes::v2_auth))
// .route( // .route(
// "/v2/intros/remove/:guild_id/:channel", // "/v2/intros/remove/:guild_id/:channel",
// post(routes::v2_remove_intro_from_user), // post(routes::v2_remove_intro_from_user),

View File

@ -1,15 +1,25 @@
use std::collections::HashMap;
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, Query, State},
response::{Html, Redirect}, response::{Html, Redirect},
}; };
use axum_extra::extract::{CookieJar, cookie::Cookie};
use reqwest::Url;
use serde::{Deserialize, Deserializer};
use crate::{ use crate::{
auth,
domain::intro_tool::{ domain::intro_tool::{
models::guild::{ChannelName, GuildRef, Intro, User}, models::guild::{ChannelName, GuildRef, Intro, User},
ports::IntroToolService, ports::IntroToolService,
}, },
htmx::{Build, HtmxBuilder, Tag}, htmx::{Build, HtmxBuilder, Tag},
inbound::{http::ApiState, response::ErrorAsRedirect}, inbound::{
http::ApiState,
response::{ApiError, ErrorAsRedirect, PageError},
},
outbound::discord::{DiscordAuthParams, DiscordService},
}; };
pub async fn home<S: IntroToolService>( pub async fn home<S: IntroToolService>(
@ -22,7 +32,7 @@ pub async fn home<S: IntroToolService>(
.intro_tool_service .intro_tool_service
.get_user_guilds(user.name()) .get_user_guilds(user.name())
.await .await
.as_redirect(&state.origin, "/login")?; .as_redirect(&state.origin, "login")?;
// TODO: get user app permissions // TODO: get user app permissions
// TODO: check if user can add guilds // TODO: check if user can add guilds
@ -78,7 +88,10 @@ pub async fn login<S: IntroToolService>(
if user.is_some() { if user.is_some() {
Err(Redirect::to(&format!("{}/", state.origin))) Err(Redirect::to(&format!("{}/", state.origin)))
} else { } else {
let authorize_uri = format!("https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}/v2/auth&response_type=code&scope=guilds.members.read+guilds+identify", state.secrets.client_id, state.origin); let authorize_uri = format!(
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}/v2/auth&response_type=code&scope=guilds.members.read+guilds+identify",
state.secrets.client_id, state.origin
);
Ok(Html( Ok(Html(
HtmxBuilder::new(Tag::Html) HtmxBuilder::new(Tag::Html)
@ -111,17 +124,17 @@ pub async fn guild_dashboard<S: IntroToolService>(
.intro_tool_service .intro_tool_service
.get_guild(guild_id) .get_guild(guild_id)
.await .await
.as_redirect(&state.origin, "/login")?; .as_redirect(&state.origin, "login")?;
let user_guilds = state let user_guilds = state
.intro_tool_service .intro_tool_service
.get_user_guilds(user.name()) .get_user_guilds(user.name())
.await .await
.as_redirect(&state.origin, "/login")?; .as_redirect(&state.origin, "login")?;
let guild_intros = state let guild_intros = state
.intro_tool_service .intro_tool_service
.get_guild_intros(guild_id.into()) .get_guild_intros(guild_id.into())
.await .await
.as_redirect(&state.origin, "/login")?; .as_redirect(&state.origin, "login")?;
// does user have access to this guild // does user have access to this guild
if !user_guilds if !user_guilds
@ -213,7 +226,38 @@ pub async fn guild_dashboard<S: IntroToolService>(
)) ))
} }
fn page_header(title: &str) -> HtmxBuilder { pub async fn auth<S: IntroToolService>(
State(state): State<ApiState<S>>,
Query(params): Query<HashMap<String, String>>,
jar: CookieJar,
) -> Result<(CookieJar, Redirect), PageError> {
let Some(code) = params.get("code") else {
return Err(ApiError::bad_request("no code").into());
};
tracing::info!("attempting to get access token with code {}", code);
let token = state
.intro_tool_service
// TODO: decoulple discord from HTTP server
.authenticate_user::<DiscordService>(DiscordAuthParams {
origin: state.origin.clone(),
code: code.clone(),
client_id: state.secrets.client_id.clone(),
client_secret: state.secrets.client_secret.clone(),
})
.await
.map_err(ApiError::from)?;
let uri = Url::parse(&state.origin).expect("should be a valid url");
let mut cookie = Cookie::new("access_token", token);
cookie.set_path(uri.path().to_string());
cookie.set_secure(true);
Ok((jar.add(cookie), Redirect::to(&format!("{}/", state.origin))))
}
pub fn page_header(title: &str) -> HtmxBuilder {
HtmxBuilder::new(Tag::Html).head(|b| { HtmxBuilder::new(Tag::Html).head(|b| {
b.title(title) b.title(title)
.script( .script(

View File

@ -1,15 +1,23 @@
use std::fmt::Debug; use std::fmt::Debug;
use axum::{ use axum::{
response::{IntoResponse, Redirect},
Json, Json,
response::{Html, IntoResponse, Redirect},
}; };
use reqwest::StatusCode; use reqwest::StatusCode;
use serde::Serialize; use serde::Serialize;
use crate::domain::intro_tool::models::guild::{ use crate::{
AddIntroToGuildError, AddIntroToUserError, GetChannelError, GetGuildError, GetIntroError, domain::intro_tool::{
GetUserError, models::guild::{
AddIntroToGuildError, AddIntroToUserError, AddUserToGuildError, AutheticateUserError,
CreateUserError, GetChannelError, GetGuildError, GetIntroError, GetUserError,
},
ports::AuthService,
},
htmx::{Build, HtmxBuilder, Tag},
inbound::http::page::page_header,
outbound::discord::DiscordError,
}; };
pub(super) trait ErrorAsRedirect<T>: Sized { pub(super) trait ErrorAsRedirect<T>: Sized {
@ -70,6 +78,32 @@ impl<T: Debug> ErrorAsRedirect<T> for Result<T, GetIntroError> {
} }
} }
pub(super) struct PageError(pub ApiError);
impl IntoResponse for PageError {
fn into_response(self) -> axum::response::Response {
Html(
page_header("MemeJoin - Error")
.builder(Tag::Div, |b| {
b.attribute("class", "container")
.builder_text(
Tag::Header2,
&format!("Uh oh! - Status Code {}", self.0.status_code()),
)
.builder(Tag::Blockquote, |b| b.text(self.0.message()))
.builder(Tag::Empty, |b| b.text("<br/>"))
.builder(Tag::Anchor, |b| {
b.attribute("role", "button")
.text("Go Back")
.attribute("href", "/")
})
})
.build(),
)
.into_response()
}
}
pub(super) struct ApiResponse<T: Serialize>(StatusCode, Json<T>); pub(super) struct ApiResponse<T: Serialize>(StatusCode, Json<T>);
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
@ -100,6 +134,15 @@ impl ApiError {
} }
} }
fn message(&self) -> &str {
match self {
ApiError::NotFound { message } => message,
ApiError::BadRequest { message } => message,
ApiError::Forbidden { message } => message,
ApiError::InternalServerError { message } => message,
}
}
pub(super) fn not_found(message: impl ToString) -> Self { pub(super) fn not_found(message: impl ToString) -> Self {
Self::NotFound { Self::NotFound {
message: message.to_string(), message: message.to_string(),
@ -133,6 +176,12 @@ impl IntoResponse for ApiError {
} }
} }
impl From<ApiError> for PageError {
fn from(value: ApiError) -> Self {
Self(value)
}
}
impl From<GetGuildError> for ApiError { impl From<GetGuildError> for ApiError {
fn from(value: GetGuildError) -> Self { fn from(value: GetGuildError) -> Self {
match value { match value {
@ -215,3 +264,75 @@ impl From<GetIntroError> for ApiError {
} }
} }
} }
impl From<CreateUserError> for ApiError {
fn from(value: CreateUserError) -> Self {
match value {
CreateUserError::CouldNotGetUser(err) => err.into(),
CreateUserError::Unknown(error) => {
tracing::error!(err = ?error, "unknown error");
Self::internal(error.to_string())
}
}
}
}
impl From<AddUserToGuildError> for ApiError {
fn from(value: AddUserToGuildError) -> Self {
match value {
AddUserToGuildError::Unknown(error) => {
tracing::error!(err = ?error, "unknown error");
Self::internal(error.to_string())
}
}
}
}
impl<A: AuthService> From<AutheticateUserError<A>> for ApiError
where
<A as AuthService>::Error: Into<ApiError>,
{
fn from(value: AutheticateUserError<A>) -> Self {
match value {
AutheticateUserError::CouldNotFetchGuild(err) => err.into(),
AutheticateUserError::CouldNotCreateUser(err) => err.into(),
AutheticateUserError::CouldNotFetchUser(err) => err.into(),
AutheticateUserError::CouldNotAddUserToGuild(err) => err.into(),
AutheticateUserError::UserNotPartOfInstanceGuilds => {
Self::internal("User not part of instance guilds")
}
AutheticateUserError::ExternalError(err) => err.into(),
AutheticateUserError::Unknown(err) => {
tracing::error!(err = ?err, "unknown error");
Self::internal(err.to_string())
}
}
}
}
impl From<DiscordError> for ApiError {
fn from(value: DiscordError) -> Self {
match value {
DiscordError::ApiRequest(error) => {
tracing::error!(err = ?error, "api request error");
Self::internal(error.to_string())
}
}
}
}
impl From<reqwest::Error> for ApiError {
fn from(value: reqwest::Error) -> Self {
Self::internal(format!("error making request to external service: {value}",))
}
}
impl From<anyhow::Error> for ApiError {
fn from(value: anyhow::Error) -> Self {
Self::internal(format!("unknown error: {value}",))
}
}

130
src/lib/outbound/discord.rs Normal file
View File

@ -0,0 +1,130 @@
use std::collections::HashMap;
use serde::{Deserialize, Deserializer};
use thiserror::Error;
use crate::domain::intro_tool::{
models::guild::{ExternalGuildId, UserName},
ports::{AuthService, ExternalUser},
};
#[derive(Clone)]
pub struct DiscordService;
#[derive(Clone)]
pub struct DiscordUser {
token: String,
username: String,
guilds: Vec<u64>,
}
#[derive(Debug, Error)]
pub enum DiscordError {
#[error(transparent)]
ApiRequest(#[from] reqwest::Error),
}
pub struct DiscordAuthParams {
pub origin: String,
pub code: String,
pub client_id: String,
pub client_secret: String,
}
#[derive(Deserialize)]
struct DiscordApiAuth {
access_token: String,
token_type: String,
expires_in: usize,
refresh_token: String,
scope: String,
}
#[derive(Deserialize)]
struct DiscordApiUser {
pub username: String,
}
#[derive(Deserialize)]
struct DiscordUserGuild {
#[serde(deserialize_with = "serde_string_as_u64")]
id: u64,
name: String,
owner: bool,
}
impl AuthService for DiscordService {
type Params = DiscordAuthParams;
type User = DiscordUser;
type Error = DiscordError;
async fn authenticate_user(params: Self::Params) -> Result<Self::User, Self::Error> {
let mut data = HashMap::new();
let redirect_uri = format!("{}/v2/auth", params.origin);
data.insert("client_id", params.client_id.as_str());
data.insert("client_secret", params.client_secret.as_str());
data.insert("grant_type", "authorization_code");
data.insert("code", &params.code);
data.insert("redirect_uri", &redirect_uri);
let client = reqwest::Client::new();
let auth: DiscordApiAuth = client
.post("https://discord.com/api/oauth2/token")
.form(&data)
.send()
.await?
.json()
.await?;
// Get authorized username
let user: DiscordApiUser = client
.get("https://discord.com/api/v10/users/@me")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await?;
// TODO: get bot's guilds so we only save users who are able to use the bot
let discord_guilds: Vec<DiscordUserGuild> = client
.get("https://discord.com/api/v10/users/@me/guilds")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await?;
Ok(Self::User {
token: auth.access_token,
username: user.username,
guilds: discord_guilds.into_iter().map(|guild| guild.id).collect(),
})
}
}
impl ExternalUser for DiscordUser {
fn external_token(&self) -> &str {
&self.token
}
fn username(&self) -> UserName {
self.username.clone().into()
}
fn guilds(&self) -> impl Iterator<Item = ExternalGuildId> {
self.guilds.iter().map(|id| ExternalGuildId(*id))
}
}
fn serde_string_as_u64<'de, D>(deserializer: D) -> Result<u64, D::Error>
where
D: Deserializer<'de>,
{
let value = <&str as Deserialize>::deserialize(deserializer)?;
value
.parse::<u64>()
.map_err(|_| serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &"u64"))
}

View File

@ -1,3 +1,4 @@
pub mod discord;
pub mod ffmpeg; pub mod ffmpeg;
pub mod sqlite; pub mod sqlite;
pub mod ytdlp; pub mod ytdlp;

View File

@ -1,3 +1,4 @@
use chrono::NaiveDateTime;
use iter_tools::Itertools; use iter_tools::Itertools;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::Mutex;
@ -63,6 +64,39 @@ impl IntroToolRepository for Sqlite {
.with_channels(self.get_guild_channels(guild_id).await?)) .with_channels(self.get_guild_channels(guild_id).await?))
} }
async fn get_guilds(&self) -> Result<Vec<GuildRef>, GetGuildError> {
let conn = self.conn.lock().await;
let mut query = conn
.prepare(
"
SELECT
Guild.id,
Guild.name,
Guild.sound_delay
FROM Guild
LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id
LEFT JOIN User ON User.username = UserGuild.username
",
)
.context("failed to prepare query")?;
let guilds = query
.query_map([], |row| {
Ok(GuildRef::new(
row.get::<_, u64>(0)?.into(),
row.get(1)?,
row.get(2)?,
row.get::<_, u64>(0)?.into(),
))
})
.context("failed to map prepared query")?
.collect::<Result<_, _>>()
.context("failed to fetch guild user rows")?;
Ok(guilds)
}
async fn get_guild_count(&self) -> Result<usize, GetGuildError> { async fn get_guild_count(&self) -> Result<usize, GetGuildError> {
let conn = self.conn.lock().await; let conn = self.conn.lock().await;
@ -117,43 +151,6 @@ impl IntroToolRepository for Sqlite {
Ok(users) Ok(users)
} }
async fn get_user_guilds(
&self,
username: impl AsRef<str>,
) -> Result<Vec<GuildRef>, GetGuildError> {
let conn = self.conn.lock().await;
let mut query = conn
.prepare(
"
SELECT
Guild.id,
Guild.name,
Guild.sound_delay
FROM Guild
LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id
LEFT JOIN User ON User.username = UserGuild.username
WHERE User.username = :username
",
)
.context("failed to prepare query")?;
let guilds = query
.query_map(&[(":username", username.as_ref())], |row| {
Ok(GuildRef::new(
row.get::<_, u64>(0)?.into(),
row.get(1)?,
row.get(2)?,
row.get::<_, u64>(0)?.into(),
))
})
.context("failed to map prepared query")?
.collect::<Result<_, _>>()
.context("failed to fetch guild user rows")?;
Ok(guilds)
}
async fn get_guild_channels(&self, guild_id: GuildId) -> Result<Vec<Channel>, GetChannelError> { async fn get_guild_channels(&self, guild_id: GuildId) -> Result<Vec<Channel>, GetChannelError> {
let conn = self.conn.lock().await; let conn = self.conn.lock().await;
@ -305,6 +302,43 @@ impl IntroToolRepository for Sqlite {
Ok(intros) Ok(intros)
} }
async fn get_user_guilds(
&self,
username: impl AsRef<str>,
) -> Result<Vec<GuildRef>, GetGuildError> {
let conn = self.conn.lock().await;
let mut query = conn
.prepare(
"
SELECT
Guild.id,
Guild.name,
Guild.sound_delay
FROM Guild
LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id
LEFT JOIN User ON User.username = UserGuild.username
WHERE User.username = :username
",
)
.context("failed to prepare query")?;
let guilds = query
.query_map(&[(":username", username.as_ref())], |row| {
Ok(GuildRef::new(
row.get::<_, u64>(0)?.into(),
row.get(1)?,
row.get(2)?,
row.get::<_, u64>(0)?.into(),
))
})
.context("failed to map prepared query")?
.collect::<Result<_, _>>()
.context("failed to fetch guild user rows")?;
Ok(guilds)
}
async fn get_user_from_api_key(&self, api_key: &str) -> Result<User, GetUserError> { async fn get_user_from_api_key(&self, api_key: &str) -> Result<User, GetUserError> {
let username = { let username = {
let conn = self.conn.lock().await; let conn = self.conn.lock().await;
@ -330,12 +364,124 @@ impl IntroToolRepository for Sqlite {
self.get_user(username).await self.get_user(username).await
} }
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError> { async fn set_user_api_key(
todo!() &self,
username: &str,
api_key: &str,
expires_at: NaiveDateTime,
) -> Result<(), GetUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
UPDATE User
SET api_key = ?1, api_key_expires_at = ?2
WHERE username = ?3
",
[api_key, &expires_at.to_string(), username],
)
.context("failed to update user api key")?;
Ok(())
} }
async fn create_user(&self, req: CreateUserRequest) -> Result<User, CreateUserError> { async fn set_user_intro(
todo!() &self,
req: AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
DELETE FROM UserIntro
WHERE username = ?1
AND guild_id = ?2
AND channel_name = ?3
",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
],
)
.context("failed to delete user intros")?;
conn.execute(
"
INSERT INTO
UserIntro (username, guild_id, channel_name, intro_id)
VALUES (?1, ?2, ?3, ?4)",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
&req.intro_id.to_string(),
],
)
.context("failed to insert user intro")?;
Ok(())
}
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError> {
let conn = self.conn.lock().await;
let guild_id: GuildId = req.external_id.0.into();
conn.execute(
"
INSERT INTO
Guild (id, name, sound_delay)
VALUES (?1, ?2, ?3)
",
[
&guild_id.to_string(),
&req.name,
&req.sound_delay.to_string(),
],
)
.context("failed to insert guild")?;
Ok(Guild::new(
guild_id,
req.name,
req.sound_delay,
req.external_id,
))
}
async fn create_user(&self, req: CreateUserRequest) -> Result<(), CreateUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
INSERT INTO
User (username)
VALUES (?1)
",
[req.user.as_ref()],
)
.context("failed to insert user")?;
Ok(())
}
async fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> Result<(), guild::AddUserToGuildError> {
let conn = self.conn.lock().await;
conn.execute(
"
INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)
",
[username, &guild_id.to_string()],
)
.context("failed to insert user guild")?;
Ok(())
} }
async fn create_channel( async fn create_channel(
@ -389,42 +535,4 @@ impl IntroToolRepository for Sqlite {
Ok(intro_id) Ok(intro_id)
} }
async fn set_user_intro(
&self,
req: AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
DELETE FROM UserIntro
WHERE username = ?1
AND guild_id = ?2
AND channel_name = ?3
",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
],
)
.context("failed to delete user intros")?;
conn.execute(
"
INSERT INTO
UserIntro (username, guild_id, channel_name, intro_id)
VALUES (?1, ?2, ?3, ?4)",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
&req.intro_id.to_string(),
],
)
.context("failed to insert user intro")?;
Ok(())
}
} }