Skip to content

Commit b39eff3

Browse files
fix: Removed garbage code, and improving tests
1 parent b572e22 commit b39eff3

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,6 @@ async def add_pilot_references(
5454
await self.conn.execute(stmt)
5555
return
5656

57-
async def register_pilot(self, pilot_reference: str, pilot_secret: str):
58-
hashed_secret = hash(pilot_secret)
59-
60-
stmt = (
61-
select(PilotRegistrations)
62-
.join(PilotAgents, PilotRegistrations.pilot_id == PilotAgents.pilot_id)
63-
.where(PilotRegistrations.pilot_hashed_secret == hashed_secret)
64-
.where(PilotAgents.pilot_job_reference == pilot_reference)
65-
)
66-
67-
# Execute the request
68-
res = await self.conn.execute(stmt)
69-
70-
result = res.fetchone()
71-
72-
if result is None:
73-
raise AuthorizationError(detail="bad pilot_reference / pilot_secret")
74-
75-
# Increment the count
76-
await self.increment_pilot_secret_use(pilot_reference=pilot_reference)
77-
return
78-
7957
async def increment_pilot_secret_use(
8058
self,
8159
pilot_reference: str,
@@ -200,9 +178,9 @@ async def get_pilots(self):
200178

201179
async def get_pilot_by_id(self, pilot_id: int):
202180
stmt = (
203-
select(PilotRegistrations)
181+
select(PilotAgents)
204182
.with_for_update()
205-
.where(PilotRegistrations.pilot_id == pilot_id)
183+
.where(PilotAgents.pilot_id == pilot_id)
206184
)
207185

208186
return dict((await self.conn.execute(stmt)).one()._mapping)

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

33
import pytest
4+
from sqlalchemy.exc import NoResultFound
45

6+
from diracx.core.exceptions import AuthorizationError
57
from diracx.db.sql.pilot_agents.db import PilotAgentsDB
68

79

@@ -41,6 +43,9 @@ async def test_insert_and_select_single(pilot_agents_db: PilotAgentsDB):
4143

4244
res = await pilot_agents_db.get_pilot_by_id(new_pilot_id)
4345

46+
with pytest.raises(NoResultFound):
47+
await pilot_agents_db.get_pilot_by_id(10)
48+
4449
# Set values
4550
assert res["PilotID"] == new_pilot_id
4651
assert res["VO"] == "pilot-vo"
@@ -50,3 +55,28 @@ async def test_insert_and_select_single(pilot_agents_db: PilotAgentsDB):
5055
assert res["PilotStamp"] == ""
5156
assert res["BenchMark"] == 0.0
5257
assert res["Status"] == "Unknown"
58+
59+
60+
async def test_create_pilot_and_verify_secret(pilot_agents_db: PilotAgentsDB):
61+
62+
async with pilot_agents_db as pilot_agents_db:
63+
# Add a pilot reference
64+
pilot_ref = "pilot-ref"
65+
66+
new_pilot_id = await pilot_agents_db.register_new_pilot(
67+
vo="pilot-vo", pilot_job_reference=pilot_ref
68+
)
69+
70+
# Add creds
71+
secret = await pilot_agents_db.add_pilot_credentials(new_pilot_id)
72+
73+
assert secret is not None
74+
75+
await pilot_agents_db.verify_pilot_secret(
76+
pilot_reference=pilot_ref, pilot_secret=secret
77+
)
78+
79+
with pytest.raises(AuthorizationError):
80+
await pilot_agents_db.verify_pilot_secret(
81+
pilot_reference=pilot_ref, pilot_secret="I love stawberries :)"
82+
)

0 commit comments

Comments
 (0)