|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
| 5 | +import base64 |
5 | 6 | import inspect
|
6 | 7 | import logging
|
7 | 8 | import os
|
| 9 | +import re |
8 | 10 | from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
|
9 | 11 | from functools import partial
|
10 | 12 | from http import HTTPStatus
|
11 | 13 | from importlib.metadata import EntryPoint, EntryPoints, entry_points
|
12 | 14 | from logging import Formatter, StreamHandler
|
13 |
| -from typing import ( |
14 |
| - Any, |
15 |
| - TypeVar, |
16 |
| - cast, |
17 |
| -) |
| 15 | +from typing import Any, TypeVar, cast |
18 | 16 |
|
19 | 17 | import dotenv
|
20 | 18 | from cachetools import TTLCache
|
21 | 19 | from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
|
22 | 20 | from fastapi.dependencies.models import Dependant
|
23 |
| -from fastapi.exception_handlers import request_validation_exception_handler |
| 21 | +from fastapi.exception_handlers import ( |
| 22 | + http_exception_handler, |
| 23 | + request_validation_exception_handler, |
| 24 | +) |
24 | 25 | from fastapi.exceptions import RequestValidationError
|
25 | 26 | from fastapi.middleware.cors import CORSMiddleware
|
26 | 27 | from fastapi.responses import JSONResponse, Response
|
27 | 28 | from fastapi.routing import APIRoute
|
28 | 29 | from packaging.version import InvalidVersion, parse
|
29 | 30 | from pydantic import TypeAdapter
|
| 31 | +from starlette.exceptions import HTTPException as StarletteHTTPException |
30 | 32 | from starlette.middleware.base import BaseHTTPMiddleware
|
31 | 33 | from uvicorn.logging import AccessFormatter, DefaultFormatter
|
32 | 34 |
|
|
49 | 51 |
|
50 | 52 |
|
51 | 53 | logger = logging.getLogger(__name__)
|
| 54 | +logger_401 = logger.getChild("debug.401.errors") |
52 | 55 | logger_422 = logger.getChild("debug.422.errors")
|
53 | 56 |
|
54 | 57 |
|
@@ -299,6 +302,9 @@ def create_app_inner(
|
299 | 302 | app.add_exception_handler(
|
300 | 303 | RequestValidationError, cast(handler_signature, validation_error_handler)
|
301 | 304 | )
|
| 305 | + app.add_exception_handler( |
| 306 | + StarletteHTTPException, cast(handler_signature, http_exception_error_handler) |
| 307 | + ) |
302 | 308 |
|
303 | 309 | # TODO: remove the CORSMiddleware once we figure out how to launch
|
304 | 310 | # diracx and diracx-web under the same origin
|
@@ -409,6 +415,30 @@ def route_unavailable_error_hander(request: Request, exc: DBUnavailableError):
|
409 | 415 | )
|
410 | 416 |
|
411 | 417 |
|
| 418 | +async def http_exception_error_handler(request: Request, exc: StarletteHTTPException): |
| 419 | + if exc.status_code == status.HTTP_401_UNAUTHORIZED: |
| 420 | + header_info = "Unknown" |
| 421 | + if auth_header := request.headers.get("Authorization"): |
| 422 | + if match := re.fullmatch(r"Bearer (.+)", auth_header): |
| 423 | + try: |
| 424 | + raw_token = match.group(1).split(".")[1] |
| 425 | + padding = "=" * (-len(raw_token) % 4) |
| 426 | + header_info = base64.urlsafe_b64decode(raw_token + padding).decode( |
| 427 | + "utf-8" |
| 428 | + ) |
| 429 | + except Exception as e: |
| 430 | + header_info = f"Error decoding token: {e}" |
| 431 | + logger_401.warning( |
| 432 | + "Got 401 error: %s in %s %s with header_info=%r body %r", |
| 433 | + exc.detail, |
| 434 | + request.method, |
| 435 | + request.url, |
| 436 | + header_info, |
| 437 | + await request.body(), |
| 438 | + ) |
| 439 | + return await http_exception_handler(request, exc) |
| 440 | + |
| 441 | + |
412 | 442 | async def validation_error_handler(request: Request, exc: RequestValidationError):
|
413 | 443 | logger_422.warning(
|
414 | 444 | "Got validation error: %s in %s %s with body %r",
|
|
0 commit comments