Skip to content

Commit 10205d6

Browse files
feat: Autodeletion of secrets after full use or expiration.
1 parent 3f013df commit 10205d6

File tree

6 files changed

+118
-46
lines changed

6 files changed

+118
-46
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,6 @@ class SecretHasExpiredError(GenericError):
157157
tail = "has expired"
158158

159159

160-
class OverusedSecretError(GenericError):
161-
head = "Secret"
162-
tail = "too much used"
163-
164-
165160
class SecretAlreadyExistsError(GenericError):
166161
head = "Secret"
167162
tail = "already exists"

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from sqlalchemy import DateTime, bindparam, insert, update
77
from sqlalchemy.exc import IntegrityError, OperationalError
8+
from sqlalchemy.sql import delete
89

910
from diracx.core.exceptions import (
1011
BadPilotCredentialsError,
1112
BadPilotVOError,
1213
CredentialsNotFoundError,
13-
OverusedSecretError,
1414
PilotAlreadyExistsError,
1515
PilotNotFoundError,
1616
SecretAlreadyExistsError,
@@ -47,6 +47,7 @@ async def increment_pilot_local_secret_and_last_time_use(
4747
res = await self.conn.execute(stmt)
4848

4949
if res.rowcount == 0:
50+
await self.conn.rollback()
5051
raise PilotNotFoundError(
5152
data={
5253
"pilot_stamp": pilot_stamp,
@@ -70,7 +71,13 @@ async def increment_global_secret_use(
7071
res = await self.conn.execute(stmt)
7172

7273
if res.rowcount == 0:
74+
await self.conn.rollback()
7375
raise SecretNotFoundError(data={"secret_id": str(secret_id)})
76+
if res.rowcount != 1:
77+
await self.conn.rollback()
78+
raise DBInBadStateError(
79+
detail="This should not happen. Pilot should have a secret, but is not found."
80+
)
7481

7582
async def verify_pilot_secret(
7683
self, pilot_stamp: str, pilot_hashed_secret: str
@@ -90,6 +97,7 @@ async def verify_pilot_secret(
9097

9198
# 3. Compare the secret_id
9299
if not secret["SecretID"] == pilot_credentials["PilotSecretID"]:
100+
93101
raise BadPilotCredentialsError(
94102
data={
95103
"pilot_stamp": pilot_stamp,
@@ -104,6 +112,16 @@ async def verify_pilot_secret(
104112
# Convert the timezone, TODO: Change with #454: https://github.com/DIRACGrid/diracx/pull/454
105113
expiration = secret["SecretExpirationDate"].replace(tzinfo=timezone.utc)
106114
if expiration < now:
115+
116+
try:
117+
await self.delete_secrets_bulk([secret["SecretID"]])
118+
except SecretNotFoundError as e:
119+
await self.conn.rollback()
120+
121+
raise DBInBadStateError(
122+
detail="This should not happen. Pilot should have a secret, but not found."
123+
) from e
124+
107125
raise SecretHasExpiredError(
108126
data={
109127
"pilot_hashed_secret": pilot_hashed_secret,
@@ -112,42 +130,40 @@ async def verify_pilot_secret(
112130
}
113131
)
114132

115-
# 5. Verify the secret counter
116-
# 5.1 Only check if the SecretGlobalUseCountMax is defined
117-
# If not defined, there is an infinite use.
118-
if secret["SecretGlobalUseCountMax"]:
119-
# 5.2 Finite use, we check if we can still login
120-
if secret["SecretGlobalUseCount"] + 1 > secret["SecretGlobalUseCountMax"]:
121-
raise OverusedSecretError(
122-
data={
123-
"pilot_stamp" "pilot_hashed_secret": pilot_hashed_secret,
124-
"secret_global_use_count": secret["SecretGlobalUseCount"],
125-
"secret_global_use_count_max": secret[
126-
"SecretGlobalUseCountMax"
127-
],
128-
}
129-
)
130-
131-
# 6. Now the pilot is authorized, increment the counters (globally and locally).
133+
# 5. Now the pilot is authorized, increment the counters (globally and locally).
132134
try:
133-
# 6.1 Increment the local count
135+
# 5.1 Increment the local count
134136
await self.increment_pilot_local_secret_and_last_time_use(
135137
pilot_secret_id=pilot_credentials["PilotSecretID"],
136138
pilot_stamp=pilot_credentials["PilotStamp"],
137139
)
138140

139-
# 6.2 Increment the global count
141+
# 5.2 Increment the global count
140142
await self.increment_global_secret_use(
141143
secret_id=pilot_credentials["PilotSecretID"]
142144
)
143145
except Exception as e: # Generic, to catch it.
144146
# Should NOT happen
145147
# Wrapped in a try/catch to still catch in case of an error in the counters
146148
# Caught and raised here to avoid raising a 4XX error
149+
await self.conn.rollback()
150+
147151
raise DBInBadStateError(
148152
detail="This should not happen. Pilot has credentials, but has a corrupted secret."
149153
) from e
150154

155+
# 6. Delete all secrets if its count attained the secret_global_use_count_max
156+
if secret["SecretGlobalUseCountMax"]:
157+
if secret["SecretGlobalUseCount"] + 1 == secret["SecretGlobalUseCountMax"]:
158+
try:
159+
await self.delete_secrets_bulk([secret["SecretID"]])
160+
except SecretNotFoundError as e:
161+
# Should NOT happen
162+
await self.conn.rollback()
163+
raise DBInBadStateError(
164+
detail="This should not happen. Pilot has credentials, but has corrupted secret."
165+
) from e
166+
151167
async def add_pilots_bulk(
152168
self,
153169
pilot_stamps: list[str],
@@ -217,6 +233,20 @@ async def add_pilots_credentials_bulk(
217233
# Used later to add an expiration date to the credentials
218234
return secrets
219235

236+
async def delete_secrets_bulk(self, secret_ids: list[int]):
237+
"""Bulk delete secrets."""
238+
stmt = delete(PilotSecrets).where(PilotSecrets.secret_id.in_(secret_ids))
239+
240+
res = await self.conn.execute(stmt)
241+
242+
if res.rowcount != len(secret_ids):
243+
await self.conn.rollback()
244+
245+
raise SecretNotFoundError(data={"secrets": str(secret_ids)})
246+
247+
# To avoid raise condition
248+
await self.conn.commit()
249+
220250
async def insert_unique_secrets_bulk(
221251
self,
222252
hashed_secrets: list[str],

diracx-db/src/diracx/db/sql/pilot_agents/schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ class PilotToSecretMapping(PilotAgentsDBBase):
9090

9191
# Primary key is (PilotSecretID, PilotStamp) pair
9292
pilot_secret_id = Column(
93-
"PilotSecretID", Integer, ForeignKey("PilotSecrets.SecretID"), primary_key=True
93+
"PilotSecretID",
94+
Integer,
95+
ForeignKey("PilotSecrets.SecretID", ondelete="CASCADE"),
96+
primary_key=True,
9497
)
9598
pilot_stamp = Column("PilotStamp", String(32), primary_key=True)
9699
# Different from global use: only counts how many a specific pilot used a specific secret

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from diracx.core.exceptions import (
1010
BadPilotVOError,
1111
CredentialsNotFoundError,
12-
OverusedSecretError,
1312
PilotNotFoundError,
1413
SecretHasExpiredError,
1514
SecretNotFoundError,
@@ -223,7 +222,8 @@ async def test_create_pilot_and_verify_secret_too_much_secret_use(
223222
)
224223

225224
# Second login, should not work because maxed out at 1 try
226-
with pytest.raises(OverusedSecretError):
225+
# If the foreign key works, we should have "SecretNotFoundError"
226+
with pytest.raises(SecretNotFoundError):
227227
await pilot_agents_db.verify_pilot_secret(
228228
pilot_stamp=pilot_stamp,
229229
pilot_hashed_secret=pilot_hashed_secret,

diracx-routers/src/diracx/routers/auth/token.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
DiracHttpResponseError,
1515
ExpiredFlowError,
1616
InvalidCredentialsError,
17-
OverusedSecretError,
1817
PendingAuthorizationError,
1918
PilotNotFoundError,
2019
SecretHasExpiredError,
@@ -305,11 +304,6 @@ async def pilot_login(
305304
status_code=status.HTTP_401_UNAUTHORIZED,
306305
detail="bad pilot_secret",
307306
) from e
308-
except OverusedSecretError as e:
309-
raise HTTPException(
310-
status_code=status.HTTP_401_UNAUTHORIZED,
311-
detail="secret has been overused",
312-
) from e
313307
except SecretHasExpiredError as e:
314308
raise HTTPException(
315309
status_code=status.HTTP_401_UNAUTHORIZED,

diracx-routers/tests/auth/test_pilot_auth.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ async def test_create_pilot_and_verify_secret(test_client):
7474
pilot_secret_expiration_dates=[expiration_date],
7575
)
7676

77+
# ----------------- Wrong password -----------------
78+
body = {
79+
"pilot_stamp": pilot_stamp,
80+
"pilot_secret": "My 1ncr3d1bl3 t0k3n",
81+
}
82+
83+
r = test_client.post("/api/auth/pilot-login", json=body)
84+
85+
assert r.status_code == 401, r.json()
86+
assert r.json()["detail"] == "bad pilot_secret"
87+
88+
# ----------------- Good password -----------------
89+
7790
body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret}
7891

7992
r = test_client.post("/api/auth/pilot-login", json=body)
@@ -108,17 +121,6 @@ async def test_create_pilot_and_verify_secret(test_client):
108121
assert r.status_code == 401, r.json()
109122
assert r.json()["detail"] == "Invalid JWT"
110123

111-
# ----------------- Wrong password -----------------
112-
body = {
113-
"pilot_stamp": pilot_stamp,
114-
"pilot_secret": "My 1ncr3d1bl3 t0k3n",
115-
}
116-
117-
r = test_client.post("/api/auth/pilot-login", json=body)
118-
119-
assert r.status_code == 401, r.json()
120-
assert r.json()["detail"] == "bad pilot_secret"
121-
122124
# ----------------- Wrong ID -----------------
123125
body = {"pilot_stamp": "It is a stamp", "pilot_secret": secret}
124126

@@ -178,7 +180,45 @@ async def test_create_pilot_and_verify_secret(test_client):
178180
r = test_client.post("/api/auth/pilot-login", json=body)
179181

180182
assert r.status_code == 401
181-
assert r.json()["detail"] == "secret has been overused"
183+
assert r.json()["detail"] == "bad credentials"
184+
185+
186+
async def test_expired_secret(test_client):
187+
188+
# see https://github.com/DIRACGrid/diracx/blob/78e00aa57f4191034dbf643c7ed2857a93b53f60/diracx-routers/tests/pilots/test_pilot_logger.py#L37
189+
db = test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0]
190+
191+
async with db as pilot_agents_db:
192+
pilot_stamp = "pilot-stamp"
193+
vo = "lhcb"
194+
# Register a pilot
195+
await pilot_agents_db.add_pilots_bulk(
196+
vo=vo,
197+
pilot_stamps=[pilot_stamp],
198+
grid_type="grid-type",
199+
)
200+
201+
secret = "AW0nd3rfulS3cr3t"
202+
pilot_hashed_secret = hash(secret)
203+
204+
# Add creds
205+
secrets_added = await pilot_agents_db.add_pilots_credentials_bulk(
206+
pilot_stamps=[pilot_stamp],
207+
pilot_hashed_secrets=[pilot_hashed_secret],
208+
pilot_secret_use_count_max=1, # Important later
209+
vo=vo,
210+
)
211+
212+
assert len(secrets_added) == 1
213+
214+
secret_added = secrets_added[0]
215+
216+
expiration_date = secret_added["SecretCreationDate"] + timedelta(seconds=2)
217+
218+
await pilot_agents_db.set_secret_expirations_bulk(
219+
secret_ids=[secret_added["SecretID"]],
220+
pilot_secret_expiration_dates=[expiration_date],
221+
)
182222

183223
# ----------------- Secret that expired -----------------
184224
sleep(2)
@@ -190,6 +230,16 @@ async def test_create_pilot_and_verify_secret(test_client):
190230
assert r.status_code == 401
191231
assert r.json()["detail"] == "secret expired"
192232

233+
# ----------------- Secret that expired, but reused -----------------
234+
# Should be deleted by the verify_pilot_secret, because deleted
235+
236+
body = {"pilot_stamp": pilot_stamp, "pilot_secret": secret}
237+
238+
r = test_client.post("/api/auth/pilot-login", json=body)
239+
240+
assert r.status_code == 401
241+
assert r.json()["detail"] == "bad credentials"
242+
193243

194244
async def test_create_pilots_with_credentials(normal_test_client):
195245
# Lots of request, to validate that it returns the credentials in the same order as the input references

0 commit comments

Comments
 (0)