Skip to content

Commit e50babe

Browse files
committed
Add http_exception_error_handler for 401 errors
1 parent d6ac57f commit e50babe

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

diracx-routers/src/diracx/routers/factory.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,33 @@
22

33
from __future__ import annotations
44

5+
import base64
56
import inspect
67
import logging
78
import os
9+
import re
810
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
911
from functools import partial
1012
from http import HTTPStatus
1113
from importlib.metadata import EntryPoint, EntryPoints, entry_points
1214
from logging import Formatter, StreamHandler
13-
from typing import (
14-
Any,
15-
TypeVar,
16-
cast,
17-
)
15+
from typing import Any, TypeVar, cast
1816

1917
import dotenv
2018
from cachetools import TTLCache
2119
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
2220
from fastapi.dependencies.models import Dependant
23-
from fastapi.exception_handlers import request_validation_exception_handler
21+
from fastapi.exception_handlers import (
22+
http_exception_handler,
23+
request_validation_exception_handler,
24+
)
2425
from fastapi.exceptions import RequestValidationError
2526
from fastapi.middleware.cors import CORSMiddleware
2627
from fastapi.responses import JSONResponse, Response
2728
from fastapi.routing import APIRoute
2829
from packaging.version import InvalidVersion, parse
2930
from pydantic import TypeAdapter
31+
from starlette.exceptions import HTTPException as StarletteHTTPException
3032
from starlette.middleware.base import BaseHTTPMiddleware
3133
from uvicorn.logging import AccessFormatter, DefaultFormatter
3234

@@ -49,6 +51,7 @@
4951

5052

5153
logger = logging.getLogger(__name__)
54+
logger_401 = logger.getChild("debug.401.errors")
5255
logger_422 = logger.getChild("debug.422.errors")
5356

5457

@@ -299,6 +302,9 @@ def create_app_inner(
299302
app.add_exception_handler(
300303
RequestValidationError, cast(handler_signature, validation_error_handler)
301304
)
305+
app.add_exception_handler(
306+
StarletteHTTPException, cast(handler_signature, http_exception_error_handler)
307+
)
302308

303309
# TODO: remove the CORSMiddleware once we figure out how to launch
304310
# diracx and diracx-web under the same origin
@@ -409,6 +415,30 @@ def route_unavailable_error_hander(request: Request, exc: DBUnavailableError):
409415
)
410416

411417

418+
async def http_exception_error_handler(request: Request, exc: StarletteHTTPException):
419+
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
420+
header_info = "Unknown"
421+
if auth_header := request.headers.get("Authorization"):
422+
if match := re.fullmatch(r"Bearer (.+)", auth_header):
423+
try:
424+
raw_token = match.group(1).split(".")[1]
425+
padding = "=" * (-len(raw_token) % 4)
426+
header_info = base64.urlsafe_b64decode(raw_token + padding).decode(
427+
"utf-8"
428+
)
429+
except Exception as e:
430+
header_info = f"Error decoding token: {e}"
431+
logger_401.warning(
432+
"Got 401 error: %s in %s %s with header_info=%r body %r",
433+
exc.detail,
434+
request.method,
435+
request.url,
436+
header_info,
437+
await request.body(),
438+
)
439+
return await http_exception_handler(request, exc)
440+
441+
412442
async def validation_error_handler(request: Request, exc: RequestValidationError):
413443
logger_422.warning(
414444
"Got validation error: %s in %s %s with body %r",

0 commit comments

Comments
 (0)