Skip to content

Commit 3e358bb

Browse files
fix: Add patch for the pilot client
1 parent 6674d37 commit 3e358bb

File tree

3 files changed

+101
-3
lines changed

3 files changed

+101
-3
lines changed

diracx-client/src/diracx/client/patches/pilots/aio.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616
from azure.core.tracing.decorator_async import distributed_trace_async
1717

1818
from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations
19-
from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs
19+
from .common import (
20+
make_search_body,
21+
make_summary_body,
22+
make_add_pilot_stamps_body,
23+
make_update_pilot_fields_body,
24+
SearchKwargs,
25+
SummaryKwargs,
26+
AddPilotStampsKwargs,
27+
UpdatePilotFieldsKwargs
28+
)
2029

2130
# We're intentionally ignoring overrides here because we want to change the interface.
2231
# mypy: disable-error-code=override
@@ -32,3 +41,13 @@ async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]:
3241
async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]:
3342
"""TODO"""
3443
return await super().summary(**make_summary_body(**kwargs))
44+
45+
@distributed_trace_async
46+
async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None:
47+
"""TODO"""
48+
return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs))
49+
50+
@distributed_trace_async
51+
async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None:
52+
"""TODO"""
53+
return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs))

diracx-client/src/diracx/client/patches/pilots/common.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
"SearchKwargs",
88
"make_summary_body",
99
"SummaryKwargs",
10+
"AddPilotStampsKwargs",
11+
"make_add_pilot_stamps_body",
12+
"UpdatePilotFieldsKwargs",
13+
"make_update_pilot_fields_body"
1014
]
1115

1216
import json
1317
from io import BytesIO
1418
from typing import Any, IO, TypedDict, Unpack, cast, Literal
1519

16-
from diracx.core.models import SearchSpec
20+
from diracx.core.models import SearchSpec, PilotStatus, PilotFieldsMapping
1721

1822

1923
class ResponseExtra(TypedDict, total=False):
@@ -23,6 +27,7 @@ class ResponseExtra(TypedDict, total=False):
2327
cls: Any
2428

2529

30+
# ------------------ Search ------------------
2631
class SearchBody(TypedDict, total=False):
2732
parameters: list[str] | None
2833
search: list[SearchSpec] | None
@@ -56,6 +61,7 @@ def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs:
5661
result.update(cast(SearchExtra, kwargs))
5762
return result
5863

64+
# ------------------ Summary ------------------
5965

6066
class SummaryBody(TypedDict, total=False):
6167
grouping: list[str]
@@ -83,3 +89,57 @@ def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs:
8389
result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))}
8490
result.update(cast(ResponseExtra, kwargs))
8591
return result
92+
93+
# ------------------ AddPilotStamps ------------------
94+
95+
class AddPilotStampsBody(TypedDict, total=False):
96+
pilot_stamps: list[str]
97+
grid_type: str
98+
grid_site: str
99+
pilot_references: dict[str, str]
100+
pilot_status: PilotStatus
101+
102+
class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ...
103+
104+
class UnderlyingAddPilotStampsArgs(ResponseExtra, total=False):
105+
# FIXME: The autorest-generated has a bug that it expected IO[bytes] despite
106+
# the code being generated to support IO[bytes] | bytes.
107+
body: IO[bytes]
108+
109+
def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> UnderlyingAddPilotStampsArgs:
110+
body: AddPilotStampsBody = {}
111+
for key in AddPilotStampsBody.__optional_keys__:
112+
if key not in kwargs:
113+
continue
114+
key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status"], key)
115+
value = kwargs.pop(key)
116+
if value is not None:
117+
body[key] = value
118+
result: UnderlyingAddPilotStampsArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))}
119+
result.update(cast(ResponseExtra, kwargs))
120+
return result
121+
122+
# ------------------ UpdatePilotFields ------------------
123+
124+
class UpdatePilotFieldsBody(TypedDict, total=False):
125+
pilot_stamps_to_fields_mapping: list[PilotFieldsMapping]
126+
127+
class UpdatePilotFieldsKwargs(UpdatePilotFieldsBody, ResponseExtra): ...
128+
129+
class UnderlyingUpdatePilotFields(ResponseExtra, total=False):
130+
# FIXME: The autorest-generated has a bug that it expected IO[bytes] despite
131+
# the code being generated to support IO[bytes] | bytes.
132+
body: IO[bytes]
133+
134+
def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> UnderlyingUpdatePilotFields:
135+
body: UpdatePilotFieldsBody = {}
136+
for key in UpdatePilotFieldsBody.__optional_keys__:
137+
if key not in kwargs:
138+
continue
139+
key = cast(Literal["pilot_stamps_to_fields_mapping"], key)
140+
value = kwargs.pop(key)
141+
if value is not None:
142+
body[key] = value
143+
result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))}
144+
result.update(cast(ResponseExtra, kwargs))
145+
return result

diracx-client/src/diracx/client/patches/pilots/sync.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616
from azure.core.tracing.decorator import distributed_trace
1717

1818
from ..._generated.operations._operations import PilotsOperations as _PilotsOperations
19-
from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs
19+
from .common import (
20+
make_search_body,
21+
make_summary_body,
22+
make_add_pilot_stamps_body,
23+
make_update_pilot_fields_body,
24+
SearchKwargs,
25+
SummaryKwargs,
26+
AddPilotStampsKwargs,
27+
UpdatePilotFieldsKwargs
28+
)
2029

2130
# We're intentionally ignoring overrides here because we want to change the interface.
2231
# mypy: disable-error-code=override
@@ -32,3 +41,13 @@ def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]:
3241
def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]:
3342
"""TODO"""
3443
return super().summary(**make_summary_body(**kwargs))
44+
45+
@distributed_trace
46+
def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None:
47+
"""TODO"""
48+
return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs))
49+
50+
@distributed_trace
51+
def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None:
52+
"""TODO"""
53+
return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs))

0 commit comments

Comments
 (0)