Skip to content

Commit 90acb3b

Browse files
feat: Adding pilot secret creation
1 parent 2404324 commit 90acb3b

File tree

5 files changed

+243
-21
lines changed

5 files changed

+243
-21
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,38 @@ def __init__(self, job_id, detail: str | None = None):
104104

105105

106106
class PilotNotFoundError(Exception):
107-
def __init__(self, pilot_ref: str, detail: str | None = None):
107+
def __init__(
108+
self,
109+
pilot_ref: str | None = None,
110+
pilot_id: int | None = None,
111+
detail: str | None = None,
112+
):
108113
self.pilot_ref = pilot_ref
114+
self.pilot_id = pilot_id
109115
self.detail = detail
110116
super().__init__(
111-
f"Pilot {pilot_ref} not found" + (": {detail} " if detail else "")
117+
"Pilot "
118+
+ (f"(Ref: {pilot_ref})" if pilot_ref else "")
119+
+ (f" (ID: {str(pilot_id)})" if pilot_id is not None else "")
120+
+ " not found"
121+
+ (f": {detail}" if detail else "")
122+
)
123+
124+
125+
class PilotAlreadyExistsError(Exception):
126+
def __init__(
127+
self,
128+
pilot_ref: str | None = None, # Changed to str based on the format
129+
pilot_id: int | None = None,
130+
detail: str | None = None,
131+
):
132+
self.pilot_ref = pilot_ref
133+
self.pilot_id = pilot_id
134+
self.detail = detail
135+
super().__init__(
136+
"Pilot "
137+
+ (f"(Ref: {pilot_ref})" if pilot_ref else "")
138+
+ (f" (ID: {str(pilot_id)})" if pilot_id is not None else "")
139+
+ " already exists"
140+
+ (f": {detail}" if detail else "")
112141
)

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from __future__ import annotations
22

3-
import hashlib
43
from datetime import datetime, timezone
4+
from os import urandom
55

66
from sqlalchemy import insert, select, update
7+
from sqlalchemy.exc import IntegrityError
78

89
from diracx.core.exceptions import (
910
AuthorizationError,
11+
PilotAlreadyExistsError,
1012
PilotNotFoundError,
1113
)
14+
from diracx.db.sql.utils.functions import hash
1215

1316
from ..utils import BaseSQLDB
1417
from .schema import PilotAgents, PilotAgentsDBBase, PilotRegistrations
@@ -51,6 +54,28 @@ async def add_pilot_references(
5154
await self.conn.execute(stmt)
5255
return
5356

57+
async def register_pilot(self, pilot_reference: str, pilot_secret: str):
58+
hashed_secret = hash(pilot_secret)
59+
60+
stmt = (
61+
select(PilotRegistrations)
62+
.join(PilotAgents, PilotRegistrations.pilot_id == PilotAgents.pilot_id)
63+
.where(PilotRegistrations.pilot_hashed_secret == hashed_secret)
64+
.where(PilotAgents.pilot_job_reference == pilot_reference)
65+
)
66+
67+
# Execute the request
68+
res = await self.conn.execute(stmt)
69+
70+
result = res.fetchone()
71+
72+
if result is None:
73+
raise AuthorizationError(detail="bad pilot_reference / pilot_secret")
74+
75+
# Increment the count
76+
await self.increment_pilot_secret_use(pilot_reference=pilot_reference)
77+
return
78+
5479
async def increment_pilot_secret_use(
5580
self,
5681
pilot_reference: str,
@@ -72,10 +97,10 @@ async def increment_pilot_secret_use(
7297
if res.rowcount == 0:
7398
raise PilotNotFoundError(pilot_ref=pilot_reference)
7499

75-
async def register_pilot(self, pilot_reference: str, pilot_secret: str):
76-
hashed_secret = hashlib.sha256(pilot_secret.encode()).hexdigest()
77-
78-
print("hash", hashed_secret)
100+
async def verify_pilot_secret(
101+
self, pilot_reference: str, pilot_secret: str
102+
) -> None:
103+
hashed_secret = hash(pilot_secret)
79104

80105
stmt = (
81106
select(PilotRegistrations)
@@ -95,3 +120,85 @@ async def register_pilot(self, pilot_reference: str, pilot_secret: str):
95120
# Increment the count
96121
await self.increment_pilot_secret_use(pilot_reference=pilot_reference)
97122
return
123+
124+
async def register_new_pilot(
125+
self,
126+
vo: str,
127+
initial_job_id: int = 0,
128+
current_job_id: int = 0,
129+
benchmark: float = 0.0,
130+
pilot_job_reference: str = "Unknown",
131+
pilot_stamp: str = "",
132+
status: str = "Unknown",
133+
status_reason: str = "Unknown",
134+
queue: str = "Unknown",
135+
grid_site: str = "Unknown",
136+
destination_site: str = "NotAssigned",
137+
grid_type: str = "LCG",
138+
submission_time: str | None = None, # ?
139+
last_update_time: str | None = None, # = now?
140+
accounting_sent: bool = False,
141+
) -> int | None:
142+
stmt = insert(PilotAgents).values(
143+
initial_job_id=initial_job_id,
144+
current_job_id=current_job_id,
145+
pilot_job_reference=pilot_job_reference,
146+
pilot_stamp=pilot_stamp,
147+
destination_site=destination_site,
148+
queue=queue,
149+
grid_site=grid_site,
150+
vo=vo,
151+
grid_type=grid_type,
152+
benchmark=benchmark,
153+
submission_time=submission_time,
154+
last_update_time=last_update_time,
155+
status=status,
156+
status_reason=status_reason,
157+
accounting_sent=accounting_sent,
158+
)
159+
160+
# Execute the request
161+
res = await self.conn.execute(stmt)
162+
163+
new_pilot_id = res.inserted_primary_key
164+
165+
# Returns the new pilot ID
166+
return int(new_pilot_id[0]) if new_pilot_id else None
167+
168+
async def add_pilot_credentials(self, pilot_id: int) -> str:
169+
170+
# Get a random string
171+
# Can be customized
172+
random_secret = urandom(30).hex()
173+
174+
hashed_random_secret = hash(random_secret)
175+
176+
stmt = insert(PilotRegistrations).values(
177+
pilot_id=pilot_id, pilot_hashed_secret=hashed_random_secret
178+
)
179+
180+
try:
181+
await self.conn.execute(stmt)
182+
except IntegrityError as e:
183+
if "foreign key" in str(e.orig).lower():
184+
raise PilotNotFoundError(pilot_id=pilot_id) from e
185+
if "duplicate entry" in str(e.orig).lower():
186+
raise PilotAlreadyExistsError(
187+
pilot_id=pilot_id, detail="this pilot has already credentials"
188+
) from e
189+
190+
return random_secret
191+
192+
async def get_pilots(self):
193+
"""Récupère tous les pilotes et les retourne sous forme de dictionnaires.
194+
195+
:raises: NoResultFound
196+
"""
197+
# La clause with_for_update empêche que le jeton soit récupéré plusieurs fois simultanément
198+
stmt = select(PilotRegistrations).with_for_update()
199+
result = await self.conn.execute(stmt)
200+
201+
# Convertir les résultats en dictionnaires
202+
pilots = [dict(row._mapping) for row in result]
203+
204+
return pilots

diracx-db/src/diracx/db/sql/pilot_agents/schema.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
from __future__ import annotations
22

3-
from sqlalchemy import (
4-
DateTime,
5-
Double,
6-
Index,
7-
Integer,
8-
String,
9-
Text,
10-
)
3+
from sqlalchemy import DateTime, Double, ForeignKey, Index, Integer, String, Text
114
from sqlalchemy.orm import declarative_base
125

136
from ..utils import Column, EnumBackedBool, NullColumn
@@ -63,6 +56,11 @@ class PilotOutput(PilotAgentsDBBase):
6356
class PilotRegistrations(PilotAgentsDBBase):
6457
__tablename__ = "PilotRegistrations"
6558

66-
pilot_id = Column("PilotID", Integer, primary_key=True)
59+
pilot_id = Column(
60+
"PilotID",
61+
Integer,
62+
ForeignKey("PilotAgents.PilotID", ondelete="CASCADE"),
63+
primary_key=True,
64+
)
6765
pilot_hashed_secret = Column("PilotHashedSecret", String(64))
6866
pilot_secret_use_count = Column("PilotSecretUseCount", Integer, default=0)

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@pytest.fixture
9-
async def pilot_agents_db(tmp_path) -> PilotAgentsDB:
9+
async def pilot_agents_db(tmp_path):
1010
agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:")
1111
async with agents_db.engine_context():
1212
async with agents_db.engine.begin() as conn:
@@ -29,3 +29,24 @@ async def test_insert_and_select(pilot_agents_db: PilotAgentsDB):
2929
await pilot_agents_db.add_pilot_references(
3030
refs, "test_vo", grid_type="DIRAC", pilot_stamps=None
3131
)
32+
33+
34+
async def test_insert_and_select_single(pilot_agents_db: PilotAgentsDB):
35+
36+
async with pilot_agents_db as pilot_agents_db:
37+
# Add a pilot reference
38+
new_pilot_id = await pilot_agents_db.register_new_pilot(
39+
vo="pilot-vo", pilot_job_reference="pilot-ref"
40+
)
41+
42+
res = await pilot_agents_db.get_pilot_by_id(new_pilot_id)
43+
44+
# Set values
45+
assert res["PilotID"] == new_pilot_id
46+
assert res["VO"] == "pilot-vo"
47+
assert res["PilotJobReference"] == "pilot-ref"
48+
49+
# Default values
50+
assert res["PilotStamp"] == ""
51+
assert res["BenchMark"] == 0.0
52+
assert res["Status"] == "Unknown"
Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from __future__ import annotations
22

3+
from os import getenv
4+
35
from fastapi import HTTPException, status
46

57
from diracx.core.exceptions import (
68
AuthorizationError,
9+
PilotAlreadyExistsError,
10+
PilotNotFoundError,
711
)
812

913
from ..dependencies import (
@@ -14,16 +18,79 @@
1418
router = DiracxRouter(require_auth=False)
1519

1620

17-
@router.post("/register")
18-
async def search(pilot_db: PilotAgentsDB, pilot_reference: str, pilot_secret: str):
21+
@router.post("/pilot-login")
22+
async def pilot_login(pilot_db: PilotAgentsDB, pilot_reference: str, pilot_secret: str):
1923
"""Endpoint without policy, the pilot uses only its secret."""
2024
try:
21-
await pilot_db.register_pilot(
25+
await pilot_db.verify_pilot_secret(
2226
pilot_reference=pilot_reference, pilot_secret=pilot_secret
2327
)
2428
except AuthorizationError as e:
2529
raise HTTPException(
2630
status_code=status.HTTP_401_UNAUTHORIZED, detail=e.detail
2731
) from e
2832

29-
return pilot_reference
33+
# TODO: Returns a JWT
34+
return {"ref": pilot_reference}
35+
36+
37+
# Debug only route
38+
# Keep it?
39+
# TODO: Add to the env?
40+
if getenv("production") is None:
41+
42+
@router.post("/register-pilot")
43+
async def add_new_pilot(
44+
pilot_db: PilotAgentsDB,
45+
vo: str = "default-vo",
46+
initial_job_id: int = 0,
47+
current_job_id: int = 0,
48+
benchmark: float = 0.0,
49+
pilot_job_reference: str = "Unknown",
50+
pilot_stamp: str = "",
51+
status: str = "Unknown",
52+
status_reason: str = "Unknown",
53+
queue: str = "Unknown",
54+
grid_site: str = "Unknown",
55+
destination_site: str = "NotAssigned",
56+
grid_type: str = "LCG",
57+
submission_time: str | None = None, # ?
58+
last_update_time: str | None = None, # = now?
59+
accounting_sent: bool = False,
60+
):
61+
return {
62+
"id": await pilot_db.register_new_pilot(
63+
initial_job_id=initial_job_id,
64+
current_job_id=current_job_id,
65+
pilot_job_reference=pilot_job_reference,
66+
pilot_stamp=pilot_stamp,
67+
destination_site=destination_site,
68+
queue=queue,
69+
grid_site=grid_site,
70+
vo=vo,
71+
grid_type=grid_type,
72+
benchmark=benchmark,
73+
submission_time=submission_time,
74+
last_update_time=last_update_time,
75+
status=status,
76+
status_reason=status_reason,
77+
accounting_sent=accounting_sent,
78+
)
79+
}
80+
81+
@router.post("/add-credentials")
82+
async def add_credentials(pilot_db: PilotAgentsDB, pilot_id: int):
83+
try:
84+
return {"secret": await pilot_db.add_pilot_credentials(pilot_id=pilot_id)}
85+
except PilotNotFoundError as e:
86+
raise HTTPException(
87+
status_code=status.HTTP_404_NOT_FOUND, detail=e.detail
88+
) from e
89+
except PilotAlreadyExistsError as e:
90+
raise HTTPException(
91+
status_code=status.HTTP_409_CONFLICT, detail=e.detail
92+
) from e
93+
94+
@router.post("/get-pilots")
95+
async def get_pilots(pilot_db: PilotAgentsDB):
96+
return await pilot_db.get_pilots()

0 commit comments

Comments
 (0)