Skip to content

Commit c2e2b8f

Browse files
author
Robin VAN DE MERGHEL
committed
feat: Split users and pilots
1 parent 4507b79 commit c2e2b8f

File tree

4 files changed

+194
-1
lines changed

4 files changed

+194
-1
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,7 @@ class SecretHasExpiredError(DiracFormattedError):
160160

161161
class SecretAlreadyExistsError(DiracFormattedError):
162162
pattern = "Secret %s already exists"
163+
164+
165+
class PilotCantAccessJobError(DiracFormattedError):
166+
pattern = "Pilot can't access jobs. %s"

diracx-routers/src/diracx/routers/pilot_resources/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010

1111
from ..fastapi_classes import DiracxRouter
12+
from .jobs import router as jobs_router
1213
from .util import router as util_router
1314

1415
logger = logging.getLogger(__name__)
@@ -21,3 +22,4 @@
2122
)
2223

2324
router.include_router(util_router)
25+
router.include_router(jobs_router)
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Duplicates from /jobs/status.
2+
3+
In the future, pilots will only have DiracX tokens, so they will eventually won't be using /jobs/status anymore.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from datetime import datetime
9+
from http import HTTPStatus
10+
from typing import Annotated, Any
11+
12+
from fastapi import Depends, HTTPException, status
13+
14+
from diracx.core.exceptions import JobNotFoundError, PilotCantAccessJobError
15+
from diracx.core.models import (
16+
HeartbeatData,
17+
JobCommand,
18+
JobStatusUpdate,
19+
SetJobStatusReturn,
20+
)
21+
from diracx.logic.jobs.status import add_heartbeat as add_heartbeat_bl
22+
from diracx.logic.jobs.status import get_job_commands as get_job_commands_bl
23+
from diracx.logic.jobs.status import (
24+
set_job_parameters_or_attributes as set_job_parameters_or_attributes_bl,
25+
)
26+
from diracx.logic.jobs.status import set_job_statuses as set_job_statuses_bl
27+
from diracx.routers.utils.pilots import (
28+
AuthorizedPilotInfo,
29+
verify_dirac_pilot_access_token,
30+
verify_that_pilot_can_access_jobs,
31+
)
32+
33+
from ..dependencies import (
34+
Config,
35+
JobDB,
36+
JobLoggingDB,
37+
JobParametersDB,
38+
PilotAgentsDB,
39+
TaskQueueDB,
40+
)
41+
from ..fastapi_classes import DiracxRouter
42+
43+
router = DiracxRouter()
44+
45+
46+
@router.patch("/status")
47+
async def pilot_set_job_statuses(
48+
job_update: dict[int, dict[datetime, JobStatusUpdate]],
49+
config: Config,
50+
job_db: JobDB,
51+
job_logging_db: JobLoggingDB,
52+
task_queue_db: TaskQueueDB,
53+
job_parameters_db: JobParametersDB,
54+
pilot_db: PilotAgentsDB,
55+
pilot_info: Annotated[
56+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
57+
],
58+
force: bool = False,
59+
) -> SetJobStatusReturn:
60+
# Endpoint only for DiracX pilots (with a pilot token)
61+
try:
62+
await verify_that_pilot_can_access_jobs(
63+
pilot_db=pilot_db,
64+
pilot_stamp=pilot_info.pilot_stamp,
65+
job_ids=list(job_update),
66+
)
67+
except (PilotCantAccessJobError, JobNotFoundError) as e:
68+
raise HTTPException(
69+
status_code=status.HTTP_403_FORBIDDEN, detail="Pilot can't access this job."
70+
) from e
71+
72+
try:
73+
result = await set_job_statuses_bl(
74+
status_changes=job_update,
75+
config=config,
76+
job_db=job_db,
77+
job_logging_db=job_logging_db,
78+
task_queue_db=task_queue_db,
79+
job_parameters_db=job_parameters_db,
80+
force=force,
81+
)
82+
except ValueError as e:
83+
raise HTTPException(
84+
status_code=HTTPStatus.BAD_REQUEST,
85+
detail=str(e),
86+
) from e
87+
88+
if not result.success:
89+
raise HTTPException(
90+
status_code=HTTPStatus.NOT_FOUND,
91+
detail=result.model_dump(),
92+
)
93+
94+
return result
95+
96+
97+
@router.patch("/heartbeat")
98+
async def pilot_add_heartbeat(
99+
data: dict[int, HeartbeatData],
100+
config: Config,
101+
job_db: JobDB,
102+
job_logging_db: JobLoggingDB,
103+
task_queue_db: TaskQueueDB,
104+
job_parameters_db: JobParametersDB,
105+
pilot_db: PilotAgentsDB,
106+
pilot_info: Annotated[
107+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
108+
],
109+
) -> list[JobCommand]:
110+
"""Register a heartbeat from the job.
111+
112+
This endpoint is used by the JobAgent to send heartbeats to the WMS and to
113+
receive job commands from the WMS. It also results in stalled jobs being
114+
restored to the RUNNING status.
115+
116+
The `data` parameter and return value are mappings keyed by job ID.
117+
"""
118+
# Endpoint only for DiracX pilots (with a pilot token)
119+
try:
120+
await verify_that_pilot_can_access_jobs(
121+
pilot_db=pilot_db,
122+
pilot_stamp=pilot_info.pilot_stamp,
123+
job_ids=list(data),
124+
)
125+
except (PilotCantAccessJobError, JobNotFoundError) as e:
126+
raise HTTPException(
127+
status_code=status.HTTP_403_FORBIDDEN, detail="Pilot can't access this job."
128+
) from e
129+
130+
await add_heartbeat_bl(
131+
data, config, job_db, job_logging_db, task_queue_db, job_parameters_db
132+
)
133+
return await get_job_commands_bl(data, job_db)
134+
135+
136+
@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT)
137+
async def pilot_patch_metadata(
138+
updates: dict[int, dict[str, Any]],
139+
job_db: JobDB,
140+
job_parameters_db: JobParametersDB,
141+
pilot_db: PilotAgentsDB,
142+
pilot_info: Annotated[
143+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
144+
],
145+
):
146+
# Endpoint only for DiracX pilots (with a pilot token)
147+
try:
148+
await verify_that_pilot_can_access_jobs(
149+
pilot_db=pilot_db,
150+
pilot_stamp=pilot_info.pilot_stamp,
151+
job_ids=list(updates),
152+
)
153+
except (PilotCantAccessJobError, JobNotFoundError) as e:
154+
raise HTTPException(
155+
status_code=status.HTTP_403_FORBIDDEN, detail="Pilot can't access this job."
156+
) from e
157+
158+
try:
159+
await set_job_parameters_or_attributes_bl(updates, job_db, job_parameters_db)
160+
except ValueError as e:
161+
raise HTTPException(
162+
status_code=HTTPStatus.BAD_REQUEST,
163+
detail=str(e),
164+
) from e

diracx-routers/src/diracx/routers/utils/pilots.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
from joserfc.jwt import JWTClaimsRegistry
99
from pydantic import BaseModel
1010

11+
from diracx.core.exceptions import PilotCantAccessJobError
1112
from diracx.core.models import UUID, PilotInfo
1213
from diracx.logic.auth.utils import read_token
13-
from diracx.routers.dependencies import AuthSettings
14+
from diracx.logic.pilots.management import get_pilot_jobs_ids_by_stamp
15+
16+
from ..dependencies import AuthSettings, PilotAgentsDB
1417

1518

1619
class AuthInfo(BaseModel):
@@ -71,3 +74,23 @@ async def verify_dirac_pilot_access_token(
7174
status_code=status.HTTP_401_UNAUTHORIZED,
7275
detail="Invalid JWT",
7376
) from e
77+
78+
79+
async def verify_that_pilot_can_access_jobs(
80+
pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int]
81+
):
82+
# Get its jobs
83+
pilot_jobs = await get_pilot_jobs_ids_by_stamp(
84+
pilot_db=pilot_db, pilot_stamp=pilot_stamp
85+
)
86+
87+
# Equivalent of issubset, but cleaner
88+
if set(job_ids) <= set(pilot_jobs):
89+
return
90+
91+
forbidden_jobs_ids = set(job_ids) - set(pilot_jobs)
92+
93+
if forbidden_jobs_ids:
94+
return PilotCantAccessJobError(
95+
data={"forbidden_jobs_ids": str(forbidden_jobs_ids)}
96+
)

0 commit comments

Comments
 (0)