You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
5.5 KiB
210 lines
5.5 KiB
use std::{fmt::Display, sync::Arc, time::Duration}; |
|
|
|
use cookie::{Cookie, CookieBuilder, Key, SameSite}; |
|
use dashmap::DashMap; |
|
use sea_orm::{error::DbErr, DatabaseConnection, EntityTrait, QuerySelect}; |
|
use serde::{Deserialize, Serialize}; |
|
use tokio::time::Instant; |
|
|
|
use entity::{link, user}; |
|
use tracing::warn; |
|
use uuid::Uuid; |
|
|
|
use crate::utils; |
|
|
|
#[derive(Clone)] |
|
pub struct AppState { |
|
pub sessions: Arc<DashMap<String, UserSession>>, |
|
pub opts: AppOptions, |
|
pub cache: Arc<DashMap<String, String>>, |
|
pub db: DatabaseConnection, |
|
pub cookie_key: Key, |
|
} |
|
|
|
impl AppState { |
|
pub fn new(opts: AppOptions, db: DatabaseConnection, cookie_key: Option<Key>) -> Self { |
|
AppState { |
|
sessions: Arc::new(DashMap::new()), |
|
opts, |
|
cache: Arc::new(DashMap::new()), |
|
db, |
|
cookie_key: cookie_key.unwrap_or(Key::generate()), |
|
} |
|
} |
|
|
|
pub fn get_jar(&self) -> cookie::CookieJar { |
|
cookie::CookieJar::new() |
|
} |
|
|
|
pub fn encrypt_cookie(&self, cookie: Cookie<'static>) -> Cookie { |
|
let c_name = cookie.name().to_string(); |
|
let mut jar = self.get_jar(); |
|
jar.private_mut(&self.cookie_key).add(cookie); |
|
jar.get(&c_name) |
|
.expect("we have a thief among us! cookie we just added should be in the jar!") |
|
.to_owned() |
|
} |
|
|
|
pub fn decrypt_cookie(&self, cookie: Cookie<'static>) -> Option<Cookie> { |
|
let c_name = cookie.name().to_string(); |
|
let mut jar = self.get_jar(); |
|
jar.add(cookie); |
|
jar.private(&self.cookie_key).get(&c_name).to_owned() |
|
} |
|
|
|
pub async fn seed_cache(&self) -> Result<(), DbErr> { |
|
link::Entity::find() |
|
.all(&self.db) |
|
.await? |
|
.into_iter() |
|
.for_each(|link| { |
|
self.cache.insert(link.source, link.target); |
|
}); |
|
Ok(()) |
|
} |
|
} |
|
|
|
#[derive(Clone)] |
|
pub struct AppOptions { |
|
pub domain: Option<String>, |
|
pub charset: Vec<char>, |
|
pub id_len: usize, |
|
} |
|
|
|
impl AppOptions { |
|
pub fn use_secure_cookie(&self) -> bool { |
|
self.domain |
|
.as_ref() |
|
.is_some_and(|d| !d.starts_with("http:")) |
|
} |
|
} |
|
|
|
#[derive(Clone, Copy, Debug)] |
|
pub struct UserSessionCookie { |
|
pub data: Uuid, |
|
} |
|
|
|
impl UserSessionCookie { |
|
pub fn new() -> Self { |
|
UserSessionCookie { |
|
data: Uuid::new_v4(), |
|
} |
|
} |
|
} |
|
|
|
impl From<UserSession> for UserSessionCookie { |
|
fn from(value: UserSession) -> Self { |
|
UserSessionCookie { |
|
data: value.decrypted_session_cookie, |
|
} |
|
} |
|
} |
|
|
|
#[derive(Clone, Copy, PartialEq, Eq, Debug)] |
|
pub struct UserSession { |
|
pub user_id: i32, |
|
created: Instant, |
|
expiry: Duration, |
|
pub decrypted_session_cookie: Uuid, |
|
pub admin: bool, |
|
} |
|
|
|
impl UserSession { |
|
pub fn new(user_id: i32, admin: bool) -> Self { |
|
Self::new_with_expiry(user_id, admin, Duration::from_secs(60 * 60 * 24 * 7)) |
|
} |
|
|
|
pub fn new_with_expiry(user_id: i32, admin: bool, expiry: Duration) -> Self { |
|
UserSession { |
|
user_id, |
|
created: Instant::now(), |
|
expiry, |
|
decrypted_session_cookie: Uuid::new_v4(), |
|
admin, |
|
} |
|
} |
|
|
|
pub fn is_expired(&self) -> bool { |
|
self.created.elapsed() > self.expiry |
|
} |
|
|
|
pub fn get_cookie_value(&self) -> String { |
|
self.decrypted_session_cookie.to_string() |
|
} |
|
|
|
pub fn into_cookie(self, opts: AppOptions) -> Cookie<'static> { |
|
if opts.use_secure_cookie() { |
|
if opts.domain.is_none() { |
|
warn!("configured to use secure cookie, but no domain is set! falling back to insecure cookie"); |
|
return self.into_insecure_cookie().finish(); |
|
} |
|
let domain = opts.domain.expect("math broke"); |
|
self.into_insecure_cookie() |
|
.secure(true) |
|
.domain(domain) |
|
.same_site(SameSite::Strict) |
|
.finish() |
|
} else { |
|
self.into_insecure_cookie().finish() |
|
} |
|
} |
|
|
|
fn into_insecure_cookie(self) -> CookieBuilder<'static> { |
|
Cookie::build("_session", self.get_cookie_value()) |
|
.http_only(true) |
|
.max_age(cookie::time::Duration::days(7)) |
|
} |
|
} |
|
|
|
impl Display for UserSession { |
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
write!( |
|
f, |
|
"UserSession for id={} ({}) authenticated using cookie={}", |
|
self.user_id, |
|
if self.admin { "admin" } else { "normal user" }, |
|
self.decrypted_session_cookie |
|
)?; |
|
Ok(()) |
|
} |
|
} |
|
|
|
pub enum UserSessionError { |
|
MissingPassword, |
|
Invalid, |
|
} |
|
impl TryFrom<(user::Model, String)> for UserSession { |
|
type Error = UserSessionError; |
|
|
|
fn try_from(value: (user::Model, String)) -> Result<Self, Self::Error> { |
|
let (user, pw) = value; |
|
match utils::verify_password( |
|
pw.as_bytes(), |
|
user.password.ok_or(Self::Error::MissingPassword)?.as_str(), |
|
) { |
|
true => Ok(UserSession::new(user.id, user.is_admin)), |
|
false => Err(Self::Error::Invalid), |
|
} |
|
} |
|
} |
|
|
|
pub type CurrentUser = user::Model; |
|
|
|
#[derive(Clone, Deserialize, Serialize, Debug)] |
|
pub struct LoginForm { |
|
pub username: String, |
|
pub password: String, |
|
} |
|
|
|
#[derive(Clone, Deserialize, Serialize, Debug)] |
|
pub struct UserForm { |
|
pub username: String, |
|
pub password: String, |
|
pub admin: bool, |
|
} |
|
|
|
#[derive(Clone, Deserialize, Serialize, Debug)] |
|
pub struct LinkForm { |
|
pub source: Option<String>, |
|
pub target: String, |
|
}
|
|
|