Skip to content

Commit 3f7c632

Browse files
feat: Adding duration to secrets
1 parent 8aff5d9 commit 3f7c632

File tree

6 files changed

+121
-11
lines changed

6 files changed

+121
-11
lines changed

diracx-core/src/diracx/core/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class AuthSettings(ServiceSettingsBase):
161161
allowed_redirects: list[str] = []
162162
device_flow_expiration_seconds: int = 600
163163
authorization_flow_expiration_seconds: int = 300
164+
pilot_secret_expire_seconds: int = 600
164165

165166
# State key is used to encrypt/decrypt the state dict passed to the IAM
166167
state_key: FernetKey

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from datetime import datetime, timezone
44

5-
from sqlalchemy import insert, select, update
5+
from sqlalchemy import DateTime, insert, select, update
66
from sqlalchemy.exc import IntegrityError, NoResultFound
77

88
from diracx.core.exceptions import (
@@ -85,8 +85,13 @@ async def verify_pilot_secret(
8585

8686
stmt = (
8787
select(PilotRegistrations)
88+
.with_for_update()
8889
.where(PilotRegistrations.pilot_hashed_secret == pilot_hashed_secret)
8990
.where(PilotRegistrations.pilot_id == pilot_id)
91+
.where(
92+
PilotRegistrations.pilot_secret_expiration_date
93+
> datetime.now(tz=timezone.utc)
94+
)
9095
)
9196

9297
# Execute the request
@@ -95,12 +100,16 @@ async def verify_pilot_secret(
95100
result = res.fetchone()
96101

97102
if result is None:
98-
raise AuthorizationError(detail="bad pilot_id / pilot_secret")
103+
raise AuthorizationError(
104+
detail="bad pilot_id / pilot_secret or secret has expired"
105+
)
99106

100107
# Increment the count
101108
await self.increment_pilot_secret_use(pilot_id=pilot_id)
102109

103-
async def add_pilot_credentials(self, pilot_id: int, pilot_hashed_secret: str):
110+
async def add_pilot_credentials(
111+
self, pilot_id: int, pilot_hashed_secret: str
112+
) -> datetime:
104113

105114
stmt = insert(PilotRegistrations).values(
106115
pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret
@@ -116,6 +125,22 @@ async def add_pilot_credentials(self, pilot_id: int, pilot_hashed_secret: str):
116125
pilot_id=pilot_id, detail="this pilot has already credentials"
117126
) from e
118127

128+
added_creds = await self.get_pilot_creds_by_id(pilot_id)
129+
130+
return added_creds["PilotSecretCreationDate"]
131+
132+
async def set_pilot_credentials_expiration(
133+
self, pilot_id: int, pilot_secret_expiration_date: DateTime
134+
):
135+
# Prepare the update statement
136+
stmt = (
137+
update(PilotRegistrations)
138+
.values(pilot_secret_expiration_date=pilot_secret_expiration_date)
139+
.where(PilotRegistrations.pilot_id == pilot_id)
140+
)
141+
142+
await self.conn.execute(stmt)
143+
119144
async def fetch_all_pilots(self):
120145
stmt = select(PilotRegistrations).with_for_update()
121146
result = await self.conn.execute(stmt)
@@ -134,3 +159,12 @@ async def get_pilot_by_reference(self, pilot_ref: str):
134159

135160
# We assume it is unique...
136161
return dict((await self.conn.execute(stmt)).one()._mapping)
162+
163+
async def get_pilot_creds_by_id(self, pilot_id: int):
164+
stmt = (
165+
select(PilotRegistrations)
166+
.with_for_update()
167+
.where(PilotRegistrations.pilot_id == pilot_id)
168+
)
169+
170+
return dict((await self.conn.execute(stmt)).one()._mapping)

diracx-db/src/diracx/db/sql/pilot_agents/schema.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlalchemy import DateTime, Double, ForeignKey, Index, Integer, String, Text
44
from sqlalchemy.orm import declarative_base
55

6-
from ..utils import Column, EnumBackedBool, NullColumn
6+
from ..utils import Column, DateNowColumn, EnumBackedBool, NullColumn
77

88
PilotAgentsDBBase = declarative_base()
99

@@ -64,3 +64,7 @@ class PilotRegistrations(PilotAgentsDBBase):
6464
)
6565
pilot_hashed_secret = Column("PilotHashedSecret", String(64))
6666
pilot_secret_use_count = Column("PilotSecretUseCount", Integer, default=0)
67+
pilot_secret_creation_time = DateNowColumn("PilotSecretCreationDate")
68+
pilot_secret_expiration_date = NullColumn(
69+
"PilotSecretExpirationDate", DateTime(timezone=True)
70+
)

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from datetime import timedelta
4+
from time import sleep
5+
36
import pytest
47
from sqlalchemy.exc import NoResultFound
58

@@ -74,10 +77,16 @@ async def test_create_pilot_and_verify_secret(pilot_agents_db: PilotAgentsDB):
7477
pilot_hashed_secret = hash(secret)
7578

7679
# Add creds
77-
await pilot_agents_db.add_pilot_credentials(
80+
date_added = await pilot_agents_db.add_pilot_credentials(
7881
pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret
7982
)
8083

84+
expiration_date = date_added + timedelta(seconds=10)
85+
86+
await pilot_agents_db.set_pilot_credentials_expiration(
87+
pilot_id=pilot_id, pilot_secret_expiration_date=expiration_date
88+
)
89+
8190
assert secret is not None
8291

8392
await pilot_agents_db.verify_pilot_secret(
@@ -94,3 +103,46 @@ async def test_create_pilot_and_verify_secret(pilot_agents_db: PilotAgentsDB):
94103
pilot_job_reference="I am a spider",
95104
pilot_hashed_secret=pilot_hashed_secret,
96105
)
106+
107+
108+
async def test_create_pilot_and_verify_secret_with_delay(
109+
pilot_agents_db: PilotAgentsDB,
110+
):
111+
112+
async with pilot_agents_db as pilot_agents_db:
113+
pilot_reference = "pilot-reference-test"
114+
# Register a pilot
115+
await pilot_agents_db.add_pilot_references(
116+
vo="lhcb",
117+
pilot_ref=[pilot_reference],
118+
grid_type="grid-type",
119+
)
120+
121+
pilot = await pilot_agents_db.get_pilot_by_reference(pilot_reference)
122+
123+
pilot_id = pilot["PilotID"]
124+
125+
secret = "AW0nd3rfulS3cr3t"
126+
pilot_hashed_secret = hash(secret)
127+
128+
# Add creds
129+
date_added = await pilot_agents_db.add_pilot_credentials(
130+
pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret
131+
)
132+
133+
expiration_date = date_added + timedelta(seconds=1)
134+
135+
await pilot_agents_db.set_pilot_credentials_expiration(
136+
pilot_id=pilot_id, pilot_secret_expiration_date=expiration_date
137+
)
138+
139+
assert secret is not None
140+
141+
# So that the secret expires
142+
sleep(3)
143+
144+
with pytest.raises(AuthorizationError):
145+
await pilot_agents_db.verify_pilot_secret(
146+
pilot_job_reference=pilot_reference,
147+
pilot_hashed_secret=pilot_hashed_secret,
148+
)

diracx-logic/src/diracx/logic/auth/pilot.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
from datetime import timedelta
34
from secrets import token_hex
45

6+
from diracx.core.settings import AuthSettings
57
from diracx.db.sql import PilotAgentsDB
68

79
# TODO: Move this hash function in diracx-logic, and rename it
@@ -13,23 +15,32 @@ def generate_pilot_secret() -> str:
1315
return token_hex(32)
1416

1517

16-
async def add_pilot_credentials(pilot_id: int, pilot_db: PilotAgentsDB) -> str:
18+
async def add_pilot_credentials(
19+
pilot_id: int, pilot_db: PilotAgentsDB, settings: AuthSettings
20+
) -> str:
1721

1822
# Get a random string
1923
# Can be customized
2024
random_secret = generate_pilot_secret()
2125

2226
hashed_secret = hash(random_secret)
2327

24-
await pilot_db.add_pilot_credentials(
28+
date_added = await pilot_db.add_pilot_credentials(
2529
pilot_id=pilot_id, pilot_hashed_secret=hashed_secret
2630
)
2731

32+
# Helps compatibility between sql engines
33+
await pilot_db.set_pilot_credentials_expiration(
34+
pilot_id=pilot_id,
35+
pilot_secret_expiration_date=date_added # type: ignore
36+
+ timedelta(seconds=settings.pilot_secret_expire_seconds),
37+
)
38+
2839
return random_secret
2940

3041

3142
def generate_pilot_scope(pilot: dict) -> str:
32-
return f"vo:{pilot['VO']}"
43+
return f"vo:{pilot['VO']} property:LimitedDelegation property:GenericPilot"
3344

3445

3546
async def try_login(

diracx-routers/tests/auth/test_pilot_auth.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from datetime import timedelta
4+
35
import pytest
46

57
from diracx.db.sql.pilot_agents.db import PilotAgentsDB
@@ -54,10 +56,16 @@ async def test_create_pilot_and_verify_secret(test_client):
5456
pilot_id = pilot["PilotID"]
5557

5658
# Add credentials to this pilot
57-
await pilot_agents_db.add_pilot_credentials(
59+
date_added = await pilot_agents_db.add_pilot_credentials(
5860
pilot_id=pilot_id, pilot_hashed_secret=pilot_hashed_secret
5961
)
6062

63+
expiration_date = date_added + timedelta(seconds=2)
64+
65+
await pilot_agents_db.set_pilot_credentials_expiration(
66+
pilot_id=pilot_id, pilot_secret_expiration_date=expiration_date
67+
)
68+
6169
request_data = {"pilot_job_reference": pilot_reference, "pilot_secret": secret}
6270

6371
r = test_client.post(
@@ -66,7 +74,7 @@ async def test_create_pilot_and_verify_secret(test_client):
6674
headers={"Content-Type": "application/json"},
6775
)
6876

69-
assert r.status_code == 200
77+
assert r.status_code == 200, r.json()
7078

7179
access_token = r.json()["access_token"]
7280
refresh_token = r.json()["refresh_token"]
@@ -109,7 +117,7 @@ async def test_create_pilot_and_verify_secret(test_client):
109117
)
110118

111119
assert r.status_code == 401, r.json()
112-
assert r.json()["detail"] == "bad pilot_id / pilot_secret"
120+
assert r.json()["detail"] == "bad pilot_id / pilot_secret or secret has expired"
113121

114122
# ----------------- Wrong ID -----------------
115123
request_data = {"pilot_job_reference": "It is a reference", "pilot_secret": secret}

0 commit comments

Comments
 (0)