Skip to content

Commit e74fe72

Browse files
test: Improving pilot auth tests
1 parent 6dc1ad0 commit e74fe72

File tree

2 files changed

+135
-12
lines changed

2 files changed

+135
-12
lines changed

diracx-routers/src/diracx/routers/auth/pilot_auth.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,26 @@ async def pilot_login(
4242
"sub": pilot["PilotJobReference"],
4343
}
4444

45-
# return pilot_info
46-
access_token, refresh_token = await exchange_token(
47-
auth_db=auth_db,
48-
scope="vo:diracAdmin",
49-
oidc_token_info=pilot_info,
50-
config=config,
51-
settings=settings,
52-
available_properties=available_properties,
53-
pilot_exchange=True,
54-
)
55-
56-
return [create_token(access_token, settings), create_token(refresh_token, settings)]
45+
try:
46+
access_token, refresh_token = await exchange_token(
47+
auth_db=auth_db,
48+
scope=generate_pilot_scope(pilot),
49+
oidc_token_info=pilot_info,
50+
config=config,
51+
settings=settings,
52+
available_properties=available_properties,
53+
pilot_exchange=True,
54+
)
55+
except ValueError as e:
56+
raise HTTPException(
57+
status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)
58+
) from e
59+
60+
return {
61+
"access_token": create_token(access_token, settings),
62+
"refresh_token": create_token(refresh_token, settings),
63+
}
64+
65+
66+
def generate_pilot_scope(pilot: dict) -> str:
67+
return f"vo:{pilot['VO']}"
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from diracx.db.sql.pilot_agents.db import PilotAgentsDB
6+
7+
pytestmark = pytest.mark.enabled_dependencies(
8+
[
9+
"DevelopmentSettings",
10+
"AuthDB",
11+
"AuthSettings",
12+
"ConfigSource",
13+
"BaseAccessPolicy",
14+
"PilotAgentsDB",
15+
"RegisteredPilotAccessPolicy",
16+
]
17+
)
18+
19+
20+
@pytest.fixture
21+
def test_client(client_factory):
22+
with client_factory.unauthenticated() as client:
23+
yield client
24+
25+
26+
@pytest.fixture
27+
def non_mocked_hosts(test_client) -> list[str]:
28+
return [test_client.base_url.host]
29+
30+
31+
async def test_create_pilot_and_verify_secret(test_client):
32+
33+
# see https://github.com/DIRACGrid/diracx/blob/78e00aa57f4191034dbf643c7ed2857a93b53f60/diracx-routers/tests/pilots/test_pilot_logger.py#L37
34+
db = test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0]
35+
36+
# Add a pilot vo
37+
pilot_vo = "lhcb"
38+
# Add a pilot reference
39+
pilot_ref = "pilot-ref"
40+
41+
async with db as pilot_agents_db:
42+
# Register a pilot
43+
pilot_id = await pilot_agents_db.register_new_pilot(
44+
vo=pilot_vo, pilot_job_reference=pilot_ref
45+
)
46+
47+
# Add credentials to this pilot
48+
secret = await pilot_agents_db.add_pilot_credentials(pilot_id=pilot_id)
49+
50+
assert secret is not None
51+
52+
request_data = {"pilot_id": pilot_id, "pilot_secret": secret}
53+
54+
r = test_client.post(
55+
"/api/auth/pilot-login",
56+
params=request_data,
57+
headers={"Content-Type": "application/json"},
58+
)
59+
60+
assert r.status_code == 200
61+
62+
access_token = r.json()["access_token"]
63+
refresh_token = r.json()["refresh_token"]
64+
65+
assert access_token is not None
66+
assert refresh_token is not None
67+
68+
# ----------------- Get pilot info without permissions -----------------
69+
r = test_client.get(
70+
"/api/pilots/info",
71+
)
72+
73+
assert r.status_code == 401
74+
75+
# ----------------- Get pilot info with access_token -----------------
76+
r = test_client.get(
77+
"/api/pilots/info", headers={"Authorization": f"Bearer {access_token}"}
78+
)
79+
80+
assert r.status_code == 200
81+
82+
# ----------------- Get pilot info with wrong access_token -----------------
83+
r = test_client.get(
84+
"/api/pilots/info", headers={"Authorization": "Bearer 4dm1n B34r3r"}
85+
)
86+
87+
assert r.status_code == 401, r.json()
88+
assert r.json()["detail"] == "Invalid JWT"
89+
90+
# ----------------- Wrong password -----------------
91+
request_data = {"pilot_id": pilot_id, "pilot_secret": "My 1ncr3d1bl3 t0k3n"}
92+
93+
r = test_client.post(
94+
"/api/auth/pilot-login",
95+
params=request_data,
96+
headers={"Content-Type": "application/json"},
97+
)
98+
99+
assert r.status_code == 401, r.json()
100+
assert r.json()["detail"] == "bad pilot_id / pilot_secret"
101+
102+
# ----------------- Wrong ID -----------------
103+
request_data = {"pilot_id": 63000, "pilot_secret": secret}
104+
105+
r = test_client.post(
106+
"/api/auth/pilot-login",
107+
params=request_data,
108+
headers={"Content-Type": "application/json"},
109+
)
110+
111+
assert r.status_code == 401
112+
assert r.json()["detail"] == "bad pilot_id / pilot_secret"

0 commit comments

Comments
 (0)