Compare commits

..

No commits in common. "623cc2bbaa2f246459c2cfaee20e8323a496a0e0" and "47f8e01210fd2a933b0fd304ea0083db32cf98b0" have entirely different histories.

4 changed files with 44 additions and 115 deletions

View file

@ -1,14 +1,16 @@
pub mod auth_extractor;
pub(crate) mod token;
pub mod auth_extractor;
use axum::Router;
use crate::auth::token::*;
use crate::storage::DbPool;
use axum::Router;
/// Authentication routes
// route prefix: `/auth/token/`
// mod token;
// use self::token::token_auth_router;
pub fn auth_router(pool: DbPool) -> Router<DbPool> {
// The token auth router handles all token-related authentication routes
Router::new()
.nest("/token", token_auth_router(pool.clone()))
.with_state(pool)
Router::new().nest("/token", token_auth_router(pool.clone())).with_state(pool)
// .nest("/token", token_auth_router())
}

View file

@ -6,14 +6,12 @@ use axum::http::request::Parts;
use axum::http::{HeaderMap, Request, StatusCode, header};
use std::fmt::Debug;
/// AuthInfo is an extractor that retrieves authentication information from the request.
#[derive(Debug)]
pub struct AuthInfo {
token: TokenDTO,
roles: Vec<String>,
}
/// Extracts authentication information from the request parts.
impl FromRequestParts<DbPool> for AuthInfo {
type Rejection = StatusCode;
@ -27,15 +25,11 @@ impl FromRequestParts<DbPool> for AuthInfo {
}
}
/// Extracts the headers from request and returns the result from inspect_with_header function.
pub async fn inspect_req(state: &DbPool, req: &Request<Body>) -> Result<AuthInfo, StatusCode> {
let header = req.headers();
inspect_with_header(state, header).await
}
/// Inspects the request headers and extracts authentication information.
/// Returns an `AuthInfo` struct containing the token and roles if successful.
/// If the authorization header is missing or invalid, it returns a `StatusCode::UNAUTHORIZED`.
pub async fn inspect_with_header(
state: &DbPool,
header: &HeaderMap,

View file

@ -1,22 +1,26 @@
use crate::storage::DbPool;
use std::ops::Index;
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::{Json, Router};
use axum::response::{IntoResponse, NoContent, Response};
use axum::routing::post;
use axum::{Json, Router};
use log::error;
use rand::{Rng, distributions::Alphanumeric};
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use sqlx::Error;
use std::ops::Index;
use rand::{distributions::Alphanumeric, Rng};
use uuid::Uuid;
use crate::storage::DbPool;
#[derive(Debug, Serialize)]
enum TokenType {
}
#[derive(Debug)]
pub struct IdentityDTO {
id: String,
name: String,
name: String
}
#[derive(Debug)]
pub struct TokenDTO {
key: String,
@ -32,27 +36,11 @@ pub struct TokenRoleMembershipDTO {
token_id: String,
}
/// Represents a request body for the `/auth/token/lookup` endpoint.
#[derive(Deserialize)]
struct RequestBodyPostLookup {
token: String,
}
/// Represents the response body for the `/auth/token/lookup` endpoint.
#[derive(Serialize)]
struct TokenLookupResponse {
id: String,
type_name: String,
roles: Vec<String>,
}
/// Represents an error response for the API.
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
/// Generates a random string of the specified length using alphanumeric characters.
// TODO: Make string generation secure
fn get_random_string(len: usize) -> String {
rand::thread_rng()
@ -62,62 +50,47 @@ fn get_random_string(len: usize) -> String {
.collect()
}
/// Creates a root token if none exists in the database.
/// Returns true if a new root token was created, false if one already exists.
// Returns if a token was created or not. Prints out the created token to the console.
pub async fn create_root_token_if_none_exist(pool: &DbPool) -> bool {
// Check if a root token already exists
let exists = sqlx::query!(
r#"SELECT service_token.* FROM service_token, service_token_role_membership
WHERE service_token.id = service_token_role_membership.token_id AND
service_token_role_membership.role_name = 'root'
LIMIT 1"#
)
.fetch_one(pool)
.await
LIMIT 1"#).fetch_one(pool).await
.is_ok();
if exists {
return false;
}
// If no root token exists, create one
let result = create_root_token(pool).await;
if result.is_err() {
let error = result.err().unwrap();
// Log the error and panic
error!("create_root_token failed: {:?}", error);
panic!("create_root_token failed: {:?}", error);
}
// If successful, print the root token. This will only happen once.
println!("\n\nYour root token is: {}", result.unwrap());
println!("It will only be displayed once!\n\n");
true
}
/// Creates a root token in the database.
// Return the token key if successful
async fn create_root_token(pool: &DbPool) -> Result<String, Error> {
let id = Uuid::new_v4().to_string();
let key = "s.".to_string() + &get_random_string(24);
// Insert the root token into the database
let result = sqlx::query!(r#"
INSERT INTO service_token (id, key) VALUES ($1, $2);
INSERT INTO service_token_role_membership (token_id, role_name) VALUES ($3, 'root');
"#, id, key, id).execute(pool).await;
// If the insert was successful, return the key
if result.is_ok() {
return Ok(key);
}
// Else, return the error
Err(result.unwrap_err())
}
/// Gets the current time in seconds since unix epoch
// Gets the current time in seconds since unix epoch
fn get_time_as_int() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64
}
/// Gets the type of token. (The first character of the key always specifies the type)
fn get_token_type(token: &TokenDTO) -> Result<String, &str> {
Ok(match token.key.clone().chars().next().unwrap_or('?') {
's' => "service",
@ -127,13 +100,9 @@ fn get_token_type(token: &TokenDTO) -> Result<String, &str> {
error!("Unsupported token type");
return Err("Unsupported token type");
}
}
.to_string())
}.to_string())
}
/// Retrieves a token from the database using its key.
/// If the token is found and not expired, it returns the token.
/// Else, it returns an error.
pub async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result<TokenDTO, Error> {
let time = get_time_as_int();
sqlx::query_as!(
@ -142,66 +111,29 @@ pub async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result<TokenD
token_key, time).fetch_one(pool).await
}
/// Retrieves the roles associated with a given token from the database.
/// If the token does not exist, it returns an empty vector.
pub async fn get_roles_from_token(token: &TokenDTO, pool:&DbPool) -> Vec<String> {
let result = sqlx::query_as!(
TokenRoleMembershipDTO,
r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#,
token.id
)
.fetch_all(pool)
.await;
result
.unwrap_or(Vec::new())
.iter()
.map(|r| r.role_name.to_string())
.collect()
token.id).fetch_all(pool).await;
result.unwrap_or(Vec::new()).iter().map(|r| r.role_name.to_string()).collect()
}
/// Return a router, that may be used to route traffic to the corresponding handlers
pub fn token_auth_router(pool: DbPool) -> Router<DbPool> {
Router::new()
.route("/lookup", post(post_lookup))
.with_state(pool)
}
/// Handles the `/auth/token/lookup` endpoint.
/// Retrieves the token and its associated roles from the database using the provided token key.
/// The output format does not yet match the openBao specification and is for testing only!
async fn post_lookup(
State(pool): State<DbPool>,
Json(body): Json<RequestBodyPostLookup>,
) -> Response {
let token_str = body.token;
// Validate the token string
match get_token_from_key(&token_str, &pool).await {
// If the token is found, retrieve its type and roles
Ok(token) => {
let type_name = get_token_type(&token).unwrap_or_else(|_| String::from("Unknown"));
let roles = get_roles_from_token(&token, &pool).await;
let resp = TokenLookupResponse {
id: token.id,
type_name,
roles,
};
// Return the token information as a JSON response
(StatusCode::OK, axum::Json(resp)).into_response()
}
// If the token is not found, return a 404 Not Found error
Err(e) => {
error!("Failed to retrieve token: {:?}", e);
let err = ErrorResponse {
error: "Failed to retrieve token".to_string(),
};
(StatusCode::NOT_FOUND, axum::Json(err)).into_response()
}
}
}
Json(body): Json<RequestBodyPostLookup>
) -> Result<Response, ()> {
let token = body.token;
//
// The following functions are placeholders for the various token-related operations.
//
Ok(IntoResponse::into_response(token))
}
async fn get_accessors() {}
@ -213,6 +145,7 @@ async fn post_create_role() {}
async fn get_lookup() {}
async fn get_lookup_self() {}
async fn post_lookup_self() {}

View file

@ -1,19 +1,19 @@
#![forbid(unsafe_code)]
use crate::auth::auth_extractor::AuthInfo;
use crate::common::HttpError;
use axum::{
extract::Request,
http::StatusCode,
middleware::{self, Next},
response::{IntoResponse, Response},
routing::get,
Router,
Router
};
use log::*;
use std::{env, net::SocketAddr, str::FromStr};
use storage::DbPool;
use tokio::{net::TcpListener, signal};
use crate::auth::auth_extractor::AuthInfo;
use crate::common::HttpError;
mod auth;
mod common;