switch to db for auth
parent
52d7cc7ded
commit
2e1d41b2cd
|
@ -330,6 +330,7 @@ dependencies = [
|
||||||
"iana-time-zone",
|
"iana-time-zone",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"num-traits 0.2.16",
|
"num-traits 0.2.16",
|
||||||
|
"serde",
|
||||||
"time 0.1.45",
|
"time 0.1.45",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"winapi",
|
"winapi",
|
||||||
|
@ -1602,6 +1603,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
|
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.3.3",
|
"bitflags 2.3.3",
|
||||||
|
"chrono",
|
||||||
"fallible-iterator",
|
"fallible-iterator",
|
||||||
"fallible-streaming-iterator",
|
"fallible-streaming-iterator",
|
||||||
"hashlink",
|
"hashlink",
|
||||||
|
|
|
@ -9,12 +9,12 @@ edition = "2021"
|
||||||
async-trait = "0.1.72"
|
async-trait = "0.1.72"
|
||||||
axum = { version = "0.6.9", features = ["headers", "multipart"] }
|
axum = { version = "0.6.9", features = ["headers", "multipart"] }
|
||||||
axum-extra = { version = "0.7.5", features = ["cookie-private", "cookie"] }
|
axum-extra = { version = "0.7.5", features = ["cookie-private", "cookie"] }
|
||||||
chrono = "0.4.23"
|
chrono = { version = "0.4.23", features = ["serde"] }
|
||||||
dotenv = "0.15.0"
|
dotenv = "0.15.0"
|
||||||
futures = "0.3.26"
|
futures = "0.3.26"
|
||||||
iter_tools = "0.1.4"
|
iter_tools = "0.1.4"
|
||||||
reqwest = "0.11.14"
|
reqwest = "0.11.14"
|
||||||
rusqlite = { version = "0.29.0", features = ["bundled"] }
|
rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] }
|
||||||
serde = "1.0.152"
|
serde = "1.0.152"
|
||||||
serde_json = "1.0.93"
|
serde_json = "1.0.93"
|
||||||
thiserror = "1.0.38"
|
thiserror = "1.0.38"
|
||||||
|
|
149
src/db.rs
149
src/db.rs
|
@ -1,7 +1,9 @@
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
|
use chrono::NaiveDateTime;
|
||||||
use iter_tools::Itertools;
|
use iter_tools::Itertools;
|
||||||
use rusqlite::{Connection, Result};
|
use rusqlite::{Connection, OptionalExtension, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::{error, warn};
|
use tracing::{error, warn};
|
||||||
|
|
||||||
use crate::auth;
|
use crate::auth;
|
||||||
|
@ -17,6 +19,75 @@ impl Database {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_guilds(&self) -> Result<Vec<Guild>> {
|
||||||
|
let mut query = self.conn.prepare(
|
||||||
|
"
|
||||||
|
SELECT
|
||||||
|
id, name, soundDelay
|
||||||
|
FROM Guild
|
||||||
|
",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// NOTE(pcleavelin): for some reason this needs to be a let-binding or else
|
||||||
|
// the compiler complains about it being dropped too early (maybe I should update the compiler version)
|
||||||
|
let guilds = query
|
||||||
|
.query_map([], |row| {
|
||||||
|
Ok(Guild {
|
||||||
|
id: row.get(0)?,
|
||||||
|
name: row.get(1)?,
|
||||||
|
sound_delay: row.get(2)?,
|
||||||
|
})
|
||||||
|
})?
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<Guild>>>();
|
||||||
|
|
||||||
|
guilds
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_user_from_api_key(&self, api_key: &str) -> Result<User> {
|
||||||
|
self.conn.query_row(
|
||||||
|
"
|
||||||
|
SELECT
|
||||||
|
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
|
||||||
|
FROM User
|
||||||
|
WHERE api_key = ?1
|
||||||
|
",
|
||||||
|
[api_key],
|
||||||
|
|row| {
|
||||||
|
Ok(User {
|
||||||
|
name: row.get(0)?,
|
||||||
|
api_key: row.get(1)?,
|
||||||
|
api_key_expires_at: row.get(2)?,
|
||||||
|
discord_token: row.get(3)?,
|
||||||
|
discord_token_expires_at: row.get(4)?,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_user(&self, username: &str) -> Result<Option<User>> {
|
||||||
|
self.conn
|
||||||
|
.query_row(
|
||||||
|
"
|
||||||
|
SELECT
|
||||||
|
username AS name, api_key, api_key_expires_at, discord_token, discord_token_expires_at
|
||||||
|
FROM User
|
||||||
|
WHERE name = ?1
|
||||||
|
",
|
||||||
|
[username],
|
||||||
|
|row| {
|
||||||
|
Ok(User {
|
||||||
|
name: row.get(0)?,
|
||||||
|
api_key: row.get(1)?,
|
||||||
|
api_key_expires_at: row.get(2)?,
|
||||||
|
discord_token: row.get(3)?,
|
||||||
|
discord_token_expires_at: row.get(4)?,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.optional()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_user_guilds(&self, username: &str) -> Result<Vec<Guild>> {
|
pub fn get_user_guilds(&self, username: &str) -> Result<Vec<Guild>> {
|
||||||
let mut query = self.conn.prepare(
|
let mut query = self.conn.prepare(
|
||||||
"
|
"
|
||||||
|
@ -130,8 +201,9 @@ impl Database {
|
||||||
FROM UserPermission
|
FROM UserPermission
|
||||||
WHERE
|
WHERE
|
||||||
username = ?1
|
username = ?1
|
||||||
|
AND guild_id = ?2
|
||||||
",
|
",
|
||||||
[username],
|
[username, &guild_id.to_string()],
|
||||||
|row| Ok(auth::Permissions(row.get(0)?)),
|
|row| Ok(auth::Permissions(row.get(0)?)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -180,6 +252,48 @@ impl Database {
|
||||||
Ok(intros)
|
Ok(intros)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add_user(
|
||||||
|
&self,
|
||||||
|
username: &str,
|
||||||
|
api_key: &str,
|
||||||
|
api_key_expires_at: NaiveDateTime,
|
||||||
|
discord_token: &str,
|
||||||
|
discord_token_expires_at: NaiveDateTime,
|
||||||
|
) -> Result<()> {
|
||||||
|
let affected = self.conn.execute(
|
||||||
|
"INSERT INTO
|
||||||
|
User (username, api_key, api_key_expires_at, discord_token, discord_token_expires_at)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||||
|
ON CONFLICT(username) DO UPDATE SET api_key = ?2, api_key_expires_at = ?3, discord_token = ?4, discord_token_expires_at = ?5",
|
||||||
|
&[
|
||||||
|
username,
|
||||||
|
api_key,
|
||||||
|
&api_key_expires_at.to_string(),
|
||||||
|
discord_token,
|
||||||
|
&discord_token_expires_at.to_string(),
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
if affected < 1 {
|
||||||
|
warn!("no rows affected when attempting to insert new user");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert_user_guild(&self, username: &str, guild_id: u64) -> Result<()> {
|
||||||
|
let affected = self.conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO UserGuild (username, guild_id) VALUES (?1, ?2)",
|
||||||
|
&[username, &guild_id.to_string()],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
if affected < 1 {
|
||||||
|
warn!("no rows affected when attempting to insert user guild");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn insert_user_intro(
|
pub fn insert_user_intro(
|
||||||
&self,
|
&self,
|
||||||
username: &str,
|
username: &str,
|
||||||
|
@ -204,6 +318,28 @@ impl Database {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn insert_user_permission(
|
||||||
|
&self,
|
||||||
|
username: &str,
|
||||||
|
guild_id: u64,
|
||||||
|
permissions: auth::Permissions,
|
||||||
|
) -> Result<()> {
|
||||||
|
let affected = self.conn.execute(
|
||||||
|
"
|
||||||
|
INSERT INTO
|
||||||
|
UserPermission (username, guild_id, permissions)
|
||||||
|
VALUES (?1, ?2, ?3)
|
||||||
|
ON CONFLICT(username, guild_id) DO UPDATE SET permissions = ?3",
|
||||||
|
&[username, &guild_id.to_string(), &permissions.0.to_string()],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
if affected < 1 {
|
||||||
|
warn!("no rows affected when attempting to insert user permissions");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn remove_user_intro(
|
pub fn remove_user_intro(
|
||||||
&self,
|
&self,
|
||||||
username: &str,
|
username: &str,
|
||||||
|
@ -241,6 +377,15 @@ pub struct Guild {
|
||||||
pub sound_delay: u32,
|
pub sound_delay: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct User {
|
||||||
|
pub name: String,
|
||||||
|
pub api_key: String,
|
||||||
|
pub api_key_expires_at: NaiveDateTime,
|
||||||
|
pub discord_token: String,
|
||||||
|
pub discord_token_expires_at: NaiveDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Intro {
|
pub struct Intro {
|
||||||
pub id: i32,
|
pub id: i32,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{self, User},
|
auth::{self},
|
||||||
db,
|
db::{self, User},
|
||||||
htmx::{Build, HtmxBuilder, Tag},
|
htmx::{Build, HtmxBuilder, Tag},
|
||||||
settings::{ApiState, GuildSettings, Intro, IntroFriendlyName},
|
settings::{ApiState, GuildSettings, Intro, IntroFriendlyName},
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, ops::Add, sync::Arc};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Bytes,
|
body::Bytes,
|
||||||
|
@ -9,15 +9,17 @@ use axum::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum_extra::extract::{cookie::Cookie, CookieJar};
|
use axum_extra::extract::{cookie::Cookie, CookieJar};
|
||||||
|
use chrono::{Duration, NaiveDate, Utc};
|
||||||
use iter_tools::Itertools;
|
use iter_tools::Itertools;
|
||||||
use reqwest::{Proxy, StatusCode, Url};
|
use reqwest::{Proxy, StatusCode, Url};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info, log::trace};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{self, User},
|
auth::{self, User},
|
||||||
|
db,
|
||||||
htmx::Build,
|
htmx::Build,
|
||||||
page,
|
page,
|
||||||
settings::FileIntro,
|
settings::FileIntro,
|
||||||
|
@ -192,43 +194,53 @@ pub(crate) async fn v2_auth(
|
||||||
.await
|
.await
|
||||||
.map_err(|err| Error::Auth(err.to_string()))?;
|
.map_err(|err| Error::Auth(err.to_string()))?;
|
||||||
|
|
||||||
let mut settings = state.settings.lock().await;
|
let db = state.db.lock().await;
|
||||||
|
|
||||||
|
let guilds = db.get_guilds().map_err(Error::Database)?;
|
||||||
let mut in_a_guild = false;
|
let mut in_a_guild = false;
|
||||||
for g in settings.guilds.iter_mut() {
|
for guild in guilds {
|
||||||
let Some(discord_guild) = discord_guilds
|
let Some(discord_guild) = discord_guilds
|
||||||
.iter()
|
.iter()
|
||||||
.find(|discord_guild| discord_guild.id == g.0.to_string())
|
.find(|discord_guild| discord_guild.id == guild.id)
|
||||||
else {
|
else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
in_a_guild = true;
|
in_a_guild = true;
|
||||||
|
|
||||||
if !g.1.users.contains_key(&user.username) {
|
// TODO: change this
|
||||||
g.1.users.insert(
|
let guild_id = guild.id.parse::<u64>().expect("guild id should be u64");
|
||||||
user.username.clone(),
|
|
||||||
GuildUser {
|
let now = Utc::now().naive_utc();
|
||||||
permissions: if discord_guild.owner {
|
db.add_user(
|
||||||
auth::Permissions(auth::Permission::all())
|
&user.username,
|
||||||
} else {
|
&token,
|
||||||
Default::default()
|
now + Duration::weeks(4),
|
||||||
},
|
&auth.access_token,
|
||||||
},
|
now + Duration::seconds(auth.expires_in as i64),
|
||||||
);
|
)
|
||||||
}
|
.map_err(Error::Database)?;
|
||||||
|
|
||||||
|
db.insert_user_guild(&user.username, guild_id)
|
||||||
|
.map_err(Error::Database)?;
|
||||||
|
|
||||||
|
// TODO: Don't reset permissions
|
||||||
|
db.insert_user_permission(
|
||||||
|
&user.username,
|
||||||
|
guild_id,
|
||||||
|
if discord_guild.owner {
|
||||||
|
auth::Permissions(auth::Permission::all())
|
||||||
|
} else {
|
||||||
|
Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.map_err(Error::Database)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if !in_a_guild {
|
if !in_a_guild {
|
||||||
return Err(Error::NoGuildFound);
|
return Err(Error::NoGuildFound);
|
||||||
}
|
}
|
||||||
|
|
||||||
settings.auth_users.insert(
|
|
||||||
token.clone(),
|
|
||||||
auth::User {
|
|
||||||
auth,
|
|
||||||
name: user.username.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
// TODO: add permissions based on roles
|
// TODO: add permissions based on roles
|
||||||
|
|
||||||
let uri = Url::parse(&state.origin).expect("should be a valid url");
|
let uri = Url::parse(&state.origin).expect("should be a valid url");
|
||||||
|
@ -336,7 +348,7 @@ pub(crate) async fn auth(
|
||||||
pub(crate) async fn v2_add_intro_to_user(
|
pub(crate) async fn v2_add_intro_to_user(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
Path((guild_id, channel)): Path<(u64, String)>,
|
Path((guild_id, channel)): Path<(u64, String)>,
|
||||||
user: User,
|
user: db::User,
|
||||||
mut form_data: Multipart,
|
mut form_data: Multipart,
|
||||||
) -> Result<Html<String>, Redirect> {
|
) -> Result<Html<String>, Redirect> {
|
||||||
let db = state.db.lock().await;
|
let db = state.db.lock().await;
|
||||||
|
@ -389,7 +401,7 @@ pub(crate) async fn v2_add_intro_to_user(
|
||||||
pub(crate) async fn v2_remove_intro_from_user(
|
pub(crate) async fn v2_remove_intro_from_user(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
Path((guild_id, channel)): Path<(u64, String)>,
|
Path((guild_id, channel)): Path<(u64, String)>,
|
||||||
user: User,
|
user: db::User,
|
||||||
mut form_data: Multipart,
|
mut form_data: Multipart,
|
||||||
) -> Result<Html<String>, Redirect> {
|
) -> Result<Html<String>, Redirect> {
|
||||||
let db = state.db.lock().await;
|
let db = state.db.lock().await;
|
||||||
|
@ -641,7 +653,7 @@ pub(crate) async fn upload_guild_intro(
|
||||||
pub(crate) async fn v2_upload_guild_intro(
|
pub(crate) async fn v2_upload_guild_intro(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
Path(guild): Path<u64>,
|
Path(guild): Path<u64>,
|
||||||
user: User,
|
user: db::User,
|
||||||
mut form_data: Multipart,
|
mut form_data: Multipart,
|
||||||
) -> Result<HeaderMap, Error> {
|
) -> Result<HeaderMap, Error> {
|
||||||
let mut settings = state.settings.lock().await;
|
let mut settings = state.settings.lock().await;
|
||||||
|
@ -778,7 +790,7 @@ pub(crate) async fn v2_add_guild_intro(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
Path(guild): Path<u64>,
|
Path(guild): Path<u64>,
|
||||||
Query(mut params): Query<HashMap<String, String>>,
|
Query(mut params): Query<HashMap<String, String>>,
|
||||||
user: User,
|
user: db::User,
|
||||||
) -> Result<HeaderMap, Error> {
|
) -> Result<HeaderMap, Error> {
|
||||||
let mut settings = state.settings.lock().await;
|
let mut settings = state.settings.lock().await;
|
||||||
let Some(url) = params.remove("url") else {
|
let Some(url) = params.remove("url") else {
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use crate::{auth, db::Database};
|
use crate::{
|
||||||
|
auth,
|
||||||
|
db::{self, Database},
|
||||||
|
};
|
||||||
use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::Redirect};
|
use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::Redirect};
|
||||||
use axum_extra::extract::CookieJar;
|
use axum_extra::extract::CookieJar;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serenity::prelude::TypeMapKey;
|
use serenity::prelude::TypeMapKey;
|
||||||
use tracing::trace;
|
use tracing::{error, trace};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
type UserToken = String;
|
type UserToken = String;
|
||||||
|
@ -20,7 +23,7 @@ pub(crate) struct ApiState {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl FromRequestParts<ApiState> for crate::auth::User {
|
impl FromRequestParts<ApiState> for db::User {
|
||||||
type Rejection = Redirect;
|
type Rejection = Redirect;
|
||||||
|
|
||||||
async fn from_request_parts(
|
async fn from_request_parts(
|
||||||
|
@ -30,10 +33,14 @@ impl FromRequestParts<ApiState> for crate::auth::User {
|
||||||
let jar = CookieJar::from_headers(&headers);
|
let jar = CookieJar::from_headers(&headers);
|
||||||
|
|
||||||
if let Some(token) = jar.get("access_token") {
|
if let Some(token) = jar.get("access_token") {
|
||||||
match state.settings.lock().await.auth_users.get(token.value()) {
|
match state.db.lock().await.get_user_from_api_key(token.value()) {
|
||||||
// :vomit:
|
// :vomit:
|
||||||
Some(user) => Ok(user.clone()),
|
Ok(user) => Ok(user),
|
||||||
None => Err(Redirect::to("/login")),
|
Err(err) => {
|
||||||
|
error!(?err, "failed to authenticate user");
|
||||||
|
|
||||||
|
Err(Redirect::to("/login"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Err(Redirect::to("/login"))
|
Err(Redirect::to("/login"))
|
||||||
|
|
Loading…
Reference in New Issue