Skip to content

Commit 2f37df5

Browse files
authored
Merge pull request #8278 from fstagni/90_refactoring_WMS
[9.0] more WMS refactoring
2 parents 0d9a164 + d235a78 commit 2f37df5

File tree

9 files changed

+71
-186
lines changed

9 files changed

+71
-186
lines changed

src/DIRAC/Interfaces/API/Dirac.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from DIRAC.WorkloadManagementSystem.Client.JobMonitoringClient import JobMonitoringClient
4747
from DIRAC.WorkloadManagementSystem.Client.SandboxStoreClient import SandboxStoreClient
4848
from DIRAC.WorkloadManagementSystem.Client.WMSClient import WMSClient
49+
from DIRAC.WorkloadManagementSystem.Utilities.jobAdministration import _filterJobStateTransition
4950

5051

5152
def parseArguments(args):
@@ -1450,10 +1451,13 @@ def deleteJob(self, jobID):
14501451
# Remove any job IDs that can't change to the Killed or Deleted states
14511452
filteredJobs = set()
14521453
for filterState in (JobStatus.KILLED, JobStatus.DELETED):
1453-
filterRes = JobStatus.filterJobStateTransition(jobIDs, filterState)
1454-
if not filterRes["OK"]:
1455-
return filterRes
1456-
filteredJobs.update(filterRes["Value"])
1454+
# get a dictionary of jobID:status
1455+
res = JobMonitoringClient().getJobsStatus(jobIDs)
1456+
if not res["OK"]:
1457+
return res
1458+
js = {k: v["Status"] for k, v in res["Value"].items()}
1459+
# then filter
1460+
filteredJobs.update(_filterJobStateTransition(js, filterState))
14571461

14581462
return WMSClient(useCertificates=self.useCertificates).deleteJob(list(filteredJobs))
14591463

@@ -1480,11 +1484,13 @@ def rescheduleJob(self, jobID):
14801484
return ret
14811485
jobIDs = ret["Value"]
14821486

1483-
# Remove any job IDs that can't change to the rescheduled state
1484-
filterRes = JobStatus.filterJobStateTransition(jobIDs, JobStatus.RESCHEDULED)
1485-
if not filterRes["OK"]:
1486-
return filterRes
1487-
jobIDsToReschedule = filterRes["Value"]
1487+
# get a dictionary of jobID:status
1488+
res = JobMonitoringClient().getJobsStatus(jobIDs)
1489+
if not res["OK"]:
1490+
return res
1491+
js = {k: v["Status"] for k, v in res["Value"].items()}
1492+
# then filter
1493+
jobIDsToReschedule = _filterJobStateTransition(js, JobStatus.RESCHEDULED)
14881494

14891495
return WMSClient(useCertificates=self.useCertificates).rescheduleJob(jobIDsToReschedule)
14901496

@@ -1510,10 +1516,13 @@ def killJob(self, jobID):
15101516
# Remove any job IDs that can't change to the Killed or Deleted states
15111517
filteredJobs = set()
15121518
for filterState in (JobStatus.KILLED, JobStatus.DELETED):
1513-
filterRes = JobStatus.filterJobStateTransition(jobIDs, filterState)
1514-
if not filterRes["OK"]:
1515-
return filterRes
1516-
filteredJobs.update(filterRes["Value"])
1519+
# get a dictionary of jobID:status
1520+
res = JobMonitoringClient().getJobsStatus(jobIDs)
1521+
if not res["OK"]:
1522+
return res
1523+
js = {k: v["Status"] for k, v in res["Value"].items()}
1524+
# then filter
1525+
filteredJobs.update(_filterJobStateTransition(js, filterState))
15171526

15181527
return WMSClient(useCertificates=self.useCertificates).killJob(list(filteredJobs))
15191528

src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_StalledJobAgent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def sja(mocker):
2828
mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.rescheduleJobs", return_value=MagicMock())
2929
mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.PilotAgentsDB", return_value=MagicMock())
3030
mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.getJobParameters", return_value=MagicMock())
31+
mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.kill_delete_jobs", return_value=MagicMock())
3132

3233
stalledJobAgent = StalledJobAgent()
3334
stalledJobAgent._AgentModule__configDefaults = mockAM

src/DIRAC/WorkloadManagementSystem/Client/JobStatus.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22
This module contains constants and lists for the possible job states.
33
"""
44

5-
from DIRAC import gLogger, S_OK, S_ERROR
65
from DIRAC.Core.Utilities.StateMachine import State, StateMachine
7-
from DIRAC.Core.Utilities.Decorators import deprecated
8-
9-
from DIRAC.WorkloadManagementSystem.Client.JobMonitoringClient import JobMonitoringClient
10-
116

127
#:
138
SUBMITTING = "Submitting"
@@ -98,57 +93,3 @@ def __init__(self, state):
9893
RECEIVED: State(1, [SCOUTING, CHECKING, STAGING, WAITING, FAILED, DELETED, KILLED], defState=RECEIVED),
9994
SUBMITTING: State(0, [RECEIVED, CHECKING, DELETED, KILLED], defState=SUBMITTING), # initial state
10095
}
101-
102-
103-
@deprecated("Use filterJobStateTransition instead")
104-
def checkJobStateTransition(jobID, candidateState, currentStatus=None, jobMonitoringClient=None):
105-
"""Utility to check if a job state transition is allowed"""
106-
if not currentStatus:
107-
if not jobMonitoringClient:
108-
from DIRAC.WorkloadManagementSystem.Client.JobMonitoringClient import JobMonitoringClient
109-
110-
jobMonitoringClient = JobMonitoringClient()
111-
112-
res = jobMonitoringClient.getJobsStatus(jobID)
113-
if not res["OK"]:
114-
return res
115-
try:
116-
currentStatus = res["Value"][jobID]["Status"]
117-
except KeyError:
118-
return S_ERROR("Job does not exist")
119-
120-
res = JobsStateMachine(currentStatus).getNextState(candidateState)
121-
if not res["OK"]:
122-
return res
123-
124-
# If the JobsStateMachine does not accept the candidate, return an ERROR
125-
if candidateState != res["Value"]:
126-
gLogger.error(
127-
"Job Status Error",
128-
f"{jobID} can't move from {currentStatus} to {candidateState}",
129-
)
130-
return S_ERROR("Job state transition not allowed")
131-
return S_OK()
132-
133-
134-
def filterJobStateTransition(jobIDs, candidateState):
135-
"""Given a list of jobIDs, return a list that are allowed to transition
136-
to the given candidate state.
137-
"""
138-
allowedJobs = []
139-
140-
if not isinstance(jobIDs, list):
141-
jobIDs = [jobIDs]
142-
143-
res = JobMonitoringClient().getJobsStatus(jobIDs)
144-
if not res["OK"]:
145-
return res
146-
147-
for jobID in jobIDs:
148-
if jobID in res["Value"]:
149-
curState = res["Value"][jobID]["Status"]
150-
stateRes = JobsStateMachine(curState).getNextState(candidateState)
151-
if stateRes["OK"]:
152-
if stateRes["Value"] == candidateState:
153-
allowedJobs.append(jobID)
154-
return S_OK(allowedJobs)

src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,6 @@ def setJobAttribute(self, jobID: str | int, attribute: str, value: str):
8383
else:
8484
return api.jobs.patch_metadata({jobID: {attribute: value}})
8585

86-
@stripValueIfOK
87-
@convertToReturnValue
88-
def setJobFlag(self, jobID: str | int, flag: str):
89-
with DiracXClient() as api:
90-
api.jobs.patch_metadata({jobID: {flag: True}})
91-
9286
@stripValueIfOK
9387
@convertToReturnValue
9488
def setJobParameter(self, jobID: str | int, name: str, value: str):
@@ -151,12 +145,6 @@ def setJobsParameter(self, jobsParameterDict: dict):
151145
updates = {job_id: {k: v} for job_id, (k, v) in jobsParameterDict.items()}
152146
api.jobs.patch_metadata(updates)
153147

154-
@stripValueIfOK
155-
@convertToReturnValue
156-
def unsetJobFlag(self, jobID: str | int, flag: str):
157-
with DiracXClient() as api:
158-
api.jobs.patch_metadata({jobID: {flag: False}})
159-
160148
@stripValueIfOK
161149
@convertToReturnValue
162150
def updateJobFromStager(self, jobID: str | int, status: str):

src/DIRAC/WorkloadManagementSystem/Service/JobStateUpdateHandler.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,6 @@ def export_setJobSite(cls, jobID, site):
128128
"""Allows the site attribute to be set for a job specified by its jobID."""
129129
return cls.jobDB.setJobAttribute(int(jobID), "Site", site)
130130

131-
###########################################################################
132-
types_setJobFlag = [[str, int], str]
133-
134-
@classmethod
135-
def export_setJobFlag(cls, jobID, flag):
136-
"""Set job flag for job with jobID"""
137-
return cls.jobDB.setJobAttribute(int(jobID), flag, "True")
138-
139-
###########################################################################
140-
types_unsetJobFlag = [[str, int], str]
141-
142-
@classmethod
143-
def export_unsetJobFlag(cls, jobID, flag):
144-
"""Unset job flag for job with jobID"""
145-
return cls.jobDB.setJobAttribute(int(jobID), flag, "False")
146-
147131
###########################################################################
148132
types_setJobApplicationStatus = [[str, int], str, str]
149133

src/DIRAC/WorkloadManagementSystem/Utilities/jobAdministration.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from DIRAC import S_ERROR, S_OK, gLogger
22
from DIRAC.StorageManagementSystem.DB.StorageManagementDB import StorageManagementDB
33
from DIRAC.WorkloadManagementSystem.Client import JobStatus
4-
from DIRAC.WorkloadManagementSystem.Client.JobStatus import filterJobStateTransition
54
from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB
65
from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB
76
from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import TaskQueueDB
8-
from DIRAC.WorkloadManagementSystem.Service.JobPolicy import RIGHT_KILL, RIGHT_DELETE
7+
from DIRAC.WorkloadManagementSystem.Service.JobPolicy import RIGHT_DELETE, RIGHT_KILL
98

109

1110
def _deleteJob(jobID, force=False):
@@ -79,18 +78,17 @@ def kill_delete_jobs(right, validJobList, nonauthJobList=[], force=False):
7978
killJobList = []
8079
deleteJobList = []
8180
if validJobList:
81+
result = JobDB().getJobsAttributes(killJobList, ["Status"])
82+
if not result["OK"]:
83+
return result
84+
jobStates = result["Value"]
85+
8286
# Get the jobs allowed to transition to the Killed state
83-
filterRes = filterJobStateTransition(validJobList, JobStatus.KILLED)
84-
if not filterRes["OK"]:
85-
return filterRes
86-
killJobList.extend(filterRes["Value"])
87+
killJobList.extend(_filterJobStateTransition(jobStates, JobStatus.KILLED))
8788

8889
if right == RIGHT_DELETE:
8990
# Get the jobs allowed to transition to the Deleted state
90-
filterRes = filterJobStateTransition(validJobList, JobStatus.DELETED)
91-
if not filterRes["OK"]:
92-
return filterRes
93-
deleteJobList.extend(filterRes["Value"])
91+
deleteJobList.extend(_filterJobStateTransition(jobStates, JobStatus.DELETED))
9492

9593
for jobID in killJobList:
9694
result = _killJob(jobID, force=force)
@@ -103,10 +101,7 @@ def kill_delete_jobs(right, validJobList, nonauthJobList=[], force=False):
103101
badIDs.append(jobID)
104102

105103
# Look for jobs that are in the Staging state to send kill signal to the stager
106-
result = JobDB().getJobsAttributes(killJobList, ["Status"])
107-
if not result["OK"]:
108-
return result
109-
stagingJobList = [jobID for jobID, sDict in result["Value"].items() if sDict["Status"] == JobStatus.STAGING]
104+
stagingJobList = [jobID for jobID, sDict in jobStates.items() if sDict["Status"] == JobStatus.STAGING]
110105

111106
if stagingJobList:
112107
stagerDB = StorageManagementDB()
@@ -127,3 +122,17 @@ def kill_delete_jobs(right, validJobList, nonauthJobList=[], force=False):
127122

128123
jobsList = killJobList if right == RIGHT_KILL else deleteJobList
129124
return S_OK(jobsList)
125+
126+
127+
def _filterJobStateTransition(jobStates, candidateState):
128+
"""Given a dictionary of jobs states,
129+
return a list of jobs that are allowed to transition to the given candidate state.
130+
"""
131+
allowedJobs = []
132+
133+
for js in jobStates.items():
134+
stateRes = JobStatus.JobsStateMachine(js[1]["Status"]).getNextState(candidateState)
135+
if stateRes["OK"]:
136+
if stateRes["Value"] == candidateState:
137+
allowedJobs.append(js[0])
138+
return allowedJobs

src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobAdministration.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,19 @@
1010

1111

1212
@pytest.mark.parametrize(
13-
"jobIDs_list, right, filtered_jobs, expected_res, expected_value",
13+
"jobIDs_list, right",
1414
[
15-
([], "Kill", [], True, []),
16-
([], "Delete", [], True, []),
17-
(1, "Kill", [], True, []),
18-
(1, "Kill", [1], True, [1]),
19-
([1, 2], "Kill", [], True, []),
20-
([1, 2], "Kill", [1], True, [1]),
21-
(1, "Kill", [1], True, [1]),
22-
([1, 2], "Kill", [1], True, [1]),
23-
([1, 2], "Kill", [2], True, [2]),
24-
([1, 2], "Kill", [], True, []),
25-
([1, 2], "Kill", [1, 2], True, [1, 2]),
15+
([], "Kill"),
16+
([], "Delete"),
17+
(1, "Kill"),
18+
([1, 2], "Kill"),
2619
],
2720
)
28-
def test___kill_delete_jobs(mocker, jobIDs_list, right, filtered_jobs, expected_res, expected_value):
21+
def test___kill_delete_jobs(mocker, jobIDs_list, right):
2922
mocker.patch("DIRAC.WorkloadManagementSystem.Utilities.jobAdministration.JobDB", MagicMock())
3023
mocker.patch("DIRAC.WorkloadManagementSystem.Utilities.jobAdministration.TaskQueueDB", MagicMock())
3124
mocker.patch("DIRAC.WorkloadManagementSystem.Utilities.jobAdministration.PilotAgentsDB", MagicMock())
3225
mocker.patch("DIRAC.WorkloadManagementSystem.Utilities.jobAdministration.StorageManagementDB", MagicMock())
33-
mocker.patch(
34-
"DIRAC.WorkloadManagementSystem.Utilities.jobAdministration.filterJobStateTransition",
35-
return_value={"OK": True, "Value": filtered_jobs},
36-
)
3726

3827
res = kill_delete_jobs(right, jobIDs_list)
39-
assert res["OK"] == expected_res
40-
assert res["Value"] == expected_value
28+
assert res["OK"]

tests/Integration/FutureClient/WorkloadManagement/Test_JobStateUpdate.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -142,42 +142,6 @@ def test_setJobAttribute(monkeypatch, example_jobids, args):
142142
assert result[example_jobids[0]] == result[example_jobids[1]]
143143

144144

145-
def test_setJobFlag(monkeypatch, example_jobids):
146-
# JobStateUpdateClient().setJobFlag(jobID: str | int, flag: str)
147-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[0], "AccountedFlag")) == "False"
148-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[1], "AccountedFlag")) == "False"
149-
150-
method = JobStateUpdateClient().setJobFlag
151-
test_func1 = partial(method, example_jobids[0], "AccountedFlag")
152-
test_func2 = partial(method, example_jobids[1], "AccountedFlag")
153-
compare_results2(monkeypatch, test_func1, test_func2)
154-
155-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[0], "AccountedFlag")) == "True"
156-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[1], "AccountedFlag")) == "True"
157-
158-
159-
def test_unsetJobFlag(monkeypatch, example_jobids):
160-
# JobStateUpdateClient().unsetJobFlag(jobID: str | int, flag: str)
161-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[0], "AccountedFlag")) == "False"
162-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[1], "AccountedFlag")) == "False"
163-
164-
method = JobStateUpdateClient().setJobFlag
165-
test_func1 = partial(method, example_jobids[0], "AccountedFlag")
166-
test_func2 = partial(method, example_jobids[1], "AccountedFlag")
167-
compare_results2(monkeypatch, test_func1, test_func2)
168-
169-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[0], "AccountedFlag")) == "True"
170-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[1], "AccountedFlag")) == "True"
171-
172-
method = JobStateUpdateClient().unsetJobFlag
173-
test_func1 = partial(method, example_jobids[0], "AccountedFlag")
174-
test_func2 = partial(method, example_jobids[1], "AccountedFlag")
175-
compare_results2(monkeypatch, test_func1, test_func2)
176-
177-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[0], "AccountedFlag")) == "False"
178-
assert returnValueOrRaise(JobMonitoringClient().getJobAttribute(example_jobids[1], "AccountedFlag")) == "False"
179-
180-
181145
@pytest.mark.parametrize(
182146
"args",
183147
[

0 commit comments

Comments
 (0)