switch to db for auth
parent
52d7cc7ded
commit
2e1d41b2cd
|
@ -330,6 +330,7 @@ dependencies = [
|
|||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits 0.2.16",
|
||||
"serde",
|
||||
"time 0.1.45",
|
||||
"wasm-bindgen",
|
||||
"winapi",
|
||||
|
@ -1602,6 +1603,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
|
||||
dependencies = [
|
||||
"bitflags 2.3.3",
|
||||
"chrono",
|
||||
"fallible-iterator",
|
||||
"fallible-streaming-iterator",
|
||||
"hashlink",
|
||||
|
|
|
@ -9,12 +9,12 @@ edition = "2021"
|
|||
async-trait = "0.1.72"
|
||||
axum = { version = "0.6.9", features = ["headers", "multipart"] }
|
||||
axum-extra = { version = "0.7.5", features = ["cookie-private", "cookie"] }
|
||||
chrono = "0.4.23"
|
||||
chrono = { version = "0.4.23", features = ["serde"] }
|
||||
dotenv = "0.15.0"
|
||||
futures = "0.3.26"
|
||||
iter_tools = "0.1.4"
|
||||
reqwest = "0.11.14"
|
||||
rusqlite = { version = "0.29.0", features = ["bundled"] }
|
||||
rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] }
|
||||
serde = "1.0.152"
|
||||
serde_json = "1.0.93"
|
||||
thiserror = "1.0.38"
|
||||
|
|
149
src/db.rs
149
src/db.rs
|
@ -1,7 +1,9 @@
|
|||
use std::path::Path;
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
use iter_tools::Itertools;
|
||||
use rusqlite::{Connection, Result};
|
||||
use rusqlite::{Connection, OptionalExtension, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{error, warn};
|
||||
|
||||
use crate::auth;
|
||||
|
@ -17,6 +19,75 @@ impl Database {
|
|||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_guilds(&self) -> Result<Vec<Guild>> {
|
||||
let mut query = self.conn.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, name, soundDelay
|
||||
FROM Guild
|
||||
",
|
||||
)?;
|
||||
|
||||
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
|
||||
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
|
||||
let guilds = query
|
||||
.query_map([], |row| {
|
||||
Ok(Guild {
|
||||
id: row.get(0)?,
|
||||
name: row.get(1)?,
|
||||
sound_delay: row.get(2)?,
|
||||
})
|
||||
})?
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<Guild>>>();
|
||||
|
||||
guilds
|
||||
}
|
||||
|
||||
pub(crate) fn get_user_from_api_key(&self, api_key: &str) -> Result<User> {
|
||||
self.conn.query_row(
|
||||
"
|
||||
SELECT
|
||||
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
|
||||
FROM User
|
||||
WHERE api_key = ?1
|
||||
",
|
||||
[api_key],
|
||||
|row| {
|
||||
Ok(User {
|
||||
name: row.get(0)?,
|
||||
api_key: row.get(1)?,
|
||||
api_key_expires_at: row.get(2)?,
|
||||
discord_token: row.get(3)?,
|
||||
discord_token_expires_at: row.get(4)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn get_user(&self, username: &str) -> Result<Option<User>> {
|
||||
self.conn
|
||||
.query_row(
|
||||
"
|
||||
SELECT
|
||||
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
|
||||
FROM User
|
||||
WHERE name = ?1
|
||||
",
|
||||
[username],
|
||||
|row| {
|
||||
Ok(User {
|
||||
name: row.get(0)?,
|
||||
api_key: row.get(1)?,
|
||||
api_key_expires_at: row.get(2)?,
|
||||
discord_token: row.get(3)?,
|
||||
discord_token_expires_at: row.get(4)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
|
||||
pub fn get_user_guilds(&self, username: &str) -> Result<Vec<Guild>> {
|
||||
let mut query = self.conn.prepare(
|
||||
"
|
||||
|
@ -130,8 +201,9 @@ impl Database {
|
|||
FROM UserPermission
|
||||
WHERE
|
||||
username = ?1
|
||||
AND guild_id = ?2
|
||||
",
|
||||
[username],
|
||||
[username, &guild_id.to_string()],
|
||||
|row| Ok(auth::Permissions(row.get(0)?)),
|
||||
)
|
||||
}
|
||||
|
@ -180,6 +252,48 @@ impl Database {
|
|||
Ok(intros)
|
||||
}
|
||||
|
||||
pub fn add_user(
|
||||
&self,
|
||||
username: &str,
|
||||
api_key: &str,
|
||||
api_key_expires_at: NaiveDateTime,
|
||||
discord_token: &str,
|
||||
discord_token_expires_at: NaiveDateTime,
|
||||
) -> Result<()> {
|
||||
let affected = self.conn.execute(
|
||||
"INSERT INTO
|
||||
User (username, api_key, api_key_expires_at, discord_token, discord_token_expires_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
ON CONFLICT(username) DO UPDATE SET api_key = ?2, api_key_expires_at = ?3, discord_token = ?4, discord_token_expires_at = ?5",
|
||||
&[
|
||||
username,
|
||||
api_key,
|
||||
&api_key_expires_at.to_string(),
|
||||
discord_token,
|
||||
&discord_token_expires_at.to_string(),
|
||||
],
|
||||
)?;
|
||||
|
||||
if affected < 1 {
|
||||
warn!("no rows affected when attempting to insert new user");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn insert_user_guild(&self, username: &str, guild_id: u64) -> Result<()> {
|
||||
let affected = self.conn.execute(
|
||||
"INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)",
|
||||
&[username, &guild_id.to_string()],
|
||||
)?;
|
||||
|
||||
if affected < 1 {
|
||||
warn!("no rows affected when attempting to insert user guild");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn insert_user_intro(
|
||||
&self,
|
||||
username: &str,
|
||||
|
@ -204,6 +318,28 @@ impl Database {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn insert_user_permission(
|
||||
&self,
|
||||
username: &str,
|
||||
guild_id: u64,
|
||||
permissions: auth::Permissions,
|
||||
) -> Result<()> {
|
||||
let affected = self.conn.execute(
|
||||
"
|
||||
INSERT INTO
|
||||
UserPermission (username, guild_id, permissions)
|
||||
VALUES (?1, ?2, ?3)
|
||||
ON CONFLICT(username, guild_id) DO UPDATE SET permissions = ?3",
|
||||
&[username, &guild_id.to_string(), &permissions.0.to_string()],
|
||||
)?;
|
||||
|
||||
if affected < 1 {
|
||||
warn!("no rows affected when attempting to insert user permissions");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove_user_intro(
|
||||
&self,
|
||||
username: &str,
|
||||
|
@ -241,6 +377,15 @@ pub struct Guild {
|
|||
pub sound_delay: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub name: String,
|
||||
pub api_key: String,
|
||||
pub api_key_expires_at: NaiveDateTime,
|
||||
pub discord_token: String,
|
||||
pub discord_token_expires_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
pub struct Intro {
|
||||
pub id: i32,
|
||||
pub name: String,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
auth::{self, User},
|
||||
db,
|
||||
auth::{self},
|
||||
db::{self, User},
|
||||
htmx::{Build, HtmxBuilder, Tag},
|
||||
settings::{ApiState, GuildSettings, Intro, IntroFriendlyName},
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{collections::HashMap, ops::Add, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
body::Bytes,
|
||||
|
@ -9,15 +9,17 @@ use axum::{
|
|||
};
|
||||
|
||||
use axum_extra::extract::{cookie::Cookie, CookieJar};
|
||||
use chrono::{Duration, NaiveDate, Utc};
|
||||
use iter_tools::Itertools;
|
||||
use reqwest::{Proxy, StatusCode, Url};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, log::trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
auth::{self, User},
|
||||
db,
|
||||
htmx::Build,
|
||||
page,
|
||||
settings::FileIntro,
|
||||
|
@ -192,43 +194,53 @@ pub(crate) async fn v2_auth(
|
|||
.await
|
||||
.map_err(|err| Error::Auth(err.to_string()))?;
|
||||
|
||||
let mut settings = state.settings.lock().await;
|
||||
let db = state.db.lock().await;
|
||||
|
||||
let guilds = db.get_guilds().map_err(Error::Database)?;
|
||||
let mut in_a_guild = false;
|
||||
for g in settings.guilds.iter_mut() {
|
||||
for guild in guilds {
|
||||
let Some(discord_guild) = discord_guilds
|
||||
.iter()
|
||||
.find(|discord_guild| discord_guild.id == g.0.to_string())
|
||||
.find(|discord_guild| discord_guild.id == guild.id)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
in_a_guild = true;
|
||||
|
||||
if !g.1.users.contains_key(&user.username) {
|
||||
g.1.users.insert(
|
||||
user.username.clone(),
|
||||
GuildUser {
|
||||
permissions: if discord_guild.owner {
|
||||
// TODO: change this
|
||||
let guild_id = guild.id.parse::<u64>().expect("guild id should be u64");
|
||||
|
||||
let now = Utc::now().naive_utc();
|
||||
db.add_user(
|
||||
&user.username,
|
||||
&token,
|
||||
now + Duration::weeks(4),
|
||||
&auth.access_token,
|
||||
now + Duration::seconds(auth.expires_in as i64),
|
||||
)
|
||||
.map_err(Error::Database)?;
|
||||
|
||||
db.insert_user_guild(&user.username, guild_id)
|
||||
.map_err(Error::Database)?;
|
||||
|
||||
// TODO: Don't reset permissions
|
||||
db.insert_user_permission(
|
||||
&user.username,
|
||||
guild_id,
|
||||
if discord_guild.owner {
|
||||
auth::Permissions(auth::Permission::all())
|
||||
} else {
|
||||
Default::default()
|
||||
},
|
||||
},
|
||||
);
|
||||
}
|
||||
)
|
||||
.map_err(Error::Database)?;
|
||||
}
|
||||
|
||||
if !in_a_guild {
|
||||
return Err(Error::NoGuildFound);
|
||||
}
|
||||
|
||||
settings.auth_users.insert(
|
||||
token.clone(),
|
||||
auth::User {
|
||||
auth,
|
||||
name: user.username.clone(),
|
||||
},
|
||||
);
|
||||
// TODO: add permissions based on roles
|
||||
|
||||
let uri = Url::parse(&state.origin).expect("should be a valid url");
|
||||
|
@ -336,7 +348,7 @@ pub(crate) async fn auth(
|
|||
pub(crate) async fn v2_add_intro_to_user(
|
||||
State(state): State<ApiState>,
|
||||
Path((guild_id, channel)): Path<(u64, String)>,
|
||||
user: User,
|
||||
user: db::User,
|
||||
mut form_data: Multipart,
|
||||
) -> Result<Html<String>, Redirect> {
|
||||
let db = state.db.lock().await;
|
||||
|
@ -389,7 +401,7 @@ pub(crate) async fn v2_add_intro_to_user(
|
|||
pub(crate) async fn v2_remove_intro_from_user(
|
||||
State(state): State<ApiState>,
|
||||
Path((guild_id, channel)): Path<(u64, String)>,
|
||||
user: User,
|
||||
user: db::User,
|
||||
mut form_data: Multipart,
|
||||
) -> Result<Html<String>, Redirect> {
|
||||
let db = state.db.lock().await;
|
||||
|
@ -641,7 +653,7 @@ pub(crate) async fn upload_guild_intro(
|
|||
pub(crate) async fn v2_upload_guild_intro(
|
||||
State(state): State<ApiState>,
|
||||
Path(guild): Path<u64>,
|
||||
user: User,
|
||||
user: db::User,
|
||||
mut form_data: Multipart,
|
||||
) -> Result<HeaderMap, Error> {
|
||||
let mut settings = state.settings.lock().await;
|
||||
|
@ -778,7 +790,7 @@ pub(crate) async fn v2_add_guild_intro(
|
|||
State(state): State<ApiState>,
|
||||
Path(guild): Path<u64>,
|
||||
Query(mut params): Query<HashMap<String, String>>,
|
||||
user: User,
|
||||
user: db::User,
|
||||
) -> Result<HeaderMap, Error> {
|
||||
let mut settings = state.settings.lock().await;
|
||||
let Some(url) = params.remove("url") else {
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{auth, db::Database};
|
||||
use crate::{
|
||||
auth,
|
||||
db::{self, Database},
|
||||
};
|
||||
use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::Redirect};
|
||||
use axum_extra::extract::CookieJar;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serenity::prelude::TypeMapKey;
|
||||
use tracing::trace;
|
||||
use tracing::{error, trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
type UserToken = String;
|
||||
|
@ -20,7 +23,7 @@ pub(crate) struct ApiState {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<ApiState> for crate::auth::User {
|
||||
impl FromRequestParts<ApiState> for db::User {
|
||||
type Rejection = Redirect;
|
||||
|
||||
async fn from_request_parts(
|
||||
|
@ -30,10 +33,14 @@ impl FromRequestParts<ApiState> for crate::auth::User {
|
|||
let jar = CookieJar::from_headers(&headers);
|
||||
|
||||
if let Some(token) = jar.get("access_token") {
|
||||
match state.settings.lock().await.auth_users.get(token.value()) {
|
||||
match state.db.lock().await.get_user_from_api_key(token.value()) {
|
||||
// :vomit:
|
||||
Some(user) => Ok(user.clone()),
|
||||
None => Err(Redirect::to("/login")),
|
||||
Ok(user) => Ok(user),
|
||||
Err(err) => {
|
||||
error!(?err, "failed to authenticate user");
|
||||
|
||||
Err(Redirect::to("/login"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(Redirect::to("/login"))
|
||||
|
|
Loading…
Reference in New Issue