add route to import the database file

This commit is contained in:
D. Scott Boggs 2023-06-13 12:26:39 -04:00
parent eff89e7100
commit 526990704b
3 changed files with 46 additions and 8 deletions

View file

@ -1,3 +1,5 @@
//! This is behind a feature gate for a reason: it's wildly unsafe and
//! insecure. It absolutely enables arbitrary sql injection.
use rocket::{http::Status, State};
use sea_orm::{ConnectionTrait, DatabaseBackend, DatabaseConnection, Statement};
@ -5,11 +7,8 @@ use crate::error::Error;
use super::error::ApiResult;
/// This is behind a feature gate for a reason: it's wildly unsafe and
/// insecure. It absolutely enables arbitrary sql injection.
#[cfg(feature = "unsafe_import")]
#[post("/import", data = "<sql_dump>")]
pub(crate) async fn import_sql(
#[post("/dump", data = "<sql_dump>")]
pub(crate) async fn sql_dump(
db: &State<DatabaseConnection>,
sql_dump: &str,
) -> ApiResult<Status> {
@ -25,3 +24,35 @@ pub(crate) async fn import_sql(
}
Ok(Status::Ok)
}
#[post("/", data="<sqlite_db>")]
pub(crate) async fn db_file(
db: &State<DatabaseConnection>,
sqlite_db: &[u8],
) -> ApiResult<Status> {
use std::{
io::Write,
process::{Command, Stdio},
};
let mut proc = Command::new("sqlite3")
.args(["-", ".dump"])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(Error::from)?;
proc.stdin
.take()
.ok_or(Error::Unreachable)?
.write_all(sqlite_db)
.map_err(Error::from)?;
let result = proc.wait_with_output().map_err(Error::from)?;
if result.status.success() {
sql_dump(db, &String::from_utf8(result.stdout).map_err(Error::from)?).await
} else {
Err(Error::SqliteCommandError(String::from_utf8_lossy(
&result.stderr,
).to_string()).into())
}
}

View file

@ -1,5 +1,6 @@
mod error;
mod groups;
#[cfg(feature = "unsafe_import")]
mod import;
mod ticks;
mod tracks;
@ -10,8 +11,6 @@ use std::net::{IpAddr, Ipv4Addr};
use rocket::fs::{FileServer, NamedFile};
use rocket::{routes, Config};
use sea_orm::DatabaseConnection;
use crate::api::import::import_sql;
use crate::error::Error;
use crate::rocket::{Build, Rocket};
@ -58,7 +57,7 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket<Build> {
.mount("/", FileServer::from("/src/public"));
#[cfg(feature = "unsafe_import")]
let it = it.mount("/api/v1", routes![import_sql]);
let it = it.mount("/api/v1/import", routes![import::sql_dump, import::db_file]);
it
}

View file

@ -1,3 +1,5 @@
use std::string;
use derive_builder::UninitializedFieldError;
#[derive(Debug, thiserror::Error)]
@ -8,6 +10,12 @@ pub enum Error {
SeaOrm(#[from] sea_orm::DbErr),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("error running sqlite command: {0}")]
SqliteCommandError(String),
#[error("BUG: this case should have been unreachable")]
Unreachable,
#[error(transparent)]
Utf8(#[from] string::FromUtf8Error),
}
pub type Result<T> = std::result::Result<T, Error>;