rvault/src/main.rs

116 lines
3.4 KiB
Rust

use axum::{
extract::Request,
http::StatusCode,
middleware::{self, Next},
response::{IntoResponse, Response},
routing::get,
Router,
};
use log::*;
use std::{env, net::SocketAddr, str::FromStr};
use storage::DatabaseDriver;
use tokio::{net::TcpListener, signal};
use crate::common::HttpError;
mod auth;
mod common;
mod engines;
mod identity;
mod storage;
mod sys;
#[tokio::main]
async fn main() {
let _ = dotenvy::dotenv();
// To be configured via environment variables
// choose from (highest to lowest): error, warn, info, debug, trace, off
// env::set_var("RUST_LOG", "trace"); // TODO: Remove to respect user configuration
// env::set_var("DATABASE_URL", "sqlite:test.db"); // TODO: move to .env
env_logger::init();
// Listen on all IPv4 and IPv6 interfaces on port 8200 by default
let listen_addr = env::var("LISTEN_ADDR").unwrap_or("[::]:8200".to_string());
let listen_addr = SocketAddr::from_str(&listen_addr).expect("Failed to parse LISTEN_ADDR");
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = storage::create_pool(db_url).await;
// build our application with routes
let app = Router::new()
.route("/", get(root))
.nest("/v1/auth", auth::auth_router(pool.clone()))
.nest("/v1/identity", identity::identity_router(pool.clone()))
.nest("/v1/sys", sys::sys_router(pool.clone()))
.nest("/v1", engines::secrets_router(pool.clone())) // mountable secret backends
.fallback(fallback_route_unknown)
.layer(middleware::from_fn(set_default_content_type_json))
.with_state(pool.clone());
warn!("Listening on {}", listen_addr.to_string());
// Start listening
let listener = TcpListener::bind(listen_addr).await.unwrap();
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal(pool))
.await
.unwrap();
}
async fn set_default_content_type_json(
mut req: Request,
next: Next,
) -> Result<impl IntoResponse, Response> {
if req.headers().get("content-type").is_none() {
let headers = req.headers_mut();
// debug!("Request header: \n{:?}", headers);
headers.insert("content-type", "application/json".parse().unwrap());
}
Ok(next.run(req).await)
}
async fn shutdown_signal(pool: DatabaseDriver) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
warn!("Closing database pool");
pool.close().await;
}
/// Fallback route for unknown routes
/// Note: `/v1/*` is handled by [`engines::secrets_router`]
async fn fallback_route_unknown(req: Request) -> Response {
log::error!(
"Route not found: {} {}, payload {:?}",
req.method(),
req.uri(),
req.body()
);
HttpError::simple(StatusCode::NOT_FOUND, "Route not implemented")
}
/// basic handler that responds with a static string
async fn root() -> &'static str {
info!("Hello world");
"Hello, World!"
}