diff --git a/migrations/20250407112735_BasicIdentity.sql b/migrations/20250407112735_BasicIdentity.sql index 9228778..1f6d687 100644 --- a/migrations/20250407112735_BasicIdentity.sql +++ b/migrations/20250407112735_BasicIdentity.sql @@ -1,5 +1,5 @@ CREATE TABLE identity ( - id TEXT PRIMARY KEY, + id TEXT PRIMARY KEY NOT NULL, name TEXT NOT NULL ); @@ -13,7 +13,7 @@ CREATE TABLE service_token_role_membership ( ); CREATE TABLE service_token ( - id TEXT PRIMARY KEY, + id TEXT PRIMARY KEY NOT NULL, key TEXT NOT NULL, expiry INTEGER, parent_id TEXT NULL REFERENCES service_token(id) diff --git a/src/auth.rs b/src/auth.rs index 4c57870..ea850c5 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,4 +1,5 @@ pub(crate) mod token; +pub mod auth_extractor; use axum::Router; use crate::auth::token::*; diff --git a/src/auth/auth_extractor.rs b/src/auth/auth_extractor.rs new file mode 100644 index 0000000..20de686 --- /dev/null +++ b/src/auth/auth_extractor.rs @@ -0,0 +1,37 @@ +use std::fmt::Debug; +use axum::extract::FromRequestParts; +use axum::http::request::Parts; +use axum::http::{header, StatusCode}; +use crate::auth::token::{get_roles_from_token, get_token_from_key, TokenDTO}; +use crate::storage::DbPool; + +#[derive(Debug)] +pub struct AuthInfo { + token: TokenDTO, + roles: Vec, +} + +impl<> FromRequestParts for AuthInfo +{ + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, state: &DbPool) -> Result { + let auth_header = parts + .headers + .get(header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()); + + match auth_header { + Some(auth_header) => { + let token = get_token_from_key(auth_header, state).await; + if token.is_err() { + return Err(StatusCode::UNAUTHORIZED); + } + let token = token.unwrap(); + let roles = get_roles_from_token(&token, state).await; + Ok(Self {token, roles}) + } + _ => Err(StatusCode::UNAUTHORIZED), + } + } +} diff --git a/src/auth/token.rs b/src/auth/token.rs index 03711b5..2d73318 100644 --- a/src/auth/token.rs +++ b/src/auth/token.rs @@ -14,19 +14,28 @@ enum TokenType { } -struct Identity { +#[derive(Debug)] +pub struct IdentityDTO { id: String, name: String } -struct Token { - key: Option, - id: Option, + +#[derive(Debug)] +pub struct TokenDTO { + key: String, + id: String, identity_id: Option, parent_id: Option, expiry: Option, } +#[derive(Debug)] +pub struct TokenRoleMembershipDTO { + role_name: String, + token_id: String, +} + #[derive(Deserialize)] struct RequestBodyPostLookup { token: String, @@ -82,8 +91,8 @@ fn get_time_as_int() -> i64 { std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64 } -fn get_token_type(token: &Token) -> Result { - Ok(match token.key.clone().unwrap().chars().next().unwrap_or('?') { +fn get_token_type(token: &TokenDTO) -> Result { + Ok(match token.key.clone().chars().next().unwrap_or('?') { 's' => "service", 'b' => "batch", 'r' => "recovery", @@ -94,14 +103,22 @@ fn get_token_type(token: &Token) -> Result { }.to_string()) } -async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result { +pub async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result { let time = get_time_as_int(); sqlx::query_as!( - Token, - r#"SELECT id, key, expiry, parent_id, identity_id FROM 'service_token' WHERE key = $1 AND (expiry = NULL OR expiry > $2) LIMIT 1"#, + TokenDTO, + r#"SELECT * FROM 'service_token' WHERE key = $1 AND (expiry IS NULL OR expiry > $2) LIMIT 1"#, token_key, time).fetch_one(pool).await } +pub async fn get_roles_from_token(token: &TokenDTO, pool:&DbPool) -> Vec { + let result = sqlx::query_as!( + TokenRoleMembershipDTO, + r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#, + token.id).fetch_all(pool).await; + result.unwrap_or(Vec::new()).iter().map(|r| r.role_name.to_string()).collect() +} + pub fn token_auth_router(pool: DbPool) -> Router { Router::new() .route("/lookup", post(post_lookup)) diff --git a/src/main.rs b/src/main.rs index cd95cc4..a74ea22 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use log::*; use std::{env, net::SocketAddr, str::FromStr}; use storage::DbPool; use tokio::{net::TcpListener, signal}; - +use crate::auth::auth_extractor::AuthInfo; use crate::common::HttpError; mod auth; @@ -118,7 +118,7 @@ async fn fallback_route_unknown(req: Request) -> Response { } /// basic handler that responds with a static string -async fn root() -> &'static str { - info!("Hello world"); +async fn root(test: AuthInfo) -> &'static str { + println!("Hello world, {test:?}"); "Hello, World!" }