Skip to content

Commit ab6f5aa

Browse files
authored
Merge pull request #406 from mj-will/fix-reparam-resuming
Fix reparam resuming
2 parents 4d53599 + 96a7a64 commit ab6f5aa

File tree

8 files changed

+113
-67
lines changed

8 files changed

+113
-67
lines changed

nessai/flowmodel/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class FlowModel:
4141
not specified, the current working directory is used.
4242
"""
4343

44-
model_config = None
4544
noise_scale = None
4645
noise_type = None
4746
model: BaseFlow = None

nessai/proposal/flowproposal.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def __init__(
170170
super(FlowProposal, self).__init__(model)
171171
logger.debug("Initialising FlowProposal")
172172

173-
self._x_dtype = False
174-
self._x_prime_dtype = False
173+
self._x_dtype = None
174+
self._x_prime_dtype = None
175175
self._draw_func = None
176176
self._populate_dist = None
177177

@@ -193,7 +193,7 @@ def __init__(
193193
self._reparameterisation = None
194194
self.rescaling_set = False
195195
self.use_x_prime_prior = False
196-
self.update_bounds = False
196+
self.should_update_reparameterisations = False
197197
self.accumulate_weights = accumulate_weights
198198

199199
self.reparameterisations = reparameterisations
@@ -279,7 +279,7 @@ def rescaled_dims(self):
279279
@property
280280
def x_dtype(self):
281281
"""Return the dtype for the x space"""
282-
if not self._x_dtype:
282+
if self._x_dtype is None:
283283
self._x_dtype = get_dtype(
284284
self.parameters, config.livepoints.default_float_dtype
285285
)
@@ -288,7 +288,7 @@ def x_dtype(self):
288288
@property
289289
def x_prime_dtype(self):
290290
"""Return the dtype for the x prime space"""
291-
if not self._x_prime_dtype:
291+
if self._x_prime_dtype is None:
292292
self._x_prime_dtype = get_dtype(
293293
self.prime_parameters, config.livepoints.default_float_dtype
294294
)
@@ -484,31 +484,37 @@ def update_flow_config(self):
484484
"""Update the flow configuration dictionary."""
485485
self.flow_config["n_inputs"] = self.rescaled_dims
486486

487-
def initialise(self):
487+
def initialise(self, resumed: bool = False) -> None:
488488
"""
489489
Initialise the proposal class.
490490
491491
This includes:
492492
* Setting up the rescaling
493493
* Verifying the rescaling is invertible
494494
* Initialising the FlowModel
495+
496+
Parameters
497+
----------
498+
resumed : bool
499+
Indicates if the proposal is being initialised after being resumed
500+
or not. When true, the reparameterisations will not be
501+
reinitialised.
495502
"""
496503
if not os.path.exists(self.output):
497504
os.makedirs(self.output, exist_ok=True)
498505

499-
self._x_dtype = False
500-
self._x_prime_dtype = False
501-
502-
self.set_rescaling()
503-
self.verify_rescaling()
504-
if self.expansion_fraction and self.expansion_fraction is not None:
505-
logger.info("Overwriting fuzz factor with expansion fraction")
506-
self.fuzz = (1 + self.expansion_fraction) ** (
507-
1 / self.rescaled_dims
508-
)
509-
logger.info(f"New fuzz factor: {self.fuzz}")
506+
# Initialise if not resuming or resuming but initialised is False
507+
if not resumed or not self.initialised:
508+
self.set_rescaling()
509+
self.verify_rescaling()
510+
if self.expansion_fraction and self.expansion_fraction is not None:
511+
logger.info("Overwriting fuzz factor with expansion fraction")
512+
self.fuzz = (1 + self.expansion_fraction) ** (
513+
1 / self.rescaled_dims
514+
)
515+
logger.info(f"New fuzz factor: {self.fuzz}")
510516

511-
self.configure_constant_volume()
517+
self.configure_constant_volume()
512518
self.update_flow_config()
513519
self.flow = self._FlowModelClass(
514520
flow_config=self.flow_config,
@@ -697,10 +703,10 @@ def configure_reparameterisations(self, reparameterisations):
697703
r = FallbackClass(parameters=other_params, **fallback_kwargs)
698704
self._reparameterisation.add_reparameterisations(r)
699705

700-
if any(r._update_bounds for r in self._reparameterisation.values()):
701-
self.update_bounds = True
706+
if any(r._update for r in self._reparameterisation.values()):
707+
self.should_update_reparameterisations = True
702708
else:
703-
self.update_bounds = False
709+
self.should_update_reparameterisations = False
704710

705711
if self._reparameterisation.has_prime_prior:
706712
self.use_x_prime_prior = True
@@ -739,6 +745,17 @@ def rescaled_names(self):
739745
)
740746
return self.prime_parameters
741747

748+
@property
749+
def update_bounds(self):
750+
warn(
751+
(
752+
"`update_bounds` is deprecated, use "
753+
"`should_update_reparameterisations` instead."
754+
),
755+
FutureWarning,
756+
)
757+
return self.should_update_reparameterisations
758+
742759
def set_rescaling(self):
743760
"""
744761
Set function and parameter names for rescaling
@@ -1633,14 +1650,13 @@ def resume(self, model, flow_config, weights_file=None):
16331650
"""
16341651
super().resume(model)
16351652
self.flow_config = flow_config
1636-
self._reparameterisation = None
16371653

16381654
if self.mask is not None:
16391655
if isinstance(self.mask, list):
16401656
m = np.array(self.mask)
16411657
self.flow_config["mask"] = m
16421658

1643-
self.initialise()
1659+
self.initialise(resumed=True)
16441660

16451661
if weights_file is None:
16461662
weights_file = self.weights_file
@@ -1652,12 +1668,6 @@ def resume(self, model, flow_config, weights_file=None):
16521668
else:
16531669
logger.warning("Could not reload weights for flow")
16541670

1655-
if self.update_bounds:
1656-
if self.training_data is not None:
1657-
self.check_state(self.training_data)
1658-
elif self.training_data is None and self.training_count:
1659-
raise RuntimeError("Could not resume! Missing training data!")
1660-
16611671
def reset(self):
16621672
"""Reset the proposal"""
16631673
self.indices = []
@@ -1673,6 +1683,7 @@ def reset(self):
16731683
self.acceptance = []
16741684
self._draw_func = None
16751685
self._populate_dist = None
1686+
self._reparameterisation.reset()
16761687

16771688
def __getstate__(self):
16781689
state = self.__dict__.copy()
@@ -1696,7 +1707,6 @@ def __getstate__(self):
16961707

16971708
# user provides model and config for resume
16981709
# flow can be reconstructed from resume
1699-
del state["_reparameterisation"]
17001710
del state["model"]
17011711
del state["_flow_config"]
17021712
del state["flow"]

nessai/reparameterisations/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Reparameterisation:
2121
Prior bounds for the parameter(s).
2222
"""
2323

24-
_update_bounds = False
24+
_update = False
2525
has_prior = False
2626
has_prime_prior = False
2727
requires_prime_prior = False

nessai/reparameterisations/rescale.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def __init__(
294294
for p in self.boundary_inversion:
295295
self.rescale_bounds[p] = [0, 1]
296296

297-
self._update_bounds = update_bounds if not detect_edges else True
297+
self._update = update_bounds if not detect_edges else True
298298
self.detect_edges = detect_edges
299299
if self.boundary_inversion:
300300
self._edges = {n: None for n in self.parameters}
@@ -404,7 +404,7 @@ def configure_post_rescaling(self, post_rescaling):
404404
self.has_prime_prior = False
405405

406406
if post_rescaling in ["logit", "log"]:
407-
if self._update_bounds:
407+
if self._update:
408408
raise RuntimeError(
409409
"Cannot use log or logit with update bounds"
410410
)
@@ -608,7 +608,7 @@ def set_bounds(self, prior_bounds):
608608

609609
def update_bounds(self, x):
610610
"""Update the bounds used for the reparameterisation"""
611-
if self._update_bounds:
611+
if self._update:
612612
self.bounds = {
613613
p: [
614614
self.pre_rescaling(np.min(x[p]))[0] - self.offsets[p],

tests/test_deprecation_warnings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,12 @@ def test_flowproposal_rescaled_names_warning():
4646
proposal.prime_parameters = ["x"]
4747
with pytest.warns(FutureWarning, match=r"`rescaled_names` is deprecated"):
4848
assert FlowProposal.rescaled_names.__get__(proposal) == ["x"]
49+
50+
51+
def test_flowproposal_update_bounds_warning():
52+
from nessai.proposal import FlowProposal
53+
54+
proposal = create_autospec(FlowProposal)
55+
proposal.should_update_reparameterisations = True
56+
with pytest.warns(FutureWarning, match=r"`update_bounds` is deprecated"):
57+
assert FlowProposal.update_bounds.__get__(proposal) is True

tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_init_use_default_reparams(model, proposal, value, expected):
3030
def test_initialise(tmpdir, proposal, ef, fuzz):
3131
"""Test the initialise method"""
3232
p = tmpdir.mkdir("test")
33+
proposal.initialised = False
3334
proposal.output = os.path.join(p, "output")
3435
proposal.rescaled_dims = 2
3536
proposal.expansion_fraction = ef
@@ -44,7 +45,7 @@ def test_initialise(tmpdir, proposal, ef, fuzz):
4445
fm.initialise = MagicMock()
4546
proposal._FlowModelClass = MagicMock(new=fm)
4647

47-
FlowProposal.initialise(proposal)
48+
FlowProposal.initialise(proposal, resumed=False)
4849

4950
proposal.set_rescaling.assert_called_once()
5051
proposal.verify_rescaling.assert_called_once()
@@ -95,33 +96,6 @@ def test_resume_w_weights(osexist, proposal):
9596
proposal.flow.reload_weights.assert_called_once_with("weights.pt")
9697

9798

98-
@pytest.mark.parametrize("data", [[1], None])
99-
@pytest.mark.parametrize("count", [0, 1])
100-
def test_resume_w_update_bounds(proposal, data, count):
101-
"""Test the resume method with update bounds"""
102-
proposal.initialise = MagicMock()
103-
proposal.flow = MagicMock()
104-
proposal.mask = None
105-
proposal.update_bounds = True
106-
proposal.weights_file = None
107-
proposal.training_data = data
108-
proposal.training_count = count
109-
proposal.check_state = MagicMock()
110-
model = MagicMock()
111-
if count and data is None:
112-
with pytest.raises(RuntimeError) as excinfo, patch(
113-
"nessai.proposal.base.Proposal.resume"
114-
) as mock:
115-
FlowProposal.resume(proposal, model, {})
116-
assert "Could not resume" in str(excinfo.value)
117-
else:
118-
with patch("nessai.proposal.base.Proposal.resume") as mock:
119-
FlowProposal.resume(proposal, model, {})
120-
if data:
121-
proposal.check_state.assert_called_once_with(data)
122-
mock.assert_called_once_with(model)
123-
124-
12599
@pytest.mark.parametrize("populated", [False, True])
126100
@pytest.mark.parametrize("mask", [None, [1, 0]])
127101
def test_get_state(proposal, populated, mask):
@@ -149,7 +123,6 @@ def test_get_state(proposal, populated, mask):
149123
assert state["initialised"] is False
150124
assert state["weights_file"] == "file"
151125
assert state["mask"] is mask
152-
assert "_reparameterisation" not in state
153126
assert "model" not in state
154127
assert "flow" not in state
155128
assert "_flow_config" not in state
@@ -214,6 +187,7 @@ def test_reset(proposal):
214187
proposal.samples = 2
215188
proposal.populated = True
216189
proposal.populated_count = 10
190+
proposal._reparameterisation = MagicMock()
217191
FlowProposal.reset(proposal)
218192
assert proposal.x is None
219193
assert proposal.samples is None
@@ -222,6 +196,7 @@ def test_reset(proposal):
222196
assert proposal.r is np.nan
223197
assert proposal.alt_dist is None
224198
assert proposal._checked_population
199+
proposal._reparameterisation.reset.assert_called_once()
225200

226201

227202
@pytest.mark.timeout(60)
@@ -280,6 +255,7 @@ def test_reset_integration(tmpdir, model, latent_prior):
280255
# attributes that should be different
281256
ignore = [
282257
"population_time",
258+
"_reparameterisation",
283259
]
284260

285261
d1 = proposal.__getstate__()

tests/test_reparameterisations/test_rescale_to_bounds.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_post_rescaling_with_str(reparam, rescaling):
331331
332332
Also test the config for the logit
333333
"""
334-
reparam._update_bounds = False
334+
reparam._update = False
335335
reparam.parameters = ["x"]
336336
from nessai.utils.rescaling import rescaling_functions
337337

@@ -346,7 +346,7 @@ def test_post_rescaling_with_str(reparam, rescaling):
346346
@pytest.mark.parametrize("rescaling", ["log", "logit"])
347347
def test_post_rescaling_with_logit_update_bounds(reparam, rescaling):
348348
"""Assert an error is raised if using logit and update bounds"""
349-
reparam._update_bounds = True
349+
reparam._update = True
350350
with pytest.raises(
351351
RuntimeError, match=r"Cannot use log or logit with update bounds"
352352
):
@@ -370,7 +370,7 @@ def test_post_rescaling_invalid_input(reparam):
370370
def test_update_bounds_disabled(reparam, caplog):
371371
"""Assert nothing happens in _update_bounds is False"""
372372
caplog.set_level("DEBUG")
373-
reparam._update_bounds = False
373+
reparam._update = False
374374
RescaleToBounds.update_bounds(reparam, [0, 1])
375375
assert "Update bounds not enabled" in str(caplog.text)
376376

tests/test_sampling/test_standard_sampling.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,58 @@ def test_resume_fallback_reparameterisation(tmpdir, model, flow_config):
379379
)
380380

381381

382+
@pytest.mark.slow_integration_test
383+
def test_resume_reparameterisation_values(tmpdir, model, flow_config):
384+
"""
385+
Assert the scale and shift values are correct.
386+
"""
387+
from nessai.reparameterisations.rescale import ScaleAndShift
388+
389+
output = str(tmpdir.mkdir("resume"))
390+
fp = FlowSampler(
391+
model,
392+
output=output,
393+
resume=True,
394+
nlive=100,
395+
plot=False,
396+
flow_config=flow_config,
397+
training_frequency=10,
398+
maximum_uninformed=9,
399+
reparameterisations="z-score",
400+
checkpoint_on_iteration=True,
401+
checkpoint_interval=5,
402+
seed=1234,
403+
max_iteration=11,
404+
poolsize=10,
405+
)
406+
fp.run()
407+
408+
reparam = fp.ns._flow_proposal._reparameterisation
409+
reparam = next(iter(fp.ns._flow_proposal._reparameterisation.values()))
410+
assert isinstance(reparam, ScaleAndShift)
411+
original_scale = reparam.scale
412+
original_shift = reparam.shift
413+
assert os.path.exists(os.path.join(output, "nested_sampler_resume.pkl"))
414+
415+
fp = FlowSampler(
416+
model,
417+
output=output,
418+
resume=True,
419+
flow_config=flow_config,
420+
)
421+
assert fp.ns.iteration == 11
422+
fp.ns.max_iteration = 21
423+
reparam = next(iter(fp.ns._flow_proposal._reparameterisation.values()))
424+
assert isinstance(reparam, ScaleAndShift)
425+
assert reparam.scale == original_scale
426+
assert reparam.shift == original_shift
427+
fp.run()
428+
assert fp.ns.iteration == 21
429+
assert os.path.exists(
430+
os.path.join(output, "nested_sampler_resume.pkl.old")
431+
)
432+
433+
382434
@pytest.mark.slow_integration_test
383435
def test_sampling_with_infinite_prior_bounds(tmpdir):
384436
"""

0 commit comments

Comments
 (0)