Skip to content

Commit b45a02d

Browse files
feat: We can separate pilot refresh tokens and users refresh tokens
1 parent 89b88f6 commit b45a02d

File tree

21 files changed

+134
-282
lines changed

21 files changed

+134
-282
lines changed

diracx-client/src/diracx/client/_generated/aio/operations/_operations.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2504,10 +2504,6 @@ async def clear_pilots(self, *, age_in_days: int, delete_only_aborted: bool = Fa
25042504
return cls(pipeline_response, None, {}) # type: ignore
25052505

25062506
@overload
2507-
<<<<<<< HEAD
2508-
async def add_jobs_to_pilot(
2509-
self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any
2510-
=======
25112507
async def create_pilot_secrets(
25122508
self, body: _models.BodyPilotsCreatePilotSecrets, *, content_type: str = "application/json", **kwargs: Any
25132509
) -> List[_models.PilotSecretsInfo]:
@@ -2702,9 +2698,8 @@ async def update_secrets_constraints(
27022698
return cls(pipeline_response, None, {}) # type: ignore
27032699

27042700
@overload
2705-
async def associate_pilot_with_jobs(
2706-
self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any
2707-
>>>>>>> 42f29a35 (feat: Add pilot auth)
2701+
async def add_jobs_to_pilot(
2702+
self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any
27082703
) -> None:
27092704
"""Add Jobs To Pilot.
27102705

diracx-client/src/diracx/client/_generated/models/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,9 @@
1616
BodyAuthGetOidcTokenGrantType,
1717
BodyPilotsAddJobsToPilot,
1818
BodyPilotsAddPilotStamps,
19-
<<<<<<< HEAD
20-
=======
21-
BodyPilotsAssociatePilotWithJobs,
2219
BodyPilotsCreatePilotSecrets,
2320
BodyPilotsPilotLogin,
2421
BodyPilotsRefreshPilotTokens,
25-
>>>>>>> 42f29a35 (feat: Add pilot auth)
2622
BodyPilotsUpdatePilotFields,
2723
GroupInfo,
2824
HTTPValidationError,
@@ -79,13 +75,9 @@
7975
"BodyAuthGetOidcTokenGrantType",
8076
"BodyPilotsAddJobsToPilot",
8177
"BodyPilotsAddPilotStamps",
82-
<<<<<<< HEAD
83-
=======
84-
"BodyPilotsAssociatePilotWithJobs",
8578
"BodyPilotsCreatePilotSecrets",
8679
"BodyPilotsPilotLogin",
8780
"BodyPilotsRefreshPilotTokens",
88-
>>>>>>> 42f29a35 (feat: Add pilot auth)
8981
"BodyPilotsUpdatePilotFields",
9082
"GroupInfo",
9183
"HTTPValidationError",

diracx-client/src/diracx/client/_generated/models/_models.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -194,41 +194,6 @@ def __init__(
194194
self.pilot_secret_use_count_max = pilot_secret_use_count_max
195195

196196

197-
<<<<<<< HEAD
198-
=======
199-
class BodyPilotsAssociatePilotWithJobs(_serialization.Model):
200-
"""Body_pilots_associate_pilot_with_jobs.
201-
202-
All required parameters must be populated in order to send to server.
203-
204-
:ivar pilot_stamp: The stamp of the pilot. Required.
205-
:vartype pilot_stamp: str
206-
:ivar pilot_jobs_ids: The jobs we want to add to the pilot. Required.
207-
:vartype pilot_jobs_ids: list[int]
208-
"""
209-
210-
_validation = {
211-
"pilot_stamp": {"required": True},
212-
"pilot_jobs_ids": {"required": True},
213-
}
214-
215-
_attribute_map = {
216-
"pilot_stamp": {"key": "pilot_stamp", "type": "str"},
217-
"pilot_jobs_ids": {"key": "pilot_jobs_ids", "type": "[int]"},
218-
}
219-
220-
def __init__(self, *, pilot_stamp: str, pilot_jobs_ids: List[int], **kwargs: Any) -> None:
221-
"""
222-
:keyword pilot_stamp: The stamp of the pilot. Required.
223-
:paramtype pilot_stamp: str
224-
:keyword pilot_jobs_ids: The jobs we want to add to the pilot. Required.
225-
:paramtype pilot_jobs_ids: list[int]
226-
"""
227-
super().__init__(**kwargs)
228-
self.pilot_stamp = pilot_stamp
229-
self.pilot_jobs_ids = pilot_jobs_ids
230-
231-
232197
class BodyPilotsCreatePilotSecrets(_serialization.Model):
233198
"""Body_pilots_create_pilot_secrets.
234199
@@ -335,7 +300,6 @@ def __init__(self, *, refresh_token: str, pilot_stamp: str, **kwargs: Any) -> No
335300
self.pilot_stamp = pilot_stamp
336301

337302

338-
>>>>>>> 42f29a35 (feat: Add pilot auth)
339303
class BodyPilotsUpdatePilotFields(_serialization.Model):
340304
"""Body_pilots_update_pilot_fields.
341305

diracx-client/src/diracx/client/_generated/operations/_operations.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,6 @@ def build_pilots_clear_pilots_request(
653653
return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs)
654654

655655

656-
<<<<<<< HEAD
657-
def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest:
658-
=======
659656
def build_pilots_create_pilot_secrets_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long
660657
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
661658

@@ -687,8 +684,7 @@ def build_pilots_update_secrets_constraints_request(**kwargs: Any) -> HttpReques
687684
return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs)
688685

689686

690-
def build_pilots_associate_pilot_with_jobs_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long
691-
>>>>>>> 42f29a35 (feat: Add pilot auth)
687+
def build_pilots_add_jobs_to_pilot_request(**kwargs: Any) -> HttpRequest:
692688
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
693689

694690
content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
@@ -3204,10 +3200,6 @@ def clear_pilots( # pylint: disable=inconsistent-return-statements
32043200
return cls(pipeline_response, None, {}) # type: ignore
32053201

32063202
@overload
3207-
<<<<<<< HEAD
3208-
def add_jobs_to_pilot(
3209-
self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any
3210-
=======
32113203
def create_pilot_secrets(
32123204
self, body: _models.BodyPilotsCreatePilotSecrets, *, content_type: str = "application/json", **kwargs: Any
32133205
) -> List[_models.PilotSecretsInfo]:
@@ -3402,9 +3394,8 @@ def update_secrets_constraints( # pylint: disable=inconsistent-return-statement
34023394
return cls(pipeline_response, None, {}) # type: ignore
34033395

34043396
@overload
3405-
def associate_pilot_with_jobs(
3406-
self, body: _models.BodyPilotsAssociatePilotWithJobs, *, content_type: str = "application/json", **kwargs: Any
3407-
>>>>>>> 42f29a35 (feat: Add pilot auth)
3397+
def add_jobs_to_pilot(
3398+
self, body: _models.BodyPilotsAddJobsToPilot, *, content_type: str = "application/json", **kwargs: Any
34083399
) -> None:
34093400
"""Add Jobs To Pilot.
34103401

diracx-core/src/diracx/core/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def __init__(self, job_id, detail: str = ""):
9999
)
100100

101101

102+
class BadTokenError(DiracError): ...
103+
104+
102105
class NotReadyError(DiracError):
103106
"""Tried to access a value which is asynchronously loaded but not yet available."""
104107

diracx-core/src/diracx/core/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import uuid as std_uuid
99
from datetime import datetime
10-
from enum import StrEnum
10+
from enum import StrEnum, auto
1111
from typing import Any, Literal, Optional
1212

1313
from pydantic import BaseModel, Field, GetCoreSchemaHandler, GetJsonSchemaHandler
@@ -319,6 +319,7 @@ class PilotFieldsMapping(BaseModel, extra="forbid"):
319319
AccountingSent: Optional[bool] = None
320320
CurrentJobID: Optional[int] = None
321321

322+
322323
class PilotStatus(StrEnum):
323324
#: The pilot has been generated and is transferred to a remote site:
324325
SUBMITTED = "Submitted"
@@ -337,6 +338,7 @@ class PilotStatus(StrEnum):
337338
#: Cannot get information about the pilot status:
338339
UNKNOWN = "Unknown"
339340

341+
340342
class PilotSecretConstraints(TypedDict, total=False):
341343
VOs: list[str] # Authorize only a list of VOs
342344
PilotStamps: list[str] # Authorize only a list of stamps
@@ -345,6 +347,13 @@ class PilotSecretConstraints(TypedDict, total=False):
345347
# We can add constraints here
346348

347349

350+
class TokenType(StrEnum):
351+
# Pilot token
352+
PILOT_TOKEN = auto()
353+
# User token
354+
USER_TOKEN = auto()
355+
356+
348357
class PilotSecretsInfo(BaseModel):
349358
pilot_secret: str
350359
pilot_secret_expires_in: int

diracx-db/src/diracx/db/sql/pilots/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
)
1717
from diracx.core.models import (
1818
PilotFieldsMapping,
19-
PilotStatus,
2019
PilotSecretConstraints,
20+
PilotStatus,
2121
SearchSpec,
2222
SortSpec,
2323
)

diracx-db/tests/pilots/test_pilot_management.py

Lines changed: 0 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@
1111
from diracx.core.models import (
1212
PilotFieldsMapping,
1313
PilotStatus,
14-
ScalarSearchOperator,
15-
ScalarSearchSpec,
16-
VectorSearchOperator,
17-
VectorSearchSpec,
1814
)
19-
2015
from diracx.db.sql.pilots.db import PilotAgentsDB
2116

2217
from .util import (
@@ -40,117 +35,6 @@ async def pilot_db(tmp_path):
4035
yield agents_db
4136

4237

43-
async def get_pilot_jobs_ids_by_pilot_id(
44-
pilot_db: PilotAgentsDB, pilot_id: int
45-
) -> list[int]:
46-
_, jobs = await pilot_db.search_pilot_to_job_mapping(
47-
parameters=["JobID"],
48-
search=[
49-
ScalarSearchSpec(
50-
parameter="PilotID",
51-
operator=ScalarSearchOperator.EQUAL,
52-
value=pilot_id,
53-
)
54-
],
55-
sorts=[],
56-
distinct=True,
57-
per_page=10000,
58-
)
59-
60-
return [job["JobID"] for job in jobs]
61-
62-
63-
async def get_pilots_by_stamp_bulk(
64-
pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = []
65-
) -> list[dict[Any, Any]]:
66-
_, pilots = await pilot_db.search_pilots(
67-
parameters=parameters,
68-
search=[
69-
VectorSearchSpec(
70-
parameter="PilotStamp",
71-
operator=VectorSearchOperator.IN,
72-
values=pilot_stamps,
73-
)
74-
],
75-
sorts=[],
76-
distinct=True,
77-
per_page=1000,
78-
)
79-
80-
# Custom handling, to see which pilot_stamp does not exist (if so, say which one)
81-
found_keys = {row["PilotStamp"] for row in pilots}
82-
missing = set(pilot_stamps) - found_keys
83-
84-
if missing:
85-
raise PilotNotFoundError(
86-
data={"pilot_stamp": str(missing)},
87-
detail=str(missing),
88-
non_existing_pilots=missing,
89-
)
90-
91-
return pilots
92-
93-
94-
@pytest.fixture
95-
async def add_stamps(pilot_db):
96-
async def _add_stamps(start_n=0):
97-
async with pilot_db as db:
98-
# Add pilots
99-
refs = [f"ref_{i}" for i in range(start_n, start_n + N)]
100-
stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)]
101-
pilot_references = dict(zip(stamps, refs))
102-
103-
vo = MAIN_VO
104-
105-
await db.add_pilots_bulk(
106-
stamps, vo, grid_type="DIRAC", pilot_references=pilot_references
107-
)
108-
109-
pilots = await get_pilots_by_stamp_bulk(db, stamps)
110-
111-
return pilots
112-
113-
return _add_stamps
114-
115-
116-
@pytest.fixture
117-
async def create_old_pilots_environment(pilot_db, create_timed_pilots):
118-
non_aborted_recent = await create_timed_pilots(
119-
datetime(2025, 1, 1, tzinfo=timezone.utc), False, N
120-
)
121-
aborted_recent = await create_timed_pilots(
122-
datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N
123-
)
124-
125-
aborted_very_old = await create_timed_pilots(
126-
datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N
127-
)
128-
non_aborted_very_old = await create_timed_pilots(
129-
datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N
130-
)
131-
132-
pilot_number = 4 * N
133-
134-
assert pilot_number == (
135-
len(non_aborted_recent)
136-
+ len(aborted_recent)
137-
+ len(aborted_very_old)
138-
+ len(non_aborted_very_old)
139-
)
140-
141-
# Phase 0. Verify that we have the right environment
142-
async with pilot_db as pilot_db:
143-
# Ensure that we can get every pilot (only get first of each group)
144-
await get_pilots_by_stamp_bulk(pilot_db, [non_aborted_recent[0]["PilotStamp"]])
145-
await get_pilots_by_stamp_bulk(pilot_db, [aborted_recent[0]["PilotStamp"]])
146-
await get_pilots_by_stamp_bulk(pilot_db, [aborted_very_old[0]["PilotStamp"]])
147-
await get_pilots_by_stamp_bulk(
148-
pilot_db, [non_aborted_very_old[0]["PilotStamp"]]
149-
)
150-
151-
return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old
152-
153-
15438
@pytest.mark.asyncio
15539
async def test_insert_and_select(pilot_db: PilotAgentsDB):
15640
async with pilot_db as pilot_db:

diracx-logic/src/diracx/logic/auth/token.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BaseTokenPayload,
2424
GrantType,
2525
RefreshTokenPayload,
26+
TokenType,
2627
)
2728
from diracx.core.properties import SecurityProperty
2829
from diracx.core.settings import AuthSettings
@@ -159,12 +160,15 @@ async def get_oidc_token_info_from_authorization_flow(
159160

160161

161162
async def get_token_info_from_refresh_flow(
162-
refresh_token: str, auth_db: AuthDB, settings: AuthSettings
163+
refresh_token: str,
164+
auth_db: AuthDB,
165+
settings: AuthSettings,
166+
token_type: TokenType = TokenType.USER_TOKEN,
163167
) -> tuple[dict, str, bool, float, bool]:
164168
"""Get OIDC token information from the refresh token DB and check few parameters before returning it."""
165169
# Decode the refresh token to get the JWT ID
166170
jti, exp, legacy_exchange = await verify_dirac_refresh_token(
167-
refresh_token, settings
171+
refresh_token, settings, token_type
168172
)
169173

170174
# Get some useful user information from the refresh token entry in the DB

diracx-logic/src/diracx/logic/auth/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
from uuid_utils import UUID
1616

1717
from diracx.core.config.schema import Config
18-
from diracx.core.exceptions import AuthorizationError, IAMClientError, IAMServerError
19-
from diracx.core.models import GrantType
18+
from diracx.core.exceptions import (
19+
AuthorizationError,
20+
BadTokenError,
21+
IAMClientError,
22+
IAMServerError,
23+
)
24+
from diracx.core.models import GrantType, TokenType
2025
from diracx.core.properties import SecurityProperty
2126
from diracx.core.settings import AuthSettings
2227

@@ -208,6 +213,7 @@ def read_token(
208213
async def verify_dirac_refresh_token(
209214
refresh_token: str,
210215
settings: AuthSettings,
216+
token_type: TokenType = TokenType.USER_TOKEN,
211217
) -> tuple[UUID, float, bool]:
212218
"""Verify dirac user token and return a UserInfo class
213219
Used for each API endpoint.
@@ -216,6 +222,12 @@ async def verify_dirac_refresh_token(
216222
refresh_token, settings.token_keystore.jwks, settings.token_allowed_algorithms
217223
)
218224

225+
if token_type == TokenType.USER_TOKEN and "dirac_policies" not in claims:
226+
raise BadTokenError("This is not a user token.")
227+
228+
if token_type == TokenType.PILOT_TOKEN and "dirac_policies" in claims:
229+
raise BadTokenError("This is not a pilot token.")
230+
219231
return (
220232
UUID(claims["jti"]),
221233
float(claims["exp"]),

0 commit comments

Comments
 (0)