From 55270c06375b5e0b7b09767c3a51eef3d218246b Mon Sep 17 00:00:00 2001 From: C0ffeeCode Date: Sun, 5 May 2024 15:02:29 +0200 Subject: [PATCH] feat: Graceful shutdown --- Cargo.lock | 8 ++++---- src/main.rs | 37 ++++++++++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3725bc7..b245402 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2217,18 +2217,18 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "zerocopy" -version = "0.7.32" +version = "0.7.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +checksum = "087eca3c1eaf8c47b94d02790dd086cd594b912d2043d4de4bfdd466b3befb7c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.32" +version = "0.7.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +checksum = "6f4b6c273f496d8fd4eaf18853e6b448760225dc030ff2c485a786859aea6393" dependencies = [ "proc-macro2", "quote", diff --git a/src/main.rs b/src/main.rs index 261a121..46e49eb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,8 @@ use axum::{extract::Request, http::StatusCode, routing::get, Router}; use log::*; +use storage::DatabaseDriver; use std::{env, net::SocketAddr, str::FromStr}; -use tokio::net::TcpListener; +use tokio::{net::TcpListener, signal}; mod auth; mod common; @@ -15,10 +16,10 @@ async fn main() { // To be configured via environment variables // choose from (highest to lowest): error, warn, info, debug, trace, off env::set_var("RUST_LOG", "trace"); // TODO: Remove to respect user configuration - env::set_var("DATABASE_URL", "sqlite:test.db"); // TODO: move to .env + // env::set_var("DATABASE_URL", "sqlite:test.db"); // TODO: move to .env env_logger::init(); - // Listen on all IPv4 and IPv6 interfaces on port 8200 + // Listen on all IPv4 and IPv6 interfaces on port 8200 by default let listen_addr = env::var("LISTEN_ADDR").unwrap_or("[::]:8200".to_string()); // Do not change let listen_addr = SocketAddr::from_str(&listen_addr).expect("Failed to parse LISTEN_ADDR"); @@ -35,12 +36,38 @@ async fn main() { .nest("/v1", engines::secrets_router(pool.clone())) // mountable secret backends // .route("/v1/kv-v2/data/foo", post(baz)) .fallback(fallback_route_unknown) - .with_state(pool); + .with_state(pool.clone()); warn!("Listening on {}", listen_addr.to_string()); // Start listening let listener = TcpListener::bind(listen_addr).await.unwrap(); - axum::serve(listener, app).await.unwrap(); + axum::serve(listener, app).with_graceful_shutdown(shutdown_signal(pool)).await.unwrap(); +} + +async fn shutdown_signal(pool: DatabaseDriver) { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + warn!("Closing database pool"); + pool.close().await; } async fn fallback_route_unknown(req: Request) -> (StatusCode, &'static str) {