Skip to content

Commit a678d0f

Browse files
feat: Support diracx pilots for jobs
1 parent d5f4676 commit a678d0f

File tree

6 files changed

+207
-2
lines changed

6 files changed

+207
-2
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ class PilotJobsNotFoundError(DiracFormattedError):
139139
pattern = "Pilots or Jobs %s not found"
140140

141141

142+
class PilotCantAccessJobError(DiracFormattedError):
143+
pattern = "Pilot %s can't access jobs"
144+
145+
142146
class PilotAlreadyAssociatedWithJobError(DiracFormattedError):
143147
pattern = "Pilot is already associated with a job %s "
144148

diracx-logic/src/diracx/logic/pilots/query.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,19 @@ async def get_pilot_jobs_ids_by_pilot_id(
112112
return [job["JobID"] for job in jobs]
113113

114114

115+
async def get_pilot_jobs_ids_by_stamp(
116+
pilot_db: PilotAgentsDB, pilot_stamp: str
117+
) -> list[int]:
118+
pilot_ids = await get_pilot_ids_by_stamps(
119+
pilot_db=pilot_db,
120+
pilot_stamps=[pilot_stamp],
121+
)
122+
123+
return await get_pilot_jobs_ids_by_pilot_id(
124+
pilot_db=pilot_db, pilot_id=pilot_ids[0]
125+
)
126+
127+
115128
async def get_secrets_by_hashed_secrets_bulk(
116129
pilot_db: PilotAgentsDB, hashed_secrets: list[bytes], parameters: list[str] = []
117130
) -> list[dict[Any, Any]]:

diracx-routers/src/diracx/routers/jobs/access_policies.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ async def policy(
4949

5050
if action == ActionType.PILOT:
5151
# TODO: For now we map this to MANAGE but it should be changed once
52-
# we have pilot credentials
52+
# Pilots will eventually be removed from this access policy
53+
# See #572
5354
action = ActionType.MANAGE
5455

5556
if action == ActionType.CREATE:

diracx-routers/src/diracx/routers/jobs/status.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ async def set_job_statuses(
7474
check_permissions: CheckWMSPolicyCallable,
7575
force: bool = False,
7676
) -> SetJobStatusReturn:
77+
# FIXME: Pilots will eventually be removed from this endpoint
78+
# See #572
7779
await check_permissions(
7880
action=ActionType.MANAGE, job_db=job_db, job_ids=list(job_update)
7981
)
@@ -121,6 +123,8 @@ async def add_heartbeat(
121123
122124
The `data` parameter and return value are mappings keyed by job ID.
123125
"""
126+
# FIXME: Pilots will eventually be removed from this endpoint
127+
# See #572
124128
await check_permissions(action=ActionType.PILOT, job_db=job_db, job_ids=list(data))
125129

126130
await add_heartbeat_bl(
@@ -171,6 +175,8 @@ async def patch_metadata(
171175
job_parameters_db: JobParametersDB,
172176
check_permissions: CheckWMSPolicyCallable,
173177
):
178+
# FIXME: Pilots will eventually be removed from this endpoint
179+
# See #572
174180
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=updates)
175181
try:
176182
await set_job_parameters_or_attributes_bl(updates, job_db, job_parameters_db)
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import annotations
2+
3+
from datetime import datetime
4+
from http import HTTPStatus
5+
from typing import Annotated, Any
6+
7+
from fastapi import Depends, HTTPException, status
8+
9+
from diracx.core.exceptions import JobNotFoundError, PilotCantAccessJobError
10+
from diracx.core.models import (
11+
HeartbeatData,
12+
JobCommand,
13+
JobStatusUpdate,
14+
SetJobStatusReturn,
15+
)
16+
from diracx.logic.jobs.status import add_heartbeat as add_heartbeat_bl
17+
from diracx.logic.jobs.status import get_job_commands as get_job_commands_bl
18+
from diracx.logic.jobs.status import (
19+
set_job_parameters_or_attributes as set_job_parameters_or_attributes_bl,
20+
)
21+
from diracx.logic.jobs.status import set_job_statuses as set_job_statuses_bl
22+
from diracx.routers.utils.pilots import (
23+
AuthorizedPilotInfo,
24+
verify_dirac_pilot_access_token,
25+
verify_that_pilot_can_access_jobs,
26+
)
27+
28+
from ..dependencies import (
29+
Config,
30+
JobDB,
31+
JobLoggingDB,
32+
JobParametersDB,
33+
PilotAgentsDB,
34+
TaskQueueDB,
35+
)
36+
from ..fastapi_classes import DiracxRouter
37+
38+
router = DiracxRouter()
39+
40+
41+
@router.patch("/status")
42+
async def pilot_set_job_statuses(
43+
job_update: dict[int, dict[datetime, JobStatusUpdate]],
44+
config: Config,
45+
job_db: JobDB,
46+
job_logging_db: JobLoggingDB,
47+
task_queue_db: TaskQueueDB,
48+
job_parameters_db: JobParametersDB,
49+
pilot_agents_db: PilotAgentsDB,
50+
pilot_info: Annotated[
51+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
52+
],
53+
force: bool = False,
54+
) -> SetJobStatusReturn:
55+
# Endpoint only for DiracX pilots (with a pilot token)
56+
try:
57+
await verify_that_pilot_can_access_jobs(
58+
pilot_db=pilot_agents_db,
59+
pilot_stamp=pilot_info.pilot_stamp,
60+
job_ids=list(job_update),
61+
)
62+
except (PilotCantAccessJobError, JobNotFoundError) as e:
63+
raise HTTPException(
64+
status_code=status.HTTP_403_FORBIDDEN, detail="Pilot can't access this job."
65+
) from e
66+
67+
try:
68+
result = await set_job_statuses_bl(
69+
status_changes=job_update,
70+
config=config,
71+
job_db=job_db,
72+
job_logging_db=job_logging_db,
73+
task_queue_db=task_queue_db,
74+
job_parameters_db=job_parameters_db,
75+
force=force,
76+
)
77+
except ValueError as e:
78+
raise HTTPException(
79+
status_code=HTTPStatus.BAD_REQUEST,
80+
detail=str(e),
81+
) from e
82+
83+
if not result.success:
84+
raise HTTPException(
85+
status_code=HTTPStatus.NOT_FOUND,
86+
detail=result.model_dump(),
87+
)
88+
89+
return result
90+
91+
92+
@router.patch("/heartbeat")
93+
async def pilot_add_heartbeat(
94+
data: dict[int, HeartbeatData],
95+
config: Config,
96+
job_db: JobDB,
97+
job_logging_db: JobLoggingDB,
98+
task_queue_db: TaskQueueDB,
99+
job_parameters_db: JobParametersDB,
100+
pilot_agents_db: PilotAgentsDB,
101+
pilot_info: Annotated[
102+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
103+
],
104+
) -> list[JobCommand]:
105+
"""Register a heartbeat from the job.
106+
107+
This endpoint is used by the JobAgent to send heartbeats to the WMS and to
108+
receive job commands from the WMS. It also results in stalled jobs being
109+
restored to the RUNNING status.
110+
111+
The `data` parameter and return value are mappings keyed by job ID.
112+
"""
113+
# Endpoint only for DiracX pilots (with a pilot token)
114+
try:
115+
await verify_that_pilot_can_access_jobs(
116+
pilot_db=pilot_agents_db,
117+
pilot_stamp=pilot_info.pilot_stamp,
118+
job_ids=list(data),
119+
)
120+
except (PilotCantAccessJobError, JobNotFoundError) as e:
121+
raise HTTPException(
122+
status_code=status.HTTP_403_FORBIDDEN, detail="Pilot can't access this job."
123+
) from e
124+
125+
await add_heartbeat_bl(
126+
data, config, job_db, job_logging_db, task_queue_db, job_parameters_db
127+
)
128+
return await get_job_commands_bl(data, job_db)
129+
130+
131+
@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT)
132+
async def pilot_patch_metadata(
133+
updates: dict[int, dict[str, Any]],
134+
job_db: JobDB,
135+
job_parameters_db: JobParametersDB,
136+
pilot_agents_db: PilotAgentsDB,
137+
pilot_info: Annotated[
138+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
139+
],
140+
):
141+
# Endpoint only for DiracX pilots (with a pilot token)
142+
try:
143+
await verify_that_pilot_can_access_jobs(
144+
pilot_db=pilot_agents_db,
145+
pilot_stamp=pilot_info.pilot_stamp,
146+
job_ids=list(updates),
147+
)
148+
except (PilotCantAccessJobError, JobNotFoundError) as e:
149+
raise HTTPException(
150+
status_code=status.HTTP_403_FORBIDDEN, detail="Pilot can't access this job."
151+
) from e
152+
153+
try:
154+
await set_job_parameters_or_attributes_bl(updates, job_db, job_parameters_db)
155+
except ValueError as e:
156+
raise HTTPException(
157+
status_code=HTTPStatus.BAD_REQUEST,
158+
detail=str(e),
159+
) from e

diracx-routers/src/diracx/routers/utils/pilots.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from joserfc.jwt import JWTClaimsRegistry
99
from pydantic import BaseModel
1010

11+
from diracx.core.exceptions import PilotCantAccessJobError
1112
from diracx.core.models import UUID, PilotInfo
1213
from diracx.logic.auth.utils import read_token
13-
from diracx.routers.dependencies import AuthSettings
14+
from diracx.logic.pilots.query import get_pilot_jobs_ids_by_stamp
15+
from diracx.routers.dependencies import AuthSettings, PilotAgentsDB
1416

1517

1618
class AuthInfo(BaseModel):
@@ -71,3 +73,23 @@ async def verify_dirac_pilot_access_token(
7173
status_code=status.HTTP_401_UNAUTHORIZED,
7274
detail="Invalid JWT",
7375
) from e
76+
77+
78+
async def verify_that_pilot_can_access_jobs(
79+
pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int]
80+
):
81+
# Get its jobs
82+
pilot_jobs = await get_pilot_jobs_ids_by_stamp(
83+
pilot_db=pilot_db, pilot_stamp=pilot_stamp
84+
)
85+
86+
# Equivalent of issubset, but cleaner
87+
if set(job_ids) <= set(pilot_jobs):
88+
return
89+
90+
forbidden_jobs_ids = set(job_ids) - set(pilot_jobs)
91+
92+
if forbidden_jobs_ids:
93+
return PilotCantAccessJobError(
94+
data={"forbidden_jobs_ids": str(forbidden_jobs_ids)}
95+
)

0 commit comments

Comments
 (0)