Added documentation and reformatted files

This commit is contained in:
Jan Schermer 2025-06-10 19:09:39 -07:00
parent 2b47bb113e
commit 623cc2bbaa
4 changed files with 44 additions and 29 deletions

View file

@ -1,12 +1,14 @@
pub(crate) mod token;
pub mod auth_extractor; pub mod auth_extractor;
pub(crate) mod token;
use axum::Router;
use crate::auth::token::*; use crate::auth::token::*;
use crate::storage::DbPool; use crate::storage::DbPool;
use axum::Router;
/// Authentication routes /// Authentication routes
pub fn auth_router(pool: DbPool) -> Router<DbPool> { pub fn auth_router(pool: DbPool) -> Router<DbPool> {
// The token auth router handles all token-related authentication routes // The token auth router handles all token-related authentication routes
Router::new().nest("/token", token_auth_router(pool.clone())).with_state(pool) Router::new()
.nest("/token", token_auth_router(pool.clone()))
.with_state(pool)
} }

View file

@ -27,11 +27,15 @@ impl FromRequestParts<DbPool> for AuthInfo {
} }
} }
/// Extracts the headers from request and returns the result from inspect_with_header function.
pub async fn inspect_req(state: &DbPool, req: &Request<Body>) -> Result<AuthInfo, StatusCode> { pub async fn inspect_req(state: &DbPool, req: &Request<Body>) -> Result<AuthInfo, StatusCode> {
let header = req.headers(); let header = req.headers();
inspect_with_header(state, header).await inspect_with_header(state, header).await
} }
/// Inspects the request headers and extracts authentication information.
/// Returns an `AuthInfo` struct containing the token and roles if successful.
/// If the authorization header is missing or invalid, it returns a `StatusCode::UNAUTHORIZED`.
pub async fn inspect_with_header( pub async fn inspect_with_header(
state: &DbPool, state: &DbPool,
header: &HeaderMap, header: &HeaderMap,

View file

@ -1,21 +1,20 @@
use std::ops::Index; use crate::storage::DbPool;
use axum::extract::{Path, Query, State}; use axum::extract::{Path, Query, State};
use axum::{Json, Router}; use axum::http::StatusCode;
use axum::response::{IntoResponse, NoContent, Response}; use axum::response::{IntoResponse, NoContent, Response};
use axum::routing::post; use axum::routing::post;
use axum::http::StatusCode; use axum::{Json, Router};
use log::error; use log::error;
use rand::{Rng, distributions::Alphanumeric};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::Error; use sqlx::Error;
use rand::{distributions::Alphanumeric, Rng}; use std::ops::Index;
use uuid::Uuid; use uuid::Uuid;
use crate::storage::DbPool;
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct IdentityDTO { pub struct IdentityDTO {
id: String, id: String,
name: String name: String,
} }
#[derive(Debug)] #[derive(Debug)]
@ -71,7 +70,10 @@ pub async fn create_root_token_if_none_exist(pool: &DbPool) -> bool {
r#"SELECT service_token.* FROM service_token, service_token_role_membership r#"SELECT service_token.* FROM service_token, service_token_role_membership
WHERE service_token.id = service_token_role_membership.token_id AND WHERE service_token.id = service_token_role_membership.token_id AND
service_token_role_membership.role_name = 'root' service_token_role_membership.role_name = 'root'
LIMIT 1"#).fetch_one(pool).await LIMIT 1"#
)
.fetch_one(pool)
.await
.is_ok(); .is_ok();
if exists { if exists {
return false; return false;
@ -109,7 +111,10 @@ async fn create_root_token(pool: &DbPool) -> Result<String, Error> {
/// Gets the current time in seconds since unix epoch /// Gets the current time in seconds since unix epoch
fn get_time_as_int() -> i64 { fn get_time_as_int() -> i64 {
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64 std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
} }
/// Gets the type of token. (The first character of the key always specifies the type) /// Gets the type of token. (The first character of the key always specifies the type)
@ -122,7 +127,8 @@ fn get_token_type(token: &TokenDTO) -> Result<String, &str> {
error!("Unsupported token type"); error!("Unsupported token type");
return Err("Unsupported token type"); return Err("Unsupported token type");
} }
}.to_string()) }
.to_string())
} }
/// Retrieves a token from the database using its key. /// Retrieves a token from the database using its key.
@ -142,8 +148,15 @@ pub async fn get_roles_from_token(token: &TokenDTO, pool:&DbPool) -> Vec<String>
let result = sqlx::query_as!( let result = sqlx::query_as!(
TokenRoleMembershipDTO, TokenRoleMembershipDTO,
r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#, r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#,
token.id).fetch_all(pool).await; token.id
result.unwrap_or(Vec::new()).iter().map(|r| r.role_name.to_string()).collect() )
.fetch_all(pool)
.await;
result
.unwrap_or(Vec::new())
.iter()
.map(|r| r.role_name.to_string())
.collect()
} }
/// Return a router, that may be used to route traffic to the corresponding handlers /// Return a router, that may be used to route traffic to the corresponding handlers
@ -153,13 +166,12 @@ pub fn token_auth_router(pool: DbPool) -> Router<DbPool> {
.with_state(pool) .with_state(pool)
} }
/// Handles the `/auth/token/lookup` endpoint. /// Handles the `/auth/token/lookup` endpoint.
/// Retrieves the token and its associated roles from the database using the provided token key. /// Retrieves the token and its associated roles from the database using the provided token key.
/// The output format does not yet match the openBao specification and is for testing only! /// The output format does not yet match the openBao specification and is for testing only!
async fn post_lookup( async fn post_lookup(
State(pool): State<DbPool>, State(pool): State<DbPool>,
Json(body): Json<RequestBodyPostLookup> Json(body): Json<RequestBodyPostLookup>,
) -> Response { ) -> Response {
let token_str = body.token; let token_str = body.token;
// Validate the token string // Validate the token string
@ -201,7 +213,6 @@ async fn post_create_role() {}
async fn get_lookup() {} async fn get_lookup() {}
async fn get_lookup_self() {} async fn get_lookup_self() {}
async fn post_lookup_self() {} async fn post_lookup_self() {}
@ -220,9 +231,7 @@ async fn post_revoke_orphan() {}
async fn post_revoke_self() {} async fn post_revoke_self() {}
async fn get_roles() { async fn get_roles() {}
}
async fn get_role_by_name() {} async fn get_role_by_name() {}

View file

@ -1,19 +1,19 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use crate::auth::auth_extractor::AuthInfo;
use crate::common::HttpError;
use axum::{ use axum::{
extract::Request, extract::Request,
http::StatusCode, http::StatusCode,
middleware::{self, Next}, middleware::{self, Next},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::get, routing::get,
Router Router,
}; };
use log::*; use log::*;
use std::{env, net::SocketAddr, str::FromStr}; use std::{env, net::SocketAddr, str::FromStr};
use storage::DbPool; use storage::DbPool;
use tokio::{net::TcpListener, signal}; use tokio::{net::TcpListener, signal};
use crate::auth::auth_extractor::AuthInfo;
use crate::common::HttpError;
mod auth; mod auth;
mod common; mod common;