1
1
from __future__ import annotations
2
2
3
3
from datetime import datetime , timezone
4
- from typing import Any , Sequence
4
+ from typing import Any
5
5
6
- from sqlalchemy import RowMapping , bindparam , func
6
+ from sqlalchemy import bindparam
7
7
from sqlalchemy .exc import IntegrityError
8
- from sqlalchemy .sql import delete , insert , select , update
8
+ from sqlalchemy .sql import delete , insert , update
9
9
10
10
from diracx .core .exceptions import (
11
- InvalidQueryError ,
12
11
PilotAlreadyAssociatedWithJobError ,
13
- PilotJobsNotFoundError ,
14
12
PilotNotFoundError ,
15
13
)
16
14
from diracx .core .models import (
21
19
22
20
from ..utils import (
23
21
BaseSQLDB ,
24
- _get_columns ,
25
- apply_search_filters ,
26
- apply_sort_constraints ,
27
- fetch_records_bulk_or_raises ,
28
22
)
29
23
from .schema import (
30
24
JobToPilotMapping ,
@@ -43,7 +37,7 @@ async def add_pilots_bulk(
43
37
pilot_stamps : list [str ],
44
38
vo : str ,
45
39
grid_type : str = "DIRAC" ,
46
- pilot_references : dict | None = None ,
40
+ pilot_references : dict [ str , str ] | None = None ,
47
41
):
48
42
"""Bulk add pilots in the DB.
49
43
@@ -85,7 +79,9 @@ async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]):
85
79
if res .rowcount != len (pilot_stamps ):
86
80
raise PilotNotFoundError (data = {"pilot_stamps" : str (pilot_stamps )})
87
81
88
- async def associate_pilot_with_jobs (self , job_to_pilot_mapping : list [dict ]):
82
+ async def associate_pilot_with_jobs (
83
+ self , job_to_pilot_mapping : list [dict [str , Any ]]
84
+ ):
89
85
"""Associate a pilot with jobs.
90
86
91
87
job_to_pilot_mapping format:
@@ -182,61 +178,28 @@ async def update_pilot_fields_bulk(
182
178
data = {"mapping" : str (pilot_stamps_to_fields_mapping )}
183
179
)
184
180
185
- async def get_pilots_by_stamp_bulk (
186
- self , pilot_stamps : list [str ]
187
- ) -> Sequence [RowMapping ]:
188
- """Bulk fetch pilots.
189
-
190
- Raises PilotNotFoundError if one of the stamp is not associated with a pilot.
191
-
192
- """
193
- results = await fetch_records_bulk_or_raises (
194
- self .conn ,
195
- PilotAgents ,
196
- PilotNotFoundError ,
197
- "pilot_stamp" ,
198
- "PilotStamp" ,
199
- pilot_stamps ,
200
- allow_no_result = True ,
201
- )
202
-
203
- # Custom handling, to see which pilot_stamp does not exist (if so, say which one)
204
- found_keys = {row ["PilotStamp" ] for row in results }
205
- missing = set (pilot_stamps ) - found_keys
206
-
207
- if missing :
208
- raise PilotNotFoundError (
209
- data = {"pilot_stamp" : str (missing )},
210
- detail = str (missing ),
211
- non_existing_pilots = missing ,
212
- )
213
-
214
- return results
215
-
216
- async def get_pilot_jobs_ids_by_pilot_id (self , pilot_id : int ) -> list [int ]:
217
- """Fetch pilot jobs."""
218
- job_to_pilot_mapping = await fetch_records_bulk_or_raises (
219
- self .conn ,
220
- JobToPilotMapping ,
221
- PilotJobsNotFoundError ,
222
- "pilot_id" ,
223
- "PilotID" ,
224
- [pilot_id ],
225
- allow_more_than_one_result_per_input = True ,
226
- allow_no_result = True ,
181
+ async def search_pilots (
182
+ self ,
183
+ parameters : list [str ] | None ,
184
+ search : list [SearchSpec ],
185
+ sorts : list [SortSpec ],
186
+ * ,
187
+ distinct : bool = False ,
188
+ per_page : int = 100 ,
189
+ page : int | None = None ,
190
+ ) -> tuple [int , list [dict [Any , Any ]]]:
191
+ """Search for pilots in the database."""
192
+ return await self .search (
193
+ model = PilotAgents ,
194
+ parameters = parameters ,
195
+ search = search ,
196
+ sorts = sorts ,
197
+ distinct = distinct ,
198
+ per_page = per_page ,
199
+ page = page ,
227
200
)
228
201
229
- return [mapping ["JobID" ] for mapping in job_to_pilot_mapping ]
230
-
231
- async def get_pilot_ids_by_stamps (self , pilot_stamps : list [str ]) -> list [int ]:
232
- """Get pilot ids."""
233
- # This function is currently needed while we are relying on pilot_ids instead of pilot_stamps
234
- # (Ex: JobToPilotMapping)
235
- pilots = await self .get_pilots_by_stamp_bulk (pilot_stamps )
236
-
237
- return [pilot ["PilotID" ] for pilot in pilots ]
238
-
239
- async def search (
202
+ async def search_pilot_to_job_mapping (
240
203
self ,
241
204
parameters : list [str ] | None ,
242
205
search : list [SearchSpec ],
@@ -247,39 +210,15 @@ async def search(
247
210
page : int | None = None ,
248
211
) -> tuple [int , list [dict [Any , Any ]]]:
249
212
"""Search for pilots in the database."""
250
- # TODO: Refactorize with the search function for jobs.
251
- # Find which columns to select
252
- columns = _get_columns ( PilotAgents . __table__ , parameters )
253
-
254
- stmt = select ( * columns )
255
-
256
- stmt = apply_search_filters (
257
- PilotAgents . __table__ . columns . __getitem__ , stmt , search
213
+ return await self . search (
214
+ model = JobToPilotMapping ,
215
+ parameters = parameters ,
216
+ search = search ,
217
+ sorts = sorts ,
218
+ distinct = distinct ,
219
+ per_page = per_page ,
220
+ page = page ,
258
221
)
259
- stmt = apply_sort_constraints (
260
- PilotAgents .__table__ .columns .__getitem__ , stmt , sorts
261
- )
262
-
263
- if distinct :
264
- stmt = stmt .distinct ()
265
-
266
- # Calculate total count before applying pagination
267
- total_count_subquery = stmt .alias ()
268
- total_count_stmt = select (func .count ()).select_from (total_count_subquery )
269
- total = (await self .conn .execute (total_count_stmt )).scalar_one ()
270
-
271
- # Apply pagination
272
- if page is not None :
273
- if page < 1 :
274
- raise InvalidQueryError ("Page must be a positive integer" )
275
- if per_page < 1 :
276
- raise InvalidQueryError ("Per page must be a positive integer" )
277
- stmt = stmt .offset ((page - 1 ) * per_page ).limit (per_page )
278
-
279
- # Execute the query
280
- return total , [
281
- dict (row ._mapping ) async for row in (await self .conn .stream (stmt ))
282
- ]
283
222
284
223
async def clear_pilots_bulk (
285
224
self , cutoff_date : datetime , delete_only_aborted : bool
0 commit comments