From cf8380db3593180103d94cdf1c554ed3173f9b71 Mon Sep 17 00:00:00 2001 From: "D. Scott Boggs" Date: Mon, 26 Jun 2023 15:59:51 -0400 Subject: [PATCH] tracks are now relative to an authenticated user --- server/Cargo.lock | 12 +++ server/Cargo.toml | 1 + server/src/api/auth.rs | 28 ++++-- server/src/api/tracks.rs | 166 ++++++++++++++++++++++++++++-------- server/src/entities/user.rs | 25 ++++++ server/src/error.rs | 2 + server/src/main.rs | 2 +- 7 files changed, 189 insertions(+), 47 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index 7f9f8e9..8aa8a7d 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -546,6 +546,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_deref" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcdbcee2d9941369faba772587a565f4f534e42cb8d17e5295871de730163b2b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "devise" version = "0.4.1" @@ -1155,6 +1166,7 @@ dependencies = [ "bcrypt", "chrono", "derive_builder", + "derive_deref", "either", "femme", "log", diff --git a/server/Cargo.toml b/server/Cargo.toml index 68e15ca..f34bc6b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -15,6 +15,7 @@ path = "src/main.rs" [dependencies] bcrypt = "0.14.0" chrono = "0.4.26" +derive_deref = "1.1.1" femme = "2.2.1" log = { version = "0.4.19", features = ["kv_unstable", "kv_unstable_serde"] } sea-orm-migration = "0.11.3" diff --git a/server/src/api/auth.rs b/server/src/api/auth.rs index c9ab4a5..6bb75e9 100644 --- a/server/src/api/auth.rs +++ b/server/src/api/auth.rs @@ -1,3 +1,4 @@ +use derive_deref::Deref; use log::warn; use rocket::{ http::{Cookie, CookieJar, Status}, @@ -38,7 +39,10 @@ pub(super) async fn login( let Some(user) = users.get(0) else { return Ok(Status::Unauthorized); }; - cookies.add_private(Cookie::new("user_id", user.id.to_string())); + cookies.add_private(Cookie::new( + "user", + serde_json::to_string(&user).map_err(Error::from)?, + )); Ok(Status::Ok) } @@ -52,22 +56,28 @@ pub(super) async fn sign_up( .insert(db as &DatabaseConnection) .await .map_err(Error::from)?; - cookies.add_private(Cookie::new("user_id", user_data.id.to_string())); + cookies.add_private(Cookie::new( + "user", + serde_json::to_string(&user_data).map_err(Error::from)?, + )); Ok(()) } /// Authentication guard -struct Auth(i32); +#[derive(Deref)] +pub(super) struct Auth(user::Model); #[rocket::async_trait] impl<'r> FromRequest<'r> for Auth { type Error = (); async fn from_request(request: &'r Request<'_>) -> request::Outcome { - request - .cookies() - .get_private("user_id") - .and_then(|val| val.value().parse().ok()) - .map(|id| Auth(id)) - .into_outcome((Status::Unauthorized, ())) + let unauthorized = (Status::Unauthorized, ()); + let Some(user) = request.cookies().get_private("user") else { + return request::Outcome::Failure(unauthorized); + }; + serde_json::from_str(user.value()) + .ok() + .map(|user| Auth(user)) + .into_outcome(unauthorized) } } diff --git a/server/src/api/tracks.rs b/server/src/api/tracks.rs index edfa616..7c59ca2 100644 --- a/server/src/api/tracks.rs +++ b/server/src/api/tracks.rs @@ -1,51 +1,72 @@ +use std::convert::Infallible; +use std::default::default; + +use crate::api::auth::Auth; use crate::api::{self, error::ApiResult}; use crate::entities::{prelude::*, *}; use crate::error::Error; use either::Either::{self, Left, Right}; use rocket::http::Status; use rocket::{serde::json::Json, State}; -use sea_orm::{prelude::*, DatabaseConnection}; +use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement, TryIntoModel}; use tokio::sync::broadcast::Sender; use super::update::Update; +use super::ErrorResponder; #[get("/")] pub(super) async fn all_tracks( db: &State, + authorized_user: Auth, ) -> ApiResult>> { let db = db as &DatabaseConnection; - let tracks = Tracks::find().all(db).await.unwrap(); + let tracks = authorized_user + .find_related(Tracks) + .all(db) + .await + .map_err(Error::from)?; Ok(Json(tracks)) } +async fn get_track_check_user( + db: &DatabaseConnection, + track_id: i32, + user: &user::Model, +) -> Result, Either> { + if let Some(Some(user)) = user + .find_related(Tracks) + .filter(tracks::Column::Id.eq(track_id)) + .one(db) + .await + .transpose() + .map(|it| it.ok()) + { + Ok(Json(user)) + } else { + Err(Left(Status::NotFound)) + } +} + #[get("/")] pub(super) async fn track( db: &State, id: i32, + auth: Auth, ) -> Result, Either> { - let db = db as &DatabaseConnection; - match Tracks::find_by_id(id).one(db).await { - Ok(Some(track)) => Ok(Json(track)), - Ok(None) => Err(Left(Status::NotFound)), - Err(err) => Err(Right(Error::from(err).into())), - } + get_track_check_user(db, id, &*auth).await } #[get("//ticks")] pub(super) async fn ticks_for_track( db: &State, id: i32, + auth: Auth, ) -> Result>, Either> { let db = db as &DatabaseConnection; - match Tracks::find_by_id(id).one(db).await { - Ok(Some(track)) => { - let result = track.find_related(Ticks).all(db).await; - match result { - Ok(ticks) => Ok(Json(ticks)), - Err(err) => Err(Right(Error::from(err).into())), - } - } - Ok(None) => Err(Left(Status::NotFound)), + let track = get_track_check_user(db, id, &*auth).await?; + let result = track.find_related(Ticks).all(db).await; + match result { + Ok(ticks) => Ok(Json(ticks)), Err(err) => Err(Right(Error::from(err).into())), } } @@ -55,13 +76,59 @@ pub(super) async fn insert_track( db: &State, tx: &State>, track: Json, -) -> ApiResult> { - let track = track.0; - let db = db as &DatabaseConnection; - let model = tracks::ActiveModel::from_json(track).map_err(Error::from)?; - let track = model.insert(db).await.map_err(Error::from)?; + auth: Auth, +) -> Result, Either> { + fn bad() -> Either { + Left(Status::BadRequest) + } + let track = track.0.as_object().ok_or_else(bad)?; + let Some(track_id) = db + .query_one(Statement::from_sql_and_values( + sea_orm::DatabaseBackend::Postgres, + "insert into $1 (user_id, track_id) values ( + $2, ( + insert into $3 ( + name, description, icon, enabled, multiple_entries_per_day, + color, order + ) values ( + $4, $5, $6, $7, $8, $9, $10, + ) returning id + ) + ) returning track_id;", + [ + user_tracks::Entity::default().table_name().into(), + auth.id.into(), + tracks::Entity::default().table_name().into(), + track.get("name").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(), + track + .get("description") + .ok_or_else(bad)? + .as_str() + .ok_or_else(bad)? + .into(), + track.get("icon").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(), + track.get("enabled").ok_or_else(bad)?.as_i64().into(), + track + .get("multiple_entries_per_day") + .ok_or_else(bad)? + .as_i64() + .into(), + track.get("color").ok_or_else(bad)?.as_i64().into(), + track.get("order").ok_or_else(bad)?.as_i64().into(), + ], + )) + .await + .map_err(|err| Right(Error::from(err).into()))? else { + return Err(Right("no value returned from track insertion query".into())); + }; + let track_id = track_id + .try_get_by_index(0) + .map_err(|err| Right(Error::from(err).into()))?; + let track = auth.authorized_track(track_id, db).await.ok_or_else(|| { + Right(format!("failed to fetch freshly inserted track with id {track_id}").into()) + })?; tx.send(Update::track_added(track.clone())) - .map_err(Error::from)?; + .map_err(|err| Right(Error::from(err).into()))?; Ok(Json(track)) } @@ -69,16 +136,21 @@ pub(super) async fn insert_track( pub(super) async fn update_track( db: &State, tx: &State>, - track: Json, -) -> ApiResult> { + track: Json, + authorized_user: Auth, +) -> Result, Either> { let db = db as &DatabaseConnection; - let track = tracks::ActiveModel::from_json(track.0) - .map_err(Error::from)? + let track = track.0; + if !authorized_user.is_authorized_for(track.id, db).await { + return Err(Left(Status::Forbidden)); + } + let track = track + .into_active_model() .update(db) .await - .map_err(Error::from)?; + .map_err(|err| Right(Error::from(err).into()))?; tx.send(Update::track_changed(track.clone())) - .map_err(Error::from)?; + .map_err(|err| Right(Error::from(err).into()))?; Ok(Json(track)) } @@ -87,11 +159,13 @@ pub(super) async fn delete_track( db: &State, tx: &State>, id: i32, + authorized_user: Auth, ) -> ApiResult { let db = db as &DatabaseConnection; - let Some(track) = Tracks::find_by_id(id).one(db).await.map_err(Error::from)? else { + let Some(track) = authorized_user.authorized_track(id, db).await else { return Ok(Status::NotFound); }; + track.clone().delete(db).await.map_err(Error::from)?; tx.send(Update::track_removed(track)).map_err(Error::from)?; Ok(Status::Ok) } @@ -101,15 +175,20 @@ pub(super) async fn ticked( db: &State, tx: &State>, id: i32, -) -> ApiResult> { + authorized_user: Auth, +) -> Result, Either> { + if !authorized_user.is_authorized_for(id, db).await { + return Err(Left(Status::Forbidden)); + } + let tick = ticks::ActiveModel::now(id); let tick = tick .insert(db as &DatabaseConnection) .await - .map_err(Error::from)? + .map_err(|err| Right(Error::from(err).into()))? .to_owned(); tx.send(Update::tick_added(tick.clone())) - .map_err(Error::from)?; + .map_err(|err| Right(Error::from(err).into()))?; Ok(Json(tick)) } @@ -121,7 +200,12 @@ pub(super) async fn ticked_on_date( year: i32, month: u32, day: u32, + authorized_user: Auth, ) -> ApiResult, Status>> { + if !authorized_user.is_authorized_for(id, db).await { + return Ok(Right(Status::Forbidden)); + } + let Some(date) = Date::from_ymd_opt(year, month, day) else { return Ok(Right(Status::BadRequest)); }; @@ -141,10 +225,14 @@ pub(super) async fn clear_all_ticks( db: &State, tx: &State>, id: i32, + authorized_user: Auth, ) -> ApiResult>>> { let db = db as &DatabaseConnection; - let Some(track) = Tracks::find_by_id(id).one(db).await.map_err(Error::from)? else { - info!(track_id = id; "couldn't drop all ticks for track; track not found"); + let Some(track) = authorized_user.authorized_track(id, db).await else { + info!( + track_id = id, user_id = authorized_user.id; + "couldn't drop all ticks for track; track not found or user not authorized" + ); return Ok(Left(Status::NotFound)); }; let ticks = track @@ -167,8 +255,12 @@ pub(super) async fn clear_all_ticks_on_day( year: i32, month: u32, day: u32, -) -> ApiResult>> { + authorized_user: Auth, +) -> ApiResult>>> { let db = db as &DatabaseConnection; + if !authorized_user.is_authorized_for(id, db).await { + return Ok(Left(Status::Forbidden)); + } let ticks = Ticks::find() .filter(ticks::Column::TrackId.eq(id)) .filter(ticks::Column::Year.eq(year)) @@ -181,5 +273,5 @@ pub(super) async fn clear_all_ticks_on_day( tick.clone().delete(db).await.map_err(Error::from)?; Update::tick_cancelled(tick).send(&tx)?; } - Ok(Json(ticks)) + Ok(Right(Json(ticks))) } diff --git a/server/src/entities/user.rs b/server/src/entities/user.rs index 7ba561a..0eae7c1 100644 --- a/server/src/entities/user.rs +++ b/server/src/entities/user.rs @@ -14,6 +14,8 @@ use crate::{ error::{self, Error}, }; +use super::tracks; + #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[sea_orm(table_name = "user")] pub struct Model { @@ -71,4 +73,27 @@ impl Model { Err(err) => Err(Right(Error::from(err).into())), } } + + pub async fn authorized_track( + &self, + track_id: i32, + db: &DatabaseConnection, + ) -> Option { + self.find_related(super::prelude::Tracks) + .filter(tracks::Column::Id.eq(track_id)) + .one(db) + .await + .ok() + .flatten() + } + pub async fn is_authorized_for(&self, track_id: i32, db: &DatabaseConnection) -> bool { + self.authorized_track(track_id, db).await.is_some() + } + + pub async fn authorized_tracks(&self, db: &DatabaseConnection) -> Vec { + self.find_related(super::prelude::Tracks) + .all(db) + .await + .unwrap_or_default() + } } diff --git a/server/src/error.rs b/server/src/error.rs index fa39e1f..a908650 100644 --- a/server/src/error.rs +++ b/server/src/error.rs @@ -21,6 +21,8 @@ pub enum Error { ChannelSendError(#[from] tokio::sync::broadcast::error::SendError), #[error(transparent)] Bcrypt(#[from] BcryptError), + #[error(transparent)] + SerdeJson(#[from] serde_json::Error), } pub type Result = std::result::Result; diff --git a/server/src/main.rs b/server/src/main.rs index 8959895..d6aad4d 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,4 @@ -#![feature(default_free_fn, proc_macro_hygiene, decl_macro)] +#![feature(default_free_fn, proc_macro_hygiene, decl_macro, never_type)] #[macro_use] extern crate rocket; mod api;