8
8
from sqlalchemy .sql import delete , insert , select , update
9
9
10
10
from diracx .core .exceptions import (
11
- BadPilotCredentialsError ,
12
- BadPilotVOError ,
13
11
CredentialsAlreadyExistError ,
14
12
CredentialsNotFoundError ,
15
13
InvalidQueryError ,
16
14
PilotAlreadyAssociatedWithJobError ,
17
15
PilotJobsNotFoundError ,
18
16
PilotNotFoundError ,
19
17
SecretAlreadyExistsError ,
20
- SecretHasExpiredError ,
21
18
SecretNotFoundError ,
22
19
)
23
20
from diracx .core .models import PilotFieldsMapping , SearchSpec , SortSpec
@@ -96,105 +93,6 @@ async def increment_global_secret_use(
96
93
detail = "This should not happen. Pilot should have a secret, but is not found."
97
94
)
98
95
99
- async def verify_pilot_secret (
100
- self , pilot_stamp : str , pilot_hashed_secret : str
101
- ) -> None :
102
- """Verify that a pilot can login with the given credentials."""
103
- # 1. Get the pilot to secret association
104
- pilots_credentials = await self .get_pilot_credentials_by_stamp ([pilot_stamp ])
105
-
106
- # 2. Get the pilot secret itself
107
- secrets = await self .get_secrets_by_hashed_secrets_bulk ([pilot_hashed_secret ])
108
- secret = secrets [0 ] # Semantic, assured by fetch_records_bulk_or_raises
109
-
110
- matches = [
111
- pilot_credential
112
- for pilot_credential in pilots_credentials
113
- if secret ["SecretID" ] == pilot_credential ["PilotSecretID" ]
114
- ]
115
-
116
- # 3. Compare the secret_id
117
- if len (matches ) == 0 :
118
-
119
- raise BadPilotCredentialsError (
120
- data = {
121
- "pilot_stamp" : pilot_stamp ,
122
- "pilot_hashed_secret" : pilot_hashed_secret ,
123
- "real_hashed_secret" : secret ["HashedSecret" ],
124
- "pilot_secret_id[]" : str (
125
- [
126
- pilot_credential ["PilotSecretID" ]
127
- for pilot_credential in pilots_credentials
128
- ]
129
- ),
130
- "secret_id" : secret ["SecretID" ],
131
- "test" : str (pilots_credentials ),
132
- }
133
- )
134
- elif len (matches ) > 1 :
135
-
136
- raise DBInBadStateError (
137
- detail = "This should not happen. Duplicates in the database."
138
- )
139
- pilot_credentials = matches [0 ] # Semantic
140
-
141
- # 4. Check if the secret is expired
142
- now = datetime .now (tz = timezone .utc )
143
- # Convert the timezone, TODO: Change with #454: https://github.com/DIRACGrid/diracx/pull/454
144
- expiration = secret ["SecretExpirationDate" ].replace (tzinfo = timezone .utc )
145
- if expiration < now :
146
-
147
- try :
148
- await self .delete_secrets_bulk ([secret ["SecretID" ]])
149
- except SecretNotFoundError as e :
150
- await self .conn .rollback ()
151
-
152
- raise DBInBadStateError (
153
- detail = "This should not happen. Pilot should have a secret, but not found."
154
- ) from e
155
-
156
- raise SecretHasExpiredError (
157
- data = {
158
- "pilot_hashed_secret" : pilot_hashed_secret ,
159
- "now" : str (now ),
160
- "expiration_date" : secret ["SecretExpirationDate" ],
161
- }
162
- )
163
-
164
- # 5. Now the pilot is authorized, increment the counters (globally and locally).
165
- try :
166
- # 5.1 Increment the local count
167
- await self .increment_pilot_local_secret_and_last_time_use (
168
- pilot_secret_id = pilot_credentials ["PilotSecretID" ],
169
- pilot_stamp = pilot_credentials ["PilotStamp" ],
170
- )
171
-
172
- # 5.2 Increment the global count
173
- await self .increment_global_secret_use (
174
- secret_id = pilot_credentials ["PilotSecretID" ]
175
- )
176
- except Exception as e : # Generic, to catch it.
177
- # Should NOT happen
178
- # Wrapped in a try/catch to still catch in case of an error in the counters
179
- # Caught and raised here to avoid raising a 4XX error
180
- await self .conn .rollback ()
181
-
182
- raise DBInBadStateError (
183
- detail = "This should not happen. Pilot has credentials, but has a corrupted secret."
184
- ) from e
185
-
186
- # 6. Delete all secrets if its count attained the secret_global_use_count_max
187
- if secret ["SecretGlobalUseCountMax" ]:
188
- if secret ["SecretGlobalUseCount" ] + 1 == secret ["SecretGlobalUseCountMax" ]:
189
- try :
190
- await self .delete_secrets_bulk ([secret ["SecretID" ]])
191
- except SecretNotFoundError as e :
192
- # Should NOT happen
193
- await self .conn .rollback ()
194
- raise DBInBadStateError (
195
- detail = "This should not happen. Pilot has credentials, but has corrupted secret."
196
- ) from e
197
-
198
96
async def add_pilots_bulk (
199
97
self ,
200
98
pilot_stamps : list [str ],
@@ -295,11 +193,6 @@ async def associate_pilots_with_secrets_bulk(
295
193
"""Bulk associate pilots with secrets. Raises an error in case of a Integrity violation."""
296
194
# Better to give as a parameter pilot to secret associations, rather than associating here.
297
195
298
- # First verify that pilots can access a certain secret
299
- await self .verify_that_pilot_can_access_secret_bulk (
300
- pilot_to_secret_id_mapping_values
301
- )
302
-
303
196
stmt = insert (PilotToSecretMapping ).values (pilot_to_secret_id_mapping_values )
304
197
305
198
try :
@@ -324,48 +217,28 @@ async def associate_pilots_with_secrets_bulk(
324
217
) from e
325
218
raise NotImplementedError (f"This error is not caught: { str (e .orig )} " ) from e
326
219
327
- async def associate_pilot_with_jobs (self , pilot_stamp : str , job_ids : list [int ]):
220
+ async def associate_pilot_with_jobs (self , job_to_pilot_mapping : list [dict ]):
328
221
"""Associate a pilot with jobs. Raises an error if the pilot does not exist and in case of a IntegrityError.
329
222
330
223
**Important note**: We don't verify if a job exists in the JobDB
331
224
"""
332
- pilot_ids = await self .get_pilot_ids_by_stamps ([pilot_stamp ])
333
- # Semantic assured by fetch_records_bulk_or_raises
334
- pilot_id = pilot_ids [0 ]
335
-
336
- now = datetime .now (tz = timezone .utc )
337
-
338
- # Prepare the list of dictionaries for bulk insertion
339
- values = [
340
- {"PilotID" : pilot_id , "JobID" : job_id , "StartTime" : now }
341
- for job_id in job_ids
342
- ]
343
-
344
225
# Insert multiple rows in a single execute call
345
- stmt = insert (JobToPilotMapping ).values (values )
226
+ stmt = insert (JobToPilotMapping ).values (job_to_pilot_mapping )
346
227
347
228
try :
348
229
res = await self .conn .execute (stmt )
349
230
except IntegrityError as e :
350
231
raise PilotAlreadyAssociatedWithJobError (
351
- data = {"pilot_stamp " : pilot_stamp , "job_ids" : str (job_ids )}
232
+ data = {"job_to_pilot_mapping " : str (job_to_pilot_mapping )}
352
233
) from e
353
234
354
- if res .rowcount != len (job_ids ):
235
+ if res .rowcount != len (job_to_pilot_mapping ):
355
236
# If doubles
356
237
await self .conn .rollback ()
357
238
raise PilotJobsNotFoundError (
358
- data = {"pilot_stamp " : pilot_stamp , "job_ids" : str (job_ids )}
239
+ data = {"job_to_pilot_mapping " : str (job_to_pilot_mapping )}
359
240
)
360
241
361
- async def get_pilot_jobs_ids_by_stamp (self , pilot_stamp : str ) -> list [int ]:
362
- """Fetch pilot jobs by stamp."""
363
- pilot_ids = await self .get_pilot_ids_by_stamps ([pilot_stamp ])
364
- # Semantic assured by fetch_records_bulk_or_raises
365
- pilot_id = pilot_ids [0 ]
366
-
367
- return await self .get_pilot_jobs_ids_by_pilot_id (pilot_id )
368
-
369
242
async def update_pilot_fields_bulk (
370
243
self , pilot_stamps_to_fields_mapping : list [PilotFieldsMapping ]
371
244
):
@@ -406,52 +279,6 @@ async def update_pilot_fields_bulk(
406
279
407
280
await self .conn .commit ()
408
281
409
- async def verify_that_pilot_can_access_secret_bulk (
410
- self , pilot_to_secret_id_mapping_values : list [dict [str , Any ]]
411
- ):
412
- # 1. Extract unique pilot_stamps and secret_ids
413
- pilot_stamps = [
414
- entry ["PilotStamp" ] for entry in pilot_to_secret_id_mapping_values
415
- ]
416
- secret_ids = [
417
- entry ["PilotSecretID" ] for entry in pilot_to_secret_id_mapping_values
418
- ]
419
-
420
- # 2. Bulk fetch pilot and secret info
421
- pilots = await self .get_pilots_by_stamp_bulk (pilot_stamps )
422
- secrets = await self .get_secrets_by_secret_ids_bulk (secret_ids )
423
-
424
- # 3. Build lookup maps
425
- pilot_vo_map = {pilot ["PilotStamp" ]: pilot ["VO" ] for pilot in pilots }
426
- secret_vo_map = {secret ["SecretID" ]: secret ["SecretVO" ] for secret in secrets }
427
-
428
- # 4. Validate access
429
- bad_mapping = []
430
-
431
- for mapping in pilot_to_secret_id_mapping_values :
432
- pilot_stamp = mapping ["PilotStamp" ]
433
- secret_id = mapping ["PilotSecretID" ]
434
-
435
- pilot_vo = pilot_vo_map [pilot_stamp ]
436
- secret_vo = secret_vo_map [secret_id ]
437
-
438
- # If secret_vo is set to NULL, everybody can access it
439
- if not secret_vo :
440
- continue
441
-
442
- # Access allowed only if VOs match or secret_vo is open (None)
443
- if secret_vo is not None and pilot_vo != secret_vo :
444
- bad_mapping .append (
445
- {
446
- "pilot_stamp" : pilot_stamp ,
447
- "given_vo" : pilot_vo ,
448
- "expected_vo" : secret_vo ,
449
- }
450
- )
451
-
452
- if bad_mapping :
453
- raise BadPilotVOError (data = {"bad_mapping" : str (bad_mapping )})
454
-
455
282
async def set_secret_expirations_bulk (
456
283
self , secret_ids : list [int ], pilot_secret_expiration_dates : list [DateTime ]
457
284
):
0 commit comments