From 88b58bb19d426a012f34a48cde5631d3850572bc Mon Sep 17 00:00:00 2001 From: "D. Scott Boggs" Date: Mon, 26 Jun 2023 10:59:56 -0400 Subject: [PATCH] Add login+sign_up routes and auth guard --- server/src/api/auth.rs | 73 +++++++++++++++++++++++++++++++++++++ server/src/api/mod.rs | 2 + server/src/entities/user.rs | 11 ++++-- 3 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 server/src/api/auth.rs diff --git a/server/src/api/auth.rs b/server/src/api/auth.rs new file mode 100644 index 0000000..c9ab4a5 --- /dev/null +++ b/server/src/api/auth.rs @@ -0,0 +1,73 @@ +use log::warn; +use rocket::{ + http::{Cookie, CookieJar, Status}, + outcome::IntoOutcome, + request::{self, FromRequest}, + serde::json::Json, + Request, State, +}; +use sea_orm::{prelude::*, DatabaseConnection}; +use serde::Deserialize; + +use crate::{ + api::error::ApiResult, + entities::{prelude::*, *}, + error::Error, +}; + +#[derive(Clone, Deserialize)] +pub(super) struct LoginData { + name: String, + password: String, +} + +#[put("/", data = "", format = "application/json")] +pub(super) async fn login( + db: &State, + user_data: Json, + cookies: &CookieJar<'_>, +) -> ApiResult { + let users = User::find() + .filter(user::Column::Name.eq(&user_data.name)) + .all(db as &DatabaseConnection) + .await + .map_err(Error::from)?; + if users.len() > 1 { + warn!(count = users.len(), name = &user_data.name; "multiple entries found in database for user"); + } + let Some(user) = users.get(0) else { + return Ok(Status::Unauthorized); + }; + cookies.add_private(Cookie::new("user_id", user.id.to_string())); + Ok(Status::Ok) +} + +#[post("/", data = "", format = "application/json")] +pub(super) async fn sign_up( + db: &State, + user_data: Json, + cookies: &CookieJar<'_>, +) -> ApiResult<()> { + let user_data = user::ActiveModel::new(&user_data.name, &user_data.password)? + .insert(db as &DatabaseConnection) + .await + .map_err(Error::from)?; + cookies.add_private(Cookie::new("user_id", user_data.id.to_string())); + Ok(()) +} + +/// Authentication guard +struct Auth(i32); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Auth { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> request::Outcome { + request + .cookies() + .get_private("user_id") + .and_then(|val| val.value().parse().ok()) + .map(|id| Auth(id)) + .into_outcome((Status::Unauthorized, ())) + } +} diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index 40fb19b..f4386a8 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -1,3 +1,4 @@ +mod auth; mod error; mod groups; #[cfg(feature = "unsafe_import")] @@ -111,6 +112,7 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { "/api/v1/groups", routes![all_groups, group, insert_group, update_group, delete_group], ) + .mount("/api/v1/auth", routes![auth::login, auth::sign_up]) .mount("/", FileServer::from("/src/public")); #[cfg(feature = "unsafe_import")] diff --git a/server/src/entities/user.rs b/server/src/entities/user.rs index d8ba976..0ddc95a 100644 --- a/server/src/entities/user.rs +++ b/server/src/entities/user.rs @@ -3,19 +3,22 @@ use std::default::default; use bcrypt::*; +// TODO Add option for argon2 https://docs.rs/argon2/latest/argon2/ use either::Either::{self, Left, Right}; use rocket::response::status::Unauthorized; use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; use crate::{ api::ErrorResponder, error::{self, Error}, }; -#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[sea_orm(table_name = "user")] pub struct Model { #[sea_orm(primary_key)] + #[serde(skip_deserializing)] pub id: i32, pub name: String, pub password_hash: String, @@ -27,10 +30,10 @@ pub enum Relation {} impl ActiveModelBehavior for ActiveModel {} impl ActiveModel { - pub fn new(name: String, password: String) -> error::Result { + pub fn new(name: impl AsRef, password: impl AsRef) -> error::Result { use sea_orm::ActiveValue::Set; - let name = Set(name); - let password_hash = Set(hash(password, DEFAULT_COST + 2)?); + let name = Set(name.as_ref().to_string()); + let password_hash = Set(hash(password.as_ref(), DEFAULT_COST + 2)?); Ok(Self { name, password_hash,