From 610036738e2edb8ace98b393d86366ba316884ea Mon Sep 17 00:00:00 2001 From: elmartinj Date: Thu, 18 Jun 2026 10:45:32 -0600 Subject: [PATCH] Add clean cache option to forecaster --- tests/test_forecaster.py | 17 +++++++++++++++++ timecopilot/forecaster.py | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/tests/test_forecaster.py b/tests/test_forecaster.py index a1a1aba1..fae035f0 100644 --- a/tests/test_forecaster.py +++ b/tests/test_forecaster.py @@ -173,3 +173,20 @@ def test_mixed_models_unique_aliases(): # This should not raise an error forecaster = TimeCopilotForecaster(models=[model1, model2, model3]) assert len(forecaster.models) == 3 + + +def test_clean_cache_runs_after_each_model(monkeypatch, models): + calls = [] + + monkeypatch.setattr( + TimeCopilotForecaster, + "_clean_model_cache", + staticmethod(lambda: calls.append("cleaned")), + ) + + df = generate_series(n_series=1, freq="D", min_length=10) + forecaster = TimeCopilotForecaster(models=models, clean_cache=True) + + forecaster.forecast(df=df, h=2, freq="D") + + assert calls == ["cleaned"] * len(models) diff --git a/timecopilot/forecaster.py b/timecopilot/forecaster.py index 1bf8eeca..93159af8 100644 --- a/timecopilot/forecaster.py +++ b/timecopilot/forecaster.py @@ -47,6 +47,7 @@ def __init__( self, models: list[Forecaster], fallback_model: Forecaster | None = None, + clean_cache: bool = False, ): """ Initialize the TimeCopilotForecaster with a list of models. @@ -59,6 +60,9 @@ def __init__( compatible signatures. fallback_model (Forecaster, optional): Model to use as a fallback when a model fails. + clean_cache (bool): + If True, run Python garbage collection and clear the CUDA cache + after each model call. Useful for memory-heavy foundation models. Raises: ValueError: If duplicate model aliases are found in the models list. @@ -66,6 +70,7 @@ def __init__( self._validate_unique_aliases(models) self.models = models self.fallback_model = fallback_model + self.clean_cache = clean_cache def _validate_unique_aliases(self, models: list[Forecaster]) -> None: """ @@ -88,6 +93,20 @@ def _validate_unique_aliases(self, models: list[Forecaster]) -> None: f"same class." ) + @staticmethod + def _clean_model_cache() -> None: + """Release temporary Python and CUDA memory between model calls.""" + import gc + + gc.collect() + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + @staticmethod def _is_distributed_df(df: AnyDataFrame) -> bool: """ @@ -155,6 +174,8 @@ def _call_models( # (the initial model) res_df_model = res_df_model.drop(columns=["y"]) res_df = res_df.merge(res_df_model, on=merge_on, how="left") + if self.clean_cache: + self._clean_model_cache() return res_df def _forecast_pandas(