Skip to content

Commit 65abf0c

Browse files
feat: Fixes and now use of search engine instead of DIY fetch records
1 parent ce03dc2 commit 65abf0c

File tree

15 files changed

+334
-313
lines changed

15 files changed

+334
-313
lines changed
52 KB
Binary file not shown.

diracx-db/src/diracx/db/sql/job/db.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313
from diracx.core.exceptions import InvalidQueryError
1414
from diracx.core.models import JobCommand, SearchSpec, SortSpec
1515

16-
from ..utils import (
17-
BaseSQLDB,
18-
_get_columns,
19-
apply_search_filters,
20-
apply_sort_constraints,
21-
utcnow,
22-
)
16+
from ..utils import BaseSQLDB, _get_columns, apply_search_filters, utcnow
2317
from .schema import (
2418
HeartBeatLoggingInfo,
2519
InputData,
@@ -63,7 +57,7 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]:
6357
if row.count > 0 # type: ignore
6458
]
6559

66-
async def search(
60+
async def search_jobs(
6761
self,
6862
parameters: list[str] | None,
6963
search: list[SearchSpec],
@@ -74,34 +68,15 @@ async def search(
7468
page: int | None = None,
7569
) -> tuple[int, list[dict[Any, Any]]]:
7670
"""Search for jobs in the database."""
77-
# Find which columns to select
78-
columns = _get_columns(Jobs.__table__, parameters)
79-
80-
stmt = select(*columns)
81-
82-
stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
83-
stmt = apply_sort_constraints(Jobs.__table__.columns.__getitem__, stmt, sorts)
84-
85-
if distinct:
86-
stmt = stmt.distinct()
87-
88-
# Calculate total count before applying pagination
89-
total_count_subquery = stmt.alias()
90-
total_count_stmt = select(func.count()).select_from(total_count_subquery)
91-
total = (await self.conn.execute(total_count_stmt)).scalar_one()
92-
93-
# Apply pagination
94-
if page is not None:
95-
if page < 1:
96-
raise InvalidQueryError("Page must be a positive integer")
97-
if per_page < 1:
98-
raise InvalidQueryError("Per page must be a positive integer")
99-
stmt = stmt.offset((page - 1) * per_page).limit(per_page)
100-
101-
# Execute the query
102-
return total, [
103-
dict(row._mapping) async for row in (await self.conn.stream(stmt))
104-
]
71+
return await self.search(
72+
model=Jobs,
73+
parameters=parameters,
74+
search=search,
75+
sorts=sorts,
76+
distinct=distinct,
77+
per_page=per_page,
78+
page=page,
79+
)
10580

10681
async def create_job(self, compressed_original_jdl: str):
10782
"""Used to insert a new job with original JDL. Returns inserted job id."""

diracx-db/src/diracx/db/sql/pilots/db.py

Lines changed: 35 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from __future__ import annotations
22

33
from datetime import datetime, timezone
4-
from typing import Any, Sequence
4+
from typing import Any
55

6-
from sqlalchemy import RowMapping, bindparam, func
6+
from sqlalchemy import bindparam
77
from sqlalchemy.exc import IntegrityError
8-
from sqlalchemy.sql import delete, insert, select, update
8+
from sqlalchemy.sql import delete, insert, update
99

1010
from diracx.core.exceptions import (
11-
InvalidQueryError,
1211
PilotAlreadyAssociatedWithJobError,
13-
PilotJobsNotFoundError,
1412
PilotNotFoundError,
1513
)
1614
from diracx.core.models import (
@@ -21,10 +19,6 @@
2119

2220
from ..utils import (
2321
BaseSQLDB,
24-
_get_columns,
25-
apply_search_filters,
26-
apply_sort_constraints,
27-
fetch_records_bulk_or_raises,
2822
)
2923
from .schema import (
3024
JobToPilotMapping,
@@ -43,7 +37,7 @@ async def add_pilots_bulk(
4337
pilot_stamps: list[str],
4438
vo: str,
4539
grid_type: str = "DIRAC",
46-
pilot_references: dict | None = None,
40+
pilot_references: dict[str, str] | None = None,
4741
):
4842
"""Bulk add pilots in the DB.
4943
@@ -85,7 +79,9 @@ async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]):
8579
if res.rowcount != len(pilot_stamps):
8680
raise PilotNotFoundError(data={"pilot_stamps": str(pilot_stamps)})
8781

88-
async def associate_pilot_with_jobs(self, job_to_pilot_mapping: list[dict]):
82+
async def associate_pilot_with_jobs(
83+
self, job_to_pilot_mapping: list[dict[str, Any]]
84+
):
8985
"""Associate a pilot with jobs.
9086
9187
job_to_pilot_mapping format:
@@ -182,61 +178,28 @@ async def update_pilot_fields_bulk(
182178
data={"mapping": str(pilot_stamps_to_fields_mapping)}
183179
)
184180

185-
async def get_pilots_by_stamp_bulk(
186-
self, pilot_stamps: list[str]
187-
) -> Sequence[RowMapping]:
188-
"""Bulk fetch pilots.
189-
190-
Raises PilotNotFoundError if one of the stamp is not associated with a pilot.
191-
192-
"""
193-
results = await fetch_records_bulk_or_raises(
194-
self.conn,
195-
PilotAgents,
196-
PilotNotFoundError,
197-
"pilot_stamp",
198-
"PilotStamp",
199-
pilot_stamps,
200-
allow_no_result=True,
201-
)
202-
203-
# Custom handling, to see which pilot_stamp does not exist (if so, say which one)
204-
found_keys = {row["PilotStamp"] for row in results}
205-
missing = set(pilot_stamps) - found_keys
206-
207-
if missing:
208-
raise PilotNotFoundError(
209-
data={"pilot_stamp": str(missing)},
210-
detail=str(missing),
211-
non_existing_pilots=missing,
212-
)
213-
214-
return results
215-
216-
async def get_pilot_jobs_ids_by_pilot_id(self, pilot_id: int) -> list[int]:
217-
"""Fetch pilot jobs."""
218-
job_to_pilot_mapping = await fetch_records_bulk_or_raises(
219-
self.conn,
220-
JobToPilotMapping,
221-
PilotJobsNotFoundError,
222-
"pilot_id",
223-
"PilotID",
224-
[pilot_id],
225-
allow_more_than_one_result_per_input=True,
226-
allow_no_result=True,
181+
async def search_pilots(
182+
self,
183+
parameters: list[str] | None,
184+
search: list[SearchSpec],
185+
sorts: list[SortSpec],
186+
*,
187+
distinct: bool = False,
188+
per_page: int = 100,
189+
page: int | None = None,
190+
) -> tuple[int, list[dict[Any, Any]]]:
191+
"""Search for pilots in the database."""
192+
return await self.search(
193+
model=PilotAgents,
194+
parameters=parameters,
195+
search=search,
196+
sorts=sorts,
197+
distinct=distinct,
198+
per_page=per_page,
199+
page=page,
227200
)
228201

229-
return [mapping["JobID"] for mapping in job_to_pilot_mapping]
230-
231-
async def get_pilot_ids_by_stamps(self, pilot_stamps: list[str]) -> list[int]:
232-
"""Get pilot ids."""
233-
# This function is currently needed while we are relying on pilot_ids instead of pilot_stamps
234-
# (Ex: JobToPilotMapping)
235-
pilots = await self.get_pilots_by_stamp_bulk(pilot_stamps)
236-
237-
return [pilot["PilotID"] for pilot in pilots]
238-
239-
async def search(
202+
async def search_pilot_to_job_mapping(
240203
self,
241204
parameters: list[str] | None,
242205
search: list[SearchSpec],
@@ -247,39 +210,15 @@ async def search(
247210
page: int | None = None,
248211
) -> tuple[int, list[dict[Any, Any]]]:
249212
"""Search for pilots in the database."""
250-
# TODO: Refactorize with the search function for jobs.
251-
# Find which columns to select
252-
columns = _get_columns(PilotAgents.__table__, parameters)
253-
254-
stmt = select(*columns)
255-
256-
stmt = apply_search_filters(
257-
PilotAgents.__table__.columns.__getitem__, stmt, search
213+
return await self.search(
214+
model=JobToPilotMapping,
215+
parameters=parameters,
216+
search=search,
217+
sorts=sorts,
218+
distinct=distinct,
219+
per_page=per_page,
220+
page=page,
258221
)
259-
stmt = apply_sort_constraints(
260-
PilotAgents.__table__.columns.__getitem__, stmt, sorts
261-
)
262-
263-
if distinct:
264-
stmt = stmt.distinct()
265-
266-
# Calculate total count before applying pagination
267-
total_count_subquery = stmt.alias()
268-
total_count_stmt = select(func.count()).select_from(total_count_subquery)
269-
total = (await self.conn.execute(total_count_stmt)).scalar_one()
270-
271-
# Apply pagination
272-
if page is not None:
273-
if page < 1:
274-
raise InvalidQueryError("Page must be a positive integer")
275-
if per_page < 1:
276-
raise InvalidQueryError("Per page must be a positive integer")
277-
stmt = stmt.offset((page - 1) * per_page).limit(per_page)
278-
279-
# Execute the query
280-
return total, [
281-
dict(row._mapping) async for row in (await self.conn.stream(stmt))
282-
]
283222

284223
async def clear_pilots_bulk(
285224
self, cutoff_date: datetime, delete_only_aborted: bool

diracx-db/src/diracx/db/sql/utils/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from .base import (
44
BaseSQLDB,
55
SQLDBUnavailableError,
6+
_get_columns,
67
apply_search_filters,
78
apply_sort_constraints,
89
)
910
from .functions import (
10-
_get_columns,
11-
fetch_records_bulk_or_raises,
1211
hash,
1312
substract_date,
1413
utcnow,
@@ -24,7 +23,6 @@
2423
"DateNowColumn",
2524
"EnumBackedBool",
2625
"EnumColumn",
27-
"fetch_records_bulk_or_raises",
2826
"hash",
2927
"NullColumn",
3028
"substract_date",

diracx-db/src/diracx/db/sql/utils/base.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
from collections.abc import AsyncIterator
99
from contextvars import ContextVar
1010
from datetime import datetime
11-
from typing import Self, cast
11+
from typing import Any, Self, cast
1212

1313
from pydantic import TypeAdapter
14-
from sqlalchemy import DateTime, MetaData, select
14+
from sqlalchemy import DateTime, MetaData, func, select
1515
from sqlalchemy.exc import OperationalError
1616
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
1717

1818
from diracx.core.exceptions import InvalidQueryError
1919
from diracx.core.extensions import select_from_extension
20-
from diracx.core.models import SortDirection
20+
from diracx.core.models import SearchSpec, SortDirection, SortSpec
2121
from diracx.core.settings import SqlalchemyDsn
2222
from diracx.db.exceptions import DBUnavailableError
2323

@@ -227,6 +227,47 @@ async def ping(self):
227227
except OperationalError as e:
228228
raise SQLDBUnavailableError("Cannot ping the DB") from e
229229

230+
async def search(
231+
self,
232+
model: Any,
233+
parameters: list[str] | None,
234+
search: list[SearchSpec],
235+
sorts: list[SortSpec],
236+
*,
237+
distinct: bool = False,
238+
per_page: int = 100,
239+
page: int | None = None,
240+
) -> tuple[int, list[dict[Any, Any]]]:
241+
"""Search for pilots in the database."""
242+
# Find which columns to select
243+
columns = _get_columns(model.__table__, parameters)
244+
245+
stmt = select(*columns)
246+
247+
stmt = apply_search_filters(model.__table__.columns.__getitem__, stmt, search)
248+
stmt = apply_sort_constraints(model.__table__.columns.__getitem__, stmt, sorts)
249+
250+
if distinct:
251+
stmt = stmt.distinct()
252+
253+
# Calculate total count before applying pagination
254+
total_count_subquery = stmt.alias()
255+
total_count_stmt = select(func.count()).select_from(total_count_subquery)
256+
total = (await self.conn.execute(total_count_stmt)).scalar_one()
257+
258+
# Apply pagination
259+
if page is not None:
260+
if page < 1:
261+
raise InvalidQueryError("Page must be a positive integer")
262+
if per_page < 1:
263+
raise InvalidQueryError("Per page must be a positive integer")
264+
stmt = stmt.offset((page - 1) * per_page).limit(per_page)
265+
266+
# Execute the query
267+
return total, [
268+
dict(row._mapping) async for row in (await self.conn.stream(stmt))
269+
]
270+
230271

231272
def find_time_resolution(value):
232273
if isinstance(value, datetime):
@@ -258,6 +299,17 @@ def find_time_resolution(value):
258299
raise InvalidQueryError(f"Cannot parse {value=}")
259300

260301

302+
def _get_columns(table, parameters):
303+
columns = [x for x in table.columns]
304+
if parameters:
305+
if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
306+
raise InvalidQueryError(
307+
f"Unrecognised parameters requested {unrecognised_parameters}"
308+
)
309+
columns = [c for c in columns if c.name in parameters]
310+
return columns
311+
312+
261313
def apply_search_filters(column_mapping, stmt, search):
262314
for query in search:
263315
try:

0 commit comments

Comments
 (0)