Added basic token auth extractor
This commit is contained in:
parent
27dcc5489d
commit
14012b155e
5 changed files with 69 additions and 14 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
CREATE TABLE identity (
|
CREATE TABLE identity (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY NOT NULL,
|
||||||
name TEXT NOT NULL
|
name TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -13,7 +13,7 @@ CREATE TABLE service_token_role_membership (
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE service_token (
|
CREATE TABLE service_token (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY NOT NULL,
|
||||||
key TEXT NOT NULL,
|
key TEXT NOT NULL,
|
||||||
expiry INTEGER,
|
expiry INTEGER,
|
||||||
parent_id TEXT NULL REFERENCES service_token(id)
|
parent_id TEXT NULL REFERENCES service_token(id)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
pub(crate) mod token;
|
pub(crate) mod token;
|
||||||
|
pub mod auth_extractor;
|
||||||
|
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use crate::auth::token::*;
|
use crate::auth::token::*;
|
||||||
|
|
|
||||||
37
src/auth/auth_extractor.rs
Normal file
37
src/auth/auth_extractor.rs
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use axum::extract::FromRequestParts;
|
||||||
|
use axum::http::request::Parts;
|
||||||
|
use axum::http::{header, StatusCode};
|
||||||
|
use crate::auth::token::{get_roles_from_token, get_token_from_key, TokenDTO};
|
||||||
|
use crate::storage::DbPool;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AuthInfo {
|
||||||
|
token: TokenDTO,
|
||||||
|
roles: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<> FromRequestParts<DbPool> for AuthInfo
|
||||||
|
{
|
||||||
|
type Rejection = StatusCode;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &DbPool) -> Result<Self, Self::Rejection> {
|
||||||
|
let auth_header = parts
|
||||||
|
.headers
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.and_then(|value| value.to_str().ok());
|
||||||
|
|
||||||
|
match auth_header {
|
||||||
|
Some(auth_header) => {
|
||||||
|
let token = get_token_from_key(auth_header, state).await;
|
||||||
|
if token.is_err() {
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
let token = token.unwrap();
|
||||||
|
let roles = get_roles_from_token(&token, state).await;
|
||||||
|
Ok(Self {token, roles})
|
||||||
|
}
|
||||||
|
_ => Err(StatusCode::UNAUTHORIZED),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -14,19 +14,28 @@ enum TokenType {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Identity {
|
#[derive(Debug)]
|
||||||
|
pub struct IdentityDTO {
|
||||||
id: String,
|
id: String,
|
||||||
name: String
|
name: String
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Token {
|
|
||||||
key: Option<String>,
|
#[derive(Debug)]
|
||||||
id: Option<String>,
|
pub struct TokenDTO {
|
||||||
|
key: String,
|
||||||
|
id: String,
|
||||||
identity_id: Option<String>,
|
identity_id: Option<String>,
|
||||||
parent_id: Option<String>,
|
parent_id: Option<String>,
|
||||||
expiry: Option<i64>,
|
expiry: Option<i64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TokenRoleMembershipDTO {
|
||||||
|
role_name: String,
|
||||||
|
token_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct RequestBodyPostLookup {
|
struct RequestBodyPostLookup {
|
||||||
token: String,
|
token: String,
|
||||||
|
|
@ -82,8 +91,8 @@ fn get_time_as_int() -> i64 {
|
||||||
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64
|
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_token_type(token: &Token) -> Result<String, &str> {
|
fn get_token_type(token: &TokenDTO) -> Result<String, &str> {
|
||||||
Ok(match token.key.clone().unwrap().chars().next().unwrap_or('?') {
|
Ok(match token.key.clone().chars().next().unwrap_or('?') {
|
||||||
's' => "service",
|
's' => "service",
|
||||||
'b' => "batch",
|
'b' => "batch",
|
||||||
'r' => "recovery",
|
'r' => "recovery",
|
||||||
|
|
@ -94,14 +103,22 @@ fn get_token_type(token: &Token) -> Result<String, &str> {
|
||||||
}.to_string())
|
}.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result<Token, Error> {
|
pub async fn get_token_from_key(token_key: &str, pool: &DbPool) -> Result<TokenDTO, Error> {
|
||||||
let time = get_time_as_int();
|
let time = get_time_as_int();
|
||||||
sqlx::query_as!(
|
sqlx::query_as!(
|
||||||
Token,
|
TokenDTO,
|
||||||
r#"SELECT id, key, expiry, parent_id, identity_id FROM 'service_token' WHERE key = $1 AND (expiry = NULL OR expiry > $2) LIMIT 1"#,
|
r#"SELECT * FROM 'service_token' WHERE key = $1 AND (expiry IS NULL OR expiry > $2) LIMIT 1"#,
|
||||||
token_key, time).fetch_one(pool).await
|
token_key, time).fetch_one(pool).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn get_roles_from_token(token: &TokenDTO, pool:&DbPool) -> Vec<String> {
|
||||||
|
let result = sqlx::query_as!(
|
||||||
|
TokenRoleMembershipDTO,
|
||||||
|
r#"SELECT * FROM 'service_token_role_membership' WHERE token_id = $1"#,
|
||||||
|
token.id).fetch_all(pool).await;
|
||||||
|
result.unwrap_or(Vec::new()).iter().map(|r| r.role_name.to_string()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn token_auth_router(pool: DbPool) -> Router<DbPool> {
|
pub fn token_auth_router(pool: DbPool) -> Router<DbPool> {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/lookup", post(post_lookup))
|
.route("/lookup", post(post_lookup))
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ use log::*;
|
||||||
use std::{env, net::SocketAddr, str::FromStr};
|
use std::{env, net::SocketAddr, str::FromStr};
|
||||||
use storage::DbPool;
|
use storage::DbPool;
|
||||||
use tokio::{net::TcpListener, signal};
|
use tokio::{net::TcpListener, signal};
|
||||||
|
use crate::auth::auth_extractor::AuthInfo;
|
||||||
use crate::common::HttpError;
|
use crate::common::HttpError;
|
||||||
|
|
||||||
mod auth;
|
mod auth;
|
||||||
|
|
@ -118,7 +118,7 @@ async fn fallback_route_unknown(req: Request) -> Response {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// basic handler that responds with a static string
|
/// basic handler that responds with a static string
|
||||||
async fn root() -> &'static str {
|
async fn root(test: AuthInfo) -> &'static str {
|
||||||
info!("Hello world");
|
println!("Hello world, {test:?}");
|
||||||
"Hello, World!"
|
"Hello, World!"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue