diff --git a/src/auth/auth_extractor.rs b/src/auth/auth_extractor.rs index 20de686..73b4f43 100644 --- a/src/auth/auth_extractor.rs +++ b/src/auth/auth_extractor.rs @@ -1,9 +1,10 @@ -use std::fmt::Debug; +use crate::auth::token::{TokenDTO, get_roles_from_token, get_token_from_key}; +use crate::storage::DbPool; +use axum::body::Body; use axum::extract::FromRequestParts; use axum::http::request::Parts; -use axum::http::{header, StatusCode}; -use crate::auth::token::{get_roles_from_token, get_token_from_key, TokenDTO}; -use crate::storage::DbPool; +use axum::http::{HeaderMap, Request, StatusCode, header}; +use std::fmt::Debug; #[derive(Debug)] pub struct AuthInfo { @@ -11,27 +12,42 @@ pub struct AuthInfo { roles: Vec, } -impl<> FromRequestParts for AuthInfo -{ +impl FromRequestParts for AuthInfo { type Rejection = StatusCode; - async fn from_request_parts(parts: &mut Parts, state: &DbPool) -> Result { - let auth_header = parts - .headers - .get(header::AUTHORIZATION) - .and_then(|value| value.to_str().ok()); + async fn from_request_parts( + parts: &mut Parts, + state: &DbPool, + ) -> Result { + let header = &parts.headers; - match auth_header { - Some(auth_header) => { - let token = get_token_from_key(auth_header, state).await; - if token.is_err() { - return Err(StatusCode::UNAUTHORIZED); - } - let token = token.unwrap(); - let roles = get_roles_from_token(&token, state).await; - Ok(Self {token, roles}) - } - _ => Err(StatusCode::UNAUTHORIZED), - } + inspect_with_header(state, &header).await + } +} + +pub async fn inspect_req(state: &DbPool, req: &Request) -> Result { + let header = req.headers(); + inspect_with_header(state, header).await +} + +pub async fn inspect_with_header( + state: &DbPool, + header: &HeaderMap, +) -> Result { + let auth_header = header + .get(header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()); + + match auth_header { + Some(auth_value) => { + let token = get_token_from_key(auth_value, state).await; + if token.is_err() { + return Err(StatusCode::UNAUTHORIZED); + } + let token = token.unwrap(); + let roles = get_roles_from_token(&token, state).await; + Ok(AuthInfo { token, roles }) + } + None => Err(StatusCode::UNAUTHORIZED), } }