4
4
from fastapi import HTTPException , status
5
5
from uuid_utils import uuid7
6
6
7
- from diracx .core .properties import JOB_ADMINISTRATOR , NORMAL_USER
7
+ from diracx .core .properties import GENERIC_PILOT , JOB_ADMINISTRATOR , NORMAL_USER
8
8
from diracx .routers .jobs .access_policies import (
9
9
ActionType ,
10
10
SandboxAccessPolicy ,
@@ -26,6 +26,11 @@ class FakeJobDB:
26
26
async def summary (self , * args ): ...
27
27
28
28
29
+ class FakePilotDB :
30
+ async def get_pilot_by_reference (self , * args ): ...
31
+ async def get_pilot_job_ids (self , * args ): ...
32
+
33
+
29
34
class FakeSBMetadataDB :
30
35
async def get_owner_id (self , * args ): ...
31
36
async def get_sandbox_owner_id (self , * args ): ...
@@ -36,6 +41,11 @@ def job_db():
36
41
yield FakeJobDB ()
37
42
38
43
44
+ @pytest .fixture
45
+ def pilot_db ():
46
+ yield FakePilotDB ()
47
+
48
+
39
49
@pytest .fixture
40
50
def sandbox_metadata_db ():
41
51
yield FakeSBMetadataDB ()
@@ -68,6 +78,112 @@ async def test_wms_access_policy_weird_user(job_db):
68
78
)
69
79
70
80
81
+ async def test_wms_access_policy_pilot (job_db , pilot_db , monkeypatch ):
82
+
83
+ normal_user = AuthorizedUserInfo (properties = [NORMAL_USER ], ** base_payload )
84
+ pilot = AuthorizedUserInfo (properties = [GENERIC_PILOT ], ** base_payload )
85
+
86
+ # ------------------------- Simple User accessing a pilot action -------------------------
87
+ # A user cannot create any resource
88
+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
89
+ await WMSAccessPolicy .policy (
90
+ WMS_POLICY_NAME ,
91
+ normal_user ,
92
+ action = ActionType .PILOT ,
93
+ job_db = job_db ,
94
+ pilot_db = pilot_db ,
95
+ job_ids = [1 , 2 ],
96
+ )
97
+
98
+ # Split to distinguish the generated part ("403 ") from the message part ("you are not a pilot")
99
+ assert str (excinfo .value ) == "403: " + "you are not a pilot" , excinfo
100
+
101
+ # ------------------------- Lost pilot -------------------------
102
+ async def get_pilot_by_reference_patch (* args ):
103
+ return []
104
+
105
+ monkeypatch .setattr (
106
+ pilot_db , "get_pilot_by_reference" , get_pilot_by_reference_patch
107
+ )
108
+
109
+ # A pilot that has expired (removed from db) should not be able to access jobs
110
+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
111
+ await WMSAccessPolicy .policy (
112
+ WMS_POLICY_NAME ,
113
+ pilot ,
114
+ action = ActionType .PILOT ,
115
+ pilot_db = pilot_db ,
116
+ job_db = job_db ,
117
+ job_ids = [1 , 2 ],
118
+ )
119
+
120
+ assert str (excinfo .value ) == "403: " + "this pilot is not registered" , excinfo
121
+
122
+ # ------------------------- Pilot accessing wrong jobs -------------------------
123
+ async def get_pilot_by_reference_patch (* args , ** kwargs ):
124
+ return {"PilotID" : 1 }
125
+
126
+ async def get_pilot_job_ids_patch (* args , ** kwargs ):
127
+ return []
128
+
129
+ monkeypatch .setattr (
130
+ pilot_db , "get_pilot_by_reference" , get_pilot_by_reference_patch
131
+ )
132
+ monkeypatch .setattr (pilot_db , "get_pilot_job_ids" , get_pilot_job_ids_patch )
133
+
134
+ # A pilot that has is not associated with a job can't access a job
135
+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
136
+ await WMSAccessPolicy .policy (
137
+ WMS_POLICY_NAME ,
138
+ pilot ,
139
+ action = ActionType .PILOT ,
140
+ pilot_db = pilot_db ,
141
+ job_db = job_db ,
142
+ job_ids = [1 , 2 ],
143
+ )
144
+
145
+ assert (
146
+ str (excinfo .value ) == "403: " + "this pilot can't access/modify this job"
147
+ ), excinfo
148
+
149
+ # ------------------------- Pilot accessing some of his jobs -------------------------
150
+ async def get_pilot_job_ids_patch (* args , ** kwargs ):
151
+ return [1 , 2 , 3 , 4 ]
152
+
153
+ monkeypatch .setattr (pilot_db , "get_pilot_job_ids" , get_pilot_job_ids_patch )
154
+
155
+ # A pilot that is associated with a job can access a job
156
+ await WMSAccessPolicy .policy (
157
+ WMS_POLICY_NAME ,
158
+ pilot ,
159
+ action = ActionType .PILOT ,
160
+ pilot_db = pilot_db ,
161
+ job_db = job_db ,
162
+ job_ids = [1 , 2 ],
163
+ )
164
+
165
+ # ------------------------- Pilot accessing some of his jobs plus some forbidden -------------------------
166
+ async def get_pilot_job_ids_patch (* args , ** kwargs ):
167
+ return [1 , 2 , 3 , 4 ]
168
+
169
+ monkeypatch .setattr (pilot_db , "get_pilot_job_ids" , get_pilot_job_ids_patch )
170
+
171
+ # A pilot that fetches few jobs, one where he does not have the rights, and few where he has the rights
172
+ with pytest .raises (HTTPException , match = f"{ status .HTTP_403_FORBIDDEN } " ) as excinfo :
173
+ await WMSAccessPolicy .policy (
174
+ WMS_POLICY_NAME ,
175
+ pilot ,
176
+ action = ActionType .PILOT ,
177
+ pilot_db = pilot_db ,
178
+ job_db = job_db ,
179
+ job_ids = [1 , 2 , 12 ],
180
+ )
181
+
182
+ assert (
183
+ str (excinfo .value ) == "403: " + "this pilot can't access/modify this job"
184
+ ), excinfo
185
+
186
+
71
187
async def test_wms_access_policy_create (job_db ):
72
188
73
189
admin_user = AuthorizedUserInfo (properties = [JOB_ADMINISTRATOR ], ** base_payload )
0 commit comments