Skip to content

Commit 5c9d046

Browse files
committed
fix: test and logic
1 parent 230fba9 commit 5c9d046

File tree

6 files changed

+95
-35
lines changed

6 files changed

+95
-35
lines changed

diracx-core/src/diracx/core/models.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,71 @@ class JobSearchParams(BaseModel):
7373
# TODO: Add more validation
7474

7575

76+
class JobParameters(BaseModel):
77+
"""All the parameters that can be set for a job."""
78+
79+
timestamp: datetime | None = None
80+
CPUNormalizationFactor: int | None = None
81+
NormCPUTime_s: int | None = Field(None, alias="NormCPUTime(s)")
82+
TotalCPUTime_s: int | None = Field(None, alias="TotalCPUTime(s)")
83+
HostName: str | None = None
84+
GridCE: str | None = None
85+
ModelName: str | None = None
86+
PilotAgent: str | None = None
87+
Pilot_Reference: str | None = None
88+
Memory_MB: int | None = Field(None, alias="Memory(MB)")
89+
LocalAccount: str | None = None
90+
PayloadPID: int | None = None
91+
CEQueue: str | None = None
92+
BatchSystem: str | None = None
93+
JobType: str | None = None
94+
JobStatus: str | None = None
95+
96+
class Config:
97+
"""Configuration for the JobParameters model."""
98+
99+
extra = "forbid" # Disallow additional fields
100+
101+
102+
class JobAttributes(BaseModel):
103+
"""All the attributes that can be set for a job."""
104+
105+
JobType: str | None = None
106+
JobGroup: str | None = None
107+
Site: str | None = None
108+
JobName: str | None = None
109+
Owner: str | None = None
110+
OwnerGroup: str | None = None
111+
VO: str | None = None
112+
SubmissionTime: datetime | None = None
113+
RescheduleTime: datetime | None = None
114+
LastUpdateTime: datetime | None = None
115+
StartExecTime: datetime | None = None
116+
HeartBeatTime: datetime | None = None
117+
EndExecTime: datetime | None = None
118+
Status: str | None = None
119+
MinorStatus: str | None = None
120+
ApplicationStatus: str | None = None
121+
UserPriority: int | None = None
122+
RescheduleCounter: int | None = None
123+
VerifiedFlag: bool | None = None
124+
AccountedFlag: bool | str | None = None
125+
126+
class Config:
127+
"""Configuration for the JobAttributes model."""
128+
129+
extra = "forbid" # Disallow additional fields
130+
131+
132+
class JobMetaData(JobAttributes, JobParameters):
133+
"""A model that combines both JobAttributes and JobParameters."""
134+
135+
class Config:
136+
"""Configuration for the JobMetaData model."""
137+
138+
extra = "forbid" # Disallow additional fields
139+
140+
76141
class JobStatus(StrEnum):
77142
SUBMITTING = "Submitting"
78143
RECEIVED = "Received"

diracx-logic/src/diracx/logic/jobs/status.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,20 @@
3131
from diracx.core.config.schema import Config
3232
from diracx.core.models import (
3333
HeartbeatData,
34+
JobAttributes,
3435
JobCommand,
3536
JobLoggingRecord,
37+
JobMetaData,
3638
JobMinorStatus,
39+
JobParameters,
3740
JobStatus,
3841
JobStatusUpdate,
3942
SetJobStatusReturn,
4043
VectorSearchOperator,
4144
VectorSearchSpec,
4245
)
4346
from diracx.db.os.job_parameters import JobParametersDB
44-
from diracx.db.sql.job.db import JobDB, _get_columns
45-
from diracx.db.sql.job.schema import Jobs
47+
from diracx.db.sql.job.db import JobDB
4648
from diracx.db.sql.job_logging.db import JobLoggingDB
4749
from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB
4850
from diracx.db.sql.task_queue.db import TaskQueueDB
@@ -52,6 +54,10 @@
5254

5355
logger = logging.getLogger(__name__)
5456

57+
# Precalculate valid field sets for performance
58+
VALID_JOB_ATTRIBUTES = set(JobAttributes.model_fields.keys())
59+
VALID_JOB_PARAMETERS = set(JobParameters.model_fields.keys())
60+
5561

5662
async def remove_jobs(
5763
job_ids: list[int],
@@ -502,39 +508,27 @@ async def remove_jobs_from_task_queue(
502508

503509

504510
async def set_job_parameters_or_attributes(
505-
updates: dict[int, dict[str, Any]],
511+
updates: dict[int, JobMetaData],
506512
job_db: JobDB,
507513
job_parameters_db: JobParametersDB,
508514
):
509515
"""Set job parameters or attributes for a list of jobs."""
510-
attribute_columns: list[str] = [
511-
col.name for col in _get_columns(Jobs.__table__, None)
512-
]
513-
attribute_columns_lower: list[str] = [col.lower() for col in attribute_columns]
514-
516+
# Those dicts create a mapping of job_id -> {attribute_name: value}
515517
attr_updates: dict[int, dict[str, Any]] = {}
516518
param_updates: dict[int, dict[str, Any]] = {}
517519

518520
for job_id, metadata in updates.items():
519521
attr_updates[job_id] = {}
520522
param_updates[job_id] = {}
521-
for pname, pvalue in metadata.items():
523+
for pname, pvalue in metadata.model_dump(exclude_none=True).items():
522524
# An argument can be a job attribute and/or a job parameter
523525

524-
# If the attribute exactly matches one of the allowed columns, treat it as an attribute.
525-
if pname in attribute_columns:
526+
# Check if the argument is a valid job attribute
527+
if pname in VALID_JOB_ATTRIBUTES:
526528
attr_updates[job_id][pname] = pvalue
527-
# Otherwise, if the lower-case version is valid, the user likely mis-cased the key.
528-
elif pname.lower() in attribute_columns_lower:
529-
correct_name = attribute_columns[
530-
attribute_columns_lower.index(pname.lower())
531-
]
532-
raise ValueError(
533-
f"Attribute column '{pname}' is mis-cased. Did you mean '{correct_name}'?"
534-
)
535529

536530
# Check if the argument is a valid job parameter
537-
if pname in job_parameters_db.fields:
531+
if pname in VALID_JOB_PARAMETERS:
538532
param_updates[job_id][pname] = pvalue
539533

540534
# Bulk set job attributes if required

diracx-routers/src/diracx/routers/jobs/sandboxes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ async def unassign_job_sandboxes(
166166

167167
@router.delete("/sandbox")
168168
async def unassign_bulk_jobs_sandboxes(
169-
jobs_ids: Annotated[list[int], Query()],
169+
job_ids: Annotated[list[int], Query()],
170170
sandbox_metadata_db: SandboxMetadataDB,
171171
job_db: JobDB,
172172
check_permissions: CheckWMSPolicyCallable,
173173
):
174174
"""Delete bulk jobs sandbox mapping."""
175-
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=jobs_ids)
176-
await unassign_jobs_sandboxes_bl(jobs_ids, sandbox_metadata_db)
175+
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids)
176+
await unassign_jobs_sandboxes_bl(job_ids, sandbox_metadata_db)

diracx-routers/src/diracx/routers/jobs/status.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from diracx.core.models import (
1010
HeartbeatData,
1111
JobCommand,
12+
JobMetaData,
1213
JobStatusUpdate,
1314
SetJobStatusReturn,
1415
)
@@ -291,15 +292,13 @@ async def reschedule_jobs(
291292

292293
@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT)
293294
async def patch_metadata(
294-
updates: Annotated[
295-
dict[int, dict[str, Any]], Body(openapi_examples=EXAMPLE_METADATA)
296-
],
295+
updates: Annotated[dict[int, JobMetaData], Body(openapi_examples=EXAMPLE_METADATA)],
297296
job_db: JobDB,
298297
job_parameters_db: JobParametersDB,
299298
check_permissions: CheckWMSPolicyCallable,
300299
):
301300
"""Update job metadata such as UserPriority, HeartBeatTime, JobType, etc.
302-
The parameters are all the attributes of a job (except the ID).
301+
The argument are all the attributes/parameters of a job (except the ID).
303302
"""
304303
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=updates)
305304
try:

diracx-routers/tests/jobs/test_sandboxes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_assign_then_unassign_sandboxes_to_jobs(normal_user_client: TestClient):
163163

164164
# Unassign sb to job:
165165
job_ids = [job_id]
166-
r = normal_user_client.delete("/api/jobs/sandbox", params={"jobs_ids": job_ids})
166+
r = normal_user_client.delete("/api/jobs/sandbox", params={"job_ids": job_ids})
167167
assert r.status_code == 200
168168

169169
# Get the sb again, it should'nt be there anymore:

diracx-routers/tests/jobs/test_status.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ def test_reschedule_job_attr_update(normal_user_client: TestClient):
445445
assert successful_results[str(jid)]["RescheduleCounter"] == i + 1
446446
for jid in fail_resched_ids:
447447
assert str(jid) in failed_results, result
448+
# FIXME
448449
# assert successful_results[jid]["Status"] == JobStatus.RECEIVED
449450
# assert successful_results[jid]["MinorStatus"] == "Job Rescheduled"
450451
# assert successful_results[jid]["RescheduleCounter"] == i + 1
@@ -851,7 +852,7 @@ def test_patch_metadata(normal_user_client: TestClient, valid_job_id: int):
851852
assert j["ApplicationStatus"] == "Unknown"
852853

853854
# Act
854-
hbt = str(datetime.now(timezone.utc))
855+
hbt = datetime.now(timezone.utc).isoformat()
855856
r = normal_user_client.patch(
856857
"/api/jobs/metadata",
857858
json={
@@ -883,11 +884,14 @@ def test_patch_metadata(normal_user_client: TestClient, valid_job_id: int):
883884
)
884885
assert r.status_code == 200, r.json()
885886

887+
# TODO: This should be timezone aware
888+
hbt1 = datetime.fromisoformat(r.json()[0]["HeartBeatTime"])
889+
assert hbt1.tzinfo is None
890+
hbt1 = hbt1.replace(tzinfo=timezone.utc)
891+
886892
assert r.json()[0]["JobID"] == valid_job_id
887893
assert r.json()[0]["JobType"] == "VerySpecialIndeed"
888-
assert datetime.fromisoformat(
889-
r.json()[0]["HeartBeatTime"]
890-
) == datetime.fromisoformat(hbt)
894+
assert hbt1 == datetime.fromisoformat(hbt)
891895
assert r.json()[0]["UserPriority"] == 2
892896

893897

@@ -929,9 +933,7 @@ def test_bad_patch_metadata(normal_user_client: TestClient, valid_job_id: int):
929933
)
930934

931935
# Assert
932-
assert r.status_code == 400, (
933-
"PATCH metadata should 400 Bad Request if an attribute column's case is incorrect"
934-
)
936+
assert r.status_code == 422, r.text
935937

936938

937939
def test_diracx_476(normal_user_client: TestClient, valid_job_id: int):

0 commit comments

Comments
 (0)