discord auth

hexagon
Patrick Cleavelin 2025-10-18 16:16:20 -05:00
parent c9a91d3d36
commit 569d87aec4
13 changed files with 797 additions and 126 deletions

View File

@ -1,7 +1,7 @@
[package]
name = "memejoin-rs"
version = "0.2.2-alpha"
edition = "2021"
edition = "2024"
[[bin]]
name = "memejoin-rs"

View File

@ -1,3 +1,5 @@
use serde::{Deserialize, Serialize};
#[derive(Clone)]
pub struct DiscordSecret {
pub client_id: String,
@ -5,6 +7,15 @@ pub struct DiscordSecret {
pub bot_token: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Discord {
pub(crate) access_token: String,
pub(crate) token_type: String,
pub(crate) expires_in: usize,
pub(crate) refresh_token: String,
pub(crate) scope: String,
}
/*
use std::str::FromStr;

View File

@ -5,6 +5,8 @@ use crate::domain::intro_tool::{
ports::{IntroToolRepository, IntroToolService},
};
use super::ports::AuthService;
#[derive(Clone)]
pub struct DebugService<S>
where
@ -34,6 +36,13 @@ where
self.wrapped_service.needs_setup().await
}
async fn authenticate_user<A: AuthService>(
&self,
params: A::Params,
) -> Result<models::guild::ApiToken, models::guild::AutheticateUserError<A>> {
self.wrapped_service.authenticate_user(params).await
}
async fn get_guild(
&self,
guild_id: impl Into<models::guild::GuildId> + Send,
@ -41,6 +50,12 @@ where
self.wrapped_service.get_guild(guild_id).await
}
async fn get_guilds(
&self,
) -> Result<Vec<models::guild::GuildRef>, models::guild::GetGuildError> {
self.wrapped_service.get_guilds().await
}
async fn get_guild_users(
&self,
guild_id: models::guild::GuildId,
@ -88,6 +103,20 @@ where
.with_channel_intros(user.intros().clone()))
}
async fn set_user_intro(
&self,
req: models::guild::AddIntroToUserRequest,
) -> Result<(), models::guild::AddIntroToUserError> {
self.wrapped_service.set_user_intro(req).await
}
async fn refresh_user_token(
&self,
username: &str,
) -> Result<String, models::guild::GetUserError> {
self.wrapped_service.refresh_user_token(username).await
}
async fn create_guild(
&self,
req: models::guild::CreateGuildRequest,
@ -102,6 +131,16 @@ where
self.wrapped_service.create_user(req).await
}
async fn add_user_to_guild(
&self,
guild_id: models::guild::GuildId,
username: &str,
) -> Result<(), models::guild::AddUserToGuildError> {
self.wrapped_service
.add_user_to_guild(guild_id, username)
.await
}
async fn create_channel(
&self,
req: models::guild::CreateChannelRequest,
@ -115,11 +154,4 @@ where
) -> Result<IntroId, models::guild::AddIntroToGuildError> {
self.wrapped_service.add_intro_to_guild(req).await
}
async fn set_user_intro(
&self,
req: models::guild::AddIntroToUserRequest,
) -> Result<(), models::guild::AddIntroToUserError> {
self.wrapped_service.set_user_intro(req).await
}
}

View File

@ -1,13 +1,18 @@
use std::collections::HashMap;
use std::{borrow::Cow, collections::HashMap};
use chrono::NaiveDateTime;
use thiserror::Error;
use crate::domain::intro_tool::ports::AuthService;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ApiToken(String);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GuildId(u64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ExternalGuildId(u64);
pub struct ExternalGuildId(pub u64);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UserName(String);
@ -18,6 +23,18 @@ pub struct ChannelName(String);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IntroId(i32);
impl From<ApiToken> for Cow<'_, str> {
fn from(value: ApiToken) -> Self {
Cow::Owned(value.0)
}
}
impl From<String> for ApiToken {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<u64> for GuildId {
fn from(id: u64) -> Self {
Self(id)
@ -108,6 +125,10 @@ impl GuildRef {
pub fn name(&self) -> &str {
&self.name
}
pub fn external_id(&self) -> ExternalGuildId {
self.external_id
}
}
impl GuildRef {
@ -198,6 +219,10 @@ impl User {
&self.channel_intros
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn api_key_expires_at(&self) -> NaiveDateTime {
self.api_key_expires_at
}
@ -255,18 +280,18 @@ impl Intro {
}
pub struct CreateGuildRequest {
name: String,
sound_delay: u32,
external_id: ExternalGuildId,
pub name: String,
pub sound_delay: u32,
pub external_id: ExternalGuildId,
}
pub struct CreateUserRequest {
user: UserName,
pub user: UserName,
}
pub struct CreateChannelRequest {
guild_id: GuildId,
channel_name: ChannelName,
pub guild_id: GuildId,
pub channel_name: ChannelName,
}
pub struct AddIntroToGuildRequest {
@ -297,6 +322,9 @@ pub enum CreateGuildError {
#[derive(Debug, Error)]
pub enum CreateUserError {
#[error("Could not get user")]
CouldNotGetUser(#[from] GetUserError),
#[error(transparent)]
Unknown(#[from] anyhow::Error),
}
@ -307,6 +335,12 @@ pub enum CreateChannelError {
Unknown(#[from] anyhow::Error),
}
#[derive(Debug, Error)]
pub enum AddUserToGuildError {
#[error(transparent)]
Unknown(#[from] anyhow::Error),
}
#[derive(Debug, Error)]
pub enum AddIntroToGuildError {
#[error(transparent)]
@ -366,3 +400,27 @@ pub enum GetIntroError {
#[error(transparent)]
Unknown(#[from] anyhow::Error),
}
#[derive(Debug, Error)]
pub enum AutheticateUserError<A: AuthService> {
#[error("Could not fetch guild")]
CouldNotFetchGuild(#[from] GetGuildError),
#[error("Could not create user")]
CouldNotCreateUser(#[from] CreateUserError),
#[error("Could not fetch guild user")]
CouldNotFetchUser(#[from] GetUserError),
#[error("Could not add user to guild")]
CouldNotAddUserToGuild(#[from] AddUserToGuildError),
#[error("User not part of instance's guilds")]
UserNotPartOfInstanceGuilds,
#[error("Error authenticating user")]
ExternalError(A::Error),
#[error(transparent)]
Unknown(#[from] anyhow::Error),
}

View File

@ -1,21 +1,31 @@
use std::{collections::HashMap, future::Future};
use crate::domain::intro_tool::models::guild::{ChannelName, IntroId};
use chrono::NaiveDateTime;
use crate::domain::intro_tool::models::guild::{
AddUserToGuildError, AutheticateUserError, ExternalGuildId, UserName,
};
use super::models::guild::{
AddIntroToGuildError, AddIntroToGuildRequest, AddIntroToUserError, AddIntroToUserRequest,
Channel, CreateChannelError, CreateChannelRequest, CreateGuildError, CreateGuildRequest,
CreateUserError, CreateUserRequest, GetChannelError, GetGuildError, GetIntroError,
GetUserError, Guild, GuildId, GuildRef, Intro, User,
ApiToken, Channel, ChannelName, CreateChannelError, CreateChannelRequest, CreateGuildError,
CreateGuildRequest, CreateUserError, CreateUserRequest, GetChannelError, GetGuildError,
GetIntroError, GetUserError, Guild, GuildId, GuildRef, Intro, IntroId, User,
};
pub trait IntroToolService: Send + Sync + Clone + 'static {
fn needs_setup(&self) -> impl Future<Output = bool> + Send;
fn authenticate_user<A: AuthService>(
&self,
params: A::Params,
) -> impl Future<Output = Result<ApiToken, AutheticateUserError<A>>> + Send;
fn get_guild(
&self,
guild_id: impl Into<GuildId> + Send,
) -> impl Future<Output = Result<Guild, GetGuildError>> + Send;
fn get_guilds(&self) -> impl Future<Output = Result<Vec<GuildRef>, GetGuildError>> + Send;
fn get_guild_users(
&self,
guild_id: GuildId,
@ -42,8 +52,24 @@ pub trait IntroToolService: Send + Sync + Clone + 'static {
req: AddIntroToUserRequest,
) -> impl Future<Output = Result<(), AddIntroToUserError>> + Send;
fn refresh_user_token(
&self,
username: &str,
) -> impl Future<Output = Result<String, GetUserError>> + Send;
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError>;
async fn create_user(&self, req: CreateUserRequest) -> Result<User, CreateUserError>;
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<User, CreateUserError>> + Send;
fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> impl Future<Output = Result<(), AddUserToGuildError>> + Send;
async fn create_channel(
&self,
req: CreateChannelRequest,
@ -60,6 +86,7 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static {
&self,
guild_id: GuildId,
) -> impl Future<Output = Result<Guild, GetGuildError>> + Send;
fn get_guilds(&self) -> impl Future<Output = Result<Vec<GuildRef>, GetGuildError>> + Send;
fn get_guild_count(&self) -> impl Future<Output = Result<usize, GetGuildError>> + Send;
fn get_guild_users(
@ -97,13 +124,31 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static {
api_key: &str,
) -> impl Future<Output = Result<User, GetUserError>> + Send;
fn set_user_api_key(
&self,
username: &str,
api_key: &str,
expires_at: NaiveDateTime,
) -> impl Future<Output = Result<(), GetUserError>> + Send;
fn set_user_intro(
&self,
req: AddIntroToUserRequest,
) -> impl Future<Output = Result<(), AddIntroToUserError>> + Send;
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError>;
async fn create_user(&self, req: CreateUserRequest) -> Result<User, CreateUserError>;
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<(), CreateUserError>> + Send;
fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> impl Future<Output = Result<(), AddUserToGuildError>> + Send;
async fn create_channel(
&self,
req: CreateChannelRequest,
@ -117,6 +162,21 @@ pub trait IntroToolRepository: Send + Sync + Clone + 'static {
) -> impl Future<Output = Result<IntroId, AddIntroToGuildError>> + Send;
}
pub trait ExternalUser: Send + Sync + Clone + 'static {
fn external_token(&self) -> &str;
fn username(&self) -> UserName;
fn guilds(&self) -> impl Iterator<Item = ExternalGuildId>;
}
pub trait AuthService: Send + Sync + Clone + 'static {
type Params: Send;
type User: ExternalUser + Send;
type Error: std::error::Error + Send;
fn authenticate_user(
params: Self::Params,
) -> impl Future<Output = Result<Self::User, Self::Error>> + Send;
}
pub trait RemoteAudioFetcher: Send + Sync + Clone + 'static {
fn fetch_remote_audio(
&self,

View File

@ -1,10 +1,19 @@
use chrono::{Duration, Utc};
use iter_tools::Itertools;
use uuid::Uuid;
use crate::domain::intro_tool::{
models::guild::{self, GetUserError, GuildId, IntroId, User},
ports::{IntroToolRepository, IntroToolService, LocalAudioFetcher, RemoteAudioFetcher},
models::guild::{
self, ApiToken, AutheticateUserError, CreateUserRequest, GetUserError, GuildId, IntroId,
User,
},
ports::{
ExternalUser, IntroToolRepository, IntroToolService, LocalAudioFetcher, RemoteAudioFetcher,
},
};
use super::ports::AuthService;
#[derive(Clone)]
pub struct Service<R, RA, LA>
where
@ -46,6 +55,64 @@ where
guild_count == 0
}
async fn authenticate_user<A: AuthService>(
&self,
params: A::Params,
) -> Result<ApiToken, AutheticateUserError<A>> {
let external_user = A::authenticate_user(params)
.await
.map_err(AutheticateUserError::ExternalError)?;
let guilds = self.get_guilds().await?;
let external_user_guilds = guilds
.iter()
.filter(|guild| external_user.guilds().contains(&guild.external_id()))
.collect::<Vec<_>>();
if external_user_guilds.is_empty() {
return Err(AutheticateUserError::UserNotPartOfInstanceGuilds);
}
let user = match self.get_user(external_user.username()).await {
Ok(user) => Some(user),
Err(GetUserError::NotFound) => None,
Err(err) => return Err(AutheticateUserError::CouldNotFetchUser(err)),
};
match user {
Some(user) => {
self.refresh_user_token(user.name()).await?;
}
None => {
self.create_user(CreateUserRequest {
user: external_user.username().clone(),
})
.await?;
}
}
let user = self.get_user(external_user.username()).await?;
let user_guilds = self.get_user_guilds(user.name()).await?;
let guilds_to_add_user =
user_guilds
.iter()
.map(|guild| guild.id())
.filter(|user_guild_id| {
external_user_guilds
.iter()
.map(|external_guild| external_guild.id())
.contains(user_guild_id)
});
for guild in guilds_to_add_user {
self.add_user_to_guild(guild, user.name()).await?;
}
Ok(user.api_key().to_string().into())
}
async fn get_guild(
&self,
guild_id: impl Into<GuildId>,
@ -53,9 +120,14 @@ where
self.repo.get_guild(guild_id.into()).await
}
async fn get_guilds(&self) -> Result<Vec<guild::GuildRef>, guild::GetGuildError> {
self.repo.get_guilds().await
}
async fn get_guild_users(&self, guild_id: GuildId) -> Result<Vec<User>, GetUserError> {
self.repo.get_guild_users(guild_id).await
}
async fn get_guild_intros(
&self,
guild_id: GuildId,
@ -81,6 +153,31 @@ where
self.repo.get_user_from_api_key(api_key).await
}
async fn set_user_intro(
&self,
req: guild::AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
self.repo.set_user_intro(req).await
}
async fn refresh_user_token(&self, username: &str) -> Result<String, GetUserError> {
let user = self.get_user(username).await?;
let user_token = if user.api_key_expires_at() >= Utc::now().naive_utc() {
user.api_key().to_string()
} else {
Uuid::new_v4().to_string()
};
let expires_at = Utc::now().naive_utc() + Duration::weeks(4);
self.repo
.set_user_api_key(username, &user_token, expires_at)
.await?;
Ok(user_token)
}
async fn create_guild(
&self,
req: guild::CreateGuildRequest,
@ -92,7 +189,19 @@ where
&self,
req: guild::CreateUserRequest,
) -> Result<guild::User, guild::CreateUserError> {
self.repo.create_user(req).await
let username = req.user.clone();
self.repo.create_user(req).await?;
Ok(self.get_user(username.as_ref()).await?)
}
async fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> Result<(), guild::AddUserToGuildError> {
self.repo.add_user_to_guild(guild_id, username).await
}
async fn create_channel(
@ -123,11 +232,4 @@ where
.add_intro_to_guild(&req.name, req.guild_id, file_name)
.await
}
async fn set_user_intro(
&self,
req: guild::AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
self.repo.set_user_intro(req).await
}
}

View File

@ -63,6 +63,7 @@ pub enum Tag {
Header6,
Strong,
Paragraph,
Blockquote,
JustText,
}
@ -125,6 +126,7 @@ impl Tag {
Self::Header6 => "h6",
Self::Strong => "strong",
Self::Paragraph => "paragraph",
Self::Blockquote => "blockquote",
}
}

View File

@ -1,5 +1,5 @@
mod handlers;
mod page;
pub(super) mod page;
use std::{net::SocketAddr, sync::Arc};
@ -49,7 +49,9 @@ impl<S: IntroToolService> FromRequestParts<ApiState<S>> for User {
{
Ok(user) => {
let now = Utc::now().naive_utc();
if user.api_key_expires_at() < now || user.discord_token_expires_at() < now {
if user.api_key_expires_at() < now {
//|| user.discord_token_expires_at() < now {
tracing::error!("user token expired at: {}", user.api_key_expires_at());
Err(Redirect::to(&format!("{}/login", state.origin)))
} else {
Ok(user)
@ -125,6 +127,7 @@ where
"/v2/intros/add/:guild_id/:channel",
post(handlers::set_user_intro),
)
.route("/v2/auth", get(page::auth))
// .route("/guild/:guild_id/setup", get(routes::guild_setup))
// .route(
@ -135,7 +138,6 @@ where
// "/guild/:guild_id/permissions/update",
// post(routes::update_guild_permissions),
// )
// .route("/v2/auth", get(routes::v2_auth))
// .route(
// "/v2/intros/remove/:guild_id/:channel",
// post(routes::v2_remove_intro_from_user),

View File

@ -1,15 +1,25 @@
use std::collections::HashMap;
use axum::{
extract::{Path, State},
extract::{Path, Query, State},
response::{Html, Redirect},
};
use axum_extra::extract::{CookieJar, cookie::Cookie};
use reqwest::Url;
use serde::{Deserialize, Deserializer};
use crate::{
auth,
domain::intro_tool::{
models::guild::{ChannelName, GuildRef, Intro, User},
ports::IntroToolService,
},
htmx::{Build, HtmxBuilder, Tag},
inbound::{http::ApiState, response::ErrorAsRedirect},
inbound::{
http::ApiState,
response::{ApiError, ErrorAsRedirect, PageError},
},
outbound::discord::{DiscordAuthParams, DiscordService},
};
pub async fn home<S: IntroToolService>(
@ -22,7 +32,7 @@ pub async fn home<S: IntroToolService>(
.intro_tool_service
.get_user_guilds(user.name())
.await
.as_redirect(&state.origin, "/login")?;
.as_redirect(&state.origin, "login")?;
// TODO: get user app permissions
// TODO: check if user can add guilds
@ -78,7 +88,10 @@ pub async fn login<S: IntroToolService>(
if user.is_some() {
Err(Redirect::to(&format!("{}/", state.origin)))
} else {
let authorize_uri = format!("https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}/v2/auth&response_type=code&scope=guilds.members.read+guilds+identify", state.secrets.client_id, state.origin);
let authorize_uri = format!(
"https://discord.com/api/oauth2/authorize?client_id={}&redirect_uri={}/v2/auth&response_type=code&scope=guilds.members.read+guilds+identify",
state.secrets.client_id, state.origin
);
Ok(Html(
HtmxBuilder::new(Tag::Html)
@ -111,17 +124,17 @@ pub async fn guild_dashboard<S: IntroToolService>(
.intro_tool_service
.get_guild(guild_id)
.await
.as_redirect(&state.origin, "/login")?;
.as_redirect(&state.origin, "login")?;
let user_guilds = state
.intro_tool_service
.get_user_guilds(user.name())
.await
.as_redirect(&state.origin, "/login")?;
.as_redirect(&state.origin, "login")?;
let guild_intros = state
.intro_tool_service
.get_guild_intros(guild_id.into())
.await
.as_redirect(&state.origin, "/login")?;
.as_redirect(&state.origin, "login")?;
// does user have access to this guild
if !user_guilds
@ -213,7 +226,38 @@ pub async fn guild_dashboard<S: IntroToolService>(
))
}
fn page_header(title: &str) -> HtmxBuilder {
pub async fn auth<S: IntroToolService>(
State(state): State<ApiState<S>>,
Query(params): Query<HashMap<String, String>>,
jar: CookieJar,
) -> Result<(CookieJar, Redirect), PageError> {
let Some(code) = params.get("code") else {
return Err(ApiError::bad_request("no code").into());
};
tracing::info!("attempting to get access token with code {}", code);
let token = state
.intro_tool_service
// TODO: decoulple discord from HTTP server
.authenticate_user::<DiscordService>(DiscordAuthParams {
origin: state.origin.clone(),
code: code.clone(),
client_id: state.secrets.client_id.clone(),
client_secret: state.secrets.client_secret.clone(),
})
.await
.map_err(ApiError::from)?;
let uri = Url::parse(&state.origin).expect("should be a valid url");
let mut cookie = Cookie::new("access_token", token);
cookie.set_path(uri.path().to_string());
cookie.set_secure(true);
Ok((jar.add(cookie), Redirect::to(&format!("{}/", state.origin))))
}
pub fn page_header(title: &str) -> HtmxBuilder {
HtmxBuilder::new(Tag::Html).head(|b| {
b.title(title)
.script(

View File

@ -1,15 +1,23 @@
use std::fmt::Debug;
use axum::{
response::{IntoResponse, Redirect},
Json,
response::{Html, IntoResponse, Redirect},
};
use reqwest::StatusCode;
use serde::Serialize;
use crate::domain::intro_tool::models::guild::{
AddIntroToGuildError, AddIntroToUserError, GetChannelError, GetGuildError, GetIntroError,
GetUserError,
use crate::{
domain::intro_tool::{
models::guild::{
AddIntroToGuildError, AddIntroToUserError, AddUserToGuildError, AutheticateUserError,
CreateUserError, GetChannelError, GetGuildError, GetIntroError, GetUserError,
},
ports::AuthService,
},
htmx::{Build, HtmxBuilder, Tag},
inbound::http::page::page_header,
outbound::discord::DiscordError,
};
pub(super) trait ErrorAsRedirect<T>: Sized {
@ -70,6 +78,32 @@ impl<T: Debug> ErrorAsRedirect<T> for Result<T, GetIntroError> {
}
}
pub(super) struct PageError(pub ApiError);
impl IntoResponse for PageError {
fn into_response(self) -> axum::response::Response {
Html(
page_header("MemeJoin - Error")
.builder(Tag::Div, |b| {
b.attribute("class", "container")
.builder_text(
Tag::Header2,
&format!("Uh oh! - Status Code {}", self.0.status_code()),
)
.builder(Tag::Blockquote, |b| b.text(self.0.message()))
.builder(Tag::Empty, |b| b.text("<br/>"))
.builder(Tag::Anchor, |b| {
b.attribute("role", "button")
.text("Go Back")
.attribute("href", "/")
})
})
.build(),
)
.into_response()
}
}
pub(super) struct ApiResponse<T: Serialize>(StatusCode, Json<T>);
#[derive(Serialize, Debug)]
@ -100,6 +134,15 @@ impl ApiError {
}
}
fn message(&self) -> &str {
match self {
ApiError::NotFound { message } => message,
ApiError::BadRequest { message } => message,
ApiError::Forbidden { message } => message,
ApiError::InternalServerError { message } => message,
}
}
pub(super) fn not_found(message: impl ToString) -> Self {
Self::NotFound {
message: message.to_string(),
@ -133,6 +176,12 @@ impl IntoResponse for ApiError {
}
}
impl From<ApiError> for PageError {
fn from(value: ApiError) -> Self {
Self(value)
}
}
impl From<GetGuildError> for ApiError {
fn from(value: GetGuildError) -> Self {
match value {
@ -215,3 +264,75 @@ impl From<GetIntroError> for ApiError {
}
}
}
impl From<CreateUserError> for ApiError {
fn from(value: CreateUserError) -> Self {
match value {
CreateUserError::CouldNotGetUser(err) => err.into(),
CreateUserError::Unknown(error) => {
tracing::error!(err = ?error, "unknown error");
Self::internal(error.to_string())
}
}
}
}
impl From<AddUserToGuildError> for ApiError {
fn from(value: AddUserToGuildError) -> Self {
match value {
AddUserToGuildError::Unknown(error) => {
tracing::error!(err = ?error, "unknown error");
Self::internal(error.to_string())
}
}
}
}
impl<A: AuthService> From<AutheticateUserError<A>> for ApiError
where
<A as AuthService>::Error: Into<ApiError>,
{
fn from(value: AutheticateUserError<A>) -> Self {
match value {
AutheticateUserError::CouldNotFetchGuild(err) => err.into(),
AutheticateUserError::CouldNotCreateUser(err) => err.into(),
AutheticateUserError::CouldNotFetchUser(err) => err.into(),
AutheticateUserError::CouldNotAddUserToGuild(err) => err.into(),
AutheticateUserError::UserNotPartOfInstanceGuilds => {
Self::internal("User not part of instance guilds")
}
AutheticateUserError::ExternalError(err) => err.into(),
AutheticateUserError::Unknown(err) => {
tracing::error!(err = ?err, "unknown error");
Self::internal(err.to_string())
}
}
}
}
impl From<DiscordError> for ApiError {
fn from(value: DiscordError) -> Self {
match value {
DiscordError::ApiRequest(error) => {
tracing::error!(err = ?error, "api request error");
Self::internal(error.to_string())
}
}
}
}
impl From<reqwest::Error> for ApiError {
fn from(value: reqwest::Error) -> Self {
Self::internal(format!("error making request to external service: {value}",))
}
}
impl From<anyhow::Error> for ApiError {
fn from(value: anyhow::Error) -> Self {
Self::internal(format!("unknown error: {value}",))
}
}

130
src/lib/outbound/discord.rs Normal file
View File

@ -0,0 +1,130 @@
use std::collections::HashMap;
use serde::{Deserialize, Deserializer};
use thiserror::Error;
use crate::domain::intro_tool::{
models::guild::{ExternalGuildId, UserName},
ports::{AuthService, ExternalUser},
};
#[derive(Clone)]
pub struct DiscordService;
#[derive(Clone)]
pub struct DiscordUser {
token: String,
username: String,
guilds: Vec<u64>,
}
#[derive(Debug, Error)]
pub enum DiscordError {
#[error(transparent)]
ApiRequest(#[from] reqwest::Error),
}
pub struct DiscordAuthParams {
pub origin: String,
pub code: String,
pub client_id: String,
pub client_secret: String,
}
#[derive(Deserialize)]
struct DiscordApiAuth {
access_token: String,
token_type: String,
expires_in: usize,
refresh_token: String,
scope: String,
}
#[derive(Deserialize)]
struct DiscordApiUser {
pub username: String,
}
#[derive(Deserialize)]
struct DiscordUserGuild {
#[serde(deserialize_with = "serde_string_as_u64")]
id: u64,
name: String,
owner: bool,
}
impl AuthService for DiscordService {
type Params = DiscordAuthParams;
type User = DiscordUser;
type Error = DiscordError;
async fn authenticate_user(params: Self::Params) -> Result<Self::User, Self::Error> {
let mut data = HashMap::new();
let redirect_uri = format!("{}/v2/auth", params.origin);
data.insert("client_id", params.client_id.as_str());
data.insert("client_secret", params.client_secret.as_str());
data.insert("grant_type", "authorization_code");
data.insert("code", &params.code);
data.insert("redirect_uri", &redirect_uri);
let client = reqwest::Client::new();
let auth: DiscordApiAuth = client
.post("https://discord.com/api/oauth2/token")
.form(&data)
.send()
.await?
.json()
.await?;
// Get authorized username
let user: DiscordApiUser = client
.get("https://discord.com/api/v10/users/@me")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await?;
// TODO: get bot's guilds so we only save users who are able to use the bot
let discord_guilds: Vec<DiscordUserGuild> = client
.get("https://discord.com/api/v10/users/@me/guilds")
.bearer_auth(&auth.access_token)
.send()
.await?
.json()
.await?;
Ok(Self::User {
token: auth.access_token,
username: user.username,
guilds: discord_guilds.into_iter().map(|guild| guild.id).collect(),
})
}
}
impl ExternalUser for DiscordUser {
fn external_token(&self) -> &str {
&self.token
}
fn username(&self) -> UserName {
self.username.clone().into()
}
fn guilds(&self) -> impl Iterator<Item = ExternalGuildId> {
self.guilds.iter().map(|id| ExternalGuildId(*id))
}
}
fn serde_string_as_u64<'de, D>(deserializer: D) -> Result<u64, D::Error>
where
D: Deserializer<'de>,
{
let value = <&str as Deserialize>::deserialize(deserializer)?;
value
.parse::<u64>()
.map_err(|_| serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &"u64"))
}

View File

@ -1,3 +1,4 @@
pub mod discord;
pub mod ffmpeg;
pub mod sqlite;
pub mod ytdlp;

View File

@ -1,3 +1,4 @@
use chrono::NaiveDateTime;
use iter_tools::Itertools;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
@ -63,6 +64,39 @@ impl IntroToolRepository for Sqlite {
.with_channels(self.get_guild_channels(guild_id).await?))
}
async fn get_guilds(&self) -> Result<Vec<GuildRef>, GetGuildError> {
let conn = self.conn.lock().await;
let mut query = conn
.prepare(
"
SELECT
Guild.id,
Guild.name,
Guild.sound_delay
FROM Guild
LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id
LEFT JOIN User ON User.username = UserGuild.username
",
)
.context("failed to prepare query")?;
let guilds = query
.query_map([], |row| {
Ok(GuildRef::new(
row.get::<_, u64>(0)?.into(),
row.get(1)?,
row.get(2)?,
row.get::<_, u64>(0)?.into(),
))
})
.context("failed to map prepared query")?
.collect::<Result<_, _>>()
.context("failed to fetch guild user rows")?;
Ok(guilds)
}
async fn get_guild_count(&self) -> Result<usize, GetGuildError> {
let conn = self.conn.lock().await;
@ -117,43 +151,6 @@ impl IntroToolRepository for Sqlite {
Ok(users)
}
async fn get_user_guilds(
&self,
username: impl AsRef<str>,
) -> Result<Vec<GuildRef>, GetGuildError> {
let conn = self.conn.lock().await;
let mut query = conn
.prepare(
"
SELECT
Guild.id,
Guild.name,
Guild.sound_delay
FROM Guild
LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id
LEFT JOIN User ON User.username = UserGuild.username
WHERE User.username = :username
",
)
.context("failed to prepare query")?;
let guilds = query
.query_map(&[(":username", username.as_ref())], |row| {
Ok(GuildRef::new(
row.get::<_, u64>(0)?.into(),
row.get(1)?,
row.get(2)?,
row.get::<_, u64>(0)?.into(),
))
})
.context("failed to map prepared query")?
.collect::<Result<_, _>>()
.context("failed to fetch guild user rows")?;
Ok(guilds)
}
async fn get_guild_channels(&self, guild_id: GuildId) -> Result<Vec<Channel>, GetChannelError> {
let conn = self.conn.lock().await;
@ -305,6 +302,43 @@ impl IntroToolRepository for Sqlite {
Ok(intros)
}
async fn get_user_guilds(
&self,
username: impl AsRef<str>,
) -> Result<Vec<GuildRef>, GetGuildError> {
let conn = self.conn.lock().await;
let mut query = conn
.prepare(
"
SELECT
Guild.id,
Guild.name,
Guild.sound_delay
FROM Guild
LEFT JOIN UserGuild ON Guild.id = UserGuild.guild_id
LEFT JOIN User ON User.username = UserGuild.username
WHERE User.username = :username
",
)
.context("failed to prepare query")?;
let guilds = query
.query_map(&[(":username", username.as_ref())], |row| {
Ok(GuildRef::new(
row.get::<_, u64>(0)?.into(),
row.get(1)?,
row.get(2)?,
row.get::<_, u64>(0)?.into(),
))
})
.context("failed to map prepared query")?
.collect::<Result<_, _>>()
.context("failed to fetch guild user rows")?;
Ok(guilds)
}
async fn get_user_from_api_key(&self, api_key: &str) -> Result<User, GetUserError> {
let username = {
let conn = self.conn.lock().await;
@ -330,12 +364,124 @@ impl IntroToolRepository for Sqlite {
self.get_user(username).await
}
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError> {
todo!()
async fn set_user_api_key(
&self,
username: &str,
api_key: &str,
expires_at: NaiveDateTime,
) -> Result<(), GetUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
UPDATE User
SET api_key = ?1, api_key_expires_at = ?2
WHERE username = ?3
",
[api_key, &expires_at.to_string(), username],
)
.context("failed to update user api key")?;
Ok(())
}
async fn create_user(&self, req: CreateUserRequest) -> Result<User, CreateUserError> {
todo!()
async fn set_user_intro(
&self,
req: AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
DELETE FROM UserIntro
WHERE username = ?1
AND guild_id = ?2
AND channel_name = ?3
",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
],
)
.context("failed to delete user intros")?;
conn.execute(
"
INSERT INTO
UserIntro (username, guild_id, channel_name, intro_id)
VALUES (?1, ?2, ?3, ?4)",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
&req.intro_id.to_string(),
],
)
.context("failed to insert user intro")?;
Ok(())
}
async fn create_guild(&self, req: CreateGuildRequest) -> Result<Guild, CreateGuildError> {
let conn = self.conn.lock().await;
let guild_id: GuildId = req.external_id.0.into();
conn.execute(
"
INSERT INTO
Guild (id, name, sound_delay)
VALUES (?1, ?2, ?3)
",
[
&guild_id.to_string(),
&req.name,
&req.sound_delay.to_string(),
],
)
.context("failed to insert guild")?;
Ok(Guild::new(
guild_id,
req.name,
req.sound_delay,
req.external_id,
))
}
async fn create_user(&self, req: CreateUserRequest) -> Result<(), CreateUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
INSERT INTO
User (username)
VALUES (?1)
",
[req.user.as_ref()],
)
.context("failed to insert user")?;
Ok(())
}
async fn add_user_to_guild(
&self,
guild_id: GuildId,
username: &str,
) -> Result<(), guild::AddUserToGuildError> {
let conn = self.conn.lock().await;
conn.execute(
"
INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)
",
[username, &guild_id.to_string()],
)
.context("failed to insert user guild")?;
Ok(())
}
async fn create_channel(
@ -389,42 +535,4 @@ impl IntroToolRepository for Sqlite {
Ok(intro_id)
}
async fn set_user_intro(
&self,
req: AddIntroToUserRequest,
) -> Result<(), guild::AddIntroToUserError> {
let conn = self.conn.lock().await;
conn.execute(
"
DELETE FROM UserIntro
WHERE username = ?1
AND guild_id = ?2
AND channel_name = ?3
",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
],
)
.context("failed to delete user intros")?;
conn.execute(
"
INSERT INTO
UserIntro (username, guild_id, channel_name, intro_id)
VALUES (?1, ?2, ?3, ?4)",
[
&req.user.to_string(),
&req.guild_id.to_string(),
&req.channel_name.to_string(),
&req.intro_id.to_string(),
],
)
.context("failed to insert user intro")?;
Ok(())
}
}