Skip to content

Commit a171d01

Browse files
feat: Added pilot login to the cli
1 parent 54951a6 commit a171d01

File tree

5 files changed

+54
-28
lines changed

5 files changed

+54
-28
lines changed

diracx-cli/src/diracx/cli/auth.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,28 @@ async def logout():
144144
def callback(output_format: Optional[str] = None):
145145
if output_format is not None:
146146
os.environ["DIRACX_OUTPUT_FORMAT"] = output_format
147+
148+
149+
@app.async_command()
150+
async def pilot_login(
151+
pilot_reference: Optional[str] = typer.Argument(None, help="Pilot job reference."),
152+
pilot_secret: Optional[str] = typer.Argument(
153+
None, help="Pilot secret given by DiracX."
154+
),
155+
):
156+
"""Login to the DIRAC system using a pilot exchange (a [reference,secret] pair)."""
157+
async with AsyncDiracClient() as api:
158+
159+
try:
160+
response = await api.auth.pilot_login(
161+
pilot_job_reference=pilot_reference, pilot_secret=pilot_secret
162+
)
163+
except Exception as e:
164+
print(f"Error signing in DiracX {e!r}")
165+
return
166+
167+
# Save credentials
168+
write_credentials(response)
169+
credentials_path = get_diracx_preferences().credentials_path
170+
print(f"Saved credentials to {credentials_path}")
171+
print("\nLogin successful!")

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from datetime import datetime, timezone
4-
from typing import Sequence
54

65
from sqlalchemy import insert, select, update
76
from sqlalchemy.exc import IntegrityError, NoResultFound
@@ -27,7 +26,7 @@ async def add_pilot_references(
2726
vo: str,
2827
grid_type: str = "DIRAC",
2928
pilot_stamps: dict | None = None,
30-
) -> Sequence: # Return a list of primary keys
29+
) -> None:
3130

3231
if pilot_stamps is None:
3332
pilot_stamps = {}
@@ -49,17 +48,9 @@ async def add_pilot_references(
4948
]
5049

5150
# Insert multiple rows in a single execute call and use 'returning' to get primary keys
52-
stmt = (
53-
insert(PilotAgents).values(values).returning(PilotAgents.pilot_id)
54-
) # Assuming 'id' is the primary key
55-
result = await self.conn.execute(stmt)
56-
57-
# Use .scalars() and .all() to get the primary keys directly in a list
58-
primary_keys = (
59-
result.scalars().all()
60-
) # This returns a flat list of primary keys
51+
stmt = insert(PilotAgents).values(values) # Assuming 'id' is the primary key
6152

62-
return primary_keys
53+
await self.conn.execute(stmt)
6354

6455
async def increment_pilot_secret_use(
6556
self,

diracx-db/tests/pilot_agents/test_pilot_agents_db.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ async def test_create_pilot_and_verify_secret(pilot_agents_db: PilotAgentsDB):
5959

6060
async with pilot_agents_db as pilot_agents_db:
6161
pilot_reference = "pilot-reference-test"
62-
pilot_ids = await pilot_agents_db.add_pilot_references(
62+
# Register a pilot
63+
await pilot_agents_db.add_pilot_references(
6364
vo="lhcb",
6465
pilot_ref=[pilot_reference],
6566
grid_type="grid-type",
6667
)
6768

68-
assert len(pilot_ids) == 1
69+
pilot = await pilot_agents_db.get_pilot_by_reference(pilot_reference)
6970

70-
# Only one element
71-
pilot_id = pilot_ids[0]
71+
pilot_id = pilot["PilotID"]
7272

7373
secret = "AW0nd3rfulS3cr3t"
7474
pilot_hashed_secret = hash(secret)

diracx-routers/src/diracx/routers/auth/pilot.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
InvalidCredentialsError,
1414
PilotNotFoundError,
1515
)
16+
from diracx.core.models import TokenResponse
1617
from diracx.logic.auth.pilot import try_login
1718
from diracx.logic.auth.token import create_token, generate_pilot_tokens
1819
from diracx.routers.pilots.access_policies import RegisteredPilotAccessPolicyCallable
@@ -71,10 +72,15 @@ async def pilot_login(
7172
status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)
7273
) from e
7374

74-
return {
75-
"access_token": create_token(access_token, settings),
76-
"refresh_token": create_token(refresh_token, settings),
77-
}
75+
serialized_access_token = create_token(access_token, settings=settings)
76+
77+
serialized_refresh_token = create_token(refresh_token, settings=settings)
78+
79+
return TokenResponse(
80+
access_token=serialized_access_token,
81+
expires_in=settings.access_token_expire_minutes * 60,
82+
refresh_token=serialized_refresh_token,
83+
)
7884

7985

8086
@router.post("/pilot-refresh-token")
@@ -110,7 +116,12 @@ async def refresh_pilot_tokens(
110116
status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)
111117
) from e
112118

113-
return {
114-
"access_token": create_token(new_access_token, settings),
115-
"refresh_token": create_token(new_refresh_token, settings),
116-
}
119+
serialized_access_token = create_token(new_access_token, settings=settings)
120+
121+
serialized_refresh_token = create_token(new_access_token, settings=settings)
122+
123+
return TokenResponse(
124+
access_token=serialized_access_token,
125+
expires_in=settings.access_token_expire_minutes * 60,
126+
refresh_token=serialized_refresh_token,
127+
)

diracx-routers/tests/auth/test_pilot_auth.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,15 @@ async def test_create_pilot_and_verify_secret(test_client):
4343

4444
async with db as pilot_agents_db:
4545
# Register a pilot
46-
pilot_ids = await pilot_agents_db.add_pilot_references(
46+
await pilot_agents_db.add_pilot_references(
4747
vo=pilot_vo,
4848
pilot_ref=[pilot_reference],
4949
grid_type="grid-type",
5050
)
5151

52-
assert len(pilot_ids) == 1
52+
pilot = await pilot_agents_db.get_pilot_by_reference(pilot_reference)
5353

54-
# Only one element
55-
pilot_id = pilot_ids[0]
54+
pilot_id = pilot["PilotID"]
5655

5756
# Add credentials to this pilot
5857
await pilot_agents_db.add_pilot_credentials(

0 commit comments

Comments
 (0)