Skip to content

Commit 394ac96

Browse files
feat: We can separate pilot refresh tokens and users refresh tokens
1 parent c264017 commit 394ac96

File tree

13 files changed

+124
-41
lines changed

13 files changed

+124
-41
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def __init__(self, job_id, detail: str = ""):
9999
)
100100

101101

102+
class BadTokenError(DiracError): ...
103+
104+
102105
class NotReadyError(DiracError):
103106
"""Tried to access a value which is asynchronously loaded but not yet available."""
104107

diracx-core/src/diracx/core/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import uuid as std_uuid
99
from datetime import datetime
10-
from enum import StrEnum
10+
from enum import StrEnum, auto
1111
from typing import Any, Literal, Optional
1212

1313
from pydantic import BaseModel, Field, GetCoreSchemaHandler, GetJsonSchemaHandler
@@ -319,6 +319,7 @@ class PilotFieldsMapping(BaseModel, extra="forbid"):
319319
AccountingSent: Optional[bool] = None
320320
CurrentJobID: Optional[int] = None
321321

322+
322323
class PilotStatus(StrEnum):
323324
#: The pilot has been generated and is transferred to a remote site:
324325
SUBMITTED = "Submitted"
@@ -337,6 +338,7 @@ class PilotStatus(StrEnum):
337338
#: Cannot get information about the pilot status:
338339
UNKNOWN = "Unknown"
339340

341+
340342
class PilotSecretConstraints(TypedDict, total=False):
341343
VOs: list[str] # Authorize only a list of VOs
342344
PilotStamps: list[str] # Authorize only a list of stamps
@@ -345,6 +347,13 @@ class PilotSecretConstraints(TypedDict, total=False):
345347
# We can add constraints here
346348

347349

350+
class TokenType(StrEnum):
351+
# Pilot token
352+
PILOT_TOKEN = auto()
353+
# User token
354+
USER_TOKEN = auto()
355+
356+
348357
class PilotSecretsInfo(BaseModel):
349358
pilot_secret: str
350359
pilot_secret_expires_in: int

diracx-db/src/diracx/db/sql/pilots/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
)
1717
from diracx.core.models import (
1818
PilotFieldsMapping,
19-
PilotStatus,
2019
PilotSecretConstraints,
20+
PilotStatus,
2121
SearchSpec,
2222
SortSpec,
2323
)

diracx-db/tests/pilots/test_pilot_management.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
PilotFieldsMapping,
1212
PilotStatus,
1313
)
14-
1514
from diracx.db.sql.pilots.db import PilotAgentsDB
1615

1716

diracx-logic/src/diracx/logic/auth/token.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BaseTokenPayload,
2424
GrantType,
2525
RefreshTokenPayload,
26+
TokenType,
2627
)
2728
from diracx.core.properties import SecurityProperty
2829
from diracx.core.settings import AuthSettings
@@ -159,12 +160,15 @@ async def get_oidc_token_info_from_authorization_flow(
159160

160161

161162
async def get_token_info_from_refresh_flow(
162-
refresh_token: str, auth_db: AuthDB, settings: AuthSettings
163+
refresh_token: str,
164+
auth_db: AuthDB,
165+
settings: AuthSettings,
166+
token_type: TokenType = TokenType.USER_TOKEN,
163167
) -> tuple[dict, str, bool, float, bool]:
164168
"""Get OIDC token information from the refresh token DB and check few parameters before returning it."""
165169
# Decode the refresh token to get the JWT ID
166170
jti, exp, legacy_exchange = await verify_dirac_refresh_token(
167-
refresh_token, settings
171+
refresh_token, settings, token_type
168172
)
169173

170174
# Get some useful user information from the refresh token entry in the DB

diracx-logic/src/diracx/logic/auth/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
from uuid_utils import UUID
1616

1717
from diracx.core.config.schema import Config
18-
from diracx.core.exceptions import AuthorizationError, IAMClientError, IAMServerError
19-
from diracx.core.models import GrantType
18+
from diracx.core.exceptions import (
19+
AuthorizationError,
20+
BadTokenError,
21+
IAMClientError,
22+
IAMServerError,
23+
)
24+
from diracx.core.models import GrantType, TokenType
2025
from diracx.core.properties import SecurityProperty
2126
from diracx.core.settings import AuthSettings
2227

@@ -208,6 +213,7 @@ def read_token(
208213
async def verify_dirac_refresh_token(
209214
refresh_token: str,
210215
settings: AuthSettings,
216+
token_type: TokenType = TokenType.USER_TOKEN,
211217
) -> tuple[UUID, float, bool]:
212218
"""Verify dirac user token and return a UserInfo class
213219
Used for each API endpoint.
@@ -216,6 +222,12 @@ async def verify_dirac_refresh_token(
216222
refresh_token, settings.token_keystore.jwks, settings.token_allowed_algorithms
217223
)
218224

225+
if token_type == TokenType.USER_TOKEN and "dirac_policies" not in claims:
226+
raise BadTokenError("This is not a user token.")
227+
228+
if token_type == TokenType.PILOT_TOKEN and "dirac_policies" in claims:
229+
raise BadTokenError("This is not a pilot token.")
230+
219231
return (
220232
UUID(claims["jti"]),
221233
float(claims["exp"]),

diracx-logic/src/diracx/logic/pilots/auth.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
PilotSecretConstraints,
1717
PilotSecretsInfo,
1818
TokenResponse,
19+
TokenType,
1920
)
2021
from diracx.core.settings import AuthSettings
2122
from diracx.core.utils import extract_timestamp_from_uuid7, recursive_dict_merge
@@ -359,7 +360,10 @@ async def generate_pilot_tokens(
359360
_,
360361
include_refresh_token,
361362
) = await get_token_info_from_refresh_flow(
362-
refresh_token=refresh_token, auth_db=auth_db, settings=settings
363+
refresh_token=refresh_token,
364+
auth_db=auth_db,
365+
settings=settings,
366+
token_type=TokenType.PILOT_TOKEN,
363367
)
364368

365369
sub = f"{vo}:{pilot_info['sub']}"

diracx-routers/src/diracx/routers/auth/token.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from joserfc.errors import JoseError
1111

1212
from diracx.core.exceptions import (
13+
BadTokenError,
1314
DiracHttpResponseError,
1415
InvalidCredentialsError,
1516
PendingAuthorizationError,
@@ -153,7 +154,7 @@ async def get_oidc_token(
153154
detail=str(e),
154155
headers={"WWW-Authenticate": "Bearer"},
155156
) from e
156-
except PermissionError as e:
157+
except (BadTokenError, PermissionError) as e:
157158
raise HTTPException(
158159
status_code=status.HTTP_403_FORBIDDEN,
159160
detail=str(e),

diracx-routers/src/diracx/routers/pilots/auth.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from diracx.core.exceptions import (
1010
BadPilotCredentialsError,
11+
BadTokenError,
1112
InvalidCredentialsError,
1213
PilotNotFoundError,
1314
SecretHasExpiredError,
@@ -28,8 +29,8 @@
2829
router = DiracxRouter(require_auth=False)
2930

3031

31-
@router.post("/token")
32-
async def pilot_login(
32+
@router.post("/secret-exchange")
33+
async def perform_secret_exchange(
3334
pilot_db: PilotAgentsDB,
3435
auth_db: AuthDB,
3536
pilot_stamp: Annotated[str, Body(description="Stamp used by a pilot to login.")],
@@ -72,7 +73,7 @@ async def pilot_login(
7273
) from e
7374

7475

75-
@router.post("/refresh-token")
76+
@router.post("/token")
7677
async def refresh_pilot_tokens(
7778
auth_db: AuthDB,
7879
settings: AuthSettings,
@@ -82,7 +83,8 @@ async def refresh_pilot_tokens(
8283
],
8384
pilot_stamp: Annotated[str, Body(description="Pilot stamp")],
8485
) -> TokenResponse:
85-
"""Endpoint where a pilot can exchange a refresh token for a token."""
86+
"""Endpoint where *only* pilots can exchange a refresh token for a token."""
87+
# Refresh it
8688
try:
8789
return await refresh_pilot_token(
8890
pilot_stamp=pilot_stamp,
@@ -91,7 +93,7 @@ async def refresh_pilot_tokens(
9193
pilot_db=pilot_db,
9294
refresh_token=refresh_token,
9395
)
94-
except (InvalidCredentialsError, PilotNotFoundError) as e:
96+
except (InvalidCredentialsError, PilotNotFoundError, BadTokenError) as e:
9597
raise HTTPException(
9698
status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)
9799
) from e

diracx-routers/src/diracx/routers/utils/pilots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def verify_dirac_pilot_access_token(
3030
settings: AuthSettings,
3131
authorization: Annotated[str | None, Header()] = None,
3232
) -> AuthorizedPilotInfo:
33-
"""Verify dirac user token and return a UserInfo class
33+
"""Verify dirac pilot token and return a AuthorizedPilotInfo class
3434
Used for each API endpoint.
3535
"""
3636
if not authorization:

0 commit comments

Comments
 (0)