Skip to content

Commit 88ac74e

Browse files
fix: Moving pilot auth to auth endpoint, and fixes
1 parent b39eff3 commit 88ac74e

File tree

8 files changed

+134
-62
lines changed

8 files changed

+134
-62
lines changed

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def add_pilot_references(
5656

5757
async def increment_pilot_secret_use(
5858
self,
59-
pilot_reference: str,
59+
pilot_id: int,
6060
) -> None:
6161

6262
# Prepare the update statement
@@ -65,26 +65,22 @@ async def increment_pilot_secret_use(
6565
.values(
6666
pilot_secret_use_count=PilotRegistrations.pilot_secret_use_count + 1
6767
)
68-
.where(PilotRegistrations.pilot_id == PilotAgents.pilot_id)
69-
.where(PilotAgents.pilot_job_reference == pilot_reference)
68+
.where(PilotRegistrations.pilot_id == pilot_id)
7069
)
7170

7271
# Execute the update using the connection
7372
res = await self.conn.execute(stmt)
7473

7574
if res.rowcount == 0:
76-
raise PilotNotFoundError(pilot_ref=pilot_reference)
75+
raise PilotNotFoundError(pilot_id=pilot_id)
7776

78-
async def verify_pilot_secret(
79-
self, pilot_reference: str, pilot_secret: str
80-
) -> None:
77+
async def verify_pilot_secret(self, pilot_id: int, pilot_secret: str) -> None:
8178
hashed_secret = hash(pilot_secret)
8279

8380
stmt = (
8481
select(PilotRegistrations)
85-
.join(PilotAgents, PilotRegistrations.pilot_id == PilotAgents.pilot_id)
8682
.where(PilotRegistrations.pilot_hashed_secret == hashed_secret)
87-
.where(PilotAgents.pilot_job_reference == pilot_reference)
83+
.where(PilotRegistrations.pilot_id == pilot_id)
8884
)
8985

9086
# Execute the request
@@ -96,7 +92,7 @@ async def verify_pilot_secret(
9692
raise AuthorizationError(detail="bad pilot_reference / pilot_secret")
9793

9894
# Increment the count
99-
await self.increment_pilot_secret_use(pilot_reference=pilot_reference)
95+
await self.increment_pilot_secret_use(pilot_id=pilot_id)
10096
return
10197

10298
async def register_new_pilot(

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,14 @@ async def test_create_pilot_and_verify_secret(pilot_agents_db: PilotAgentsDB):
7373
assert secret is not None
7474

7575
await pilot_agents_db.verify_pilot_secret(
76-
pilot_reference=pilot_ref, pilot_secret=secret
76+
pilot_id=new_pilot_id, pilot_secret=secret
7777
)
7878

7979
with pytest.raises(AuthorizationError):
8080
await pilot_agents_db.verify_pilot_secret(
81-
pilot_reference=pilot_ref, pilot_secret="I love stawberries :)"
81+
pilot_id=new_pilot_id, pilot_secret="I love stawberries :)"
82+
)
83+
84+
await pilot_agents_db.verify_pilot_secret(
85+
pilot_id=63000, pilot_secret=secret
8286
)

diracx-logic/src/diracx/logic/auth/token.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def get_oidc_token_info_from_device_flow(
106106
# TODO: use HTTPException while still respecting the standard format
107107
# required by the RFC
108108
if info["Status"] != FlowStatus.READY:
109-
# That should never ever happen
109+
# That should exnever ever happen
110110
raise NotImplementedError(f"Unexpected flow status {info['status']!r}")
111111
return (oidc_token_info, scope)
112112

@@ -245,6 +245,39 @@ async def perform_legacy_exchange(
245245
)
246246

247247

248+
def get_verified_preferred_username(
249+
config: Config,
250+
oidc_token_info: dict,
251+
dirac_group: str,
252+
properties: set[str],
253+
sub: str,
254+
vo: str,
255+
):
256+
if user_info := config.Registry[vo].Users.get(sub):
257+
preferred_username = user_info.PreferedUsername
258+
else:
259+
preferred_username = oidc_token_info.get("preferred_username", sub)
260+
raise NotImplementedError(
261+
"Dynamic registration of users is not yet implemented"
262+
)
263+
264+
# Check that the subject is part of the dirac users
265+
if sub not in config.Registry[vo].Groups[dirac_group].Users:
266+
raise PermissionError(
267+
f"User is not a member of the requested group ({preferred_username}, {dirac_group})"
268+
)
269+
270+
# Check that the user properties are valid
271+
allowed_user_properties = get_allowed_user_properties(config, sub, vo)
272+
if not properties.issubset(allowed_user_properties):
273+
raise PermissionError(
274+
f"{' '.join(properties - allowed_user_properties)} are not valid properties "
275+
f"for user {preferred_username}, available values: {' '.join(allowed_user_properties)}"
276+
)
277+
278+
return preferred_username
279+
280+
248281
async def exchange_token(
249282
auth_db: AuthDB,
250283
scope: str,
@@ -255,6 +288,7 @@ async def exchange_token(
255288
*,
256289
refresh_token_expire_minutes: int | None = None,
257290
legacy_exchange: bool = False,
291+
pilot_exchange: bool = False,
258292
) -> tuple[AccessTokenPayload, RefreshTokenPayload]:
259293
"""Method called to exchange the OIDC token for a DIRAC generated access token."""
260294
# Extract dirac attributes from the OIDC scope
@@ -265,28 +299,17 @@ async def exchange_token(
265299

266300
# Extract attributes from the OIDC token details
267301
sub = oidc_token_info["sub"]
268-
if user_info := config.Registry[vo].Users.get(sub):
269-
preferred_username = user_info.PreferedUsername
270-
else:
271-
preferred_username = oidc_token_info.get("preferred_username", sub)
272-
raise NotImplementedError(
273-
"Dynamic registration of users is not yet implemented"
274-
)
275302

276-
# Check that the subject is part of the dirac users
277-
if sub not in config.Registry[vo].Groups[dirac_group].Users:
278-
raise PermissionError(
279-
f"User is not a member of the requested group ({preferred_username}, {dirac_group})"
280-
)
303+
preferred_username = None
281304

282-
# Check that the user properties are valid
283-
allowed_user_properties = get_allowed_user_properties(config, sub, vo)
284-
if not properties.issubset(allowed_user_properties):
285-
raise PermissionError(
286-
f"{' '.join(properties - allowed_user_properties)} are not valid properties "
287-
f"for user {preferred_username}, available values: {' '.join(allowed_user_properties)}"
305+
if not pilot_exchange:
306+
preferred_username = get_verified_preferred_username(
307+
config, oidc_token_info, dirac_group, properties, sub, vo
288308
)
289309

310+
else:
311+
preferred_username = oidc_token_info["pilot_reference"]
312+
290313
# Merge the VO with the subject to get a unique DIRAC sub
291314
sub = f"{vo}:{sub}"
292315

diracx-routers/src/diracx/routers/auth/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .authorize_code_flow import router as authorize_code_flow_router
66
from .device_flow import router as device_flow_router
77
from .management import router as management_router
8+
from .pilot_auth import router as pilot_auth_router
89
from .token import router as token_router
910
from .utils import has_properties
1011

@@ -13,5 +14,6 @@
1314
router.include_router(management_router)
1415
router.include_router(authorize_code_flow_router)
1516
router.include_router(token_router)
17+
router.include_router(pilot_auth_router)
1618

1719
__all__ = ["has_properties", "verify_dirac_access_token"]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from fastapi import HTTPException, status
4+
5+
from diracx.core.exceptions import AuthorizationError
6+
from diracx.logic.auth.token import create_token, exchange_token
7+
8+
from ..dependencies import (
9+
AuthDB,
10+
AuthSettings,
11+
AvailableSecurityProperties,
12+
Config,
13+
PilotAgentsDB,
14+
)
15+
from ..fastapi_classes import DiracxRouter
16+
17+
router = DiracxRouter(require_auth=False)
18+
19+
20+
@router.post("/pilot-login")
21+
async def pilot_login(
22+
pilot_db: PilotAgentsDB,
23+
auth_db: AuthDB,
24+
pilot_id: int,
25+
pilot_secret: str,
26+
config: Config,
27+
settings: AuthSettings,
28+
available_properties: AvailableSecurityProperties,
29+
):
30+
"""Endpoint without policy, the pilot uses only its secret."""
31+
try:
32+
await pilot_db.verify_pilot_secret(pilot_id=pilot_id, pilot_secret=pilot_secret)
33+
except AuthorizationError as e:
34+
raise HTTPException(
35+
status_code=status.HTTP_401_UNAUTHORIZED, detail=e.detail
36+
) from e
37+
38+
pilot = await pilot_db.get_pilot_by_id(pilot_id=pilot_id)
39+
40+
pilot_info = {
41+
"pilot_reference": pilot["PilotJobReference"],
42+
"sub": pilot["PilotJobReference"],
43+
}
44+
45+
# return pilot_info
46+
access_token, refresh_token = await exchange_token(
47+
auth_db=auth_db,
48+
scope="vo:diracAdmin",
49+
oidc_token_info=pilot_info,
50+
config=config,
51+
settings=settings,
52+
available_properties=available_properties,
53+
pilot_exchange=True,
54+
)
55+
56+
return [create_token(access_token, settings), create_token(refresh_token, settings)]

diracx-routers/src/diracx/routers/pilots/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import logging
44

55
from ..fastapi_classes import DiracxRouter
6-
from .auth import router as auth_router
6+
from .debug import router as debug_router
77

88
logger = logging.getLogger(__name__)
99

10-
router = DiracxRouter(require_auth=False)
11-
router.include_router(auth_router)
10+
router = DiracxRouter()
11+
router.include_router(debug_router)

diracx-routers/src/diracx/routers/pilots/access_policies.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi import Depends, HTTPException, status
88

99
# TODO: DEBUG
10-
from diracx.core.properties import GENERIC_PILOT, LIMITED_DELEGATION, NORMAL_USER
10+
from diracx.core.properties import GENERIC_PILOT, LIMITED_DELEGATION
1111
from diracx.db.sql import PilotAgentsDB
1212
from diracx.routers.access_policies import BaseAccessPolicy
1313
from diracx.routers.utils.users import AuthorizedUserInfo
@@ -33,22 +33,15 @@ async def policy(
3333
pilot_info: AuthorizedUserInfo,
3434
/,
3535
*,
36-
action: ActionType | None = None,
3736
pilot_db: PilotAgentsDB | None = None,
3837
):
3938

40-
assert pilot_db, "pilot_db is a mandatory parameter"
41-
4239
if GENERIC_PILOT in pilot_info.properties:
4340
return
4441

4542
if LIMITED_DELEGATION in pilot_info.properties:
4643
return
4744

48-
# TODO: DEBUG
49-
if NORMAL_USER in pilot_info.properties:
50-
return
51-
5245
raise HTTPException(
5346
status.HTTP_401_UNAUTHORIZED, "you don't have the right properties"
5447
)

diracx-routers/src/diracx/routers/pilots/auth.py renamed to diracx-routers/src/diracx/routers/pilots/debug.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
22

33
from os import getenv
4+
from typing import Annotated
45

5-
from fastapi import HTTPException, status
6+
from fastapi import (
7+
Depends,
8+
HTTPException,
9+
status,
10+
)
611

712
from diracx.core.exceptions import (
8-
AuthorizationError,
913
PilotAlreadyExistsError,
1014
PilotNotFoundError,
1115
)
@@ -14,31 +18,25 @@
1418
PilotAgentsDB,
1519
)
1620
from ..fastapi_classes import DiracxRouter
21+
from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token
22+
from .access_policies import RegisteredPilotAccessPolicyCallable
1723

18-
router = DiracxRouter(require_auth=False)
19-
20-
21-
@router.post("/pilot-login")
22-
async def pilot_login(pilot_db: PilotAgentsDB, pilot_reference: str, pilot_secret: str):
23-
"""Endpoint without policy, the pilot uses only its secret."""
24-
try:
25-
await pilot_db.verify_pilot_secret(
26-
pilot_reference=pilot_reference, pilot_secret=pilot_secret
27-
)
28-
except AuthorizationError as e:
29-
raise HTTPException(
30-
status_code=status.HTTP_401_UNAUTHORIZED, detail=e.detail
31-
) from e
32-
33-
# TODO: Returns a JWT
34-
return {"ref": pilot_reference}
35-
24+
router = DiracxRouter()
3625

3726
# Debug only route
3827
# Keep it?
3928
# TODO: Add to the env?
4029
if getenv("production") is None:
4130

31+
@router.get("/info")
32+
async def get_pilot_info(
33+
check_permissions: RegisteredPilotAccessPolicyCallable,
34+
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
35+
):
36+
await check_permissions()
37+
38+
return user_info
39+
4240
@router.post("/register-pilot")
4341
async def add_new_pilot(
4442
pilot_db: PilotAgentsDB,

0 commit comments

Comments
 (0)