Skip to content

Commit 0256c34

Browse files
refactor: Splitting the job section to have only pilots or only users
1 parent d00e224 commit 0256c34

File tree

7 files changed

+181
-56
lines changed

7 files changed

+181
-56
lines changed

diracx-routers/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ auth = "diracx.routers.auth:router"
4646
config = "diracx.routers.configuration:router"
4747
health = "diracx.routers.health:router"
4848
jobs = "diracx.routers.jobs:router"
49+
pilot = "diracx.routers.pilots:router"
4950

5051
[project.entry-points."diracx.access_policies"]
5152
WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy"
5253
SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy"
54+
PilotWMSAccessPolicy = "diracx.routers.jobs.access_policies:PilotWMSAccessPolicy"
5355

5456
# Minimum version of the client supported
5557
[project.entry-points."diracx.min_client_version"]

diracx-routers/src/diracx/routers/jobs/access_policies.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from fastapi import Depends, HTTPException, status
88

99
from diracx.core.properties import (
10-
GENERIC_PILOT,
1110
JOB_ADMINISTRATOR,
12-
LIMITED_DELEGATION,
1311
NORMAL_USER,
1412
)
1513
from diracx.db.sql import JobDB, PilotAgentsDB, SandboxMetadataDB
@@ -53,44 +51,6 @@ async def policy(
5351
assert action, "action is a mandatory parameter"
5452
assert job_db, "job_db is a mandatory parameter"
5553

56-
if action == ActionType.PILOT:
57-
58-
assert (
59-
pilot_db
60-
), "pilot_db is a mandatory parameter when using a pilot action"
61-
assert job_ids, "job_ids has to be defined"
62-
pilot_info = user_info # For semantic
63-
64-
# Syntax to avoid code duplication
65-
if {GENERIC_PILOT, LIMITED_DELEGATION} & set(pilot_info.properties):
66-
# Get his informations
67-
pilot_data = await pilot_db.get_pilot_by_reference(
68-
pilot_info.preferred_username
69-
)
70-
71-
if not pilot_data:
72-
raise HTTPException(
73-
status.HTTP_403_FORBIDDEN, "this pilot is not registered"
74-
)
75-
76-
# Get his jobs
77-
pilot_jobs = await pilot_db.get_pilot_job_ids(
78-
pilot_id=pilot_data["PilotID"]
79-
)
80-
81-
# Equivalent of issubset, but cleaner
82-
if set(job_ids) <= set(pilot_jobs):
83-
return
84-
85-
forbidden_jobs_ids = set(job_ids) - set(pilot_jobs)
86-
87-
raise HTTPException(
88-
status.HTTP_403_FORBIDDEN,
89-
f"this pilot can't access/modify some jobs: ids={forbidden_jobs_ids}",
90-
)
91-
92-
raise HTTPException(status.HTTP_403_FORBIDDEN, "you are not a pilot")
93-
9454
if action == ActionType.CREATE:
9555
if job_ids is not None:
9656
raise NotImplementedError(
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from ..fastapi_classes import DiracxRouter
6+
from .jobs import router as jobs_router
7+
8+
logger = logging.getLogger(__name__)
9+
10+
router = DiracxRouter()
11+
router.include_router(jobs_router)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from typing import Annotated
5+
6+
from fastapi import Depends, HTTPException, status
7+
8+
from diracx.core.properties import (
9+
GENERIC_PILOT,
10+
LIMITED_DELEGATION,
11+
)
12+
from diracx.db.sql import JobDB, PilotAgentsDB
13+
from diracx.routers.access_policies import BaseAccessPolicy
14+
from diracx.routers.utils.users import AuthorizedUserInfo
15+
16+
17+
class PilotWMSAccessPolicy(BaseAccessPolicy):
18+
"""Rules:
19+
* You need either NORMAL_USER or JOB_ADMINISTRATOR in your properties
20+
* An admin cannot create any resource but can read everything and modify everything
21+
* A NORMAL_USER can create
22+
* a NORMAL_USER can query and read only his own jobs.
23+
"""
24+
25+
@staticmethod
26+
async def policy(
27+
policy_name: str,
28+
user_info: AuthorizedUserInfo,
29+
/,
30+
*,
31+
pilot_db: PilotAgentsDB | None = None,
32+
job_db: JobDB | None = None,
33+
job_ids: list[int] | None = None,
34+
):
35+
assert job_db, "job_db is a mandatory parameter"
36+
assert pilot_db, "pilot_db is a mandatory parameter when using a pilot action"
37+
assert job_ids, "job_ids has to be defined"
38+
pilot_info = user_info # For semantic
39+
40+
# Syntax to avoid code duplication
41+
if {GENERIC_PILOT, LIMITED_DELEGATION} & set(pilot_info.properties):
42+
# Get his informations
43+
pilot_data = await pilot_db.get_pilot_by_reference(
44+
pilot_info.preferred_username
45+
)
46+
47+
if not pilot_data:
48+
raise HTTPException(
49+
status.HTTP_403_FORBIDDEN, "this pilot is not registered"
50+
)
51+
52+
# Get his jobs
53+
pilot_jobs = await pilot_db.get_pilot_job_ids(
54+
pilot_id=pilot_data["PilotID"]
55+
)
56+
57+
# Equivalent of issubset, but cleaner
58+
if set(job_ids) <= set(pilot_jobs):
59+
return
60+
61+
forbidden_jobs_ids = set(job_ids) - set(pilot_jobs)
62+
63+
raise HTTPException(
64+
status.HTTP_403_FORBIDDEN,
65+
f"this pilot can't access/modify some jobs: ids={forbidden_jobs_ids}",
66+
)
67+
68+
raise HTTPException(status.HTTP_403_FORBIDDEN, "you are not a pilot")
69+
70+
71+
CheckPilotWMSPolicyCallable = Annotated[Callable, Depends(PilotWMSAccessPolicy.check)]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
from datetime import datetime
4+
from http import HTTPStatus
5+
from typing import Any
6+
7+
from fastapi import HTTPException
8+
9+
from diracx.core.models import (
10+
JobStatusUpdate,
11+
SetJobStatusReturn,
12+
)
13+
from diracx.logic.jobs.status import (
14+
set_job_parameters_or_attributes as set_job_parameters_or_attributes_bl,
15+
)
16+
from diracx.logic.jobs.status import set_job_statuses as set_job_statuses_bl
17+
18+
from ..dependencies import (
19+
Config,
20+
JobDB,
21+
JobLoggingDB,
22+
JobParametersDB,
23+
TaskQueueDB,
24+
)
25+
from ..fastapi_classes import DiracxRouter
26+
from .access_policies import CheckPilotWMSPolicyCallable
27+
28+
router = DiracxRouter()
29+
30+
31+
@router.patch("/status")
32+
async def set_job_statuses(
33+
job_update: dict[int, dict[datetime, JobStatusUpdate]],
34+
config: Config,
35+
job_db: JobDB,
36+
job_logging_db: JobLoggingDB,
37+
task_queue_db: TaskQueueDB,
38+
job_parameters_db: JobParametersDB,
39+
check_permissions: CheckPilotWMSPolicyCallable,
40+
force: bool = False,
41+
) -> SetJobStatusReturn:
42+
await check_permissions(job_db=job_db, job_ids=list(job_update))
43+
44+
try:
45+
result = await set_job_statuses_bl(
46+
status_changes=job_update,
47+
config=config,
48+
job_db=job_db,
49+
job_logging_db=job_logging_db,
50+
task_queue_db=task_queue_db,
51+
job_parameters_db=job_parameters_db,
52+
force=force,
53+
)
54+
except ValueError as e:
55+
raise HTTPException(
56+
status_code=HTTPStatus.BAD_REQUEST,
57+
detail=str(e),
58+
) from e
59+
60+
if not result.success:
61+
raise HTTPException(
62+
status_code=HTTPStatus.NOT_FOUND,
63+
detail=result.model_dump(),
64+
)
65+
66+
return result
67+
68+
69+
@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT)
70+
async def patch_metadata(
71+
updates: dict[int, dict[str, Any]],
72+
job_db: JobDB,
73+
job_parameters_db: JobParametersDB,
74+
check_permissions: CheckPilotWMSPolicyCallable,
75+
):
76+
await check_permissions(job_db=job_db, job_ids=updates)
77+
try:
78+
await set_job_parameters_or_attributes_bl(updates, job_db, job_parameters_db)
79+
except ValueError as e:
80+
raise HTTPException(
81+
status_code=HTTPStatus.BAD_REQUEST,
82+
detail=str(e),
83+
) from e

diracx-routers/tests/jobs/test_status.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def test_set_job_status_cannot_make_impossible_transitions(
153153
]
154154
},
155155
)
156+
156157
assert r.status_code == 200, r.json()
157158
assert r.json()[0]["JobID"] == valid_job_id
158159
assert r.json()[0]["Status"] == JobStatus.RECEIVED.value

diracx-routers/tests/jobs/test_wms_access_policy.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SandboxAccessPolicy,
1111
WMSAccessPolicy,
1212
)
13+
from diracx.routers.pilots.access_policies import PilotWMSAccessPolicy
1314
from diracx.routers.utils.users import AuthorizedUserInfo
1415

1516
base_payload = {
@@ -53,6 +54,7 @@ def sandbox_metadata_db():
5354

5455
WMS_POLICY_NAME = "WMSAccessPolicy_AlthoughItDoesNotMatter"
5556
SANDBOX_POLICY_NAME = "SandboxAccessPolicy_AlthoughItDoesNotMatter"
57+
PILOT_WMS_POLICY_NAME = "Pilot_WMSAccessPolicy_AlthoughItDoesNotMatter"
5658

5759

5860
async def test_wms_access_policy_weird_user(job_db):
@@ -78,18 +80,17 @@ async def test_wms_access_policy_weird_user(job_db):
7880
)
7981

8082

81-
async def test_wms_access_policy_pilot(job_db, pilot_db, monkeypatch):
83+
async def test_pilot_wms_access_policy_pilot(job_db, pilot_db, monkeypatch):
8284

8385
normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload)
8486
pilot = AuthorizedUserInfo(properties=[GENERIC_PILOT], **base_payload)
8587

8688
# ------------------------- Simple User accessing a pilot action -------------------------
8789
# A user cannot create any resource
8890
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
89-
await WMSAccessPolicy.policy(
90-
WMS_POLICY_NAME,
91+
await PilotWMSAccessPolicy.policy(
92+
PILOT_WMS_POLICY_NAME,
9193
normal_user,
92-
action=ActionType.PILOT,
9394
job_db=job_db,
9495
pilot_db=pilot_db,
9596
job_ids=[1, 2],
@@ -108,10 +109,9 @@ async def get_pilot_by_reference_patch(*args):
108109

109110
# A pilot that has expired (removed from db) should not be able to access jobs
110111
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
111-
await WMSAccessPolicy.policy(
112-
WMS_POLICY_NAME,
112+
await PilotWMSAccessPolicy.policy(
113+
PILOT_WMS_POLICY_NAME,
113114
pilot,
114-
action=ActionType.PILOT,
115115
pilot_db=pilot_db,
116116
job_db=job_db,
117117
job_ids=[1, 2],
@@ -133,10 +133,9 @@ async def get_pilot_job_ids_patch(*args, **kwargs):
133133

134134
# A pilot that has is not associated with a job can't access a job
135135
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
136-
await WMSAccessPolicy.policy(
137-
WMS_POLICY_NAME,
136+
await PilotWMSAccessPolicy.policy(
137+
PILOT_WMS_POLICY_NAME,
138138
pilot,
139-
action=ActionType.PILOT,
140139
pilot_db=pilot_db,
141140
job_db=job_db,
142141
job_ids=[1, 2],
@@ -154,10 +153,9 @@ async def get_pilot_job_ids_patch(*args, **kwargs):
154153
monkeypatch.setattr(pilot_db, "get_pilot_job_ids", get_pilot_job_ids_patch)
155154

156155
# A pilot that is associated with a job can access a job
157-
await WMSAccessPolicy.policy(
158-
WMS_POLICY_NAME,
156+
await PilotWMSAccessPolicy.policy(
157+
PILOT_WMS_POLICY_NAME,
159158
pilot,
160-
action=ActionType.PILOT,
161159
pilot_db=pilot_db,
162160
job_db=job_db,
163161
job_ids=[1, 2],
@@ -171,10 +169,9 @@ async def get_pilot_job_ids_patch(*args, **kwargs):
171169

172170
# A pilot that fetches few jobs, one where he does not have the rights, and few where he has the rights
173171
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
174-
await WMSAccessPolicy.policy(
175-
WMS_POLICY_NAME,
172+
await PilotWMSAccessPolicy.policy(
173+
PILOT_WMS_POLICY_NAME,
176174
pilot,
177-
action=ActionType.PILOT,
178175
pilot_db=pilot_db,
179176
job_db=job_db,
180177
job_ids=[1, 2, 12],

0 commit comments

Comments
 (0)