116 lines
3.4 KiB
Rust
116 lines
3.4 KiB
Rust
use axum::{
|
|
extract::Request,
|
|
http::StatusCode,
|
|
middleware::{self, Next},
|
|
response::{IntoResponse, Response},
|
|
routing::get,
|
|
Router,
|
|
};
|
|
use log::*;
|
|
use std::{env, net::SocketAddr, str::FromStr};
|
|
use storage::DatabaseDriver;
|
|
use tokio::{net::TcpListener, signal};
|
|
|
|
use crate::common::HttpError;
|
|
|
|
mod auth;
|
|
mod common;
|
|
mod engines;
|
|
mod identity;
|
|
mod storage;
|
|
mod sys;
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
let _ = dotenvy::dotenv();
|
|
// To be configured via environment variables
|
|
// choose from (highest to lowest): error, warn, info, debug, trace, off
|
|
// env::set_var("RUST_LOG", "trace"); // TODO: Remove to respect user configuration
|
|
// env::set_var("DATABASE_URL", "sqlite:test.db"); // TODO: move to .env
|
|
env_logger::init();
|
|
|
|
// Listen on all IPv4 and IPv6 interfaces on port 8200 by default
|
|
let listen_addr = env::var("LISTEN_ADDR").unwrap_or("[::]:8200".to_string());
|
|
let listen_addr = SocketAddr::from_str(&listen_addr).expect("Failed to parse LISTEN_ADDR");
|
|
|
|
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
|
|
|
|
let pool = storage::create_pool(db_url).await;
|
|
|
|
// build our application with routes
|
|
let app = Router::new()
|
|
.route("/", get(root))
|
|
.nest("/v1/auth", auth::auth_router(pool.clone()))
|
|
.nest("/v1/identity", identity::identity_router(pool.clone()))
|
|
.nest("/v1/sys", sys::sys_router(pool.clone()))
|
|
.nest("/v1", engines::secrets_router(pool.clone())) // mountable secret backends
|
|
.fallback(fallback_route_unknown)
|
|
.layer(middleware::from_fn(set_default_content_type_json))
|
|
.with_state(pool.clone());
|
|
|
|
warn!("Listening on {}", listen_addr.to_string());
|
|
// Start listening
|
|
let listener = TcpListener::bind(listen_addr).await.unwrap();
|
|
axum::serve(listener, app)
|
|
.with_graceful_shutdown(shutdown_signal(pool))
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
async fn set_default_content_type_json(
|
|
mut req: Request,
|
|
next: Next,
|
|
) -> Result<impl IntoResponse, Response> {
|
|
if req.headers().get("content-type").is_none() {
|
|
let headers = req.headers_mut();
|
|
// debug!("Request header: \n{:?}", headers);
|
|
headers.insert("content-type", "application/json".parse().unwrap());
|
|
}
|
|
|
|
Ok(next.run(req).await)
|
|
}
|
|
|
|
async fn shutdown_signal(pool: DatabaseDriver) {
|
|
let ctrl_c = async {
|
|
signal::ctrl_c()
|
|
.await
|
|
.expect("failed to install Ctrl+C handler");
|
|
};
|
|
|
|
#[cfg(unix)]
|
|
let terminate = async {
|
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
|
.expect("failed to install signal handler")
|
|
.recv()
|
|
.await;
|
|
};
|
|
|
|
#[cfg(not(unix))]
|
|
let terminate = std::future::pending::<()>();
|
|
|
|
tokio::select! {
|
|
_ = ctrl_c => {},
|
|
_ = terminate => {},
|
|
}
|
|
warn!("Closing database pool");
|
|
pool.close().await;
|
|
}
|
|
|
|
/// Fallback route for unknown routes
|
|
/// Note: `/v1/*` is handled by [`engines::secrets_router`]
|
|
async fn fallback_route_unknown(req: Request) -> Response {
|
|
log::error!(
|
|
"Route not found: {} {}, payload {:?}",
|
|
req.method(),
|
|
req.uri(),
|
|
req.body()
|
|
);
|
|
|
|
HttpError::simple(StatusCode::NOT_FOUND, "Route not implemented")
|
|
}
|
|
|
|
/// basic handler that responds with a static string
|
|
async fn root() -> &'static str {
|
|
info!("Hello world");
|
|
"Hello, World!"
|
|
}
|