Skip to content

Commit 1c0db02

Browse files
test: Better testing: Tested search engine
1 parent cd1caf4 commit 1c0db02

File tree

2 files changed

+385
-0
lines changed

2 files changed

+385
-0
lines changed

diracx-routers/tests/pilots/test_pilot_creation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,19 @@ async def test_create_secrets_and_login(normal_test_client):
172172

173173
assert r.status_code == 200, r.json()
174174

175+
# -------------- Associate pilot with bad secrets --------------
176+
177+
body = {"pilot_stamps": pilot_stamps, "pilot_secrets": ["bad_secret"]}
178+
179+
r = normal_test_client.patch(
180+
"/api/pilots/fields/secrets",
181+
json=body,
182+
headers={"Content-Type": "application/json"},
183+
)
184+
185+
assert r.status_code == 400
186+
assert r.json()["detail"] == "one of the secrets does not exist"
187+
175188
# -------------- Associate pilot with secrets --------------
176189

177190
body = {"pilot_stamps": pilot_stamps, "pilot_secrets": secrets}
Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
"""Inspired by pilots and jobs db search tests."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
7+
from diracx.core.exceptions import InvalidQueryError
8+
from diracx.core.models import (
9+
PilotFieldsMapping,
10+
PilotStatus,
11+
ScalarSearchOperator,
12+
ScalarSearchSpec,
13+
SortDirection,
14+
SortSpec,
15+
VectorSearchOperator,
16+
VectorSearchSpec,
17+
)
18+
19+
pytestmark = pytest.mark.enabled_dependencies(
20+
[
21+
"AuthSettings",
22+
"ConfigSource",
23+
"DevelopmentSettings",
24+
"PilotAgentsDB",
25+
"PilotManagementAccessPolicy",
26+
]
27+
)
28+
29+
30+
@pytest.fixture
31+
def normal_test_client(client_factory):
32+
with client_factory.normal_user() as client:
33+
yield client
34+
35+
36+
MAIN_VO = "lhcb"
37+
N = 100
38+
39+
PILOT_REASONS = [
40+
"I was sick",
41+
"I can't, I have a pony.",
42+
"I was shopping",
43+
"I was sleeping",
44+
]
45+
46+
PILOT_STATUSES = list(PilotStatus)
47+
48+
49+
@pytest.fixture
50+
async def populated_pilot_client(normal_test_client):
51+
pilot_stamps = [f"stamp_{i}" for i in range(1, N + 1)]
52+
53+
# -------------- Bulk insert --------------
54+
body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps}
55+
56+
r = normal_test_client.post(
57+
"/api/pilots/",
58+
json=body,
59+
)
60+
61+
assert r.status_code == 200, r.json()
62+
63+
body = {
64+
"pilot_stamps_to_fields_mapping": [
65+
PilotFieldsMapping(
66+
PilotStamp=pilot_stamp,
67+
BenchMark=i**2,
68+
StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)],
69+
AccountingSent=True,
70+
Status=PILOT_STATUSES[i % len(PILOT_STATUSES)],
71+
CurrentJobID=i,
72+
Queue=f"queue_{i}",
73+
).model_dump(exclude_unset=True)
74+
for i, pilot_stamp in enumerate(pilot_stamps)
75+
]
76+
}
77+
78+
r = normal_test_client.patch("/api/pilots/fields", json=body)
79+
80+
assert r.status_code == 204
81+
82+
yield normal_test_client
83+
84+
85+
@pytest.fixture
86+
async def search(populated_pilot_client):
87+
async def _search(
88+
parameters, conditions, sorts, distinct=False, page=1, per_page=100
89+
):
90+
91+
body = {
92+
"parameters": parameters,
93+
"search": conditions,
94+
"sort": sorts,
95+
"distinct": distinct,
96+
}
97+
98+
params = {"per_page": per_page, "page": page}
99+
100+
r = populated_pilot_client.post("/api/pilots/search", json=body, params=params)
101+
102+
if r.status_code == 400:
103+
# If we have a status_code 400, that means that the query failed
104+
raise InvalidQueryError()
105+
106+
return r.json(), r.headers
107+
108+
return _search
109+
110+
111+
async def test_search_parameters(search):
112+
"""Test that we can search specific parameters for pilots."""
113+
# Search a specific parameter: PilotID
114+
result, headers = await search(["PilotID"], [], [])
115+
assert len(result) == N
116+
assert result
117+
for r in result:
118+
assert r.keys() == {"PilotID"}
119+
assert "Content-Range" not in headers
120+
121+
# Search a specific parameter: Status
122+
result, headers = await search(["Status"], [], [])
123+
assert len(result) == N
124+
assert result
125+
for r in result:
126+
assert r.keys() == {"Status"}
127+
assert "Content-Range" not in headers
128+
129+
# Search for multiple parameters: PilotID, Status
130+
result, headers = await search(["PilotID", "Status"], [], [])
131+
assert len(result) == N
132+
assert result
133+
for r in result:
134+
assert r.keys() == {"PilotID", "Status"}
135+
assert "Content-Range" not in headers
136+
137+
# Search for a specific parameter but use distinct: Status
138+
result, headers = await search(["Status"], [], [], distinct=True)
139+
assert len(result) == len(PILOT_STATUSES)
140+
assert result
141+
assert "Content-Range" not in headers
142+
143+
# Search for a non-existent parameter: Dummy
144+
with pytest.raises(InvalidQueryError):
145+
result, headers = await search(["Dummy"], [], [])
146+
147+
148+
async def test_search_conditions(search):
149+
"""Test that we can search for specific pilots."""
150+
# Search a specific scalar condition: PilotID eq 3
151+
condition = ScalarSearchSpec(
152+
parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3
153+
)
154+
result, headers = await search([], [condition], [])
155+
assert len(result) == 1
156+
assert result
157+
assert len(result) == 1
158+
assert result[0]["PilotID"] == 3
159+
assert "Content-Range" not in headers
160+
161+
# Search a specific scalar condition: PilotID lt 3
162+
condition = ScalarSearchSpec(
163+
parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3
164+
)
165+
result, headers = await search([], [condition], [])
166+
assert len(result) == 2
167+
assert result
168+
assert len(result) == 2
169+
assert result[0]["PilotID"] == 1
170+
assert result[1]["PilotID"] == 2
171+
assert "Content-Range" not in headers
172+
173+
# Search a specific scalar condition: PilotID neq 3
174+
condition = ScalarSearchSpec(
175+
parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3
176+
)
177+
result, headers = await search([], [condition], [])
178+
assert len(result) == 99
179+
assert result
180+
assert len(result) == 99
181+
assert all(r["PilotID"] != 3 for r in result)
182+
assert "Content-Range" not in headers
183+
184+
# Search a specific scalar condition: PilotID eq 5873 (does not exist)
185+
condition = ScalarSearchSpec(
186+
parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873
187+
)
188+
result, headers = await search([], [condition], [])
189+
assert not result
190+
assert "Content-Range" not in headers
191+
192+
# Search a specific vector condition: PilotID in 1,2,3
193+
condition = VectorSearchSpec(
194+
parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3]
195+
)
196+
result, headers = await search([], [condition], [])
197+
assert len(result) == 3
198+
assert result
199+
assert len(result) == 3
200+
assert all(r["PilotID"] in [1, 2, 3] for r in result)
201+
assert "Content-Range" not in headers
202+
203+
# Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist)
204+
condition = VectorSearchSpec(
205+
parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873]
206+
)
207+
result, headers = await search([], [condition], [])
208+
assert len(result) == 2
209+
assert result
210+
assert len(result) == 2
211+
assert all(r["PilotID"] in [1, 2] for r in result)
212+
assert "Content-Range" not in headers
213+
214+
# Search a specific vector condition: PilotID not in 1,2,3
215+
condition = VectorSearchSpec(
216+
parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3]
217+
)
218+
result, headers = await search([], [condition], [])
219+
assert len(result) == 97
220+
assert result
221+
assert len(result) == 97
222+
assert all(r["PilotID"] not in [1, 2, 3] for r in result)
223+
assert "Content-Range" not in headers
224+
225+
# Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist)
226+
condition = VectorSearchSpec(
227+
parameter="PilotID",
228+
operator=VectorSearchOperator.NOT_IN,
229+
values=[1, 2, 5873],
230+
)
231+
result, headers = await search([], [condition], [])
232+
assert len(result) == 98
233+
assert result
234+
assert len(result) == 98
235+
assert all(r["PilotID"] not in [1, 2] for r in result)
236+
assert "Content-Range" not in headers
237+
238+
# Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6
239+
condition1 = ScalarSearchSpec(
240+
parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5"
241+
)
242+
condition2 = VectorSearchSpec(
243+
parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6]
244+
)
245+
result, headers = await search([], [condition1, condition2], [])
246+
247+
assert result
248+
assert len(result) == 1
249+
assert result[0]["PilotID"] == 5
250+
assert result[0]["PilotStamp"] == "stamp_5"
251+
assert "Content-Range" not in headers
252+
253+
# Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6
254+
condition1 = ScalarSearchSpec(
255+
parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70
256+
)
257+
condition2 = VectorSearchSpec(
258+
parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6]
259+
)
260+
result, headers = await search([], [condition1, condition2], [])
261+
assert len(result) == 0
262+
assert not result
263+
assert "Content-Range" not in headers
264+
265+
266+
async def test_search_sorts(search):
267+
"""Test that we can search for pilots and sort the results."""
268+
# Search and sort by PilotID in ascending order
269+
sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC)
270+
result, headers = await search([], [], [sort])
271+
assert len(result) == N
272+
assert result
273+
for i, r in enumerate(result):
274+
assert r["PilotID"] == i + 1
275+
assert "Content-Range" not in headers
276+
277+
# Search and sort by PilotID in descending order
278+
sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC)
279+
result, headers = await search([], [], [sort])
280+
assert len(result) == N
281+
assert result
282+
for i, r in enumerate(result):
283+
assert r["PilotID"] == N - i
284+
assert "Content-Range" not in headers
285+
286+
# Search and sort by PilotStamp in ascending order
287+
sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC)
288+
result, headers = await search([], [], [sort])
289+
assert len(result) == N
290+
assert result
291+
# Assert that stamp_10 is before stamp_2 because of the lexicographical order
292+
assert result[2]["PilotStamp"] == "stamp_100"
293+
assert result[12]["PilotStamp"] == "stamp_2"
294+
assert "Content-Range" not in headers
295+
296+
# Search and sort by PilotStamp in descending order
297+
sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC)
298+
result, headers = await search([], [], [sort])
299+
assert len(result) == N
300+
assert result
301+
# Assert that stamp_10 is before stamp_2 because of the lexicographical order
302+
assert result[97]["PilotStamp"] == "stamp_100"
303+
assert result[87]["PilotStamp"] == "stamp_2"
304+
assert "Content-Range" not in headers
305+
306+
# Search and sort by PilotStamp in ascending order and PilotID in descending order
307+
sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC)
308+
sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC)
309+
result, headers = await search([], [], [sort1, sort2])
310+
assert len(result) == N
311+
assert result
312+
assert result[0]["PilotStamp"] == "stamp_1"
313+
assert result[0]["PilotID"] == 1
314+
assert result[99]["PilotStamp"] == "stamp_99"
315+
assert result[99]["PilotID"] == 99
316+
assert "Content-Range" not in headers
317+
318+
319+
async def test_search_pagination(search):
320+
"""Test that we can search for pilots."""
321+
# Search for the first 10 pilots
322+
result, headers = await search([], [], [], per_page=10, page=1)
323+
assert "Content-Range" in headers
324+
# Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}"
325+
total = int(headers["Content-Range"].split("/")[1])
326+
assert total == N
327+
assert result
328+
assert len(result) == 10
329+
assert result[0]["PilotID"] == 1
330+
331+
# Search for the second 10 pilots
332+
result, headers = await search([], [], [], per_page=10, page=2)
333+
assert "Content-Range" in headers
334+
# Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}"
335+
total = int(headers["Content-Range"].split("/")[1])
336+
assert total == N
337+
assert result
338+
assert len(result) == 10
339+
assert result[0]["PilotID"] == 11
340+
341+
# Search for the last 10 pilots
342+
result, headers = await search([], [], [], per_page=10, page=10)
343+
assert "Content-Range" in headers
344+
# Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}"
345+
total = int(headers["Content-Range"].split("/")[1])
346+
assert result
347+
assert len(result) == 10
348+
assert result[0]["PilotID"] == 91
349+
350+
# Search for the second 50 pilots
351+
result, headers = await search([], [], [], per_page=50, page=2)
352+
assert "Content-Range" in headers
353+
# Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}"
354+
total = int(headers["Content-Range"].split("/")[1])
355+
assert result
356+
assert len(result) == 50
357+
assert result[0]["PilotID"] == 51
358+
359+
# Invalid page number
360+
result, headers = await search([], [], [], per_page=10, page=11)
361+
assert "Content-Range" in headers
362+
# Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}"
363+
total = int(headers["Content-Range"].split("/")[1])
364+
assert not result
365+
366+
# Invalid page number
367+
with pytest.raises(InvalidQueryError):
368+
result = await search([], [], [], per_page=10, page=0)
369+
370+
# Invalid per_page number
371+
with pytest.raises(InvalidQueryError):
372+
result = await search([], [], [], per_page=0, page=1)

0 commit comments

Comments
 (0)