Skip to content

Commit 02a5e31

Browse files
feat: A user can create secrets, and associate them to a pilot
1 parent 8072227 commit 02a5e31

File tree

13 files changed

+639
-249
lines changed

13 files changed

+639
-249
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ class CredentialsNotFoundError(GenericError):
152152
tail = "not found"
153153

154154

155+
class CredentialsAlreadyExistError(GenericError):
156+
head = "Credentials"
157+
tail = "already exist"
158+
159+
155160
class SecretHasExpiredError(GenericError):
156161
head = "Secret"
157162
tail = "has expired"

diracx-core/src/diracx/core/models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,17 @@ class TokenResponse(BaseModel):
225225
refresh_token: str | None = None
226226

227227

228-
class PilotCredentialsInfo(BaseModel):
229-
pilot_stamp: str | None
228+
class PilotSecretsInfo(BaseModel):
230229
pilot_secret: str
231230
pilot_secret_expires_in: int
232231

233232

234-
class PilotCredentialsResponse(BaseModel):
235-
pilot_credentials: list[PilotCredentialsInfo]
233+
class PilotStampInfo(BaseModel):
234+
pilot_stamp: str
235+
236+
237+
class PilotCredentialsInfo(PilotSecretsInfo, PilotStampInfo):
238+
pass
236239

237240

238241
class AccessTokenPayload(TokenPayload):

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from diracx.core.exceptions import (
1111
BadPilotCredentialsError,
1212
BadPilotVOError,
13+
CredentialsAlreadyExistError,
1314
CredentialsNotFoundError,
14-
PilotAlreadyExistsError,
1515
PilotNotFoundError,
1616
SecretAlreadyExistsError,
1717
SecretHasExpiredError,
@@ -87,25 +87,41 @@ async def verify_pilot_secret(
8787
pilots_credentials = await self.get_pilots_credentials_by_stamps_bulk(
8888
[pilot_stamp]
8989
)
90-
pilot_credentials = pilots_credentials[
91-
0
92-
] # Semantic, assured by fetch_records_bulk_or_raises
9390

9491
# 2. Get the pilot secret itself
9592
secrets = await self.get_secrets_by_hashed_secrets_bulk([pilot_hashed_secret])
9693
secret = secrets[0] # Semantic, assured by fetch_records_bulk_or_raises
9794

95+
matches = [
96+
pilot_credential
97+
for pilot_credential in pilots_credentials
98+
if secret["SecretID"] == pilot_credential["PilotSecretID"]
99+
]
100+
98101
# 3. Compare the secret_id
99-
if not secret["SecretID"] == pilot_credentials["PilotSecretID"]:
102+
if len(matches) == 0:
100103

101104
raise BadPilotCredentialsError(
102105
data={
103106
"pilot_stamp": pilot_stamp,
104107
"pilot_hashed_secret": pilot_hashed_secret,
105108
"real_hashed_secret": secret["HashedSecret"],
106-
str(secret["SecretID"]): str(pilot_credentials["PilotSecretID"]),
109+
"pilot_secret_id[]": str(
110+
[
111+
pilot_credential["PilotSecretID"]
112+
for pilot_credential in pilots_credentials
113+
]
114+
),
115+
"secret_id": secret["SecretID"],
116+
"test": str(pilots_credentials),
107117
}
108118
)
119+
elif len(matches) > 1:
120+
121+
raise DBInBadStateError(
122+
detail="This should not happen. Duplicates in the database."
123+
)
124+
pilot_credentials = matches[0] # Semantic
109125

110126
# 4. Check if the secret is expired
111127
now = datetime.now(tz=timezone.utc)
@@ -260,20 +276,24 @@ async def associate_pilots_with_secrets_bulk(
260276
try:
261277
await self.conn.execute(stmt)
262278
await self.conn.commit()
279+
263280
except (IntegrityError, OperationalError) as e:
264281
# Undo changes
265282
await self.conn.rollback()
266-
267283
if "foreign key" in str(e.orig).lower():
268284
raise PilotNotFoundError(
269285
data={"pilot_stamps": str(pilot_to_secret_id_mapping_values)},
270286
detail="at least one of these pilots or secrets does not exist",
271287
) from e
272-
if "duplicate entry" in str(e.orig).lower():
273-
raise PilotAlreadyExistsError(
288+
if any(
289+
el in str(e.orig).lower()
290+
for el in ["duplicate entry", "unique constraint"]
291+
):
292+
raise CredentialsAlreadyExistError(
274293
data={"pilot_stamps": str(pilot_to_secret_id_mapping_values)},
275294
detail="at least one of these pilots already have a secret",
276295
) from e
296+
raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e
277297

278298
async def verify_that_pilot_can_access_secret_bulk(
279299
self, pilot_to_secret_id_mapping_values: list[dict[str, Any]]
@@ -363,6 +383,7 @@ async def get_pilots_credentials_by_stamps_bulk(
363383
"pilot_stamp",
364384
"PilotStamp",
365385
pilot_stamps,
386+
allow_more_than_one_result_per_input=True,
366387
)
367388

368389
async def get_secrets_by_hashed_secrets_bulk(self, hashed_secrets: list[str]):

diracx-db/src/diracx/db/sql/utils/functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ async def fetch_records_bulk_or_raises(
158158
column_name: str,
159159
elements_to_fetch: list,
160160
order_by: tuple[str, str] | None = None,
161+
allow_more_than_one_result_per_input: bool = False,
161162
) -> list[dict]:
162163
"""Fetches a list of elements in a table, returns a list of elements.
163164
All elements fro the `element_to_fetch` **must** be present.
@@ -196,8 +197,9 @@ async def fetch_records_bulk_or_raises(
196197
results = rows_to_dicts(await conn.execute(stmt))
197198

198199
# Detects duplicates
199-
if len(results) > len(elements_to_fetch):
200-
raise DBInBadStateError(detail="Seems to have duplicates in the database.")
200+
if not allow_more_than_one_result_per_input:
201+
if len(results) > len(elements_to_fetch):
202+
raise DBInBadStateError(detail="Seems to have duplicates in the database.")
201203

202204
# Checks if we have every elements we wanted
203205
found_keys = {row[column_name] for row in results}

diracx-logic/src/diracx/logic/auth/pilot.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from diracx.core.config import Config
88
from diracx.core.exceptions import (
99
ConfigurationError,
10+
CredentialsAlreadyExistError,
1011
PilotAlreadyExistsError,
1112
PilotNotFoundError,
1213
)
13-
from diracx.core.models import PilotCredentialsInfo, PilotCredentialsResponse
14+
from diracx.core.models import PilotCredentialsInfo, PilotSecretsInfo, PilotStampInfo
1415
from diracx.core.settings import AuthSettings
1516
from diracx.db.sql import PilotAgentsDB
1617

@@ -52,7 +53,7 @@ async def create_credentials(
5253
secret["SecretCreationDate"]
5354
+ timedelta(
5455
seconds=(
55-
expiration_minutes
56+
expiration_minutes * 60
5657
if expiration_minutes
5758
else settings.pilot_secret_expire_seconds
5859
)
@@ -74,6 +75,36 @@ async def create_credentials(
7475
return random_secrets, hashed_secrets, expiration_dates_timestamps
7576

7677

78+
async def associate_pilots_with_secrets(
79+
pilot_db: PilotAgentsDB,
80+
pilot_stamps: list[str],
81+
secrets: list[str] | None = None,
82+
hashed_secrets: list[str] | None = None,
83+
):
84+
85+
if not hashed_secrets:
86+
assert secrets
87+
hashed_secrets = [hash(secret) for secret in secrets]
88+
89+
# Get the secret ids to later associate them with pilots
90+
secrets_obj = await pilot_db.get_secrets_by_hashed_secrets_bulk(hashed_secrets)
91+
secret_ids = [secret["SecretID"] for secret in secrets_obj]
92+
93+
if len(secret_ids) == 1:
94+
secret_ids = secret_ids * len(pilot_stamps)
95+
96+
# Associates pilots with their secrets
97+
pilot_to_secret_id_mapping_values = [
98+
{
99+
"PilotSecretID": secret_id,
100+
"PilotStamp": pilot_stamp,
101+
}
102+
for pilot_stamp, secret_id in zip(pilot_stamps, secret_ids)
103+
]
104+
105+
await pilot_db.associate_pilots_with_secrets_bulk(pilot_to_secret_id_mapping_values)
106+
107+
77108
async def add_pilot_credentials(
78109
pilot_stamps: list[str],
79110
pilot_db: PilotAgentsDB,
@@ -91,29 +122,25 @@ async def add_pilot_credentials(
91122
)
92123
)
93124

94-
# Get the secret ids to later associate them with pilots
95-
secrets = await pilot_db.get_secrets_by_hashed_secrets_bulk(hashed_secrets)
96-
secret_ids = [secret["SecretID"] for secret in secrets]
97-
98-
# Associates pilots with their secrets
99-
pilot_to_secret_id_mapping_values = [
100-
{
101-
"PilotSecretID": secret_id,
102-
"PilotStamp": pilot_stamp,
103-
}
104-
for pilot_stamp, secret_id in zip(pilot_stamps, secret_ids)
105-
]
106-
await pilot_db.associate_pilots_with_secrets_bulk(pilot_to_secret_id_mapping_values)
125+
try:
126+
await associate_pilots_with_secrets(
127+
pilot_db=pilot_db, hashed_secrets=hashed_secrets, pilot_stamps=pilot_stamps
128+
)
129+
except CredentialsAlreadyExistError as e:
130+
# Undo everything in case of an error.
131+
# TODO: Validate in PR
132+
await pilot_db.conn.rollback()
133+
raise e
107134

108135
return random_secrets, expiration_dates_timestamps
109136

110137

111138
def create_pilot_credentials_response(
112-
pilot_stamps: list[str | None],
139+
pilot_stamps: list[str],
113140
pilot_secrets: list[str],
114141
pilot_expiration_dates: list[int],
115-
) -> PilotCredentialsResponse:
116-
credentials_list = [
142+
) -> list[PilotCredentialsInfo]:
143+
return [
117144
PilotCredentialsInfo(
118145
pilot_stamp=pilot_stamp,
119146
pilot_secret=secret,
@@ -124,7 +151,22 @@ def create_pilot_credentials_response(
124151
)
125152
]
126153

127-
return PilotCredentialsResponse(pilot_credentials=credentials_list)
154+
155+
def create_secrets_response(
156+
pilot_secrets: list[str],
157+
pilot_expiration_dates: list[int],
158+
) -> list[PilotSecretsInfo]:
159+
return [
160+
PilotSecretsInfo(
161+
pilot_secret=secret,
162+
pilot_secret_expires_in=expires_in,
163+
)
164+
for secret, expires_in in zip(pilot_secrets, pilot_expiration_dates)
165+
]
166+
167+
168+
def create_stamp_response(pilot_stamps: list[str]) -> list[PilotStampInfo]:
169+
return [PilotStampInfo(pilot_stamp=stamp) for stamp in pilot_stamps]
128170

129171

130172
def get_registry_and_group_configuration(config: Config, vo: str):
@@ -212,6 +254,8 @@ async def register_new_pilots(
212254
pilots_that_already_exist = set(pilot_stamps) - set(
213255
literal_eval(e.detail)
214256
)
257+
else:
258+
raise ValueError("Bad internal error.")
215259
except AttributeError as e2:
216260
raise ValueError("Must be defined and a set string representation") from e2
217261

diracx-routers/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ auth = "diracx.routers.auth:router"
4646
config = "diracx.routers.configuration:router"
4747
health = "diracx.routers.health:router"
4848
jobs = "diracx.routers.jobs:router"
49+
pilots = "diracx.routers.pilots:router"
4950

5051
[project.entry-points."diracx.access_policies"]
5152
WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy"
5253
SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy"
54+
PilotCredentialsAccessPolicy = "diracx.routers.pilots.access_policies:PilotCredentialsAccessPolicy"
5355

5456
# Minimum version of the client supported
5557
[project.entry-points."diracx.min_client_version"]

diracx-routers/src/diracx/routers/auth/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from .authorize_code_flow import router as authorize_code_flow_router
66
from .device_flow import router as device_flow_router
77
from .management import router as management_router
8-
from .pilot import router as pilot_router
98
from .token import router as token_router
109
from .utils import has_properties
1110

@@ -14,6 +13,5 @@
1413
router.include_router(management_router)
1514
router.include_router(authorize_code_flow_router)
1615
router.include_router(token_router)
17-
router.include_router(pilot_router)
1816

1917
__all__ = ["has_properties", "verify_dirac_access_token"]

0 commit comments

Comments
 (0)