From c22241008ea13741616cdb6bc3907ddde43ba227 Mon Sep 17 00:00:00 2001 From: "D. Scott Boggs" Date: Sun, 18 Jun 2023 08:11:11 -0400 Subject: [PATCH] Add update subscription endpoint --- server/src/api/mod.rs | 43 ++++++++++++++++++++++++------ server/src/api/ticks.rs | 23 ++++++++++++----- server/src/api/tracks.rs | 18 ++++++++----- server/src/api/update.rs | 56 ++++++++++++++++++++++++++++++++++++++++ server/src/error.rs | 6 +++-- 5 files changed, 123 insertions(+), 23 deletions(-) create mode 100644 server/src/api/update.rs diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index 24123a2..4d60f42 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -4,25 +4,48 @@ mod groups; mod import; mod ticks; mod tracks; +pub(crate) mod update; -use std::default::default; -use std::net::{IpAddr, Ipv4Addr}; +use std::{ + default::default, + net::{IpAddr, Ipv4Addr}, + sync::Arc, +}; use crate::error::Error; -use crate::rocket::{Build, Rocket}; -use rocket::fs::{FileServer, NamedFile}; -use rocket::{routes, Config}; +use rocket::{ + fs::{FileServer, NamedFile}, + response::stream::{Event, EventStream}, + routes, Build, Config, Rocket, State, +}; use sea_orm::DatabaseConnection; pub(crate) use error::ErrorResponder; +use tokio::sync::{ + broadcast::{self, error::RecvError, Receiver}, + RwLock, +}; -use self::error::ApiResult; +use self::{error::ApiResult, update::Update}; #[get("/status")] fn status() -> &'static str { "Ok" } +#[get("/updates")] +async fn stream_updates(rx: &State>>>) -> EventStream![Event + '_] { + let rx: Arc>> = (rx as &Arc>>).clone(); + EventStream![loop { + let mut rx = rx.write().await; + match rx.recv().await { + Ok(update) => yield update.to_event(), + Err(RecvError::Closed) => break, + Err(RecvError::Lagged(count)) => yield Update::lagged(count).to_event(), + } + }] +} + #[catch(404)] async fn spa_index_redirect() -> ApiResult { Ok(NamedFile::open("/src/public/index.html") @@ -34,6 +57,7 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { use groups::*; use ticks::*; use tracks::*; + let (tx, rx) = broadcast::channel::(8); let it = rocket::build() .configure(Config { address: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), @@ -41,7 +65,9 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { }) .register("/", catchers![spa_index_redirect]) .manage(db) - .mount("/api/v1", routes![status]) + .manage(tx) + .manage(rx) + .mount("/api/v1", routes![status, stream_updates]) .mount( "/api/v1/tracks", routes![ @@ -50,7 +76,8 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { ticks_for_track, insert_track, update_track, - delete_track + delete_track, + ticked, ], ) .mount( diff --git a/server/src/api/ticks.rs b/server/src/api/ticks.rs index 4b746e5..57905e7 100644 --- a/server/src/api/ticks.rs +++ b/server/src/api/ticks.rs @@ -1,13 +1,14 @@ use either::{Either, Left, Right}; use rocket::{http::Status, serde::json::Json, State}; use sea_orm::{prelude::*, DatabaseConnection}; +use tokio::sync::broadcast::Sender; use crate::{ entities::{prelude::*, *}, error::Error, }; -use super::error::ApiResult; +use super::{error::ApiResult, update::Update}; #[get("/")] pub(super) async fn all_ticks( @@ -59,10 +60,18 @@ pub(super) async fn update_tick( } #[delete("/")] -pub(super) async fn delete_tick(db: &State, id: i32) -> ApiResult { - Ticks::delete_by_id(id) - .exec(db as &DatabaseConnection) - .await - .map_err(Error::from)?; - Ok(Status::Ok) +pub(super) async fn delete_tick( + db: &State, + tx: &State>, + id: i32, +) -> ApiResult { + let db = db as &DatabaseConnection; + let tick = Ticks::find_by_id(id).one(db).await.map_err(Error::from)?; + if let Some(tick) = tick { + tick.clone().delete(db).await.map_err(Error::from)?; + tx.send(Update::tick_cancelled(tick)).map_err(Error::from)?; + Ok(Status::Ok) + } else { + Ok(Status::NotFound) + } } diff --git a/server/src/api/tracks.rs b/server/src/api/tracks.rs index f8bfc4f..ba3793d 100644 --- a/server/src/api/tracks.rs +++ b/server/src/api/tracks.rs @@ -6,6 +6,9 @@ use rocket::http::Status; use rocket::{serde::json::Json, State}; use sea_orm::{prelude::*, DatabaseConnection}; use std::default::default; +use tokio::sync::broadcast::Sender; + +use super::update::Update; #[get("/")] pub(super) async fn all_tracks( @@ -90,13 +93,16 @@ pub(super) async fn delete_track(db: &State, id: i32) -> Api #[patch("//ticked")] pub(super) async fn ticked( db: &State, + tx: &State>, id: i32, ) -> ApiResult> { let tick = ticks::ActiveModel::now(id); - Ok(Json( - tick.insert(db as &DatabaseConnection) - .await - .map_err(Error::from)? - .to_owned(), - )) + let tick = tick + .insert(db as &DatabaseConnection) + .await + .map_err(Error::from)? + .to_owned(); + tx.send(Update::tick_added(tick.clone())) + .map_err(Error::from)?; + Ok(Json(tick)) } diff --git a/server/src/api/update.rs b/server/src/api/update.rs new file mode 100644 index 0000000..74f21e2 --- /dev/null +++ b/server/src/api/update.rs @@ -0,0 +1,56 @@ +use rocket::response::stream::Event; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::entities::ticks; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) enum Update { + TickChanged { + kind: UpdateType, + tick: ticks::Model, + }, + Lagged { + kind: UpdateType, + count: u64, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(crate) enum UpdateType { + TickAdded, + TickDropped, + Error, +} + +impl Update { + pub(crate) fn lagged(count: u64) -> Update { + Update::Lagged { + kind: UpdateType::Error, + count, + } + } + pub(crate) fn tick_added(tick: ticks::Model) -> Self { + Self::TickChanged { + kind: UpdateType::TickAdded, + tick, + } + } + pub(crate) fn tick_cancelled(tick: ticks::Model) -> Self { + Self::TickChanged { + kind: UpdateType::TickDropped, + tick, + } + } + pub(crate) fn to_event(&self) -> Event { + use Update::*; + match self { + TickChanged { kind, tick } => Event::json(tick).event(format!("{kind:?}")), + Lagged { kind, count } => { + Event::json(&json! {{"message": "error: lagged", "count": count}}) + .event(format!("{kind:?}")) + } + } + } +} diff --git a/server/src/error.rs b/server/src/error.rs index 19c4f1a..2a0e389 100644 --- a/server/src/error.rs +++ b/server/src/error.rs @@ -3,7 +3,7 @@ use std::string; use derive_builder::UninitializedFieldError; #[derive(Debug, thiserror::Error)] -pub enum Error { +pub(crate) enum Error { #[error(transparent)] Builder(#[from] UninitializedFieldError), #[error(transparent)] @@ -16,6 +16,8 @@ pub enum Error { Unreachable, #[error(transparent)] Utf8(#[from] string::FromUtf8Error), + #[error(transparent)] + ChannelSendError(#[from] tokio::sync::broadcast::error::SendError), } -pub type Result = std::result::Result; +pub(crate) type Result = std::result::Result;