Fix authenticated track insertion
This commit is contained in:
parent
205a3b165e
commit
d7285a84bb
|
@ -1,5 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
use derive_deref::Deref;
|
use derive_deref::Deref;
|
||||||
use log::warn;
|
use either::Either::{self, Right};
|
||||||
|
|
||||||
use rocket::{
|
use rocket::{
|
||||||
http::{Cookie, CookieJar, Status},
|
http::{Cookie, CookieJar, Status},
|
||||||
outcome::IntoOutcome,
|
outcome::IntoOutcome,
|
||||||
|
@ -16,6 +19,8 @@ use crate::{
|
||||||
error::Error,
|
error::Error,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::ErrorResponder;
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Deserialize)]
|
||||||
pub(super) struct LoginData {
|
pub(super) struct LoginData {
|
||||||
name: String,
|
name: String,
|
||||||
|
@ -27,22 +32,22 @@ pub(super) async fn login(
|
||||||
db: &State<DatabaseConnection>,
|
db: &State<DatabaseConnection>,
|
||||||
user_data: Json<LoginData>,
|
user_data: Json<LoginData>,
|
||||||
cookies: &CookieJar<'_>,
|
cookies: &CookieJar<'_>,
|
||||||
) -> ApiResult<Status> {
|
) -> Result<Status, Either<Status, ErrorResponder>> {
|
||||||
let users = User::find()
|
let user = Users::find()
|
||||||
.filter(user::Column::Name.eq(&user_data.name))
|
.filter(users::Column::Name.eq(&user_data.name))
|
||||||
.all(db as &DatabaseConnection)
|
.one(db as &DatabaseConnection)
|
||||||
.await
|
.await
|
||||||
.map_err(Error::from)?;
|
.map_err(|err| Right(Error::from(err).into()))?;
|
||||||
if users.len() > 1 {
|
let Some(user) = user else {
|
||||||
warn!(count = users.len(), name = &user_data.name; "multiple entries found in database for user");
|
info!(name = user_data.name; "no user found with the given name");
|
||||||
}
|
|
||||||
let Some(user) = users.get(0) else {
|
|
||||||
return Ok(Status::Unauthorized);
|
return Ok(Status::Unauthorized);
|
||||||
};
|
};
|
||||||
|
let user = user.check_password(&user_data.password)?;
|
||||||
cookies.add_private(Cookie::new(
|
cookies.add_private(Cookie::new(
|
||||||
"user",
|
"user",
|
||||||
serde_json::to_string(&user).map_err(Error::from)?,
|
serde_json::to_string(&user).map_err(|err| Right(Error::from(err).into()))?,
|
||||||
));
|
));
|
||||||
|
cookies.add(Cookie::new("name", user.name));
|
||||||
Ok(Status::Ok)
|
Ok(Status::Ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,12 +65,13 @@ pub(super) async fn sign_up(
|
||||||
"user",
|
"user",
|
||||||
serde_json::to_string(&user_data).map_err(Error::from)?,
|
serde_json::to_string(&user_data).map_err(Error::from)?,
|
||||||
));
|
));
|
||||||
|
cookies.add(Cookie::new("name", user_data.name));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Authentication guard
|
/// Authentication guard
|
||||||
#[derive(Deref)]
|
#[derive(Deref)]
|
||||||
pub(super) struct Auth(user::Model);
|
pub(super) struct Auth(users::Model);
|
||||||
|
|
||||||
#[rocket::async_trait]
|
#[rocket::async_trait]
|
||||||
impl<'r> FromRequest<'r> for Auth {
|
impl<'r> FromRequest<'r> for Auth {
|
||||||
|
@ -77,7 +83,7 @@ impl<'r> FromRequest<'r> for Auth {
|
||||||
};
|
};
|
||||||
serde_json::from_str(user.value())
|
serde_json::from_str(user.value())
|
||||||
.ok()
|
.ok()
|
||||||
.map(|user| Auth(user))
|
.map(Auth)
|
||||||
.into_outcome(unauthorized)
|
.into_outcome(unauthorized)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use std::convert::Infallible;
|
|
||||||
use std::default::default;
|
|
||||||
|
|
||||||
use crate::api::auth::Auth;
|
use crate::api::auth::Auth;
|
||||||
use crate::api::{self, error::ApiResult};
|
use crate::api::{self, error::ApiResult};
|
||||||
|
@ -8,7 +8,7 @@ use crate::error::Error;
|
||||||
use either::Either::{self, Left, Right};
|
use either::Either::{self, Left, Right};
|
||||||
use rocket::http::Status;
|
use rocket::http::Status;
|
||||||
use rocket::{serde::json::Json, State};
|
use rocket::{serde::json::Json, State};
|
||||||
use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement, TryIntoModel};
|
use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement};
|
||||||
use tokio::sync::broadcast::Sender;
|
use tokio::sync::broadcast::Sender;
|
||||||
|
|
||||||
use super::update::Update;
|
use super::update::Update;
|
||||||
|
@ -53,7 +53,7 @@ pub(super) async fn track(
|
||||||
id: i32,
|
id: i32,
|
||||||
auth: Auth,
|
auth: Auth,
|
||||||
) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> {
|
) -> Result<Json<tracks::Model>, Either<Status, api::ErrorResponder>> {
|
||||||
get_track_check_user(db, id, &*auth).await
|
get_track_check_user(db, id, &auth).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/<id>/ticks")]
|
#[get("/<id>/ticks")]
|
||||||
|
@ -63,7 +63,7 @@ pub(super) async fn ticks_for_track(
|
||||||
auth: Auth,
|
auth: Auth,
|
||||||
) -> Result<Json<Vec<ticks::Model>>, Either<Status, api::ErrorResponder>> {
|
) -> Result<Json<Vec<ticks::Model>>, Either<Status, api::ErrorResponder>> {
|
||||||
let db = db as &DatabaseConnection;
|
let db = db as &DatabaseConnection;
|
||||||
let track = get_track_check_user(db, id, &*auth).await?;
|
let track = get_track_check_user(db, id, &auth).await?;
|
||||||
let result = track.find_related(Ticks).all(db).await;
|
let result = track.find_related(Ticks).all(db).await;
|
||||||
match result {
|
match result {
|
||||||
Ok(ticks) => Ok(Json(ticks)),
|
Ok(ticks) => Ok(Json(ticks)),
|
||||||
|
@ -81,40 +81,48 @@ pub(super) async fn insert_track(
|
||||||
fn bad() -> Either<Status, ErrorResponder> {
|
fn bad() -> Either<Status, ErrorResponder> {
|
||||||
Left(Status::BadRequest)
|
Left(Status::BadRequest)
|
||||||
}
|
}
|
||||||
let track = track.0.as_object().ok_or_else(bad)?;
|
fn bad_value_for(key: &'static str) -> impl Fn() -> Either<Status, ErrorResponder> {
|
||||||
|
move || {
|
||||||
|
warn!(key = key; "bad value");
|
||||||
|
bad()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let track = track.0.as_object().ok_or_else(|| {
|
||||||
|
warn!("received value was not an object");
|
||||||
|
bad()
|
||||||
|
})?;
|
||||||
let Some(track_id) = db
|
let Some(track_id) = db
|
||||||
.query_one(Statement::from_sql_and_values(
|
.query_one(Statement::from_sql_and_values(
|
||||||
sea_orm::DatabaseBackend::Postgres,
|
sea_orm::DatabaseBackend::Postgres,
|
||||||
"insert into $1 (user_id, track_id) values (
|
r#"with track_insertion as (
|
||||||
$2, (
|
insert into tracks (name, description, icon, enabled,
|
||||||
insert into $3 (
|
multiple_entries_per_day, color, "order"
|
||||||
name, description, icon, enabled, multiple_entries_per_day,
|
|
||||||
color, order
|
|
||||||
) values (
|
) values (
|
||||||
$4, $5, $6, $7, $8, $9, $10,
|
$2, $3, $4, $5, $6, $7, $8
|
||||||
) returning id
|
) returning id
|
||||||
)
|
)
|
||||||
) returning track_id;",
|
insert into user_tracks (
|
||||||
|
user_id, track_id
|
||||||
|
) select $1, ti.id
|
||||||
|
from track_insertion ti
|
||||||
|
join track_insertion using (id);"#,
|
||||||
[
|
[
|
||||||
user_tracks::Entity::default().table_name().into(),
|
|
||||||
auth.id.into(),
|
auth.id.into(),
|
||||||
tracks::Entity::default().table_name().into(),
|
track.get("name").ok_or_else(bad_value_for("name"))?.as_str().ok_or_else(bad_value_for("name"))?.into(),
|
||||||
track.get("name").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(),
|
|
||||||
track
|
track
|
||||||
.get("description")
|
.get("description")
|
||||||
.ok_or_else(bad)?
|
.ok_or_else(bad_value_for("description"))?
|
||||||
.as_str()
|
.as_str()
|
||||||
.ok_or_else(bad)?
|
.ok_or_else(bad_value_for("description"))?
|
||||||
.into(),
|
.into(),
|
||||||
track.get("icon").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(),
|
track.get("icon").ok_or_else(bad_value_for("icon"))?.as_str().ok_or_else(bad_value_for("icon"))?.into(),
|
||||||
track.get("enabled").ok_or_else(bad)?.as_i64().into(),
|
track.get("enabled").and_then(|it| it.as_i64()).into(),
|
||||||
track
|
track
|
||||||
.get("multiple_entries_per_day")
|
.get("multiple_entries_per_day")
|
||||||
.ok_or_else(bad)?
|
.and_then(|it| it.as_i64())
|
||||||
.as_i64()
|
|
||||||
.into(),
|
.into(),
|
||||||
track.get("color").ok_or_else(bad)?.as_i64().into(),
|
track.get("color").and_then(|it| it.as_i64()).into(),
|
||||||
track.get("order").ok_or_else(bad)?.as_i64().into(),
|
track.get("order").and_then(|it| it.as_i64()).into(),
|
||||||
],
|
],
|
||||||
))
|
))
|
||||||
.await
|
.await
|
||||||
|
@ -186,7 +194,7 @@ pub(super) async fn ticked(
|
||||||
.insert(db as &DatabaseConnection)
|
.insert(db as &DatabaseConnection)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| Right(Error::from(err).into()))?
|
.map_err(|err| Right(Error::from(err).into()))?
|
||||||
.to_owned();
|
;
|
||||||
tx.send(Update::tick_added(tick.clone()))
|
tx.send(Update::tick_added(tick.clone()))
|
||||||
.map_err(|err| Right(Error::from(err).into()))?;
|
.map_err(|err| Right(Error::from(err).into()))?;
|
||||||
Ok(Json(tick))
|
Ok(Json(tick))
|
||||||
|
@ -214,7 +222,7 @@ pub(super) async fn ticked_on_date(
|
||||||
.insert(db as &DatabaseConnection)
|
.insert(db as &DatabaseConnection)
|
||||||
.await
|
.await
|
||||||
.map_err(Error::from)?
|
.map_err(Error::from)?
|
||||||
.to_owned();
|
;
|
||||||
tx.send(Update::tick_added(tick.clone()))
|
tx.send(Update::tick_added(tick.clone()))
|
||||||
.map_err(Error::from)?;
|
.map_err(Error::from)?;
|
||||||
Ok(Left(Json(tick)))
|
Ok(Left(Json(tick)))
|
||||||
|
@ -242,7 +250,7 @@ pub(super) async fn clear_all_ticks(
|
||||||
.map_err(Error::from)?;
|
.map_err(Error::from)?;
|
||||||
for tick in ticks.clone() {
|
for tick in ticks.clone() {
|
||||||
tick.clone().delete(db).await.map_err(Error::from)?;
|
tick.clone().delete(db).await.map_err(Error::from)?;
|
||||||
Update::tick_cancelled(tick).send(&tx)?;
|
Update::tick_cancelled(tick).send(tx)?;
|
||||||
}
|
}
|
||||||
Ok(Right(Json(ticks)))
|
Ok(Right(Json(ticks)))
|
||||||
}
|
}
|
||||||
|
@ -271,7 +279,7 @@ pub(super) async fn clear_all_ticks_on_day(
|
||||||
.map_err(Error::from)?;
|
.map_err(Error::from)?;
|
||||||
for tick in ticks.clone() {
|
for tick in ticks.clone() {
|
||||||
tick.clone().delete(db).await.map_err(Error::from)?;
|
tick.clone().delete(db).await.map_err(Error::from)?;
|
||||||
Update::tick_cancelled(tick).send(&tx)?;
|
Update::tick_cancelled(tick).send(tx)?;
|
||||||
}
|
}
|
||||||
Ok(Right(Json(ticks)))
|
Ok(Right(Json(ticks)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,8 +65,11 @@ impl ActiveModel {
|
||||||
impl Model {
|
impl Model {
|
||||||
pub fn check_password(
|
pub fn check_password(
|
||||||
self,
|
self,
|
||||||
|
password: impl AsRef<[u8]>,
|
||||||
|
) -> std::result::Result<Self, Either<Status, ErrorResponder>> {
|
||||||
match verify(password, &self.password_hash) {
|
match verify(password, &self.password_hash) {
|
||||||
Ok(true) => Ok(self),
|
Ok(true) => Ok(self),
|
||||||
|
Ok(false) => Err(Left(Status::Unauthorized)),
|
||||||
Err(err) => Err(Right(Error::from(err).into())),
|
Err(err) => Err(Right(Error::from(err).into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue