Skip to content

Commit 56daac5

Browse files
refactor: Moving db parts to the logic, and some fixes
1 parent 79a11cf commit 56daac5

File tree

6 files changed

+321
-279
lines changed

6 files changed

+321
-279
lines changed

diracx-cli/src/diracx/cli/internal/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def add_user(
182182

183183

184184
@app.command()
185-
def set_user_as_pilot(
185+
def set_user_as_pilot_user(
186186
config_repo: str,
187187
*,
188188
vo: Annotated[str, typer.Option()],

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 5 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,13 @@
88
from sqlalchemy.sql import delete, insert, select, update
99

1010
from diracx.core.exceptions import (
11-
BadPilotCredentialsError,
12-
BadPilotVOError,
1311
CredentialsAlreadyExistError,
1412
CredentialsNotFoundError,
1513
InvalidQueryError,
1614
PilotAlreadyAssociatedWithJobError,
1715
PilotJobsNotFoundError,
1816
PilotNotFoundError,
1917
SecretAlreadyExistsError,
20-
SecretHasExpiredError,
2118
SecretNotFoundError,
2219
)
2320
from diracx.core.models import PilotFieldsMapping, SearchSpec, SortSpec
@@ -96,105 +93,6 @@ async def increment_global_secret_use(
9693
detail="This should not happen. Pilot should have a secret, but is not found."
9794
)
9895

99-
async def verify_pilot_secret(
100-
self, pilot_stamp: str, pilot_hashed_secret: str
101-
) -> None:
102-
"""Verify that a pilot can login with the given credentials."""
103-
# 1. Get the pilot to secret association
104-
pilots_credentials = await self.get_pilot_credentials_by_stamp([pilot_stamp])
105-
106-
# 2. Get the pilot secret itself
107-
secrets = await self.get_secrets_by_hashed_secrets_bulk([pilot_hashed_secret])
108-
secret = secrets[0] # Semantic, assured by fetch_records_bulk_or_raises
109-
110-
matches = [
111-
pilot_credential
112-
for pilot_credential in pilots_credentials
113-
if secret["SecretID"] == pilot_credential["PilotSecretID"]
114-
]
115-
116-
# 3. Compare the secret_id
117-
if len(matches) == 0:
118-
119-
raise BadPilotCredentialsError(
120-
data={
121-
"pilot_stamp": pilot_stamp,
122-
"pilot_hashed_secret": pilot_hashed_secret,
123-
"real_hashed_secret": secret["HashedSecret"],
124-
"pilot_secret_id[]": str(
125-
[
126-
pilot_credential["PilotSecretID"]
127-
for pilot_credential in pilots_credentials
128-
]
129-
),
130-
"secret_id": secret["SecretID"],
131-
"test": str(pilots_credentials),
132-
}
133-
)
134-
elif len(matches) > 1:
135-
136-
raise DBInBadStateError(
137-
detail="This should not happen. Duplicates in the database."
138-
)
139-
pilot_credentials = matches[0] # Semantic
140-
141-
# 4. Check if the secret is expired
142-
now = datetime.now(tz=timezone.utc)
143-
# Convert the timezone, TODO: Change with #454: https://github.com/DIRACGrid/diracx/pull/454
144-
expiration = secret["SecretExpirationDate"].replace(tzinfo=timezone.utc)
145-
if expiration < now:
146-
147-
try:
148-
await self.delete_secrets_bulk([secret["SecretID"]])
149-
except SecretNotFoundError as e:
150-
await self.conn.rollback()
151-
152-
raise DBInBadStateError(
153-
detail="This should not happen. Pilot should have a secret, but not found."
154-
) from e
155-
156-
raise SecretHasExpiredError(
157-
data={
158-
"pilot_hashed_secret": pilot_hashed_secret,
159-
"now": str(now),
160-
"expiration_date": secret["SecretExpirationDate"],
161-
}
162-
)
163-
164-
# 5. Now the pilot is authorized, increment the counters (globally and locally).
165-
try:
166-
# 5.1 Increment the local count
167-
await self.increment_pilot_local_secret_and_last_time_use(
168-
pilot_secret_id=pilot_credentials["PilotSecretID"],
169-
pilot_stamp=pilot_credentials["PilotStamp"],
170-
)
171-
172-
# 5.2 Increment the global count
173-
await self.increment_global_secret_use(
174-
secret_id=pilot_credentials["PilotSecretID"]
175-
)
176-
except Exception as e: # Generic, to catch it.
177-
# Should NOT happen
178-
# Wrapped in a try/catch to still catch in case of an error in the counters
179-
# Caught and raised here to avoid raising a 4XX error
180-
await self.conn.rollback()
181-
182-
raise DBInBadStateError(
183-
detail="This should not happen. Pilot has credentials, but has a corrupted secret."
184-
) from e
185-
186-
# 6. Delete all secrets if its count attained the secret_global_use_count_max
187-
if secret["SecretGlobalUseCountMax"]:
188-
if secret["SecretGlobalUseCount"] + 1 == secret["SecretGlobalUseCountMax"]:
189-
try:
190-
await self.delete_secrets_bulk([secret["SecretID"]])
191-
except SecretNotFoundError as e:
192-
# Should NOT happen
193-
await self.conn.rollback()
194-
raise DBInBadStateError(
195-
detail="This should not happen. Pilot has credentials, but has corrupted secret."
196-
) from e
197-
19896
async def add_pilots_bulk(
19997
self,
20098
pilot_stamps: list[str],
@@ -295,11 +193,6 @@ async def associate_pilots_with_secrets_bulk(
295193
"""Bulk associate pilots with secrets. Raises an error in case of a Integrity violation."""
296194
# Better to give as a parameter pilot to secret associations, rather than associating here.
297195

298-
# First verify that pilots can access a certain secret
299-
await self.verify_that_pilot_can_access_secret_bulk(
300-
pilot_to_secret_id_mapping_values
301-
)
302-
303196
stmt = insert(PilotToSecretMapping).values(pilot_to_secret_id_mapping_values)
304197

305198
try:
@@ -324,48 +217,28 @@ async def associate_pilots_with_secrets_bulk(
324217
) from e
325218
raise NotImplementedError(f"This error is not caught: {str(e.orig)}") from e
326219

327-
async def associate_pilot_with_jobs(self, pilot_stamp: str, job_ids: list[int]):
220+
async def associate_pilot_with_jobs(self, job_to_pilot_mapping: list[dict]):
328221
"""Associate a pilot with jobs. Raises an error if the pilot does not exist and in case of a IntegrityError.
329222
330223
**Important note**: We don't verify if a job exists in the JobDB
331224
"""
332-
pilot_ids = await self.get_pilot_ids_by_stamps([pilot_stamp])
333-
# Semantic assured by fetch_records_bulk_or_raises
334-
pilot_id = pilot_ids[0]
335-
336-
now = datetime.now(tz=timezone.utc)
337-
338-
# Prepare the list of dictionaries for bulk insertion
339-
values = [
340-
{"PilotID": pilot_id, "JobID": job_id, "StartTime": now}
341-
for job_id in job_ids
342-
]
343-
344225
# Insert multiple rows in a single execute call
345-
stmt = insert(JobToPilotMapping).values(values)
226+
stmt = insert(JobToPilotMapping).values(job_to_pilot_mapping)
346227

347228
try:
348229
res = await self.conn.execute(stmt)
349230
except IntegrityError as e:
350231
raise PilotAlreadyAssociatedWithJobError(
351-
data={"pilot_stamp": pilot_stamp, "job_ids": str(job_ids)}
232+
data={"job_to_pilot_mapping": str(job_to_pilot_mapping)}
352233
) from e
353234

354-
if res.rowcount != len(job_ids):
235+
if res.rowcount != len(job_to_pilot_mapping):
355236
# If doubles
356237
await self.conn.rollback()
357238
raise PilotJobsNotFoundError(
358-
data={"pilot_stamp": pilot_stamp, "job_ids": str(job_ids)}
239+
data={"job_to_pilot_mapping": str(job_to_pilot_mapping)}
359240
)
360241

361-
async def get_pilot_jobs_ids_by_stamp(self, pilot_stamp: str) -> list[int]:
362-
"""Fetch pilot jobs by stamp."""
363-
pilot_ids = await self.get_pilot_ids_by_stamps([pilot_stamp])
364-
# Semantic assured by fetch_records_bulk_or_raises
365-
pilot_id = pilot_ids[0]
366-
367-
return await self.get_pilot_jobs_ids_by_pilot_id(pilot_id)
368-
369242
async def update_pilot_fields_bulk(
370243
self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping]
371244
):
@@ -406,52 +279,6 @@ async def update_pilot_fields_bulk(
406279

407280
await self.conn.commit()
408281

409-
async def verify_that_pilot_can_access_secret_bulk(
410-
self, pilot_to_secret_id_mapping_values: list[dict[str, Any]]
411-
):
412-
# 1. Extract unique pilot_stamps and secret_ids
413-
pilot_stamps = [
414-
entry["PilotStamp"] for entry in pilot_to_secret_id_mapping_values
415-
]
416-
secret_ids = [
417-
entry["PilotSecretID"] for entry in pilot_to_secret_id_mapping_values
418-
]
419-
420-
# 2. Bulk fetch pilot and secret info
421-
pilots = await self.get_pilots_by_stamp_bulk(pilot_stamps)
422-
secrets = await self.get_secrets_by_secret_ids_bulk(secret_ids)
423-
424-
# 3. Build lookup maps
425-
pilot_vo_map = {pilot["PilotStamp"]: pilot["VO"] for pilot in pilots}
426-
secret_vo_map = {secret["SecretID"]: secret["SecretVO"] for secret in secrets}
427-
428-
# 4. Validate access
429-
bad_mapping = []
430-
431-
for mapping in pilot_to_secret_id_mapping_values:
432-
pilot_stamp = mapping["PilotStamp"]
433-
secret_id = mapping["PilotSecretID"]
434-
435-
pilot_vo = pilot_vo_map[pilot_stamp]
436-
secret_vo = secret_vo_map[secret_id]
437-
438-
# If secret_vo is set to NULL, everybody can access it
439-
if not secret_vo:
440-
continue
441-
442-
# Access allowed only if VOs match or secret_vo is open (None)
443-
if secret_vo is not None and pilot_vo != secret_vo:
444-
bad_mapping.append(
445-
{
446-
"pilot_stamp": pilot_stamp,
447-
"given_vo": pilot_vo,
448-
"expected_vo": secret_vo,
449-
}
450-
)
451-
452-
if bad_mapping:
453-
raise BadPilotVOError(data={"bad_mapping": str(bad_mapping)})
454-
455282
async def set_secret_expirations_bulk(
456283
self, secret_ids: list[int], pilot_secret_expiration_dates: list[DateTime]
457284
):

0 commit comments

Comments
 (0)