Skip to content

fix: jobparameters and jobattributes pydantic models #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions diracx-client/src/diracx/client/_generated/models/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def __init__(self, *, job_id: int, command: str, arguments: Optional[str] = None
class JobMetaData(_serialization.Model):
"""A model that combines both JobAttributes and JobParameters.

:ivar additional_properties: Unmatched properties from the message are deserialized to this
collection.
:vartype additional_properties: dict[str, any]
:ivar timestamp: Timestamp.
:vartype timestamp: ~datetime.datetime
:ivar cpu_normalization_factor: Cpunormalizationfactor.
Expand Down Expand Up @@ -486,6 +489,7 @@ class JobMetaData(_serialization.Model):
"""

_attribute_map = {
"additional_properties": {"key": "", "type": "{object}"},
"timestamp": {"key": "timestamp", "type": "iso-8601"},
"cpu_normalization_factor": {"key": "CPUNormalizationFactor", "type": "int"},
"norm_cpu_time_s": {"key": "NormCPUTime(s)", "type": "int"},
Expand Down Expand Up @@ -526,6 +530,7 @@ class JobMetaData(_serialization.Model):
def __init__( # pylint: disable=too-many-locals
self,
*,
additional_properties: Optional[Dict[str, Any]] = None,
timestamp: Optional[datetime.datetime] = None,
cpu_normalization_factor: Optional[int] = None,
norm_cpu_time_s: Optional[int] = None,
Expand Down Expand Up @@ -564,6 +569,9 @@ def __init__( # pylint: disable=too-many-locals
**kwargs: Any
) -> None:
"""
:keyword additional_properties: Unmatched properties from the message are deserialized to this
collection.
:paramtype additional_properties: dict[str, any]
:keyword timestamp: Timestamp.
:paramtype timestamp: ~datetime.datetime
:keyword cpu_normalization_factor: Cpunormalizationfactor.
Expand Down Expand Up @@ -636,6 +644,7 @@ def __init__( # pylint: disable=too-many-locals
:paramtype accounted_flag: ~_generated.models.JobMetaDataAccountedFlag
"""
super().__init__(**kwargs)
self.additional_properties = additional_properties
self.timestamp = timestamp
self.cpu_normalization_factor = cpu_normalization_factor
self.norm_cpu_time_s = norm_cpu_time_s
Expand Down
29 changes: 23 additions & 6 deletions diracx-core/src/diracx/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enum import StrEnum
from typing import Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict


Expand Down Expand Up @@ -75,8 +75,8 @@ class SearchParams(BaseModel):
# TODO: Add more validation


class JobParameters(BaseModel, extra="forbid"):
"""All the parameters that can be set for a job."""
class JobParameters(BaseModel, populate_by_name=True, extra="allow"):
"""Some of the most important parameters that can be set for a job."""

timestamp: datetime | None = None
cpu_normalization_factor: int | None = Field(None, alias="CPUNormalizationFactor")
Expand All @@ -95,8 +95,25 @@ class JobParameters(BaseModel, extra="forbid"):
job_type: str | None = Field(None, alias="JobType")
job_status: str | None = Field(None, alias="JobStatus")


class JobAttributes(BaseModel, extra="forbid"):
@field_validator(
"cpu_normalization_factor", "norm_cpu_time_s", "total_cpu_time_s", mode="before"
)
@classmethod
def convert_cpu_fields_to_int(cls, v):
"""Convert string representation of float to integer for CPU-related fields."""
if v is None:
return v
if isinstance(v, str):
try:
return int(float(v))
except (ValueError, TypeError) as e:
raise ValueError(f"Cannot convert '{v}' to integer") from e
if isinstance(v, (int, float)):
return int(v)
return v


class JobAttributes(BaseModel, populate_by_name=True, extra="forbid"):
"""All the attributes that can be set for a job."""

job_type: str | None = Field(None, alias="JobType")
Expand All @@ -121,7 +138,7 @@ class JobAttributes(BaseModel, extra="forbid"):
accounted_flag: bool | str | None = Field(None, alias="AccountedFlag")


class JobMetaData(JobAttributes, JobParameters, extra="forbid"):
class JobMetaData(JobAttributes, JobParameters, extra="allow"):
"""A model that combines both JobAttributes and JobParameters."""


Expand Down
5 changes: 5 additions & 0 deletions diracx-logic/src/diracx/logic/jobs/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ async def set_job_parameters_or_attributes(
if pname in JOB_PARAMETERS_ALIASES:
param_updates[job_id][pname] = pvalue

# If the field is not in either known aliases, default to treating it as a parameter
# This allows for more flexible metadata handling
elif pname not in JOB_ATTRIBUTES_ALIASES:
param_updates[job_id][pname] = pvalue

# Bulk set job attributes if required
attr_updates = {k: v for k, v in attr_updates.items() if v}
if attr_updates:
Expand Down
160 changes: 160 additions & 0 deletions diracx-logic/tests/jobs/test_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

from collections.abc import AsyncGenerator
from datetime import datetime, timezone

import pytest
import sqlalchemy

from diracx.core.models import JobMetaData
from diracx.db.os.job_parameters import JobParametersDB as RealJobParametersDB
from diracx.db.sql.job.db import JobDB
from diracx.logic.jobs.status import set_job_parameters_or_attributes
from diracx.testing.mock_osdb import MockOSDBMixin
from diracx.testing.time import mock_sqlite_time


# Reuse the generic MockOSDBMixin to build a mock JobParameters DB implementation
class _MockJobParametersDB(MockOSDBMixin, RealJobParametersDB):
def __init__(self): # type: ignore[override]
super().__init__({"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"})

def upsert(self, vo, doc_id, document):
"""Override to add JobID to the document."""
# Add JobID to the document, which is required by the base class
document["JobID"] = doc_id
return super().upsert(vo, doc_id, document)


# --------------------------------------------------------------------------------------
# Test setup fixtures
# --------------------------------------------------------------------------------------


@pytest.fixture
async def job_db() -> AsyncGenerator[JobDB, None]:
"""Create a fake sandbox metadata database."""
db = JobDB(db_url="sqlite+aiosqlite:///:memory:")
async with db.engine_context():
sqlalchemy.event.listen(db.engine.sync_engine, "connect", mock_sqlite_time)

async with db.engine.begin() as conn:
await conn.run_sync(db.metadata.create_all)

yield db


@pytest.fixture
async def job_parameters_db() -> AsyncGenerator[_MockJobParametersDB, None]:
db = _MockJobParametersDB()
# Need engine_context entered before creating tables
async with db.client_context():
await db.create_index_template()
yield db


TEST_JDL = """
Arguments = "jobDescription.xml -o LogLevel=INFO";
Executable = "dirac-jobexec";
JobGroup = jobGroup;
JobName = jobName;
JobType = User;
LogLevel = INFO;
OutputSandbox =
{
Script1_CodeOutput.log,
std.err,
std.out
};
Priority = 1;
Site = ANY;
StdError = std.err;
StdOutput = std.out;
"""


@pytest.fixture
async def valid_job_id(job_db: JobDB) -> int:
"""Create a minimal job record and return its JobID."""
async with job_db:
job_id = await job_db.create_job("") # original JDL unused in these tests
# Insert initial attributes (mimic job submission)
await job_db.insert_job_attributes(
{
job_id: {
"Status": "Received",
"MinorStatus": "Job accepted",
"ApplicationStatus": "Unknown",
"VO": "lhcb",
"Owner": "tester",
"OwnerGroup": "lhcb_user",
"VerifiedFlag": True,
"JobType": "User",
}
}
)
return job_id


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_patch_metadata_updates_attributes_and_parameters(
job_db: JobDB, job_parameters_db: _MockJobParametersDB, valid_job_id: int
):
"""Patch metadata mixing:
- Attribute only (UserPriority)
- Attribute + parameter (JobType)
- Parameter only (CPUNormalizationFactor)
- Attribute (HeartBeatTime)
- Non identified Metadata (does_not_exist)
and verify correct persistence in the two backends.
"""
hbt = datetime.now(timezone.utc)

metadata = {
"UserPriority": "2", # attr
"JobType": "VerySpecialIndeed", # attr + param
"CPUNormalizationFactor": "10", # param only
"HeartBeatTime": hbt.isoformat(), # attr
"does_not_exist": "unknown", # Does not exist should be treated as a job param
}

updates = {valid_job_id: JobMetaData.model_validate(metadata)}

# Act
async with job_db: # ensure open connection for updates
await set_job_parameters_or_attributes(updates, job_db, job_parameters_db)

# Assert job attributes (SQL)
async with job_db:
_, rows = await job_db.search(
parameters=None,
search=[{"parameter": "JobID", "operator": "eq", "value": valid_job_id}],
sorts=[],
)
assert len(rows) == 1
row = rows[0]
assert int(row["JobID"]) == valid_job_id
assert row["UserPriority"] == 2
assert row["JobType"] == "VerySpecialIndeed"
# HeartBeatTime stored as ISO string (without tz) in DB helper; just ensure present
assert row["HeartBeatTime"] is not None
assert "CPUNormalizationFactor" not in row
assert "does_not_exist" not in row

# Assert job parameters (mocked OS / sqlite)
params_rows = await job_parameters_db.search(
parameters=None,
search=[{"parameter": "JobID", "operator": "eq", "value": valid_job_id}],
sorts=[],
)
prow = params_rows[0]
assert prow["JobType"] == "VerySpecialIndeed"
assert prow["CPUNormalizationFactor"] == 10
assert prow["does_not_exist"] == "unknown"
assert "UserPriority" not in prow
assert "HeartBeatTime" not in prow
4 changes: 2 additions & 2 deletions diracx-routers/tests/jobs/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ def test_get_job_status_history_in_bulk(
assert r.json()[str(valid_job_id)][0]["Source"] == "JobManager"


def test_patch_summary(normal_user_client: TestClient, valid_job_id: int):
def test_summary(normal_user_client: TestClient, valid_job_id: int):
"""Test that the summary endpoint works as expected."""
r = normal_user_client.post(
"/api/jobs/summary",
Expand All @@ -906,7 +906,7 @@ def test_patch_summary(normal_user_client: TestClient, valid_job_id: int):
assert r.json() == [{"Owner": "preferred_username", "count": 1}]


def test_patch_summary_doc_example(normal_user_client: TestClient, valid_job_id: int):
def test_summary_doc_example(normal_user_client: TestClient, valid_job_id: int):
"""Test that the summary doc example is correct."""
payload = EXAMPLE_SUMMARY["Group by JobGroup"]["value"]
r = normal_user_client.post("/api/jobs/summary", json=payload)
Expand Down
41 changes: 0 additions & 41 deletions diracx-routers/tests/jobs/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,47 +898,6 @@ def test_patch_metadata(normal_user_client: TestClient, valid_job_id: int):
assert r.json()[0]["UserPriority"] == 2


def test_bad_patch_metadata(normal_user_client: TestClient, valid_job_id: int):
# Arrange
r = normal_user_client.post(
"/api/jobs/search",
json={
"search": [
{
"parameter": "JobID",
"operator": "eq",
"value": valid_job_id,
}
],
"parameters": ["LoggingInfo"],
},
)

assert r.status_code == 200, r.json()
for j in r.json():
assert j["JobID"] == valid_job_id
assert j["Status"] == JobStatus.RECEIVED.value
assert j["MinorStatus"] == "Job accepted"
assert j["ApplicationStatus"] == "Unknown"

# Act
hbt = str(datetime.now(timezone.utc))
r = normal_user_client.patch(
"/api/jobs/metadata",
json={
valid_job_id: {
"UserPriority": 2,
"Heartbeattime": hbt,
# set a parameter
"JobType": "VerySpecialIndeed",
}
},
)

# Assert
assert r.status_code == 422, r.text


def test_diracx_476(normal_user_client: TestClient, valid_job_id: int):
"""Test fix for https://github.com/DIRACGrid/diracx/issues/476."""
inner_payload = {"Status": JobStatus.FAILED.value, "MinorStatus": "Payload failed"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,9 @@ def __init__(self, *, job_id: int, command: str, arguments: Optional[str] = None
class JobMetaData(_serialization.Model):
"""A model that combines both JobAttributes and JobParameters.

:ivar additional_properties: Unmatched properties from the message are deserialized to this
collection.
:vartype additional_properties: dict[str, any]
:ivar timestamp: Timestamp.
:vartype timestamp: ~datetime.datetime
:ivar cpu_normalization_factor: Cpunormalizationfactor.
Expand Down Expand Up @@ -533,6 +536,7 @@ class JobMetaData(_serialization.Model):
"""

_attribute_map = {
"additional_properties": {"key": "", "type": "{object}"},
"timestamp": {"key": "timestamp", "type": "iso-8601"},
"cpu_normalization_factor": {"key": "CPUNormalizationFactor", "type": "int"},
"norm_cpu_time_s": {"key": "NormCPUTime(s)", "type": "int"},
Expand Down Expand Up @@ -573,6 +577,7 @@ class JobMetaData(_serialization.Model):
def __init__( # pylint: disable=too-many-locals
self,
*,
additional_properties: Optional[Dict[str, Any]] = None,
timestamp: Optional[datetime.datetime] = None,
cpu_normalization_factor: Optional[int] = None,
norm_cpu_time_s: Optional[int] = None,
Expand Down Expand Up @@ -611,6 +616,9 @@ def __init__( # pylint: disable=too-many-locals
**kwargs: Any
) -> None:
"""
:keyword additional_properties: Unmatched properties from the message are deserialized to this
collection.
:paramtype additional_properties: dict[str, any]
:keyword timestamp: Timestamp.
:paramtype timestamp: ~datetime.datetime
:keyword cpu_normalization_factor: Cpunormalizationfactor.
Expand Down Expand Up @@ -683,6 +691,7 @@ def __init__( # pylint: disable=too-many-locals
:paramtype accounted_flag: ~_generated.models.JobMetaDataAccountedFlag
"""
super().__init__(**kwargs)
self.additional_properties = additional_properties
self.timestamp = timestamp
self.cpu_normalization_factor = cpu_normalization_factor
self.norm_cpu_time_s = norm_cpu_time_s
Expand Down
Loading