From 2e1d41b2cd41667b6f2381c2d382468b773248bc Mon Sep 17 00:00:00 2001 From: Patrick Cleavelin Date: Sun, 6 Aug 2023 19:36:08 -0500 Subject: [PATCH] switch to db for auth --- Cargo.lock | 2 + Cargo.toml | 4 +- src/db.rs | 149 +++++++++++++++++++++++++++++++++++++++++++++++- src/page.rs | 4 +- src/routes.rs | 68 +++++++++++++--------- src/settings.rs | 19 ++++-- 6 files changed, 206 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 04daab9..1a1f2cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -330,6 +330,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits 0.2.16", + "serde", "time 0.1.45", "wasm-bindgen", "winapi", @@ -1602,6 +1603,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2" dependencies = [ "bitflags 2.3.3", + "chrono", "fallible-iterator", "fallible-streaming-iterator", "hashlink", diff --git a/Cargo.toml b/Cargo.toml index 2276373..984f717 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,12 @@ edition = "2021" async-trait = "0.1.72" axum = { version = "0.6.9", features = ["headers", "multipart"] } axum-extra = { version = "0.7.5", features = ["cookie-private", "cookie"] } -chrono = "0.4.23" +chrono = { version = "0.4.23", features = ["serde"] } dotenv = "0.15.0" futures = "0.3.26" iter_tools = "0.1.4" reqwest = "0.11.14" -rusqlite = { version = "0.29.0", features = ["bundled"] } +rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] } serde = "1.0.152" serde_json = "1.0.93" thiserror = "1.0.38" diff --git a/src/db.rs b/src/db.rs index a11bfed..9a60249 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,7 +1,9 @@ use std::path::Path; +use chrono::NaiveDateTime; use iter_tools::Itertools; -use rusqlite::{Connection, Result}; +use rusqlite::{Connection, OptionalExtension, Result}; +use serde::{Deserialize, Serialize}; use tracing::{error, warn}; use crate::auth; @@ -17,6 +19,75 @@ impl Database { }) } + pub(crate) fn get_guilds(&self) -> Result> { + let mut query = self.conn.prepare( + " + SELECT + id, name, soundDelay + FROM Guild + ", + )?; + + // NOTE(pcleavelin): for some reason this needs to be a let-binding or else + // the compiler complains about it being dropped too early (maybe I should update the compiler version) + let guilds = query + .query_map([], |row| { + Ok(Guild { + id: row.get(0)?, + name: row.get(1)?, + sound_delay: row.get(2)?, + }) + })? + .into_iter() + .collect::>>(); + + guilds + } + + pub(crate) fn get_user_from_api_key(&self, api_key: &str) -> Result { + self.conn.query_row( + " + SELECT + username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at + FROM User + WHERE api_key = ?1 + ", + [api_key], + |row| { + Ok(User { + name: row.get(0)?, + api_key: row.get(1)?, + api_key_expires_at: row.get(2)?, + discord_token: row.get(3)?, + discord_token_expires_at: row.get(4)?, + }) + }, + ) + } + + pub(crate) fn get_user(&self, username: &str) -> Result> { + self.conn + .query_row( + " + SELECT + username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at + FROM User + WHERE name = ?1 + ", + [username], + |row| { + Ok(User { + name: row.get(0)?, + api_key: row.get(1)?, + api_key_expires_at: row.get(2)?, + discord_token: row.get(3)?, + discord_token_expires_at: row.get(4)?, + }) + }, + ) + .optional() + } + pub fn get_user_guilds(&self, username: &str) -> Result> { let mut query = self.conn.prepare( " @@ -130,8 +201,9 @@ impl Database { FROM UserPermission WHERE username = ?1 + AND guild_id = ?2 ", - [username], + [username, &guild_id.to_string()], |row| Ok(auth::Permissions(row.get(0)?)), ) } @@ -180,6 +252,48 @@ impl Database { Ok(intros) } + pub fn add_user( + &self, + username: &str, + api_key: &str, + api_key_expires_at: NaiveDateTime, + discord_token: &str, + discord_token_expires_at: NaiveDateTime, + ) -> Result<()> { + let affected = self.conn.execute( + "INSERT INTO + User (username, api_key, api_key_expires_at, discord_token, discord_token_expires_at) + VALUES (?1, ?2, ?3, ?4, ?5) + ON CONFLICT(username) DO UPDATE SET api_key = ?2, api_key_expires_at = ?3, discord_token = ?4, discord_token_expires_at = ?5", + &[ + username, + api_key, + &api_key_expires_at.to_string(), + discord_token, + &discord_token_expires_at.to_string(), + ], + )?; + + if affected < 1 { + warn!("no rows affected when attempting to insert new user"); + } + + Ok(()) + } + + pub fn insert_user_guild(&self, username: &str, guild_id: u64) -> Result<()> { + let affected = self.conn.execute( + "INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)", + &[username, &guild_id.to_string()], + )?; + + if affected < 1 { + warn!("no rows affected when attempting to insert user guild"); + } + + Ok(()) + } + pub fn insert_user_intro( &self, username: &str, @@ -204,6 +318,28 @@ impl Database { Ok(()) } + pub(crate) fn insert_user_permission( + &self, + username: &str, + guild_id: u64, + permissions: auth::Permissions, + ) -> Result<()> { + let affected = self.conn.execute( + " + INSERT INTO + UserPermission (username, guild_id, permissions) + VALUES (?1, ?2, ?3) + ON CONFLICT(username, guild_id) DO UPDATE SET permissions = ?3", + &[username, &guild_id.to_string(), &permissions.0.to_string()], + )?; + + if affected < 1 { + warn!("no rows affected when attempting to insert user permissions"); + } + + Ok(()) + } + pub fn remove_user_intro( &self, username: &str, @@ -241,6 +377,15 @@ pub struct Guild { pub sound_delay: u32, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct User { + pub name: String, + pub api_key: String, + pub api_key_expires_at: NaiveDateTime, + pub discord_token: String, + pub discord_token_expires_at: NaiveDateTime, +} + pub struct Intro { pub id: i32, pub name: String, diff --git a/src/page.rs b/src/page.rs index 6ce9166..88315e3 100644 --- a/src/page.rs +++ b/src/page.rs @@ -1,6 +1,6 @@ use crate::{ - auth::{self, User}, - db, + auth::{self}, + db::{self, User}, htmx::{Build, HtmxBuilder, Tag}, settings::{ApiState, GuildSettings, Intro, IntroFriendlyName}, }; diff --git a/src/routes.rs b/src/routes.rs index a30bf8e..dd1dfdf 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, ops::Add, sync::Arc}; use axum::{ body::Bytes, @@ -9,15 +9,17 @@ use axum::{ }; use axum_extra::extract::{cookie::Cookie, CookieJar}; +use chrono::{Duration, NaiveDate, Utc}; use iter_tools::Itertools; use reqwest::{Proxy, StatusCode, Url}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use tracing::{error, info}; +use tracing::{error, info, log::trace}; use uuid::Uuid; use crate::{ auth::{self, User}, + db, htmx::Build, page, settings::FileIntro, @@ -192,43 +194,53 @@ pub(crate) async fn v2_auth( .await .map_err(|err| Error::Auth(err.to_string()))?; - let mut settings = state.settings.lock().await; + let db = state.db.lock().await; + + let guilds = db.get_guilds().map_err(Error::Database)?; let mut in_a_guild = false; - for g in settings.guilds.iter_mut() { + for guild in guilds { let Some(discord_guild) = discord_guilds .iter() - .find(|discord_guild| discord_guild.id == g.0.to_string()) + .find(|discord_guild| discord_guild.id == guild.id) else { continue; }; in_a_guild = true; - if !g.1.users.contains_key(&user.username) { - g.1.users.insert( - user.username.clone(), - GuildUser { - permissions: if discord_guild.owner { - auth::Permissions(auth::Permission::all()) - } else { - Default::default() - }, - }, - ); - } + // TODO: change this + let guild_id = guild.id.parse::().expect("guild id should be u64"); + + let now = Utc::now().naive_utc(); + db.add_user( + &user.username, + &token, + now + Duration::weeks(4), + &auth.access_token, + now + Duration::seconds(auth.expires_in as i64), + ) + .map_err(Error::Database)?; + + db.insert_user_guild(&user.username, guild_id) + .map_err(Error::Database)?; + + // TODO: Don't reset permissions + db.insert_user_permission( + &user.username, + guild_id, + if discord_guild.owner { + auth::Permissions(auth::Permission::all()) + } else { + Default::default() + }, + ) + .map_err(Error::Database)?; } if !in_a_guild { return Err(Error::NoGuildFound); } - settings.auth_users.insert( - token.clone(), - auth::User { - auth, - name: user.username.clone(), - }, - ); // TODO: add permissions based on roles let uri = Url::parse(&state.origin).expect("should be a valid url"); @@ -336,7 +348,7 @@ pub(crate) async fn auth( pub(crate) async fn v2_add_intro_to_user( State(state): State, Path((guild_id, channel)): Path<(u64, String)>, - user: User, + user: db::User, mut form_data: Multipart, ) -> Result, Redirect> { let db = state.db.lock().await; @@ -389,7 +401,7 @@ pub(crate) async fn v2_add_intro_to_user( pub(crate) async fn v2_remove_intro_from_user( State(state): State, Path((guild_id, channel)): Path<(u64, String)>, - user: User, + user: db::User, mut form_data: Multipart, ) -> Result, Redirect> { let db = state.db.lock().await; @@ -641,7 +653,7 @@ pub(crate) async fn upload_guild_intro( pub(crate) async fn v2_upload_guild_intro( State(state): State, Path(guild): Path, - user: User, + user: db::User, mut form_data: Multipart, ) -> Result { let mut settings = state.settings.lock().await; @@ -778,7 +790,7 @@ pub(crate) async fn v2_add_guild_intro( State(state): State, Path(guild): Path, Query(mut params): Query>, - user: User, + user: db::User, ) -> Result { let mut settings = state.settings.lock().await; let Some(url) = params.remove("url") else { diff --git a/src/settings.rs b/src/settings.rs index 58ecdea..ef58c30 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,11 +1,14 @@ use std::{collections::HashMap, sync::Arc}; -use crate::{auth, db::Database}; +use crate::{ + auth, + db::{self, Database}, +}; use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::Redirect}; use axum_extra::extract::CookieJar; use serde::{Deserialize, Serialize}; use serenity::prelude::TypeMapKey; -use tracing::trace; +use tracing::{error, trace}; use uuid::Uuid; type UserToken = String; @@ -20,7 +23,7 @@ pub(crate) struct ApiState { } #[async_trait] -impl FromRequestParts for crate::auth::User { +impl FromRequestParts for db::User { type Rejection = Redirect; async fn from_request_parts( @@ -30,10 +33,14 @@ impl FromRequestParts for crate::auth::User { let jar = CookieJar::from_headers(&headers); if let Some(token) = jar.get("access_token") { - match state.settings.lock().await.auth_users.get(token.value()) { + match state.db.lock().await.get_user_from_api_key(token.value()) { // :vomit: - Some(user) => Ok(user.clone()), - None => Err(Redirect::to("/login")), + Ok(user) => Ok(user), + Err(err) => { + error!(?err, "failed to authenticate user"); + + Err(Redirect::to("/login")) + } } } else { Err(Redirect::to("/login"))