diff --git a/src/auth.rs b/src/auth.rs index 1f41886..2ff4592 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,12 +1,14 @@ -pub(crate) mod token; pub mod auth_extractor; +pub(crate) mod token; -use axum::Router; use crate::auth::token::*; use crate::storage::DbPool; +use axum::Router; /// Authentication routes pub fn auth_router(pool: DbPool) -> Router { // The token auth router handles all token-related authentication routes - Router::new().nest("/token", token_auth_router(pool.clone())).with_state(pool) + Router::new() + .nest("/token", token_auth_router(pool.clone())) + .with_state(pool) } diff --git a/src/auth/auth_extractor.rs b/src/auth/auth_extractor.rs index 8b5e8dd..df62c2d 100644 --- a/src/auth/auth_extractor.rs +++ b/src/auth/auth_extractor.rs @@ -27,11 +27,15 @@ impl FromRequestParts for AuthInfo { } } +/// Extracts the headers from request and returns the result from inspect_with_header function. pub async fn inspect_req(state: &DbPool, req: &Request) -> Result { let header = req.headers(); inspect_with_header(state, header).await } +/// Inspects the request headers and extracts authentication information. +/// Returns an `AuthInfo` struct containing the token and roles if successful. +/// If the authorization header is missing or invalid, it returns a `StatusCode::UNAUTHORIZED`. pub async fn inspect_with_header( state: &DbPool, header: &HeaderMap, diff --git a/src/auth/token.rs b/src/auth/token.rs index 53819cd..40db8ce 100644 --- a/src/auth/token.rs +++ b/src/auth/token.rs @@ -1,21 +1,20 @@ -use std::ops::Index; +use crate::storage::DbPool; use axum::extract::{Path, Query, State}; -use axum::{Json, Router}; +use axum::http::StatusCode; use axum::response::{IntoResponse, NoContent, Response}; use axum::routing::post; -use axum::http::StatusCode; +use axum::{Json, Router}; use log::error; +use rand::{Rng, distributions::Alphanumeric}; use serde::{Deserialize, Serialize}; use sqlx::Error; -use rand::{distributions::Alphanumeric, Rng}; +use std::ops::Index; use uuid::Uuid; -use crate::storage::DbPool; - #[derive(Debug, Serialize)] pub struct IdentityDTO { id: String, - name: String + name: String, } #[derive(Debug)] @@ -68,11 +67,14 @@ fn get_random_string(len: usize) -> String { pub async fn create_root_token_if_none_exist(pool: &DbPool) -> bool { // Check if a root token already exists let exists = sqlx::query!( - r#"SELECT service_token.* FROM service_token, service_token_role_membership + r#"SELECT service_token.* FROM service_token, service_token_role_membership WHERE service_token.id = service_token_role_membership.token_id AND service_token_role_membership.role_name = 'root' - LIMIT 1"#).fetch_one(pool).await - .is_ok(); + LIMIT 1"# + ) + .fetch_one(pool) + .await + .is_ok(); if exists { return false; } @@ -109,7 +111,10 @@ async fn create_root_token(pool: &DbPool) -> Result { /// Gets the current time in seconds since unix epoch fn get_time_as_int() -> i64 { - std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64 + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 } /// Gets the type of token. (The first character of the key always specifies the type) @@ -122,7 +127,8 @@ fn get_token_type(token: &TokenDTO) -> Result { error!("Unsupported token type"); return Err("Unsupported token type"); } - }.to_string()) + } + .to_string()) } /// Retrieves a token from the database using its key. @@ -138,12 +144,19 @@ pub async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result Vec { +pub async fn get_roles_from_token(token: &TokenDTO, pool: &DbPool) -> Vec { let result = sqlx::query_as!( - TokenRoleMembershipDTO, - r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#, - token.id).fetch_all(pool).await; - result.unwrap_or(Vec::new()).iter().map(|r| r.role_name.to_string()).collect() + TokenRoleMembershipDTO, + r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#, + token.id + ) + .fetch_all(pool) + .await; + result + .unwrap_or(Vec::new()) + .iter() + .map(|r| r.role_name.to_string()) + .collect() } /// Return a router, that may be used to route traffic to the corresponding handlers @@ -153,13 +166,12 @@ pub fn token_auth_router(pool: DbPool) -> Router { .with_state(pool) } - /// Handles the `/auth/token/lookup` endpoint. /// Retrieves the token and its associated roles from the database using the provided token key. /// The output format does not yet match the openBao specification and is for testing only! async fn post_lookup( State(pool): State, - Json(body): Json + Json(body): Json, ) -> Response { let token_str = body.token; // Validate the token string @@ -201,7 +213,6 @@ async fn post_create_role() {} async fn get_lookup() {} - async fn get_lookup_self() {} async fn post_lookup_self() {} @@ -220,9 +231,7 @@ async fn post_revoke_orphan() {} async fn post_revoke_self() {} -async fn get_roles() { - -} +async fn get_roles() {} async fn get_role_by_name() {} diff --git a/src/main.rs b/src/main.rs index ba594a4..3a23e6f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,19 @@ #![forbid(unsafe_code)] +use crate::auth::auth_extractor::AuthInfo; +use crate::common::HttpError; use axum::{ extract::Request, http::StatusCode, middleware::{self, Next}, response::{IntoResponse, Response}, routing::get, - Router + Router, }; use log::*; use std::{env, net::SocketAddr, str::FromStr}; use storage::DbPool; use tokio::{net::TcpListener, signal}; -use crate::auth::auth_extractor::AuthInfo; -use crate::common::HttpError; mod auth; mod common;