Skip to content

Commit cb10ac1

Browse files
feat: Pilots and Users are splitted for jobs
1 parent f17e4db commit cb10ac1

File tree

6 files changed

+179
-48
lines changed

6 files changed

+179
-48
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,7 @@ class PilotJobsNotFoundError(DiracFormattedError):
165165

166166
class PilotAlreadyAssociatedWithJobError(DiracFormattedError):
167167
pattern = "Pilot is already associated with a job %s "
168+
169+
170+
class PilotCantAccessJobError(DiracFormattedError):
171+
pattern = "Pilot can't access some jobs %s "
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from diracx.core.exceptions import PilotCantAccessJobError
4+
from diracx.db.sql import PilotAgentsDB
5+
6+
7+
async def get_pilot_jobs_ids_by_stamp(
8+
pilot_db: PilotAgentsDB, pilot_stamp: str
9+
) -> list[int]:
10+
"""Fetch pilot jobs by stamp."""
11+
pilot_ids = await pilot_db.get_pilot_ids_by_stamps([pilot_stamp])
12+
# Semantic assured by fetch_records_bulk_or_raises
13+
pilot_id = pilot_ids[0]
14+
15+
return await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id)
16+
17+
18+
async def verify_that_pilot_can_access_jobs(
19+
pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int]
20+
):
21+
# Get its jobs
22+
pilot_jobs = await get_pilot_jobs_ids_by_stamp(
23+
pilot_db=pilot_db, pilot_stamp=pilot_stamp
24+
)
25+
26+
# Equivalent of issubset, but cleaner
27+
if set(job_ids) <= set(pilot_jobs):
28+
return
29+
30+
forbidden_jobs_ids = set(job_ids) - set(pilot_jobs)
31+
32+
if forbidden_jobs_ids:
33+
return PilotCantAccessJobError(
34+
data={"forbidden_jobs_ids": str(forbidden_jobs_ids)}
35+
)

diracx-logic/src/diracx/logic/pilots/management.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,6 @@ async def associate_pilot_with_jobs(
4141
)
4242

4343

44-
async def get_pilot_jobs_ids_by_stamp(
45-
pilot_db: PilotAgentsDB, pilot_stamp: str
46-
) -> list[int]:
47-
"""Fetch pilot jobs by stamp."""
48-
pilot_ids = await pilot_db.get_pilot_ids_by_stamps([pilot_stamp])
49-
# Semantic assured by fetch_records_bulk_or_raises
50-
pilot_id = pilot_ids[0]
51-
52-
return await pilot_db.get_pilot_jobs_ids_by_pilot_id(pilot_id)
53-
54-
5544
async def search(
5645
pilot_db: PilotAgentsDB,
5746
user_vo: str,

diracx-routers/src/diracx/routers/jobs/access_policies.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ class ActionType(StrEnum):
2222
MANAGE = auto()
2323
# Search
2424
QUERY = auto()
25-
# Actions from a pilot (e.g. heartbeat)
26-
PILOT = auto()
2725

2826

2927
class WMSAccessPolicy(BaseAccessPolicy):
@@ -47,11 +45,6 @@ async def policy(
4745
assert action, "action is a mandatory parameter"
4846
assert job_db, "job_db is a mandatory parameter"
4947

50-
if action == ActionType.PILOT:
51-
# TODO: For now we map this to MANAGE but it should be changed once
52-
# we have pilot credentials
53-
action = ActionType.MANAGE
54-
5548
if action == ActionType.CREATE:
5649
if job_ids is not None:
5750
raise NotImplementedError(

diracx-routers/src/diracx/routers/jobs/status.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,9 @@
77
from fastapi import HTTPException, Query
88

99
from diracx.core.models import (
10-
HeartbeatData,
11-
JobCommand,
1210
JobStatusUpdate,
1311
SetJobStatusReturn,
1412
)
15-
from diracx.logic.jobs.status import add_heartbeat as add_heartbeat_bl
16-
from diracx.logic.jobs.status import get_job_commands as get_job_commands_bl
1713
from diracx.logic.jobs.status import remove_jobs as remove_jobs_bl
1814
from diracx.logic.jobs.status import reschedule_jobs as reschedule_jobs_bl
1915
from diracx.logic.jobs.status import (
@@ -103,32 +99,6 @@ async def set_job_statuses(
10399
return result
104100

105101

106-
@router.patch("/heartbeat")
107-
async def add_heartbeat(
108-
data: dict[int, HeartbeatData],
109-
config: Config,
110-
job_db: JobDB,
111-
job_logging_db: JobLoggingDB,
112-
task_queue_db: TaskQueueDB,
113-
job_parameters_db: JobParametersDB,
114-
check_permissions: CheckWMSPolicyCallable,
115-
) -> list[JobCommand]:
116-
"""Register a heartbeat from the job.
117-
118-
This endpoint is used by the JobAgent to send heartbeats to the WMS and to
119-
receive job commands from the WMS. It also results in stalled jobs being
120-
restored to the RUNNING status.
121-
122-
The `data` parameter and return value are mappings keyed by job ID.
123-
"""
124-
await check_permissions(action=ActionType.PILOT, job_db=job_db, job_ids=list(data))
125-
126-
await add_heartbeat_bl(
127-
data, config, job_db, job_logging_db, task_queue_db, job_parameters_db
128-
)
129-
return await get_job_commands_bl(data, job_db)
130-
131-
132102
@router.post("/reschedule")
133103
async def reschedule_jobs(
134104
job_ids: Annotated[list[int], Query()],
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from __future__ import annotations
2+
3+
from datetime import datetime
4+
from http import HTTPStatus
5+
from typing import Annotated, Any
6+
7+
from fastapi import Depends, HTTPException
8+
9+
from diracx.core.models import (
10+
HeartbeatData,
11+
JobCommand,
12+
JobStatusUpdate,
13+
SetJobStatusReturn,
14+
)
15+
from diracx.db.sql import PilotAgentsDB
16+
from diracx.logic.jobs.status import add_heartbeat as add_heartbeat_bl
17+
from diracx.logic.jobs.status import get_job_commands as get_job_commands_bl
18+
from diracx.logic.jobs.status import (
19+
set_job_parameters_or_attributes as set_job_parameters_or_attributes_bl,
20+
)
21+
from diracx.logic.jobs.status import set_job_statuses as set_job_statuses_bl
22+
from diracx.logic.pilots.jobs import verify_that_pilot_can_access_jobs
23+
from diracx.routers.utils.pilots import (
24+
AuthorizedPilotInfo,
25+
verify_dirac_pilot_access_token,
26+
)
27+
28+
from ..dependencies import (
29+
Config,
30+
JobDB,
31+
JobLoggingDB,
32+
JobParametersDB,
33+
TaskQueueDB,
34+
)
35+
from ..fastapi_classes import DiracxRouter
36+
37+
router = DiracxRouter()
38+
39+
40+
@router.patch("/status")
41+
async def pilot_set_job_statuses(
42+
job_update: dict[int, dict[datetime, JobStatusUpdate]],
43+
config: Config,
44+
job_db: JobDB,
45+
job_logging_db: JobLoggingDB,
46+
task_queue_db: TaskQueueDB,
47+
job_parameters_db: JobParametersDB,
48+
pilot_agents_db: PilotAgentsDB,
49+
pilot_info: Annotated[
50+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
51+
],
52+
force: bool = False,
53+
) -> SetJobStatusReturn:
54+
await verify_that_pilot_can_access_jobs(
55+
pilot_db=pilot_agents_db,
56+
pilot_stamp=pilot_info.pilot_stamp,
57+
job_ids=list(job_update),
58+
)
59+
60+
try:
61+
result = await set_job_statuses_bl(
62+
status_changes=job_update,
63+
config=config,
64+
job_db=job_db,
65+
job_logging_db=job_logging_db,
66+
task_queue_db=task_queue_db,
67+
job_parameters_db=job_parameters_db,
68+
force=force,
69+
)
70+
except ValueError as e:
71+
raise HTTPException(
72+
status_code=HTTPStatus.BAD_REQUEST,
73+
detail=str(e),
74+
) from e
75+
76+
if not result.success:
77+
raise HTTPException(
78+
status_code=HTTPStatus.NOT_FOUND,
79+
detail=result.model_dump(),
80+
)
81+
82+
return result
83+
84+
85+
@router.patch("/heartbeat")
86+
async def pilot_add_heartbeat(
87+
data: dict[int, HeartbeatData],
88+
config: Config,
89+
job_db: JobDB,
90+
job_logging_db: JobLoggingDB,
91+
task_queue_db: TaskQueueDB,
92+
job_parameters_db: JobParametersDB,
93+
pilot_agents_db: PilotAgentsDB,
94+
pilot_info: Annotated[
95+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
96+
],
97+
) -> list[JobCommand]:
98+
"""Register a heartbeat from the job.
99+
100+
This endpoint is used by the JobAgent to send heartbeats to the WMS and to
101+
receive job commands from the WMS. It also results in stalled jobs being
102+
restored to the RUNNING status.
103+
104+
The `data` parameter and return value are mappings keyed by job ID.
105+
"""
106+
await verify_that_pilot_can_access_jobs(
107+
pilot_db=pilot_agents_db,
108+
pilot_stamp=pilot_info.pilot_stamp,
109+
job_ids=list(data),
110+
)
111+
112+
await add_heartbeat_bl(
113+
data, config, job_db, job_logging_db, task_queue_db, job_parameters_db
114+
)
115+
return await get_job_commands_bl(data, job_db)
116+
117+
118+
@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT)
119+
async def pilot_patch_metadata(
120+
updates: dict[int, dict[str, Any]],
121+
job_db: JobDB,
122+
job_parameters_db: JobParametersDB,
123+
pilot_agents_db: PilotAgentsDB,
124+
pilot_info: Annotated[
125+
AuthorizedPilotInfo, Depends(verify_dirac_pilot_access_token)
126+
],
127+
):
128+
await verify_that_pilot_can_access_jobs(
129+
pilot_db=pilot_agents_db,
130+
pilot_stamp=pilot_info.pilot_stamp,
131+
job_ids=list(updates),
132+
)
133+
134+
try:
135+
await set_job_parameters_or_attributes_bl(updates, job_db, job_parameters_db)
136+
except ValueError as e:
137+
raise HTTPException(
138+
status_code=HTTPStatus.BAD_REQUEST,
139+
detail=str(e),
140+
) from e

0 commit comments

Comments
 (0)