Skip to content

Commit 2404324

Browse files
feat: Pilot can exchange against nothing a secret
1 parent 6b9f48f commit 2404324

File tree

7 files changed

+172
-2
lines changed

7 files changed

+172
-2
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,12 @@ def __init__(self, job_id, detail: str | None = None):
101101
super().__init__(
102102
f"Error concerning job {job_id}" + (": {detail} " if detail else "")
103103
)
104+
105+
106+
class PilotNotFoundError(Exception):
107+
def __init__(self, pilot_ref: str, detail: str | None = None):
108+
self.pilot_ref = pilot_ref
109+
self.detail = detail
110+
super().__init__(
111+
f"Pilot {pilot_ref} not found" + (": {detail} " if detail else "")
112+
)

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from __future__ import annotations
22

3+
import hashlib
34
from datetime import datetime, timezone
45

5-
from sqlalchemy import insert
6+
from sqlalchemy import insert, select, update
7+
8+
from diracx.core.exceptions import (
9+
AuthorizationError,
10+
PilotNotFoundError,
11+
)
612

713
from ..utils import BaseSQLDB
8-
from .schema import PilotAgents, PilotAgentsDBBase
14+
from .schema import PilotAgents, PilotAgentsDBBase, PilotRegistrations
915

1016

1117
class PilotAgentsDB(BaseSQLDB):
@@ -44,3 +50,48 @@ async def add_pilot_references(
4450
stmt = insert(PilotAgents).values(values)
4551
await self.conn.execute(stmt)
4652
return
53+
54+
async def increment_pilot_secret_use(
55+
self,
56+
pilot_reference: str,
57+
) -> None:
58+
59+
# Prepare the update statement
60+
stmt = (
61+
update(PilotRegistrations)
62+
.values(
63+
pilot_secret_use_count=PilotRegistrations.pilot_secret_use_count + 1
64+
)
65+
.where(PilotRegistrations.pilot_id == PilotAgents.pilot_id)
66+
.where(PilotAgents.pilot_job_reference == pilot_reference)
67+
)
68+
69+
# Execute the update using the connection
70+
res = await self.conn.execute(stmt)
71+
72+
if res.rowcount == 0:
73+
raise PilotNotFoundError(pilot_ref=pilot_reference)
74+
75+
async def register_pilot(self, pilot_reference: str, pilot_secret: str):
76+
hashed_secret = hashlib.sha256(pilot_secret.encode()).hexdigest()
77+
78+
print("hash", hashed_secret)
79+
80+
stmt = (
81+
select(PilotRegistrations)
82+
.join(PilotAgents, PilotRegistrations.pilot_id == PilotAgents.pilot_id)
83+
.where(PilotRegistrations.pilot_hashed_secret == hashed_secret)
84+
.where(PilotAgents.pilot_job_reference == pilot_reference)
85+
)
86+
87+
# Execute the request
88+
res = await self.conn.execute(stmt)
89+
90+
result = res.fetchone()
91+
92+
if result is None:
93+
raise AuthorizationError(detail="bad pilot_reference / pilot_secret")
94+
95+
# Increment the count
96+
await self.increment_pilot_secret_use(pilot_reference=pilot_reference)
97+
return

diracx-db/src/diracx/db/sql/pilot_agents/schema.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,11 @@ class PilotOutput(PilotAgentsDBBase):
5858
pilot_id = Column("PilotID", Integer, primary_key=True)
5959
std_output = Column("StdOutput", Text)
6060
std_error = Column("StdError", Text)
61+
62+
63+
class PilotRegistrations(PilotAgentsDBBase):
64+
__tablename__ = "PilotRegistrations"
65+
66+
pilot_id = Column("PilotID", Integer, primary_key=True)
67+
pilot_hashed_secret = Column("PilotHashedSecret", String(64))
68+
pilot_secret_use_count = Column("PilotSecretUseCount", Integer, default=0)

diracx-routers/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ types = [
4444
jobs = "diracx.routers.jobs:router"
4545
config = "diracx.routers.configuration:router"
4646
auth = "diracx.routers.auth:router"
47+
pilots = "diracx.routers.pilots:router"
4748
".well-known" = "diracx.routers.auth.well_known:router"
4849

4950
[project.entry-points."diracx.access_policies"]
5051
WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy"
5152
SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy"
53+
RegisteredPilotAccessPolicy = "diracx.routers.pilots.access_policies:RegisteredPilotAccessPolicy"
5254

5355
# Minimum version of the client supported
5456
[project.entry-points."diracx.min_client_version"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from ..fastapi_classes import DiracxRouter
6+
from .auth import router as auth_router
7+
8+
logger = logging.getLogger(__name__)
9+
10+
router = DiracxRouter(require_auth=False)
11+
router.include_router(auth_router)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from enum import StrEnum, auto
5+
from typing import Annotated
6+
7+
from fastapi import Depends, HTTPException, status
8+
9+
# TODO: DEBUG
10+
from diracx.core.properties import GENERIC_PILOT, LIMITED_DELEGATION, NORMAL_USER
11+
from diracx.db.sql import PilotAgentsDB
12+
from diracx.routers.access_policies import BaseAccessPolicy
13+
from diracx.routers.utils.users import AuthorizedUserInfo
14+
15+
16+
class ActionType(StrEnum):
17+
#: Create a job or a sandbox
18+
CREATE = auto()
19+
#: Check job status, download a sandbox
20+
READ = auto()
21+
#: delete, kill, remove, set status, etc of a job
22+
#: delete or assign a sandbox
23+
MANAGE = auto()
24+
#: Search
25+
QUERY = auto()
26+
27+
28+
class RegisteredPilotAccessPolicy(BaseAccessPolicy):
29+
30+
@staticmethod
31+
async def policy(
32+
policy_name: str,
33+
pilot_info: AuthorizedUserInfo,
34+
/,
35+
*,
36+
action: ActionType | None = None,
37+
pilot_db: PilotAgentsDB | None = None,
38+
):
39+
40+
assert pilot_db, "pilot_db is a mandatory parameter"
41+
42+
if GENERIC_PILOT in pilot_info.properties:
43+
return
44+
45+
if LIMITED_DELEGATION in pilot_info.properties:
46+
return
47+
48+
# TODO: DEBUG
49+
if NORMAL_USER in pilot_info.properties:
50+
return
51+
52+
raise HTTPException(
53+
status.HTTP_401_UNAUTHORIZED, "you don't have the right properties"
54+
)
55+
return
56+
57+
58+
RegisteredPilotAccessPolicyCallable = Annotated[
59+
Callable, Depends(RegisteredPilotAccessPolicy.check)
60+
]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from __future__ import annotations
2+
3+
from fastapi import HTTPException, status
4+
5+
from diracx.core.exceptions import (
6+
AuthorizationError,
7+
)
8+
9+
from ..dependencies import (
10+
PilotAgentsDB,
11+
)
12+
from ..fastapi_classes import DiracxRouter
13+
14+
router = DiracxRouter(require_auth=False)
15+
16+
17+
@router.post("/register")
18+
async def search(pilot_db: PilotAgentsDB, pilot_reference: str, pilot_secret: str):
19+
"""Endpoint without policy, the pilot uses only its secret."""
20+
try:
21+
await pilot_db.register_pilot(
22+
pilot_reference=pilot_reference, pilot_secret=pilot_secret
23+
)
24+
except AuthorizationError as e:
25+
raise HTTPException(
26+
status_code=status.HTTP_401_UNAUTHORIZED, detail=e.detail
27+
) from e
28+
29+
return pilot_reference

0 commit comments

Comments
 (0)