Fix authenticated track insertion

This commit is contained in:
D. Scott Boggs 2023-06-27 14:20:43 -04:00
parent 205a3b165e
commit d7285a84bb
3 changed files with 58 additions and 41 deletions

View file

@ -1,5 +1,8 @@
use derive_deref::Deref; use derive_deref::Deref;
use log::warn; use either::Either::{self, Right};
use rocket::{ use rocket::{
http::{Cookie, CookieJar, Status}, http::{Cookie, CookieJar, Status},
outcome::IntoOutcome, outcome::IntoOutcome,
@ -16,6 +19,8 @@ use crate::{
error::Error, error::Error,
}; };
use super::ErrorResponder;
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub(super) struct LoginData { pub(super) struct LoginData {
name: String, name: String,
@ -27,22 +32,22 @@ pub(super) async fn login(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
user_data: Json<LoginData>, user_data: Json<LoginData>,
cookies: &CookieJar<'_>, cookies: &CookieJar<'_>,
) -> ApiResult<Status> { ) -> Result<Status, Either<Status, ErrorResponder>> {
let users = User::find() let user = Users::find()
.filter(user::Column::Name.eq(&user_data.name)) .filter(users::Column::Name.eq(&user_data.name))
.all(db as &DatabaseConnection) .one(db as &DatabaseConnection)
.await .await
.map_err(Error::from)?; .map_err(|err| Right(Error::from(err).into()))?;
if users.len() > 1 { let Some(user) = user else {
warn!(count = users.len(), name = &user_data.name; "multiple entries found in database for user"); info!(name = user_data.name; "no user found with the given name");
}
let Some(user) = users.get(0) else {
return Ok(Status::Unauthorized); return Ok(Status::Unauthorized);
}; };
let user = user.check_password(&user_data.password)?;
cookies.add_private(Cookie::new( cookies.add_private(Cookie::new(
"user", "user",
serde_json::to_string(&user).map_err(Error::from)?, serde_json::to_string(&user).map_err(|err| Right(Error::from(err).into()))?,
)); ));
cookies.add(Cookie::new("name", user.name));
Ok(Status::Ok) Ok(Status::Ok)
} }
@ -60,12 +65,13 @@ pub(super) async fn sign_up(
"user", "user",
serde_json::to_string(&user_data).map_err(Error::from)?, serde_json::to_string(&user_data).map_err(Error::from)?,
)); ));
cookies.add(Cookie::new("name", user_data.name));
Ok(()) Ok(())
} }
/// Authentication guard /// Authentication guard
#[derive(Deref)] #[derive(Deref)]
pub(super) struct Auth(user::Model); pub(super) struct Auth(users::Model);
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for Auth { impl<'r> FromRequest<'r> for Auth {
@ -77,7 +83,7 @@ impl<'r> FromRequest<'r> for Auth {
}; };
serde_json::from_str(user.value()) serde_json::from_str(user.value())
.ok() .ok()
.map(|user| Auth(user)) .map(Auth)
.into_outcome(unauthorized) .into_outcome(unauthorized)
} }
} }

View file

@ -1,5 +1,5 @@
use std::convert::Infallible;
use std::default::default;
use crate::api::auth::Auth; use crate::api::auth::Auth;
use crate::api::{self, error::ApiResult}; use crate::api::{self, error::ApiResult};
@ -8,7 +8,7 @@ use crate::error::Error;
use either::Either::{self, Left, Right}; use either::Either::{self, Left, Right};
use rocket::http::Status; use rocket::http::Status;
use rocket::{serde::json::Json, State}; use rocket::{serde::json::Json, State};
use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement, TryIntoModel}; use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement};
use tokio::sync::broadcast::Sender; use tokio::sync::broadcast::Sender;
use super::update::Update; use super::update::Update;
@ -53,7 +53,7 @@ pub(super) async fn track(
id: i32, id: i32,
auth: Auth, auth: Auth,
) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> { ) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> {
get_track_check_user(db, id, &*auth).await get_track_check_user(db, id, &auth).await
} }
#[get("/<id>/ticks")] #[get("/<id>/ticks")]
@ -63,7 +63,7 @@ pub(super) async fn ticks_for_track(
auth: Auth, auth: Auth,
) -> Result<Json<Vec<ticks::Model>>, Either<Status, api::ErrorResponder>> { ) -> Result<Json<Vec<ticks::Model>>, Either<Status, api::ErrorResponder>> {
let db = db as &DatabaseConnection; let db = db as &DatabaseConnection;
let track = get_track_check_user(db, id, &*auth).await?; let track = get_track_check_user(db, id, &auth).await?;
let result = track.find_related(Ticks).all(db).await; let result = track.find_related(Ticks).all(db).await;
match result { match result {
Ok(ticks) => Ok(Json(ticks)), Ok(ticks) => Ok(Json(ticks)),
@ -81,40 +81,48 @@ pub(super) async fn insert_track(
fn bad() -> Either<Status, ErrorResponder> { fn bad() -> Either<Status, ErrorResponder> {
Left(Status::BadRequest) Left(Status::BadRequest)
} }
let track = track.0.as_object().ok_or_else(bad)?; fn bad_value_for(key: &'static str) -> impl Fn() -> Either<Status, ErrorResponder> {
move || {
warn!(key = key; "bad value");
bad()
}
}
let track = track.0.as_object().ok_or_else(|| {
warn!("received value was not an object");
bad()
})?;
let Some(track_id) = db let Some(track_id) = db
.query_one(Statement::from_sql_and_values( .query_one(Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres, sea_orm::DatabaseBackend::Postgres,
"insert into $1 (user_id, track_id) values ( r#"with track_insertion as (
$2, ( insert into tracks (name, description, icon, enabled,
insert into $3 ( multiple_entries_per_day, color, "order"
name, description, icon, enabled, multiple_entries_per_day,
color, order
) values ( ) values (
$4, $5, $6, $7, $8, $9, $10, $2, $3, $4, $5, $6, $7, $8
) returning id ) returning id
) )
) returning track_id;", insert into user_tracks (
user_id, track_id
) select $1, ti.id
from track_insertion ti
join track_insertion using (id);"#,
[ [
user_tracks::Entity::default().table_name().into(),
auth.id.into(), auth.id.into(),
tracks::Entity::default().table_name().into(), track.get("name").ok_or_else(bad_value_for("name"))?.as_str().ok_or_else(bad_value_for("name"))?.into(),
track.get("name").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(),
track track
.get("description") .get("description")
.ok_or_else(bad)? .ok_or_else(bad_value_for("description"))?
.as_str() .as_str()
.ok_or_else(bad)? .ok_or_else(bad_value_for("description"))?
.into(), .into(),
track.get("icon").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(), track.get("icon").ok_or_else(bad_value_for("icon"))?.as_str().ok_or_else(bad_value_for("icon"))?.into(),
track.get("enabled").ok_or_else(bad)?.as_i64().into(), track.get("enabled").and_then(|it| it.as_i64()).into(),
track track
.get("multiple_entries_per_day") .get("multiple_entries_per_day")
.ok_or_else(bad)? .and_then(|it| it.as_i64())
.as_i64()
.into(), .into(),
track.get("color").ok_or_else(bad)?.as_i64().into(), track.get("color").and_then(|it| it.as_i64()).into(),
track.get("order").ok_or_else(bad)?.as_i64().into(), track.get("order").and_then(|it| it.as_i64()).into(),
], ],
)) ))
.await .await
@ -186,7 +194,7 @@ pub(super) async fn ticked(
.insert(db as &DatabaseConnection) .insert(db as &DatabaseConnection)
.await .await
.map_err(|err| Right(Error::from(err).into()))? .map_err(|err| Right(Error::from(err).into()))?
.to_owned(); ;
tx.send(Update::tick_added(tick.clone())) tx.send(Update::tick_added(tick.clone()))
.map_err(|err| Right(Error::from(err).into()))?; .map_err(|err| Right(Error::from(err).into()))?;
Ok(Json(tick)) Ok(Json(tick))
@ -214,7 +222,7 @@ pub(super) async fn ticked_on_date(
.insert(db as &DatabaseConnection) .insert(db as &DatabaseConnection)
.await .await
.map_err(Error::from)? .map_err(Error::from)?
.to_owned(); ;
tx.send(Update::tick_added(tick.clone())) tx.send(Update::tick_added(tick.clone()))
.map_err(Error::from)?; .map_err(Error::from)?;
Ok(Left(Json(tick))) Ok(Left(Json(tick)))
@ -242,7 +250,7 @@ pub(super) async fn clear_all_ticks(
.map_err(Error::from)?; .map_err(Error::from)?;
for tick in ticks.clone() { for tick in ticks.clone() {
tick.clone().delete(db).await.map_err(Error::from)?; tick.clone().delete(db).await.map_err(Error::from)?;
Update::tick_cancelled(tick).send(&tx)?; Update::tick_cancelled(tick).send(tx)?;
} }
Ok(Right(Json(ticks))) Ok(Right(Json(ticks)))
} }
@ -271,7 +279,7 @@ pub(super) async fn clear_all_ticks_on_day(
.map_err(Error::from)?; .map_err(Error::from)?;
for tick in ticks.clone() { for tick in ticks.clone() {
tick.clone().delete(db).await.map_err(Error::from)?; tick.clone().delete(db).await.map_err(Error::from)?;
Update::tick_cancelled(tick).send(&tx)?; Update::tick_cancelled(tick).send(tx)?;
} }
Ok(Right(Json(ticks))) Ok(Right(Json(ticks)))
} }

View file

@ -65,8 +65,11 @@ impl ActiveModel {
impl Model { impl Model {
pub fn check_password( pub fn check_password(
self, self,
password: impl AsRef<[u8]>,
) -> std::result::Result<Self, Either<Status, ErrorResponder>> {
match verify(password, &self.password_hash) { match verify(password, &self.password_hash) {
Ok(true) => Ok(self), Ok(true) => Ok(self),
Ok(false) => Err(Left(Status::Unauthorized)),
Err(err) => Err(Right(Error::from(err).into())), Err(err) => Err(Right(Error::from(err).into())),
} }
} }