tracks are now relative to an authenticated user

This commit is contained in:
D. Scott Boggs 2023-06-26 15:59:51 -04:00
parent 46a9374571
commit f44f15d2b6
7 changed files with 189 additions and 47 deletions

12
server/Cargo.lock generated
View file

@ -546,6 +546,17 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "derive_deref"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcdbcee2d9941369faba772587a565f4f534e42cb8d17e5295871de730163b2b"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "devise" name = "devise"
version = "0.4.1" version = "0.4.1"
@ -1155,6 +1166,7 @@ dependencies = [
"bcrypt", "bcrypt",
"chrono", "chrono",
"derive_builder", "derive_builder",
"derive_deref",
"either", "either",
"femme", "femme",
"log", "log",

View file

@ -15,6 +15,7 @@ path = "src/main.rs"
[dependencies] [dependencies]
bcrypt = "0.14.0" bcrypt = "0.14.0"
chrono = "0.4.26" chrono = "0.4.26"
derive_deref = "1.1.1"
femme = "2.2.1" femme = "2.2.1"
log = { version = "0.4.19", features = ["kv_unstable", "kv_unstable_serde"] } log = { version = "0.4.19", features = ["kv_unstable", "kv_unstable_serde"] }
sea-orm-migration = "0.11.3" sea-orm-migration = "0.11.3"

View file

@ -1,3 +1,4 @@
use derive_deref::Deref;
use log::warn; use log::warn;
use rocket::{ use rocket::{
http::{Cookie, CookieJar, Status}, http::{Cookie, CookieJar, Status},
@ -38,7 +39,10 @@ pub(super) async fn login(
let Some(user) = users.get(0) else { let Some(user) = users.get(0) else {
return Ok(Status::Unauthorized); return Ok(Status::Unauthorized);
}; };
cookies.add_private(Cookie::new("user_id", user.id.to_string())); cookies.add_private(Cookie::new(
"user",
serde_json::to_string(&user).map_err(Error::from)?,
));
Ok(Status::Ok) Ok(Status::Ok)
} }
@ -52,22 +56,28 @@ pub(super) async fn sign_up(
.insert(db as &DatabaseConnection) .insert(db as &DatabaseConnection)
.await .await
.map_err(Error::from)?; .map_err(Error::from)?;
cookies.add_private(Cookie::new("user_id", user_data.id.to_string())); cookies.add_private(Cookie::new(
"user",
serde_json::to_string(&user_data).map_err(Error::from)?,
));
Ok(()) Ok(())
} }
/// Authentication guard /// Authentication guard
struct Auth(i32); #[derive(Deref)]
pub(super) struct Auth(user::Model);
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for Auth { impl<'r> FromRequest<'r> for Auth {
type Error = (); type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
request let unauthorized = (Status::Unauthorized, ());
.cookies() let Some(user) = request.cookies().get_private("user") else {
.get_private("user_id") return request::Outcome::Failure(unauthorized);
.and_then(|val| val.value().parse().ok()) };
.map(|id| Auth(id)) serde_json::from_str(user.value())
.into_outcome((Status::Unauthorized, ())) .ok()
.map(|user| Auth(user))
.into_outcome(unauthorized)
} }
} }

View file

@ -1,51 +1,72 @@
use std::convert::Infallible;
use std::default::default;
use crate::api::auth::Auth;
use crate::api::{self, error::ApiResult}; use crate::api::{self, error::ApiResult};
use crate::entities::{prelude::*, *}; use crate::entities::{prelude::*, *};
use crate::error::Error; use crate::error::Error;
use either::Either::{self, Left, Right}; use either::Either::{self, Left, Right};
use rocket::http::Status; use rocket::http::Status;
use rocket::{serde::json::Json, State}; use rocket::{serde::json::Json, State};
use sea_orm::{prelude::*, DatabaseConnection}; use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement, TryIntoModel};
use tokio::sync::broadcast::Sender; use tokio::sync::broadcast::Sender;
use super::update::Update; use super::update::Update;
use super::ErrorResponder;
#[get("/")] #[get("/")]
pub(super) async fn all_tracks( pub(super) async fn all_tracks(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
authorized_user: Auth,
) -> ApiResult<Json<Vec<tracks::Model>>> { ) -> ApiResult<Json<Vec<tracks::Model>>> {
let db = db as &DatabaseConnection; let db = db as &DatabaseConnection;
let tracks = Tracks::find().all(db).await.unwrap(); let tracks = authorized_user
.find_related(Tracks)
.all(db)
.await
.map_err(Error::from)?;
Ok(Json(tracks)) Ok(Json(tracks))
} }
async fn get_track_check_user(
db: &DatabaseConnection,
track_id: i32,
user: &user::Model,
) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> {
if let Some(Some(user)) = user
.find_related(Tracks)
.filter(tracks::Column::Id.eq(track_id))
.one(db)
.await
.transpose()
.map(|it| it.ok())
{
Ok(Json(user))
} else {
Err(Left(Status::NotFound))
}
}
#[get("/<id>")] #[get("/<id>")]
pub(super) async fn track( pub(super) async fn track(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
id: i32, id: i32,
auth: Auth,
) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> { ) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> {
let db = db as &DatabaseConnection; get_track_check_user(db, id, &*auth).await
match Tracks::find_by_id(id).one(db).await {
Ok(Some(track)) => Ok(Json(track)),
Ok(None) => Err(Left(Status::NotFound)),
Err(err) => Err(Right(Error::from(err).into())),
}
} }
#[get("/<id>/ticks")] #[get("/<id>/ticks")]
pub(super) async fn ticks_for_track( pub(super) async fn ticks_for_track(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
id: i32, id: i32,
auth: Auth,
) -> Result<Json<Vec<ticks::Model>>, Either<Status, api::ErrorResponder>> { ) -> Result<Json<Vec<ticks::Model>>, Either<Status, api::ErrorResponder>> {
let db = db as &DatabaseConnection; let db = db as &DatabaseConnection;
match Tracks::find_by_id(id).one(db).await { let track = get_track_check_user(db, id, &*auth).await?;
Ok(Some(track)) => { let result = track.find_related(Ticks).all(db).await;
let result = track.find_related(Ticks).all(db).await; match result {
match result { Ok(ticks) => Ok(Json(ticks)),
Ok(ticks) => Ok(Json(ticks)),
Err(err) => Err(Right(Error::from(err).into())),
}
}
Ok(None) => Err(Left(Status::NotFound)),
Err(err) => Err(Right(Error::from(err).into())), Err(err) => Err(Right(Error::from(err).into())),
} }
} }
@ -55,13 +76,59 @@ pub(super) async fn insert_track(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
tx: &State<Sender<Update>>, tx: &State<Sender<Update>>,
track: Json<serde_json::Value>, track: Json<serde_json::Value>,
) -> ApiResult<Json<tracks::Model>> { auth: Auth,
let track = track.0; ) -> Result<Json<tracks::Model>, Either<Status, ErrorResponder>> {
let db = db as &DatabaseConnection; fn bad() -> Either<Status, ErrorResponder> {
let model = tracks::ActiveModel::from_json(track).map_err(Error::from)?; Left(Status::BadRequest)
let track = model.insert(db).await.map_err(Error::from)?; }
let track = track.0.as_object().ok_or_else(bad)?;
let Some(track_id) = db
.query_one(Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
"insert into $1 (user_id, track_id) values (
$2, (
insert into $3 (
name, description, icon, enabled, multiple_entries_per_day,
color, order
) values (
$4, $5, $6, $7, $8, $9, $10,
) returning id
)
) returning track_id;",
[
user_tracks::Entity::default().table_name().into(),
auth.id.into(),
tracks::Entity::default().table_name().into(),
track.get("name").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(),
track
.get("description")
.ok_or_else(bad)?
.as_str()
.ok_or_else(bad)?
.into(),
track.get("icon").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(),
track.get("enabled").ok_or_else(bad)?.as_i64().into(),
track
.get("multiple_entries_per_day")
.ok_or_else(bad)?
.as_i64()
.into(),
track.get("color").ok_or_else(bad)?.as_i64().into(),
track.get("order").ok_or_else(bad)?.as_i64().into(),
],
))
.await
.map_err(|err| Right(Error::from(err).into()))? else {
return Err(Right("no value returned from track insertion query".into()));
};
let track_id = track_id
.try_get_by_index(0)
.map_err(|err| Right(Error::from(err).into()))?;
let track = auth.authorized_track(track_id, db).await.ok_or_else(|| {
Right(format!("failed to fetch freshly inserted track with id {track_id}").into())
})?;
tx.send(Update::track_added(track.clone())) tx.send(Update::track_added(track.clone()))
.map_err(Error::from)?; .map_err(|err| Right(Error::from(err).into()))?;
Ok(Json(track)) Ok(Json(track))
} }
@ -69,16 +136,21 @@ pub(super) async fn insert_track(
pub(super) async fn update_track( pub(super) async fn update_track(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
tx: &State<Sender<Update>>, tx: &State<Sender<Update>>,
track: Json<serde_json::Value>, track: Json<tracks::Model>,
) -> ApiResult<Json<tracks::Model>> { authorized_user: Auth,
) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> {
let db = db as &DatabaseConnection; let db = db as &DatabaseConnection;
let track = tracks::ActiveModel::from_json(track.0) let track = track.0;
.map_err(Error::from)? if !authorized_user.is_authorized_for(track.id, db).await {
return Err(Left(Status::Forbidden));
}
let track = track
.into_active_model()
.update(db) .update(db)
.await .await
.map_err(Error::from)?; .map_err(|err| Right(Error::from(err).into()))?;
tx.send(Update::track_changed(track.clone())) tx.send(Update::track_changed(track.clone()))
.map_err(Error::from)?; .map_err(|err| Right(Error::from(err).into()))?;
Ok(Json(track)) Ok(Json(track))
} }
@ -87,11 +159,13 @@ pub(super) async fn delete_track(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
tx: &State<Sender<Update>>, tx: &State<Sender<Update>>,
id: i32, id: i32,
authorized_user: Auth,
) -> ApiResult<Status> { ) -> ApiResult<Status> {
let db = db as &DatabaseConnection; let db = db as &DatabaseConnection;
let Some(track) = Tracks::find_by_id(id).one(db).await.map_err(Error::from)? else { let Some(track) = authorized_user.authorized_track(id, db).await else {
return Ok(Status::NotFound); return Ok(Status::NotFound);
}; };
track.clone().delete(db).await.map_err(Error::from)?;
tx.send(Update::track_removed(track)).map_err(Error::from)?; tx.send(Update::track_removed(track)).map_err(Error::from)?;
Ok(Status::Ok) Ok(Status::Ok)
} }
@ -101,15 +175,20 @@ pub(super) async fn ticked(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
tx: &State<Sender<Update>>, tx: &State<Sender<Update>>,
id: i32, id: i32,
) -> ApiResult<Json<ticks::Model>> { authorized_user: Auth,
) -> Result<Json<ticks::Model>, Either<Status, api::ErrorResponder>> {
if !authorized_user.is_authorized_for(id, db).await {
return Err(Left(Status::Forbidden));
}
let tick = ticks::ActiveModel::now(id); let tick = ticks::ActiveModel::now(id);
let tick = tick let tick = tick
.insert(db as &DatabaseConnection) .insert(db as &DatabaseConnection)
.await .await
.map_err(Error::from)? .map_err(|err| Right(Error::from(err).into()))?
.to_owned(); .to_owned();
tx.send(Update::tick_added(tick.clone())) tx.send(Update::tick_added(tick.clone()))
.map_err(Error::from)?; .map_err(|err| Right(Error::from(err).into()))?;
Ok(Json(tick)) Ok(Json(tick))
} }
@ -121,7 +200,12 @@ pub(super) async fn ticked_on_date(
year: i32, year: i32,
month: u32, month: u32,
day: u32, day: u32,
authorized_user: Auth,
) -> ApiResult<Either<Json<ticks::Model>, Status>> { ) -> ApiResult<Either<Json<ticks::Model>, Status>> {
if !authorized_user.is_authorized_for(id, db).await {
return Ok(Right(Status::Forbidden));
}
let Some(date) = Date::from_ymd_opt(year, month, day) else { let Some(date) = Date::from_ymd_opt(year, month, day) else {
return Ok(Right(Status::BadRequest)); return Ok(Right(Status::BadRequest));
}; };
@ -141,10 +225,14 @@ pub(super) async fn clear_all_ticks(
db: &State<DatabaseConnection>, db: &State<DatabaseConnection>,
tx: &State<Sender<Update>>, tx: &State<Sender<Update>>,
id: i32, id: i32,
authorized_user: Auth,
) -> ApiResult<Either<Status, Json<Vec<ticks::Model>>>> { ) -> ApiResult<Either<Status, Json<Vec<ticks::Model>>>> {
let db = db as &DatabaseConnection; let db = db as &DatabaseConnection;
let Some(track) = Tracks::find_by_id(id).one(db).await.map_err(Error::from)? else { let Some(track) = authorized_user.authorized_track(id, db).await else {
info!(track_id = id; "couldn't drop all ticks for track; track not found"); info!(
track_id = id, user_id = authorized_user.id;
"couldn't drop all ticks for track; track not found or user not authorized"
);
return Ok(Left(Status::NotFound)); return Ok(Left(Status::NotFound));
}; };
let ticks = track let ticks = track
@ -167,8 +255,12 @@ pub(super) async fn clear_all_ticks_on_day(
year: i32, year: i32,
month: u32, month: u32,
day: u32, day: u32,
) -> ApiResult<Json<Vec<ticks::Model>>> { authorized_user: Auth,
) -> ApiResult<Either<Status, Json<Vec<ticks::Model>>>> {
let db = db as &DatabaseConnection; let db = db as &DatabaseConnection;
if !authorized_user.is_authorized_for(id, db).await {
return Ok(Left(Status::Forbidden));
}
let ticks = Ticks::find() let ticks = Ticks::find()
.filter(ticks::Column::TrackId.eq(id)) .filter(ticks::Column::TrackId.eq(id))
.filter(ticks::Column::Year.eq(year)) .filter(ticks::Column::Year.eq(year))
@ -181,5 +273,5 @@ pub(super) async fn clear_all_ticks_on_day(
tick.clone().delete(db).await.map_err(Error::from)?; tick.clone().delete(db).await.map_err(Error::from)?;
Update::tick_cancelled(tick).send(&tx)?; Update::tick_cancelled(tick).send(&tx)?;
} }
Ok(Json(ticks)) Ok(Right(Json(ticks)))
} }

View file

@ -14,6 +14,8 @@ use crate::{
error::{self, Error}, error::{self, Error},
}; };
use super::tracks;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "user")] #[sea_orm(table_name = "user")]
pub struct Model { pub struct Model {
@ -71,4 +73,27 @@ impl Model {
Err(err) => Err(Right(Error::from(err).into())), Err(err) => Err(Right(Error::from(err).into())),
} }
} }
pub async fn authorized_track(
&self,
track_id: i32,
db: &DatabaseConnection,
) -> Option<tracks::Model> {
self.find_related(super::prelude::Tracks)
.filter(tracks::Column::Id.eq(track_id))
.one(db)
.await
.ok()
.flatten()
}
pub async fn is_authorized_for(&self, track_id: i32, db: &DatabaseConnection) -> bool {
self.authorized_track(track_id, db).await.is_some()
}
pub async fn authorized_tracks(&self, db: &DatabaseConnection) -> Vec<tracks::Model> {
self.find_related(super::prelude::Tracks)
.all(db)
.await
.unwrap_or_default()
}
} }

View file

@ -21,6 +21,8 @@ pub enum Error {
ChannelSendError(#[from] tokio::sync::broadcast::error::SendError<crate::api::update::Update>), ChannelSendError(#[from] tokio::sync::broadcast::error::SendError<crate::api::update::Update>),
#[error(transparent)] #[error(transparent)]
Bcrypt(#[from] BcryptError), Bcrypt(#[from] BcryptError),
#[error(transparent)]
SerdeJson(#[from] serde_json::Error),
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;

View file

@ -1,4 +1,4 @@
#![feature(default_free_fn, proc_macro_hygiene, decl_macro)] #![feature(default_free_fn, proc_macro_hygiene, decl_macro, never_type)]
#[macro_use] #[macro_use]
extern crate rocket; extern crate rocket;
mod api; mod api;