Added basic token auth extractor

This commit is contained in:
Jan Schermer 2025-05-19 13:21:10 +02:00
parent 27dcc5489d
commit 14012b155e
5 changed files with 69 additions and 14 deletions

View file

@ -1,5 +1,5 @@
CREATE TABLE identity ( CREATE TABLE identity (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY NOT NULL,
name TEXT NOT NULL name TEXT NOT NULL
); );
@ -13,7 +13,7 @@ CREATE TABLE service_token_role_membership (
); );
CREATE TABLE service_token ( CREATE TABLE service_token (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY NOT NULL,
key TEXT NOT NULL, key TEXT NOT NULL,
expiry INTEGER, expiry INTEGER,
parent_id TEXT NULL REFERENCES service_token(id) parent_id TEXT NULL REFERENCES service_token(id)

View file

@ -1,4 +1,5 @@
pub(crate) mod token; pub(crate) mod token;
pub mod auth_extractor;
use axum::Router; use axum::Router;
use crate::auth::token::*; use crate::auth::token::*;

View file

@ -0,0 +1,37 @@
use std::fmt::Debug;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::{header, StatusCode};
use crate::auth::token::{get_roles_from_token, get_token_from_key, TokenDTO};
use crate::storage::DbPool;
#[derive(Debug)]
pub struct AuthInfo {
token: TokenDTO,
roles: Vec<String>,
}
impl<> FromRequestParts<DbPool> for AuthInfo
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &DbPool) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
Some(auth_header) => {
let token = get_token_from_key(auth_header, state).await;
if token.is_err() {
return Err(StatusCode::UNAUTHORIZED);
}
let token = token.unwrap();
let roles = get_roles_from_token(&token, state).await;
Ok(Self {token, roles})
}
_ => Err(StatusCode::UNAUTHORIZED),
}
}
}

View file

@ -14,19 +14,28 @@ enum TokenType {
} }
struct Identity { #[derive(Debug)]
pub struct IdentityDTO {
id: String, id: String,
name: String name: String
} }
struct Token {
key: Option<String>, #[derive(Debug)]
id: Option<String>, pub struct TokenDTO {
key: String,
id: String,
identity_id: Option<String>, identity_id: Option<String>,
parent_id: Option<String>, parent_id: Option<String>,
expiry: Option<i64>, expiry: Option<i64>,
} }
#[derive(Debug)]
pub struct TokenRoleMembershipDTO {
role_name: String,
token_id: String,
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct RequestBodyPostLookup { struct RequestBodyPostLookup {
token: String, token: String,
@ -82,8 +91,8 @@ fn get_time_as_int() -> i64 {
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64
} }
fn get_token_type(token: &Token) -> Result<String, &str> { fn get_token_type(token: &TokenDTO) -> Result<String, &str> {
Ok(match token.key.clone().unwrap().chars().next().unwrap_or('?') { Ok(match token.key.clone().chars().next().unwrap_or('?') {
's' => "service", 's' => "service",
'b' => "batch", 'b' => "batch",
'r' => "recovery", 'r' => "recovery",
@ -94,14 +103,22 @@ fn get_token_type(token: &Token) -> Result<String, &str> {
}.to_string()) }.to_string())
} }
async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result<Token, Error> { pub async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result<TokenDTO, Error> {
let time = get_time_as_int(); let time = get_time_as_int();
sqlx::query_as!( sqlx::query_as!(
Token, TokenDTO,
r#"SELECT id, key, expiry, parent_id, identity_id FROM 'service_token' WHERE key = $1 AND (expiry = NULL OR expiry > $2) LIMIT 1"#, r#"SELECT * FROM 'service_token' WHERE key = $1 AND (expiry IS NULL OR expiry > $2) LIMIT 1"#,
token_key, time).fetch_one(pool).await token_key, time).fetch_one(pool).await
} }
pub async fn get_roles_from_token(token: &TokenDTO, pool:&DbPool) -> Vec<String> {
let result = sqlx::query_as!(
TokenRoleMembershipDTO,
r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#,
token.id).fetch_all(pool).await;
result.unwrap_or(Vec::new()).iter().map(|r| r.role_name.to_string()).collect()
}
pub fn token_auth_router(pool: DbPool) -> Router<DbPool> { pub fn token_auth_router(pool: DbPool) -> Router<DbPool> {
Router::new() Router::new()
.route("/lookup", post(post_lookup)) .route("/lookup", post(post_lookup))

View file

@ -12,7 +12,7 @@ use log::*;
use std::{env, net::SocketAddr, str::FromStr}; use std::{env, net::SocketAddr, str::FromStr};
use storage::DbPool; use storage::DbPool;
use tokio::{net::TcpListener, signal}; use tokio::{net::TcpListener, signal};
use crate::auth::auth_extractor::AuthInfo;
use crate::common::HttpError; use crate::common::HttpError;
mod auth; mod auth;
@ -118,7 +118,7 @@ async fn fallback_route_unknown(req: Request) -> Response {
} }
/// basic handler that responds with a static string /// basic handler that responds with a static string
async fn root() -> &'static str { async fn root(test: AuthInfo) -> &'static str {
info!("Hello world"); println!("Hello world, {test:?}");
"Hello, World!" "Hello, World!"
} }