You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
4.9 KiB

use std::{fmt::Display, sync::Arc, time::Duration};
use cookie::{Cookie, Key};
use dashmap::DashMap;
use sea_orm::{error::DbErr, DatabaseConnection, EntityTrait, QuerySelect};
use serde::{Deserialize, Serialize};
use tokio::time::Instant;
use entity::{link, user};
use uuid::Uuid;
use crate::utils;
#[derive(Clone)]
pub struct AppState {
pub sessions: Arc<DashMap<String, UserSession>>,
pub opts: AppOptions,
pub cache: Arc<DashMap<String, String>>,
pub db: DatabaseConnection,
pub cookie_key: Key,
}
impl AppState {
pub fn new(opts: AppOptions, db: DatabaseConnection, cookie_key: Option<Key>) -> Self {
AppState {
sessions: Arc::new(DashMap::new()),
opts,
cache: Arc::new(DashMap::new()),
db,
cookie_key: cookie_key.unwrap_or(Key::generate()),
}
}
pub fn get_jar(&self) -> cookie::CookieJar {
cookie::CookieJar::new()
}
pub fn encrypt_cookie(&self, cookie: Cookie<'static>) -> Cookie {
let c_name = cookie.name().to_string();
let mut jar = self.get_jar();
jar.private_mut(&self.cookie_key).add(cookie);
jar.get(&c_name)
.expect("we have a thief among us! cookie we just added should be in the jar!")
.to_owned()
}
pub fn decrypt_cookie(&self, cookie: Cookie<'static>) -> Option<Cookie> {
let c_name = cookie.name().to_string();
let mut jar = self.get_jar();
jar.add(cookie);
jar.private(&self.cookie_key).get(&c_name).to_owned()
}
pub async fn seed_cache(&self) -> Result<(), DbErr> {
link::Entity::find()
.all(&self.db)
.await?
.into_iter()
.for_each(|link| {
self.cache.insert(link.source, link.target);
});
Ok(())
}
pub async fn is_user_admin(&self, id: i32) -> Result<bool, DbErr> {
Ok(user::Entity::find_by_id(id)
.select_only()
.column(user::Column::IsAdmin)
.one(&self.db)
.await?
.map(|u| u.is_admin)
.unwrap_or(false))
}
}
#[derive(Clone)]
pub struct AppOptions {
pub domain: Option<String>,
pub charset: Vec<char>,
pub id_len: usize,
}
impl AppOptions {
pub fn use_secure_cookie(&self) -> bool {
self.domain
.as_ref()
.is_some_and(|d| !d.starts_with("http:"))
}
}
#[derive(Clone, Copy, Debug)]
pub struct UserSessionCookie {
pub data: Uuid,
}
impl UserSessionCookie {
pub fn new() -> Self {
UserSessionCookie {
data: Uuid::new_v4(),
}
}
}
impl From<UserSession> for UserSessionCookie {
fn from(value: UserSession) -> Self {
UserSessionCookie {
data: value.decrypted_session_cookie,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct UserSession {
pub user_id: i32,
created: Instant,
expiry: Duration,
pub decrypted_session_cookie: Uuid,
pub admin: bool,
}
impl UserSession {
pub fn new(user_id: i32, admin: bool) -> Self {
Self::new_with_expiry(user_id, admin, Duration::from_secs(60 * 60 * 24 * 7))
}
pub fn new_with_expiry(user_id: i32, admin: bool, expiry: Duration) -> Self {
UserSession {
user_id,
created: Instant::now(),
expiry,
decrypted_session_cookie: Uuid::new_v4(),
admin,
}
}
pub fn is_expired(&self) -> bool {
self.created.elapsed() > self.expiry
}
pub fn get_decrypted_cookie(&self) -> String {
self.decrypted_session_cookie.to_string()
}
}
impl Display for UserSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"UserSession for id={} ({}) authenticated using cookie={}",
self.user_id,
if self.admin { "admin" } else { "normal user" },
self.decrypted_session_cookie
)?;
Ok(())
}
}
pub enum UserSessionError {
MissingPassword,
Invalid,
}
impl TryFrom<(user::Model, String)> for UserSession {
type Error = UserSessionError;
fn try_from(value: (user::Model, String)) -> Result<Self, Self::Error> {
let (user, pw) = value;
match utils::verify_password(
pw.as_bytes(),
user.password.ok_or(Self::Error::MissingPassword)?.as_str(),
) {
true => Ok(UserSession::new(user.id, user.is_admin)),
false => Err(Self::Error::Invalid),
}
}
}
pub type CurrentUser = user::Model;
#[derive(Clone, Deserialize, Serialize, Debug)]
pub struct LoginForm {
pub username: String,
pub password: String,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
pub struct UserForm {
pub username: String,
pub password: String,
pub admin: bool,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
pub struct LinkForm {
pub source: Option<String>,
pub target: String,
}