Skip to content

Commit 5c5dba7

Browse files
fix: Moved secrets table to the AuthDB
1 parent f02aca1 commit 5c5dba7

File tree

18 files changed

+230
-560
lines changed

18 files changed

+230
-560
lines changed

diracx-db/src/diracx/db/sql/auth/db.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
from __future__ import annotations
22

33
import secrets
4+
from typing import Any
45

5-
from sqlalchemy import insert, select, update
6+
from sqlalchemy import DateTime, bindparam, delete, insert, select, update
67
from sqlalchemy.exc import IntegrityError, NoResultFound
78
from uuid_utils import UUID, uuid7
89

910
from diracx.core.exceptions import (
1011
AuthorizationError,
12+
SecretNotFoundError,
1113
TokenNotFoundError,
1214
)
15+
from diracx.core.models import PilotSecretConstraints, SearchSpec, SortSpec
1316
from diracx.db.sql.utils import BaseSQLDB, hash, substract_date
17+
from diracx.db.sql.utils.functions import utcnow
1418

1519
from .schema import (
1620
AuthorizationFlows,
1721
DeviceFlows,
1822
FlowStatus,
23+
PilotSecrets,
1924
RefreshTokens,
2025
RefreshTokenStatus,
2126
)
@@ -264,3 +269,153 @@ async def revoke_user_refresh_tokens(self, subject):
264269
.where(RefreshTokens.sub == subject)
265270
.values(status=RefreshTokenStatus.REVOKED)
266271
)
272+
273+
# ------------- Pilot secrets mechanism -------------
274+
275+
async def insert_unique_secrets(
276+
self,
277+
hashed_secrets: list[bytes],
278+
secret_global_use_count_max: int | None = 1,
279+
secret_constraints: dict[bytes, PilotSecretConstraints] = {},
280+
):
281+
"""Bulk insert secrets.
282+
283+
Raises:
284+
- NotImplementedError if we have an IntegrityError not caught
285+
286+
"""
287+
values = [
288+
{
289+
"SecretUUID": str(uuid7()),
290+
"SecretRemainingUseCount": secret_global_use_count_max,
291+
"HashedSecret": hashed_secret,
292+
"SecretConstraints": secret_constraints.get(hashed_secret, {}),
293+
}
294+
for hashed_secret in hashed_secrets
295+
]
296+
297+
stmt = insert(PilotSecrets).values(values)
298+
await self.conn.execute(stmt)
299+
300+
async def delete_secrets(self, secret_uuids: list[str]):
301+
"""Bulk delete secrets.
302+
303+
Raises SecretNotFoundError if one of the secret was not found.
304+
"""
305+
stmt = delete(PilotSecrets).where(PilotSecrets.secret_uuid.in_(secret_uuids))
306+
307+
res = await self.conn.execute(stmt)
308+
309+
if res.rowcount != len(secret_uuids):
310+
raise SecretNotFoundError(
311+
"At least one of the secret has not been deleted."
312+
)
313+
314+
# We NEED to commit here, because we will raise an error after this function
315+
await self.conn.commit()
316+
317+
async def update_pilot_secret_use_time(self, secret_uuid: str) -> None:
318+
"""Updates when a pilot uses a secret.
319+
320+
Raises PilotNotFoundError if the pilot does not exist
321+
322+
"""
323+
# Prepare the update statement
324+
stmt = (
325+
update(PilotSecrets)
326+
.values(
327+
pilot_secret_use_date=utcnow(),
328+
secret_remaining_use_count=PilotSecrets.secret_remaining_use_count - 1,
329+
)
330+
.where(PilotSecrets.secret_uuid == secret_uuid)
331+
)
332+
333+
# Execute the update using the connection
334+
res = await self.conn.execute(stmt)
335+
336+
if res.rowcount == 0:
337+
raise SecretNotFoundError("Unknown secret")
338+
339+
async def update_pilot_secrets_constraints(
340+
self, hashed_secrets_to_pilot_stamps_mapping: list[dict[str, Any]]
341+
):
342+
"""Bulk associate pilots with secrets by updating theirs constraints.
343+
344+
Important: We have to provide the updated constraints.
345+
346+
Raises:
347+
- PilotNotFoundError if one of the pilot does not exist
348+
- NotImplementedError if at least of the pilot
349+
350+
"""
351+
# Better to give as a parameter pilot to secret associations, rather than associating here.
352+
353+
stmt = (
354+
update(PilotSecrets)
355+
.where(PilotSecrets.hashed_secret == bindparam("PilotHashedSecret"))
356+
.values({"SecretConstraints": bindparam("PilotSecretConstraints")})
357+
)
358+
359+
try:
360+
await self.conn.execute(stmt, hashed_secrets_to_pilot_stamps_mapping)
361+
except IntegrityError as e:
362+
if "foreign key" in str(e.orig).lower():
363+
raise SecretNotFoundError(
364+
detail="at least one of these secrets does not exist",
365+
) from e
366+
raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e
367+
368+
async def set_secret_expirations(
369+
self, secret_uuids: list[str], pilot_secret_expiration_dates: list[DateTime]
370+
):
371+
"""Bulk set expiration dates to secrets.
372+
373+
Raises:
374+
- SecretNotFoundError if one of the secret_uuid is not associated with a secret.
375+
- NotImplementedError if a integrity error is not caught.
376+
-
377+
378+
"""
379+
values = [
380+
{"b_SecretUUID": secret_uuid, "SecretExpirationDate": pilot_secret}
381+
for secret_uuid, pilot_secret in zip(
382+
secret_uuids, pilot_secret_expiration_dates
383+
)
384+
]
385+
386+
# Prepare the update statement
387+
stmt = (
388+
update(PilotSecrets)
389+
.where(PilotSecrets.secret_uuid == bindparam("b_SecretUUID"))
390+
.values({"SecretExpirationDate": bindparam("SecretExpirationDate")})
391+
)
392+
393+
try:
394+
await self.conn.execute(stmt, values)
395+
except IntegrityError as e:
396+
if "foreign key" in str(e.orig).lower():
397+
raise SecretNotFoundError(
398+
detail="at least one of these secrets does not exist",
399+
) from e
400+
raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e
401+
402+
async def search_secrets(
403+
self,
404+
parameters: list[str] | None,
405+
search: list[SearchSpec],
406+
sorts: list[SortSpec],
407+
*,
408+
distinct: bool = False,
409+
per_page: int = 100,
410+
page: int | None = None,
411+
) -> tuple[int, list[dict[Any, Any]]]:
412+
"""Search for secrets in the database."""
413+
return await self._search(
414+
table=PilotSecrets,
415+
parameters=parameters,
416+
search=search,
417+
sorts=sorts,
418+
distinct=distinct,
419+
per_page=per_page,
420+
page=page,
421+
)

diracx-db/src/diracx/db/sql/auth/schema.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
from enum import Enum, auto
44

55
from sqlalchemy import (
6+
BINARY,
67
JSON,
8+
DateTime,
79
Index,
10+
SmallInteger,
811
String,
12+
UniqueConstraint,
913
Uuid,
1014
)
1115
from sqlalchemy.orm import declarative_base
@@ -99,3 +103,28 @@ class RefreshTokens(Base):
99103
sub = Column("Sub", String(256), index=True)
100104

101105
__table_args__ = (Index("index_status_sub", status, sub),)
106+
107+
108+
class PilotSecrets(Base):
109+
__tablename__ = "PilotSecrets"
110+
111+
secret_uuid = Column("SecretUUID", Uuid(as_uuid=False), primary_key=True)
112+
113+
hashed_secret = Column("HashedSecret", BINARY(32))
114+
# Global count
115+
# Null: Infinite use
116+
secret_remaining_use_count = NullColumn(
117+
"SecretRemainingUseCount", SmallInteger, default=1
118+
)
119+
secret_expiration_date = NullColumn("SecretExpirationDate", DateTime(timezone=True))
120+
# To authorize only specific pilots to access a secret
121+
# The constraint format follows diracx.code.models.PilotSecretConstraints
122+
secret_constraints = NullColumn("SecretConstraints", JSON)
123+
124+
# If a date is set, then it used a secret (acts also like a "PilotUsedSecret" field)
125+
pilot_secret_use_date = NullColumn("PilotSecretUseDate", DateTime(timezone=True))
126+
127+
__table_args__ = (
128+
UniqueConstraint("HashedSecret", name="uq_hashed_secret"),
129+
Index("HashedSecret", "HashedSecret"),
130+
)

0 commit comments

Comments
 (0)