feat: Graceful shutdown

This commit is contained in:
Laurenz 2024-05-05 15:02:29 +02:00
parent 53fe085e2e
commit 55270c0637
2 changed files with 36 additions and 9 deletions

8
Cargo.lock generated
View file

@ -2217,18 +2217,18 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"
version = "0.7.32" version = "0.7.33"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" checksum = "087eca3c1eaf8c47b94d02790dd086cd594b912d2043d4de4bfdd466b3befb7c"
dependencies = [ dependencies = [
"zerocopy-derive", "zerocopy-derive",
] ]
[[package]] [[package]]
name = "zerocopy-derive" name = "zerocopy-derive"
version = "0.7.32" version = "0.7.33"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" checksum = "6f4b6c273f496d8fd4eaf18853e6b448760225dc030ff2c485a786859aea6393"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",

View file

@ -1,7 +1,8 @@
use axum::{extract::Request, http::StatusCode, routing::get, Router}; use axum::{extract::Request, http::StatusCode, routing::get, Router};
use log::*; use log::*;
use storage::DatabaseDriver;
use std::{env, net::SocketAddr, str::FromStr}; use std::{env, net::SocketAddr, str::FromStr};
use tokio::net::TcpListener; use tokio::{net::TcpListener, signal};
mod auth; mod auth;
mod common; mod common;
@ -15,10 +16,10 @@ async fn main() {
// To be configured via environment variables // To be configured via environment variables
// choose from (highest to lowest): error, warn, info, debug, trace, off // choose from (highest to lowest): error, warn, info, debug, trace, off
env::set_var("RUST_LOG", "trace"); // TODO: Remove to respect user configuration env::set_var("RUST_LOG", "trace"); // TODO: Remove to respect user configuration
env::set_var("DATABASE_URL", "sqlite:test.db"); // TODO: move to .env // env::set_var("DATABASE_URL", "sqlite:test.db"); // TODO: move to .env
env_logger::init(); env_logger::init();
// Listen on all IPv4 and IPv6 interfaces on port 8200 // Listen on all IPv4 and IPv6 interfaces on port 8200 by default
let listen_addr = env::var("LISTEN_ADDR").unwrap_or("[::]:8200".to_string()); // Do not change let listen_addr = env::var("LISTEN_ADDR").unwrap_or("[::]:8200".to_string()); // Do not change
let listen_addr = SocketAddr::from_str(&listen_addr).expect("Failed to parse LISTEN_ADDR"); let listen_addr = SocketAddr::from_str(&listen_addr).expect("Failed to parse LISTEN_ADDR");
@ -35,12 +36,38 @@ async fn main() {
.nest("/v1", engines::secrets_router(pool.clone())) // mountable secret backends .nest("/v1", engines::secrets_router(pool.clone())) // mountable secret backends
// .route("/v1/kv-v2/data/foo", post(baz)) // .route("/v1/kv-v2/data/foo", post(baz))
.fallback(fallback_route_unknown) .fallback(fallback_route_unknown)
.with_state(pool); .with_state(pool.clone());
warn!("Listening on {}", listen_addr.to_string()); warn!("Listening on {}", listen_addr.to_string());
// Start listening // Start listening
let listener = TcpListener::bind(listen_addr).await.unwrap(); let listener = TcpListener::bind(listen_addr).await.unwrap();
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).with_graceful_shutdown(shutdown_signal(pool)).await.unwrap();
}
async fn shutdown_signal(pool: DatabaseDriver) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
warn!("Closing database pool");
pool.close().await;
} }
async fn fallback_route_unknown(req: Request) -> (StatusCode, &'static str) { async fn fallback_route_unknown(req: Request) -> (StatusCode, &'static str) {