//! Admin connection extractor with RLS context. //! //! Provides database connections that set RLS context when a user is authenticated. //! Used by admin API handlers to ensure write operations respect RLS policies. //! //! - In owner app: No session, uses plain connection (RLS bypassed by chattyness_owner role) //! - In user app: Session exists, sets current_user_id() for RLS enforcement use axum::{ Json, extract::FromRequestParts, http::{Request, StatusCode, request::Parts}, response::{IntoResponse, Response}, }; use sqlx::{PgPool, Postgres, pool::PoolConnection, postgres::PgConnection}; use std::{ future::Future, ops::{Deref, DerefMut}, pin::Pin, sync::Arc, task::{Context, Poll}, }; use tokio::sync::{Mutex, MutexGuard}; use tower::{Layer, Service}; use tower_sessions::Session; use uuid::Uuid; use super::{ADMIN_SESSION_STAFF_ID_KEY, SESSION_USER_ID_KEY}; use chattyness_error::ErrorResponse; // ============================================================================= // Admin Connection Wrapper // ============================================================================= struct AdminConnectionInner { conn: Option>, pool: PgPool, had_user_context: bool, } impl Drop for AdminConnectionInner { fn drop(&mut self) { if let Some(mut conn) = self.conn.take() { // Only clear context if we set it if self.had_user_context { let pool = self.pool.clone(); tokio::spawn(async move { let _ = sqlx::query("SELECT public.set_current_user_id(NULL)") .execute(&mut *conn) .await; drop(conn); drop(pool); }); } } } } /// A database connection with optional RLS user context set. #[derive(Clone)] pub struct AdminConnection { inner: Arc>, } impl AdminConnection { fn new(conn: PoolConnection, pool: PgPool, had_user_context: bool) -> Self { Self { inner: Arc::new(Mutex::new(AdminConnectionInner { conn: Some(conn), pool, had_user_context, })), } } /// Acquire exclusive access to the admin connection. pub async fn acquire(&self) -> AdminConnGuard<'_> { AdminConnGuard { guard: self.inner.lock().await, } } } /// A guard providing mutable access to the admin database connection. pub struct AdminConnGuard<'a> { guard: MutexGuard<'a, AdminConnectionInner>, } impl Deref for AdminConnGuard<'_> { type Target = PgConnection; fn deref(&self) -> &Self::Target { self.guard .conn .as_ref() .expect("AdminConnection already consumed") .deref() } } impl DerefMut for AdminConnGuard<'_> { fn deref_mut(&mut self) -> &mut Self::Target { self.guard .conn .as_mut() .expect("AdminConnection already consumed") .deref_mut() } } // ============================================================================= // Admin Connection Extractor // ============================================================================= /// Extractor for an admin database connection with RLS context. /// /// Usage in handlers: /// ```ignore /// pub async fn create_scene( /// admin_conn: AdminConn, /// Json(req): Json, /// ) -> Result, AppError> { /// let mut conn = admin_conn.0; /// let mut guard = conn.acquire().await; /// scenes::create_scene(&mut *guard, ...).await /// } /// ``` pub struct AdminConn(pub AdminConnection); impl FromRequestParts for AdminConn where S: Send + Sync, { type Rejection = AdminConnError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { parts .extensions .remove::() .map(AdminConn) .ok_or(AdminConnError::NoConnection) } } impl Deref for AdminConn { type Target = AdminConnection; fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for AdminConn { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } /// Errors related to admin connection handling. #[derive(Debug)] pub enum AdminConnError { NoConnection, DatabaseError(String), } impl IntoResponse for AdminConnError { fn into_response(self) -> Response { let (status, message) = match self { AdminConnError::NoConnection => ( StatusCode::INTERNAL_SERVER_ERROR, "Admin connection not available - is AdminConnLayer middleware configured?", ), AdminConnError::DatabaseError(msg) => ( StatusCode::INTERNAL_SERVER_ERROR, msg.leak() as &'static str, ), }; let body = ErrorResponse { error: message.to_string(), code: Some("ADMIN_CONN_ERROR".to_string()), }; (status, Json(body)).into_response() } } // ============================================================================= // Admin Connection Middleware Layer // ============================================================================= /// Layer that provides admin database connections with RLS context per request. /// /// This middleware: /// 1. Checks for user_id in session (staff_id or user_id) /// 2. Acquires a connection from the pool /// 3. If user_id exists, calls `set_current_user_id($1)` for RLS /// 4. Inserts the connection into request extensions /// /// Usage: /// ```ignore /// let app = Router::new() /// .nest("/api/admin", admin_api_router()) /// .layer(AdminConnLayer::new(pool.clone())) /// .layer(session_layer); /// ``` #[derive(Clone)] pub struct AdminConnLayer { pool: PgPool, } impl AdminConnLayer { pub fn new(pool: PgPool) -> Self { Self { pool } } } impl Layer for AdminConnLayer { type Service = AdminConnMiddleware; fn layer(&self, inner: S) -> Self::Service { AdminConnMiddleware { inner, pool: self.pool.clone(), } } } /// Middleware that sets up admin connections with RLS context per request. #[derive(Clone)] pub struct AdminConnMiddleware { inner: S, pool: PgPool, } impl Service> for AdminConnMiddleware where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send, B: Send + 'static, { type Response = Response; type Error = S::Error; type Future = Pin> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut request: Request) -> Self::Future { let pool = self.pool.clone(); let mut inner = self.inner.clone(); let session = request.extensions().get::().cloned(); Box::pin(async move { let user_id = get_admin_user_id(session).await; match acquire_admin_connection(&pool, user_id).await { Ok(admin_conn) => { request.extensions_mut().insert(admin_conn); inner.call(request).await } Err(e) => { tracing::error!("Failed to acquire admin connection: {}", e); Ok(AdminConnError::DatabaseError(e.to_string()).into_response()) } } }) } } /// Get user ID from session for RLS context. /// /// Checks in order: /// 1. staff_id - server staff member /// 2. user_id - realm admin user /// /// Returns None if no session or no user ID (owner app context). async fn get_admin_user_id(session: Option) -> Option { let Some(session) = session else { return None; }; // Try staff_id first (server staff) if let Ok(Some(staff_id)) = session.get::(ADMIN_SESSION_STAFF_ID_KEY).await { return Some(staff_id); } // Try user_id (realm admin) if let Ok(Some(user_id)) = session.get::(SESSION_USER_ID_KEY).await { return Some(user_id); } None } /// Acquire a database connection and set RLS context if user_id is provided. async fn acquire_admin_connection( pool: &PgPool, user_id: Option, ) -> Result { let mut conn = pool.acquire().await?; let had_user_context = user_id.is_some(); if let Some(user_id) = user_id { sqlx::query("SELECT public.set_current_user_id($1)") .bind(user_id) .execute(&mut *conn) .await?; } Ok(AdminConnection::new(conn, pool.clone(), had_user_context)) }