Skip to content

Commit d15825e

Browse files
test: Adding tests to the WMS access policy, and some fixes
1 parent f8bc670 commit d15825e

File tree

3 files changed

+142
-7
lines changed

3 files changed

+142
-7
lines changed

diracx-routers/src/diracx/routers/jobs/access_policies.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ async def policy(
5959
pilot_db
6060
), "pilot_db is a mandatory parameter when using a pilot action"
6161
assert job_ids, "job_ids has to be defined"
62-
assert (
63-
len(job_ids) == 1
64-
), "a pilot can have only one job_id associated, and it has to be given"
65-
6662
pilot_info = user_info # For semantic
6763

6864
# Syntax to avoid code duplication
@@ -83,11 +79,11 @@ async def policy(
8379
)
8480

8581
# Equivalent of issubset, but cleaner
86-
if set(job_ids) <= pilot_jobs:
82+
if set(job_ids) <= set(pilot_jobs):
8783
return
8884

8985
raise HTTPException(
90-
status.HTTP_403_FORBIDDEN, "this pilot can't modify this job"
86+
status.HTTP_403_FORBIDDEN, "this pilot can't access/modify this job"
9187
)
9288

9389
raise HTTPException(status.HTTP_403_FORBIDDEN, "you are not a pilot")

diracx-routers/tests/jobs/test_wms_access_policy.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi import HTTPException, status
55
from uuid_utils import uuid7
66

7-
from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER
7+
from diracx.core.properties import GENERIC_PILOT, JOB_ADMINISTRATOR, NORMAL_USER
88
from diracx.routers.jobs.access_policies import (
99
ActionType,
1010
SandboxAccessPolicy,
@@ -26,6 +26,11 @@ class FakeJobDB:
2626
async def summary(self, *args): ...
2727

2828

29+
class FakePilotDB:
30+
async def get_pilot_by_reference(self, *args): ...
31+
async def get_pilot_job_ids(self, *args): ...
32+
33+
2934
class FakeSBMetadataDB:
3035
async def get_owner_id(self, *args): ...
3136
async def get_sandbox_owner_id(self, *args): ...
@@ -36,6 +41,11 @@ def job_db():
3641
yield FakeJobDB()
3742

3843

44+
@pytest.fixture
45+
def pilot_db():
46+
yield FakePilotDB()
47+
48+
3949
@pytest.fixture
4050
def sandbox_metadata_db():
4151
yield FakeSBMetadataDB()
@@ -68,6 +78,112 @@ async def test_wms_access_policy_weird_user(job_db):
6878
)
6979

7080

81+
async def test_wms_access_policy_pilot(job_db, pilot_db, monkeypatch):
82+
83+
normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload)
84+
pilot = AuthorizedUserInfo(properties=[GENERIC_PILOT], **base_payload)
85+
86+
# ------------------------- Simple User accessing a pilot action -------------------------
87+
# A user cannot create any resource
88+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
89+
await WMSAccessPolicy.policy(
90+
WMS_POLICY_NAME,
91+
normal_user,
92+
action=ActionType.PILOT,
93+
job_db=job_db,
94+
pilot_db=pilot_db,
95+
job_ids=[1, 2],
96+
)
97+
98+
# Split to distinguish the generated part ("403 ") from the message part ("you are not a pilot")
99+
assert str(excinfo.value) == "403: " + "you are not a pilot", excinfo
100+
101+
# ------------------------- Lost pilot -------------------------
102+
async def get_pilot_by_reference_patch(*args):
103+
return []
104+
105+
monkeypatch.setattr(
106+
pilot_db, "get_pilot_by_reference", get_pilot_by_reference_patch
107+
)
108+
109+
# A pilot that has expired (removed from db) should not be able to access jobs
110+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
111+
await WMSAccessPolicy.policy(
112+
WMS_POLICY_NAME,
113+
pilot,
114+
action=ActionType.PILOT,
115+
pilot_db=pilot_db,
116+
job_db=job_db,
117+
job_ids=[1, 2],
118+
)
119+
120+
assert str(excinfo.value) == "403: " + "this pilot is not registered", excinfo
121+
122+
# ------------------------- Pilot accessing wrong jobs -------------------------
123+
async def get_pilot_by_reference_patch(*args, **kwargs):
124+
return {"PilotID": 1}
125+
126+
async def get_pilot_job_ids_patch(*args, **kwargs):
127+
return []
128+
129+
monkeypatch.setattr(
130+
pilot_db, "get_pilot_by_reference", get_pilot_by_reference_patch
131+
)
132+
monkeypatch.setattr(pilot_db, "get_pilot_job_ids", get_pilot_job_ids_patch)
133+
134+
# A pilot that has is not associated with a job can't access a job
135+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
136+
await WMSAccessPolicy.policy(
137+
WMS_POLICY_NAME,
138+
pilot,
139+
action=ActionType.PILOT,
140+
pilot_db=pilot_db,
141+
job_db=job_db,
142+
job_ids=[1, 2],
143+
)
144+
145+
assert (
146+
str(excinfo.value) == "403: " + "this pilot can't access/modify this job"
147+
), excinfo
148+
149+
# ------------------------- Pilot accessing some of his jobs -------------------------
150+
async def get_pilot_job_ids_patch(*args, **kwargs):
151+
return [1, 2, 3, 4]
152+
153+
monkeypatch.setattr(pilot_db, "get_pilot_job_ids", get_pilot_job_ids_patch)
154+
155+
# A pilot that is associated with a job can access a job
156+
await WMSAccessPolicy.policy(
157+
WMS_POLICY_NAME,
158+
pilot,
159+
action=ActionType.PILOT,
160+
pilot_db=pilot_db,
161+
job_db=job_db,
162+
job_ids=[1, 2],
163+
)
164+
165+
# ------------------------- Pilot accessing some of his jobs plus some forbidden -------------------------
166+
async def get_pilot_job_ids_patch(*args, **kwargs):
167+
return [1, 2, 3, 4]
168+
169+
monkeypatch.setattr(pilot_db, "get_pilot_job_ids", get_pilot_job_ids_patch)
170+
171+
# A pilot that fetches few jobs, one where he does not have the rights, and few where he has the rights
172+
with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}") as excinfo:
173+
await WMSAccessPolicy.policy(
174+
WMS_POLICY_NAME,
175+
pilot,
176+
action=ActionType.PILOT,
177+
pilot_db=pilot_db,
178+
job_db=job_db,
179+
job_ids=[1, 2, 12],
180+
)
181+
182+
assert (
183+
str(excinfo.value) == "403: " + "this pilot can't access/modify this job"
184+
), excinfo
185+
186+
71187
async def test_wms_access_policy_create(job_db):
72188

73189
admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload)

diracx-testing/src/diracx/testing/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,29 @@ def unauthenticated(self):
359359
with TestClient(self.app) as client:
360360
yield client
361361

362+
@contextlib.contextmanager
363+
def pilot(self):
364+
from diracx.core.properties import GENERIC_PILOT, LIMITED_DELEGATION
365+
from diracx.routers.auth.token import create_token
366+
367+
with self.unauthenticated() as client:
368+
payload = {
369+
"sub": "testingVO:yellow-sub",
370+
"exp": datetime.now(tz=timezone.utc)
371+
+ timedelta(self.test_auth_settings.access_token_expire_minutes),
372+
"iss": ISSUER,
373+
"dirac_properties": [GENERIC_PILOT, LIMITED_DELEGATION],
374+
"jti": str(uuid4()),
375+
"preferred_username": "preferred_username",
376+
"dirac_group": "test_group",
377+
"vo": "lhcb",
378+
}
379+
token = create_token(payload, self.test_auth_settings)
380+
381+
client.headers["Authorization"] = f"Bearer {token}"
382+
client.dirac_token_payload = payload
383+
yield client
384+
362385
@contextlib.contextmanager
363386
def normal_user(self):
364387
from diracx.core.properties import NORMAL_USER

0 commit comments

Comments
 (0)