add tests

This commit is contained in:
D. Scott Boggs 2023-07-22 15:46:52 -04:00
parent 1c400e7ffa
commit a8e4e5145b
17 changed files with 283 additions and 66 deletions

14
server/Cargo.lock generated
View file

@ -1177,6 +1177,7 @@ dependencies = [
"serde_json",
"thiserror",
"tokio",
"tokio-test",
]
[[package]]
@ -2637,6 +2638,19 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-test"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
dependencies = [
"async-stream",
"bytes",
"futures-core",
"tokio",
"tokio-stream",
]
[[package]]
name = "tokio-util"
version = "0.7.8"

View file

@ -21,6 +21,7 @@ log = { version = "0.4.19", features = ["kv_unstable", "kv_unstable_serde"] }
sea-orm-migration = "0.11.3"
serde_json = "1.0.96"
thiserror = "1.0.40"
tokio-test = "0.4.2"
[dependencies.derive_builder]
version = "0.12.0"

View file

@ -1,9 +0,0 @@
# DEVELOPMENT shell environment
{ pkgs ? import <nixpkgs> {} }:
pkgs.mkShell {
nativeBuildInputs = with pkgs.buildPackages; [
clang
];
}

View file

@ -1,8 +1,7 @@
use derive_deref::Deref;
use either::Either::{self, Right};
use log::{as_debug, as_serde, debug};
use rocket::{
http::{Cookie, CookieJar, Status},
outcome::IntoOutcome,
@ -11,7 +10,7 @@ use rocket::{
Request, State,
};
use sea_orm::{prelude::*, DatabaseConnection};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use crate::{
api::error::ApiResult,
@ -21,10 +20,10 @@ use crate::{
use super::ErrorResponder;
#[derive(Clone, Deserialize)]
pub(super) struct LoginData {
name: String,
password: String,
#[derive(Clone, Deserialize, Serialize)]
pub struct LoginData {
pub name: String,
pub password: String,
}
#[put("/", data = "<user_data>", format = "application/json")]
@ -61,6 +60,7 @@ pub(super) async fn sign_up(
.insert(db as &DatabaseConnection)
.await
.map_err(Error::from)?;
debug!(user = as_serde!(user_data); "user added");
cookies.add_private(Cookie::new(
"user",
serde_json::to_string(&user_data).map_err(Error::from)?,
@ -73,6 +73,23 @@ pub(super) async fn sign_up(
#[derive(Deref)]
pub(super) struct Auth(users::Model);
#[derive(Deserialize)]
struct AuthData {
id: i32,
name: String,
password_hash: String,
}
impl From<AuthData> for Auth {
fn from(value: AuthData) -> Self {
Auth(users::Model {
id: value.id,
name: value.name,
password_hash: value.password_hash,
})
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Auth {
type Error = ();
@ -81,9 +98,13 @@ impl<'r> FromRequest<'r> for Auth {
let Some(user) = request.cookies().get_private("user") else {
return request::Outcome::Failure(unauthorized);
};
serde_json::from_str(user.value())
let user = user.value();
debug!(user = user; "user retreived from private cookie");
let result = serde_json::from_str(user)
.ok()
.map(Auth)
.into_outcome(unauthorized)
.map(|model: AuthData| model.into())
.into_outcome(unauthorized);
debug!(result = as_debug!(result); "auth FromRequest return value");
result
}
}

View file

@ -28,6 +28,8 @@ use tokio::sync::broadcast::{self, error::RecvError, Sender};
use self::{error::ApiResult, update::Update};
use log::{as_debug, as_serde, debug, trace};
pub use auth::LoginData;
#[get("/status")]
fn status() -> &'static str {
"Ok"
@ -74,7 +76,7 @@ fn get_secret() -> [u8; 32] {
data
}
pub(crate) fn start_server(db: DatabaseConnection) -> Rocket<Build> {
pub fn start_server(db: DatabaseConnection) -> Rocket<Build> {
use groups::*;
use ticks::*;
use tracks::*;

View file

@ -1,11 +1,10 @@
use crate::api::auth::Auth;
use crate::api::{self, error::ApiResult};
use crate::entities::{prelude::*, *};
use crate::error::Error;
use either::Either::{self, Left, Right};
use log::as_debug;
use log::{as_serde, debug, warn};
use rocket::http::Status;
use rocket::{serde::json::Json, State};
use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement};
@ -78,6 +77,11 @@ pub(super) async fn insert_track(
track: Json<serde_json::Value>,
auth: Auth,
) -> Result<Json<tracks::Model>, Either<Status, ErrorResponder>> {
debug!(
user=as_serde!(*auth),
track=as_serde!(track.0);
"authenticated user making track insertion request"
);
fn bad() -> Either<Status, ErrorResponder> {
Left(Status::BadRequest)
}
@ -105,17 +109,28 @@ pub(super) async fn insert_track(
user_id, track_id
) select $1, ti.id
from track_insertion ti
join track_insertion using (id);"#,
join track_insertion using (id)
returning id;"#,
[
auth.id.into(),
track.get("name").ok_or_else(bad_value_for("name"))?.as_str().ok_or_else(bad_value_for("name"))?.into(),
track
.get("name")
.ok_or_else(bad_value_for("name"))?
.as_str()
.ok_or_else(bad_value_for("name"))?
.into(),
track
.get("description")
.ok_or_else(bad_value_for("description"))?
.as_str()
.ok_or_else(bad_value_for("description"))?
.into(),
track.get("icon").ok_or_else(bad_value_for("icon"))?.as_str().ok_or_else(bad_value_for("icon"))?.into(),
track
.get("icon")
.ok_or_else(bad_value_for("icon"))?
.as_str()
.ok_or_else(bad_value_for("icon"))?
.into(),
track.get("enabled").and_then(|it| it.as_i64()).into(),
track
.get("multiple_entries_per_day")
@ -126,17 +141,21 @@ pub(super) async fn insert_track(
],
))
.await
.map_err(|err| Right(Error::from(err).into()))? else {
return Err(Right("no value returned from track insertion query".into()));
};
.map_err(|err| Right(Error::from(err).into()))?
else {
return Err(Right("no value returned from track insertion query".into()));
};
trace!("query completed");
let track_id = track_id
.try_get_by_index(0)
.map_err(|err| Right(Error::from(err).into()))?;
trace!(track_id = track_id; "freshly inserted track ID");
let track = auth.authorized_track(track_id, db).await.ok_or_else(|| {
Right(format!("failed to fetch freshly inserted track with id {track_id}").into())
})?;
tx.send(Update::track_added(track.clone()))
.map_err(|err| Right(Error::from(err).into()))?;
if let Err(err) = tx.send(Update::track_added(track.clone())) {
warn!(err = as_debug!(err); "error sending updates to subscribed channels");
}
Ok(Json(track))
}
@ -157,8 +176,9 @@ pub(super) async fn update_track(
.update(db)
.await
.map_err(|err| Right(Error::from(err).into()))?;
tx.send(Update::track_changed(track.clone()))
.map_err(|err| Right(Error::from(err).into()))?;
if let Err(err) = tx.send(Update::track_changed(track.clone())) {
warn!(err = as_debug!(err); "error sending updates to subscribed channels");
}
Ok(Json(track))
}
@ -193,10 +213,10 @@ pub(super) async fn ticked(
let tick = tick
.insert(db as &DatabaseConnection)
.await
.map_err(|err| Right(Error::from(err).into()))?
;
tx.send(Update::tick_added(tick.clone()))
.map_err(|err| Right(Error::from(err).into()))?;
if let Err(err) = tx.send(Update::tick_added(tick.clone())) {
warn!(err = as_debug!(err); "error sending updates to subscribed channels");
}
Ok(Json(tick))
}
@ -221,10 +241,10 @@ pub(super) async fn ticked_on_date(
let tick = tick
.insert(db as &DatabaseConnection)
.await
.map_err(Error::from)?
;
tx.send(Update::tick_added(tick.clone()))
.map_err(Error::from)?;
if let Err(err) = tx.send(Update::tick_added(tick.clone())) {
warn!(err = as_debug!(err); "error sending updates to subscribed channels");
}
Ok(Left(Json(tick)))
}
@ -250,7 +270,9 @@ pub(super) async fn clear_all_ticks(
.map_err(Error::from)?;
for tick in ticks.clone() {
tick.clone().delete(db).await.map_err(Error::from)?;
Update::tick_cancelled(tick).send(tx)?;
if let Err(err) = Update::tick_cancelled(tick).send(tx) {
warn!(err = as_debug!(err); "error sending updates to subscribed channels");
}
}
Ok(Right(Json(ticks)))
}
@ -279,7 +301,9 @@ pub(super) async fn clear_all_ticks_on_day(
.map_err(Error::from)?;
for tick in ticks.clone() {
tick.clone().delete(db).await.map_err(Error::from)?;
Update::tick_cancelled(tick).send(tx)?;
if let Err(err) = Update::tick_cancelled(tick).send(tx) {
warn!(err = as_debug!(err); "error sending updates to subscribed channels");
}
}
Ok(Right(Json(ticks)))
}

View file

@ -91,8 +91,13 @@ impl Update {
}
pub fn send(self, tx: &Sender<Self>) -> Result<()> {
let count = tx.send(self.clone())?;
trace!(sent_to = count, update = as_serde!(self); "sent update to SSE channel");
let receiver_count = tx.receiver_count();
if receiver_count > 0 {
trace!(receiver_count = receiver_count, update = as_serde!(self); "sending update");
let count = tx.send(self.clone())?;
} else {
trace!("no update receivers, skipping message");
}
Ok(())
}
}

View file

@ -1,3 +1,6 @@
use crate::migrator::Migrator;
use sea_orm_migration::MigratorTrait;
use sea_orm_migration::SchemaManager;
use std::{
env,
ffi::{OsStr, OsString},
@ -5,6 +8,8 @@ use std::{
io::Read,
};
use sea_orm::{Database, DatabaseConnection};
// from https://doc.rust-lang.org/std/ffi/struct.OsString.html
fn concat_os_strings(a: &OsStr, b: &OsStr) -> OsString {
let mut ret = OsString::with_capacity(a.len() + b.len()); // This will allocate
@ -57,3 +62,31 @@ pub fn connection_url() -> String {
.unwrap_or(5432_u16);
format!("postgres://{user}:{password}@{host}:{port}/{db}")
}
pub async fn connection() -> DatabaseConnection {
Database::connect(connection_url())
.await
.expect("db connection")
}
pub async fn migrated() -> DatabaseConnection {
let db = connection().await;
let schema_manager = SchemaManager::new(&db);
Migrator::refresh(&db).await.expect("migration");
assert!(schema_manager
.has_table("tracks")
.await
.expect("fetch tracks table"));
assert!(schema_manager
.has_table("ticks")
.await
.expect("fetch ticks table"));
assert!(schema_manager
.has_table("groups")
.await
.expect("fetch groups table"));
assert!(schema_manager
.has_table("track2_groups")
.await
.expect("fetch track2groups table"));
db
}

8
server/src/lib.rs Normal file
View file

@ -0,0 +1,8 @@
#![feature(proc_macro_hygiene, decl_macro, never_type)]
#[macro_use]
extern crate rocket;
pub mod api;
pub mod db;
pub mod entities;
pub mod error;
mod migrator;

View file

@ -6,32 +6,9 @@ mod db;
mod entities;
mod error;
mod migrator;
use crate::migrator::Migrator;
use sea_orm::Database;
use sea_orm_migration::prelude::*;
#[launch]
async fn rocket_defines_the_main_fn() -> _ {
femme::with_level(femme::LevelFilter::Debug);
let url = db::connection_url();
let db = Database::connect(url).await.expect("db connection");
let schema_manager = SchemaManager::new(&db);
Migrator::refresh(&db).await.expect("migration");
assert!(schema_manager
.has_table("tracks")
.await
.expect("fetch tracks table"));
assert!(schema_manager
.has_table("ticks")
.await
.expect("fetch ticks table"));
assert!(schema_manager
.has_table("groups")
.await
.expect("fetch groups table"));
assert!(schema_manager
.has_table("track2_groups")
.await
.expect("fetch track2groups table"));
api::start_server(db)
femme::with_level(femme::LevelFilter::Trace);
api::start_server(db::migrated().await)
}