switch to db for auth

pull/11/head
Patrick Cleavelin 2023-08-06 19:36:08 -05:00
parent 52d7cc7ded
commit 2e1d41b2cd
6 changed files with 206 additions and 40 deletions

2
Cargo.lock generated
View File

@ -330,6 +330,7 @@ dependencies = [
"iana-time-zone",
"js-sys",
"num-traits 0.2.16",
"serde",
"time 0.1.45",
"wasm-bindgen",
"winapi",
@ -1602,6 +1603,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
dependencies = [
"bitflags 2.3.3",
"chrono",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",

View File

@ -9,12 +9,12 @@ edition = "2021"
async-trait = "0.1.72"
axum = { version = "0.6.9", features = ["headers", "multipart"] }
axum-extra = { version = "0.7.5", features = ["cookie-private", "cookie"] }
chrono = "0.4.23"
chrono = { version = "0.4.23", features = ["serde"] }
dotenv = "0.15.0"
futures = "0.3.26"
iter_tools = "0.1.4"
reqwest = "0.11.14"
rusqlite = { version = "0.29.0", features = ["bundled"] }
rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] }
serde = "1.0.152"
serde_json = "1.0.93"
thiserror = "1.0.38"

149
src/db.rs
View File

@ -1,7 +1,9 @@
use std::path::Path;
use chrono::NaiveDateTime;
use iter_tools::Itertools;
use rusqlite::{Connection, Result};
use rusqlite::{Connection, OptionalExtension, Result};
use serde::{Deserialize, Serialize};
use tracing::{error, warn};
use crate::auth;
@ -17,6 +19,75 @@ impl Database {
})
}
pub(crate) fn get_guilds(&self) -> Result<Vec<Guild>> {
let mut query = self.conn.prepare(
"
SELECT
id, name, soundDelay
FROM Guild
",
)?;
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
let guilds = query
.query_map([], |row| {
Ok(Guild {
id: row.get(0)?,
name: row.get(1)?,
sound_delay: row.get(2)?,
})
})?
.into_iter()
.collect::<Result<Vec<Guild>>>();
guilds
}
pub(crate) fn get_user_from_api_key(&self, api_key: &str) -> Result<User> {
self.conn.query_row(
"
SELECT
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
FROM User
WHERE api_key = ?1
",
[api_key],
|row| {
Ok(User {
name: row.get(0)?,
api_key: row.get(1)?,
api_key_expires_at: row.get(2)?,
discord_token: row.get(3)?,
discord_token_expires_at: row.get(4)?,
})
},
)
}
pub(crate) fn get_user(&self, username: &str) -> Result<Option<User>> {
self.conn
.query_row(
"
SELECT
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
FROM User
WHERE name = ?1
",
[username],
|row| {
Ok(User {
name: row.get(0)?,
api_key: row.get(1)?,
api_key_expires_at: row.get(2)?,
discord_token: row.get(3)?,
discord_token_expires_at: row.get(4)?,
})
},
)
.optional()
}
pub fn get_user_guilds(&self, username: &str) -> Result<Vec<Guild>> {
let mut query = self.conn.prepare(
"
@ -130,8 +201,9 @@ impl Database {
FROM UserPermission
WHERE
username = ?1
AND guild_id = ?2
",
[username],
[username, &guild_id.to_string()],
|row| Ok(auth::Permissions(row.get(0)?)),
)
}
@ -180,6 +252,48 @@ impl Database {
Ok(intros)
}
pub fn add_user(
&self,
username: &str,
api_key: &str,
api_key_expires_at: NaiveDateTime,
discord_token: &str,
discord_token_expires_at: NaiveDateTime,
) -> Result<()> {
let affected = self.conn.execute(
"INSERT INTO
User (username, api_key, api_key_expires_at, discord_token, discord_token_expires_at)
VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT(username) DO UPDATE SET api_key = ?2, api_key_expires_at = ?3, discord_token = ?4, discord_token_expires_at = ?5",
&[
username,
api_key,
&api_key_expires_at.to_string(),
discord_token,
&discord_token_expires_at.to_string(),
],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert new user");
}
Ok(())
}
pub fn insert_user_guild(&self, username: &str, guild_id: u64) -> Result<()> {
let affected = self.conn.execute(
"INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)",
&[username, &guild_id.to_string()],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert user guild");
}
Ok(())
}
pub fn insert_user_intro(
&self,
username: &str,
@ -204,6 +318,28 @@ impl Database {
Ok(())
}
pub(crate) fn insert_user_permission(
&self,
username: &str,
guild_id: u64,
permissions: auth::Permissions,
) -> Result<()> {
let affected = self.conn.execute(
"
INSERT INTO
UserPermission (username, guild_id, permissions)
VALUES (?1, ?2, ?3)
ON CONFLICT(username, guild_id) DO UPDATE SET permissions = ?3",
&[username, &guild_id.to_string(), &permissions.0.to_string()],
)?;
if affected < 1 {
warn!("no rows affected when attempting to insert user permissions");
}
Ok(())
}
pub fn remove_user_intro(
&self,
username: &str,
@ -241,6 +377,15 @@ pub struct Guild {
pub sound_delay: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub name: String,
pub api_key: String,
pub api_key_expires_at: NaiveDateTime,
pub discord_token: String,
pub discord_token_expires_at: NaiveDateTime,
}
pub struct Intro {
pub id: i32,
pub name: String,

View File

@ -1,6 +1,6 @@
use crate::{
auth::{self, User},
db,
auth::{self},
db::{self, User},
htmx::{Build, HtmxBuilder, Tag},
settings::{ApiState, GuildSettings, Intro, IntroFriendlyName},
};

View File

@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, ops::Add, sync::Arc};
use axum::{
body::Bytes,
@ -9,15 +9,17 @@ use axum::{
};
use axum_extra::extract::{cookie::Cookie, CookieJar};
use chrono::{Duration, NaiveDate, Utc};
use iter_tools::Itertools;
use reqwest::{Proxy, StatusCode, Url};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::{error, info};
use tracing::{error, info, log::trace};
use uuid::Uuid;
use crate::{
auth::{self, User},
db,
htmx::Build,
page,
settings::FileIntro,
@ -192,43 +194,53 @@ pub(crate) async fn v2_auth(
.await
.map_err(|err| Error::Auth(err.to_string()))?;
let mut settings = state.settings.lock().await;
let db = state.db.lock().await;
let guilds = db.get_guilds().map_err(Error::Database)?;
let mut in_a_guild = false;
for g in settings.guilds.iter_mut() {
for guild in guilds {
let Some(discord_guild) = discord_guilds
.iter()
.find(|discord_guild| discord_guild.id == g.0.to_string())
.find(|discord_guild| discord_guild.id == guild.id)
else {
continue;
};
in_a_guild = true;
if !g.1.users.contains_key(&user.username) {
g.1.users.insert(
user.username.clone(),
GuildUser {
permissions: if discord_guild.owner {
auth::Permissions(auth::Permission::all())
} else {
Default::default()
},
},
);
}
// TODO: change this
let guild_id = guild.id.parse::<u64>().expect("guild id should be u64");
let now = Utc::now().naive_utc();
db.add_user(
&user.username,
&token,
now + Duration::weeks(4),
&auth.access_token,
now + Duration::seconds(auth.expires_in as i64),
)
.map_err(Error::Database)?;
db.insert_user_guild(&user.username, guild_id)
.map_err(Error::Database)?;
// TODO: Don't reset permissions
db.insert_user_permission(
&user.username,
guild_id,
if discord_guild.owner {
auth::Permissions(auth::Permission::all())
} else {
Default::default()
},
)
.map_err(Error::Database)?;
}
if !in_a_guild {
return Err(Error::NoGuildFound);
}
settings.auth_users.insert(
token.clone(),
auth::User {
auth,
name: user.username.clone(),
},
);
// TODO: add permissions based on roles
let uri = Url::parse(&state.origin).expect("should be a valid url");
@ -336,7 +348,7 @@ pub(crate) async fn auth(
pub(crate) async fn v2_add_intro_to_user(
State(state): State<ApiState>,
Path((guild_id, channel)): Path<(u64, String)>,
user: User,
user: db::User,
mut form_data: Multipart,
) -> Result<Html<String>, Redirect> {
let db = state.db.lock().await;
@ -389,7 +401,7 @@ pub(crate) async fn v2_add_intro_to_user(
pub(crate) async fn v2_remove_intro_from_user(
State(state): State<ApiState>,
Path((guild_id, channel)): Path<(u64, String)>,
user: User,
user: db::User,
mut form_data: Multipart,
) -> Result<Html<String>, Redirect> {
let db = state.db.lock().await;
@ -641,7 +653,7 @@ pub(crate) async fn upload_guild_intro(
pub(crate) async fn v2_upload_guild_intro(
State(state): State<ApiState>,
Path(guild): Path<u64>,
user: User,
user: db::User,
mut form_data: Multipart,
) -> Result<HeaderMap, Error> {
let mut settings = state.settings.lock().await;
@ -778,7 +790,7 @@ pub(crate) async fn v2_add_guild_intro(
State(state): State<ApiState>,
Path(guild): Path<u64>,
Query(mut params): Query<HashMap<String, String>>,
user: User,
user: db::User,
) -> Result<HeaderMap, Error> {
let mut settings = state.settings.lock().await;
let Some(url) = params.remove("url") else {

View File

@ -1,11 +1,14 @@
use std::{collections::HashMap, sync::Arc};
use crate::{auth, db::Database};
use crate::{
auth,
db::{self, Database},
};
use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::Redirect};
use axum_extra::extract::CookieJar;
use serde::{Deserialize, Serialize};
use serenity::prelude::TypeMapKey;
use tracing::trace;
use tracing::{error, trace};
use uuid::Uuid;
type UserToken = String;
@ -20,7 +23,7 @@ pub(crate) struct ApiState {
}
#[async_trait]
impl FromRequestParts<ApiState> for crate::auth::User {
impl FromRequestParts<ApiState> for db::User {
type Rejection = Redirect;
async fn from_request_parts(
@ -30,10 +33,14 @@ impl FromRequestParts<ApiState> for crate::auth::User {
let jar = CookieJar::from_headers(&headers);
if let Some(token) = jar.get("access_token") {
match state.settings.lock().await.auth_users.get(token.value()) {
match state.db.lock().await.get_user_from_api_key(token.value()) {
// :vomit:
Some(user) => Ok(user.clone()),
None => Err(Redirect::to("/login")),
Ok(user) => Ok(user),
Err(err) => {
error!(?err, "failed to authenticate user");
Err(Redirect::to("/login"))
}
}
} else {
Err(Redirect::to("/login"))