Skip to content

Commit 2626f03

Browse files
jemrobinsonIFenton
andauthored
Linting fixes (#66)
* 🚨 Add list of disabled ruff checks * 🚨 Fix builtin-variable-shadowing * 🚨 Fix unnecessary-generator-set * 🚨 Fix unnecessary-comprehension * 🚨 Fix commented-out-code * 🚨 Fix manual-list-comprehension * 🚨 Fix manual-from-import * 🚨 Fix missing-type-function-argument * 🚨 Fix missing-return-type-undocumented-public-function * 🚨 Fix missing-return-type-special-method * 🚨 Fix unused-method-argument * 🚨 Fix call-datetime-without-tzinfo * 🚨 Fix call-datetime-now-without-tzinfo * 🚨 Fix call-datetime-strptime-without-zone * 🚨 Fix f-string-in-exception * 🚨 Fix boolean-type-hint-positional-argument * 🚨 Fix boolean-type-hint-positional-argument * 🚨 Fix logging-f-string * 🚨 Fix implicit-namespace-package * 🚨 Fix magic-value-comparison * 🚨 Fix pytest-raises-too-broad * 🚨 Fix os-getcwd * 🚨 Fix unnecessary-assign * 🚨 Fix mutable-class-default * 🚨 Fix unsorted-dunder-all * 🚨 Fix in-dict-keys * 🚨 Fix typing-only-first-party-import * 🚨 Fix typing-only-third-party-import * 🚨 Fix non-pep585-annotation * 🚨 Fix missing-terminal-punctuation * 🚨 Fix multi-line-summary-first-line * 🚨 Fix too-many-arguments * 🚨 Fix new-line-after-last-paragraph * 🚨 Fix missing-blank-line-after-last-section * 🚨 Fix undocumented-public-function * 🚨 Fix undocumented-magic-method * 🚨 Fix undocumented-public-init * 👽 Allow unlocalised timezones in tests to avoid pydata/xarray#8653 * 🏷️ Accept generic callable in hydra_adaptor * 🏷️ Accept Sequence in CLI output check tests * 🎨 Tidying the common models __init__.py --------- Co-authored-by: Isabel Fenton <[email protected]>
1 parent a9d632d commit 2626f03

36 files changed

+295
-174
lines changed

ice_station_zebra/callbacks/metric_summary_callback.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,25 @@
1515
class MetricSummaryCallback(Callback):
1616
"""A callback to summarise metrics during evaluation."""
1717

18-
def __init__(self, average_loss: bool = True) -> None:
18+
def __init__(self, *, average_loss: bool = True) -> None:
1919
"""Summarise metrics during evaluation.
2020
2121
Args:
2222
average_loss: Whether to log average loss
23+
2324
"""
2425
self.metrics: dict[str, list[float]] = {}
2526
if average_loss:
2627
self.metrics["average_loss"] = []
2728

2829
def on_test_batch_end(
2930
self,
30-
trainer: Trainer,
31-
module: LightningModule,
31+
_trainer: Trainer,
32+
_module: LightningModule,
3233
outputs: Tensor | Mapping[str, Any] | None,
33-
batch: Any,
34-
batch_idx: int,
35-
dataloader_idx: int = 0,
34+
_batch: Any, # noqa: ANN401
35+
_batch_idx: int,
36+
_dataloader_idx: int = 0,
3637
) -> None:
3738
"""Called when the test batch ends."""
3839
if not isinstance(outputs, ModelTestOutput):
@@ -43,7 +44,11 @@ def on_test_batch_end(
4344
if "average_loss" in self.metrics:
4445
self.metrics["average_loss"].append(outputs.loss.item())
4546

46-
def on_test_epoch_end(self, trainer: Trainer, module: LightningModule) -> None:
47+
def on_test_epoch_end(
48+
self,
49+
trainer: Trainer,
50+
_module: LightningModule,
51+
) -> None:
4752
"""Called at the end of the test epoch."""
4853
# Post-process accumulated metrics into a single value
4954
metrics_: dict[str, float] = {}

ice_station_zebra/callbacks/plotting_callback.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,33 @@
11
import logging
22
from collections.abc import Mapping, Sequence
3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44

55
from lightning import LightningModule, Trainer
66
from lightning.pytorch import Callback
77
from torch import Tensor
8-
from torch.utils.data import DataLoader
98

109
from ice_station_zebra.data_loaders import CombinedDataset
1110
from ice_station_zebra.types import ModelTestOutput
1211
from ice_station_zebra.visualisations import plot_sic_comparison
1312

13+
if TYPE_CHECKING:
14+
from torch.utils.data import DataLoader
15+
1416
logger = logging.getLogger(__name__)
1517

1618

1719
class PlottingCallback(Callback):
1820
"""A callback to create plots during evaluation."""
1921

2022
def __init__(
21-
self, frequency: int = 10, plot_sea_ice_concentration: bool = True
23+
self, *, frequency: int = 10, plot_sea_ice_concentration: bool = True
2224
) -> None:
2325
"""Create plots during evaluation.
2426
2527
Args:
2628
frequency: Create a new plot every `frequency` batches.
2729
plot_sea_ice_concentration: Whether to plot sea ice concentration.
30+
2831
"""
2932
super().__init__()
3033
self.frequency = frequency
@@ -35,9 +38,9 @@ def __init__(
3538
def on_test_batch_end(
3639
self,
3740
trainer: Trainer,
38-
module: LightningModule,
41+
_module: LightningModule,
3942
outputs: Tensor | Mapping[str, Any] | None,
40-
batch: Any,
43+
_batch: Any, # noqa: ANN401
4144
batch_idx: int,
4245
dataloader_idx: int = 0,
4346
) -> None:
@@ -82,5 +85,6 @@ def on_test_batch_end(
8285
lightning_logger.log_image(key=key, images=image_list)
8386
else:
8487
logger.debug(
85-
f"Logger {lightning_logger.name} does not support logging images."
88+
"Logger %s does not support logging images.",
89+
lightning_logger.name,
8690
)

ice_station_zebra/callbacks/unconditional_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
class UnconditionalCheckpoint(Callback):
88
"""A callback to summarise metrics during evaluation."""
99

10-
def __init__(self, on_train_end: bool = False) -> None:
10+
def __init__(self, *, on_train_end: bool = False) -> None:
1111
"""Save a checkpoint unconditionally.
1212
1313
Args:
1414
on_train_end: Whether to save a checkpoint at the end of training
15+
1516
"""
1617
super().__init__()
1718
self.impl = ModelCheckpoint()

ice_station_zebra/cli/hydra.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
RetType = TypeVar("RetType")
1212

1313

14-
def hydra_adaptor(function) -> Callable[Param, RetType]:
15-
"""Replace a function that takes a Hydra config with one that takes string arguments
14+
def hydra_adaptor(function: Callable) -> Callable[Param, RetType]:
15+
"""Replace a function that takes a Hydra config with one that takes string arguments.
1616
1717
Args:
1818
function: Callable(*args, config: DictConfig, **kwargs)
1919
2020
Returns:
2121
Callable(*args, config_name: str, **kwargs, overrides: list[str])
22+
2223
"""
2324

2425
def wrapper(

ice_station_zebra/data_loaders/combined_dataset.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Sequence
2-
from datetime import datetime
2+
from datetime import UTC, datetime
33

44
import numpy as np
55
from torch.utils.data import Dataset
@@ -18,7 +18,12 @@ def __init__(
1818
n_forecast_steps: int = 1,
1919
n_history_steps: int = 1,
2020
) -> None:
21-
"""Constructor"""
21+
"""Initialise a combined dataset from a sequence of ZebraDatasets.
22+
23+
One of the datasets must be the target and all must have the same frequency. The
24+
number of forecast and history steps can be set, which will determine the shape
25+
of the NTCHW tensors returned by __getitem__.
26+
"""
2227
super().__init__()
2328

2429
# Store the number of forecast and history steps
@@ -27,10 +32,10 @@ def __init__(
2732

2833
# Define target and input datasets
2934
self.target = next(ds for ds in datasets if ds.name == target)
30-
self.inputs = [ds for ds in datasets]
35+
self.inputs = list(datasets)
3136

3237
# Require that all datasets have the same frequency
33-
frequencies = sorted(set(ds.dataset.frequency for ds in datasets))
38+
frequencies = sorted({ds.dataset.frequency for ds in datasets})
3439
if len(frequencies) != 1:
3540
msg = f"Cannot combine datasets with different frequencies: {frequencies}."
3641
raise ValueError(msg)
@@ -57,17 +62,18 @@ def __init__(
5762
]
5863

5964
def __len__(self) -> int:
60-
"""Return the total length of the dataset"""
65+
"""Return the total length of the dataset."""
6166
return len(self.available_dates)
6267

6368
def __getitem__(self, idx: int) -> dict[str, ArrayTCHW]:
64-
"""Return the data for a single timestep as a dictionary
69+
"""Return the data for a single timestep as a dictionary.
6570
6671
Returns:
6772
A dictionary with dataset names as keys and a numpy array as the value.
6873
The shape of each array is:
6974
- input datasets: [n_history_steps, C_input_k, H_input_k, W_input_k]
7075
- target dataset: [n_forecast_steps, C_target, H_target, W_target]
76+
7177
"""
7278
return {
7379
ds.name: ds.get_tchw(self.get_history_steps(self.available_dates[idx]))
@@ -79,9 +85,9 @@ def __getitem__(self, idx: int) -> dict[str, ArrayTCHW]:
7985
}
8086

8187
def date_from_index(self, idx: int) -> datetime:
82-
"""Return the date of the timestep"""
88+
"""Return the date of the timestep."""
8389
np_datetime = self.available_dates[idx]
84-
return datetime.strptime(str(np_datetime), r"%Y-%m-%dT%H:%M:%S")
90+
return datetime.strptime(str(np_datetime), r"%Y-%m-%dT%H:%M:%S").astimezone(UTC)
8591

8692
def get_forecast_steps(self, start_date: np.datetime64) -> list[np.datetime64]:
8793
"""Return list of consecutive forecast dates for a given start date."""
@@ -99,7 +105,7 @@ def get_history_steps(self, start_date: np.datetime64) -> list[np.datetime64]:
99105
@property
100106
def end_date(self) -> np.datetime64:
101107
"""Return the end date of the dataset."""
102-
end_date = set(dataset.end_date for dataset in self.inputs)
108+
end_date = {dataset.end_date for dataset in self.inputs}
103109
if len(end_date) != 1:
104110
msg = f"Datasets have {len(end_date)} different end dates"
105111
raise ValueError(msg)
@@ -108,7 +114,7 @@ def end_date(self) -> np.datetime64:
108114
@property
109115
def start_date(self) -> np.datetime64:
110116
"""Return the start date of the dataset."""
111-
start_date = set(dataset.start_date for dataset in self.inputs)
117+
start_date = {dataset.start_date for dataset in self.inputs}
112118
if len(start_date) != 1:
113119
msg = f"Datasets have {len(start_date)} different start dates"
114120
raise ValueError(msg)

ice_station_zebra/data_loaders/zebra_data_module.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717

1818
class ZebraDataModule(LightningDataModule):
1919
def __init__(self, config: DictConfig) -> None:
20+
"""Initialise a ZebraDataModule from a config.
21+
22+
The config specifies all datasets used and how to group them. Data splits are
23+
also determined from the config, and the appropriate data loaders are created.
24+
"""
2025
super().__init__()
2126

2227
# Load paths
@@ -30,14 +35,15 @@ def __init__(self, config: DictConfig) -> None:
3035
self.base_path / "data" / "anemoi" / f"{dataset['name']}.zarr"
3136
).resolve()
3237
)
33-
logger.info(f"Found {len(self.dataset_groups)} dataset_groups")
34-
for dataset_group in self.dataset_groups.keys():
35-
logger.debug(f"... {dataset_group}")
38+
logger.info("Found %d dataset_groups.", len(self.dataset_groups))
39+
for dataset_group in self.dataset_groups:
40+
logger.debug("... %s.", dataset_group)
3641

3742
# Check prediction target
3843
self.predict_target = config["predict"]["dataset_group"]
3944
if self.predict_target not in self.dataset_groups:
40-
raise ValueError(f"Could not find prediction target {self.predict_target}")
45+
msg = f"Could not find prediction target {self.predict_target}."
46+
raise ValueError(msg)
4147

4248
# Set periods for train, validation, and test
4349
self.batch_size = int(config["split"]["batch_size"])
@@ -67,7 +73,7 @@ def __init__(self, config: DictConfig) -> None:
6773

6874
@cached_property
6975
def input_spaces(self) -> list[DataSpace]:
70-
"""Return the data space for each input"""
76+
"""Return the data space for each input."""
7177
return [
7278
ZebraDataset(name, paths).space
7379
for name, paths in self.dataset_groups.items()
@@ -76,7 +82,7 @@ def input_spaces(self) -> list[DataSpace]:
7682

7783
@cached_property
7884
def output_space(self) -> DataSpace:
79-
"""Return the data space of the desired output"""
85+
"""Return the data space of the desired output."""
8086
return next(
8187
ZebraDataset(name, paths).space
8288
for name, paths in self.dataset_groups.items()
@@ -86,7 +92,7 @@ def output_space(self) -> DataSpace:
8692
def train_dataloader(
8793
self,
8894
) -> DataLoader[dict[str, ArrayTCHW]]:
89-
"""Construct train dataloader"""
95+
"""Construct train dataloader."""
9096
dataset = CombinedDataset(
9197
[
9298
ZebraDataset(
@@ -102,7 +108,7 @@ def train_dataloader(
102108
target=self.predict_target,
103109
)
104110
logger.info(
105-
"Loaded training dataset with %d samples between %s and %s",
111+
"Loaded training dataset with %d samples between %s and %s.",
106112
len(dataset),
107113
dataset.start_date,
108114
dataset.end_date,
@@ -112,7 +118,7 @@ def train_dataloader(
112118
def val_dataloader(
113119
self,
114120
) -> DataLoader[dict[str, ArrayTCHW]]:
115-
"""Construct validation dataloader"""
121+
"""Construct validation dataloader."""
116122
dataset = CombinedDataset(
117123
[
118124
ZebraDataset(
@@ -128,7 +134,7 @@ def val_dataloader(
128134
target=self.predict_target,
129135
)
130136
logger.info(
131-
"Loaded validation dataset with %d samples between %s and %s",
137+
"Loaded validation dataset with %d samples between %s and %s.",
132138
len(dataset),
133139
dataset.start_date,
134140
dataset.end_date,
@@ -138,7 +144,7 @@ def val_dataloader(
138144
def test_dataloader(
139145
self,
140146
) -> DataLoader[dict[str, ArrayTCHW]]:
141-
"""Construct test dataloader"""
147+
"""Construct test dataloader."""
142148
dataset = CombinedDataset(
143149
[
144150
ZebraDataset(
@@ -154,7 +160,7 @@ def test_dataloader(
154160
target=self.predict_target,
155161
)
156162
logger.info(
157-
"Loaded test dataset with %d samples between %s and %s",
163+
"Loaded test dataset with %d samples between %s and %s.",
158164
len(dataset),
159165
dataset.start_date,
160166
dataset.end_date,

ice_station_zebra/data_loaders/zebra_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def __init__(
1919
start: str | None = None,
2020
end: str | None = None,
2121
) -> None:
22-
"""A dataset for use by Zebra
22+
"""A dataset for use by Zebra.
2323
24-
Dataset shape is: time; variables; ensembles; position
25-
We reshape each time point to: variables; pos_x; pos_y
24+
The underlying Anemoi dataset has shape [T; C; ensembles; position].
25+
We reshape this to CHW before returning.
2626
"""
2727
super().__init__()
2828
self._cache: LRUCache = LRUCache(maxsize=128)
@@ -67,15 +67,15 @@ def start_date(self) -> np.datetime64:
6767
return self.dataset.start_date
6868

6969
def __len__(self) -> int:
70-
"""Return the total length of the dataset"""
70+
"""Return the total length of the dataset."""
7171
return len(self.dataset)
7272

7373
def __getitem__(self, idx: int) -> ArrayCHW:
74-
"""Return the data for a single timestep in [C, H, W] format"""
74+
"""Return the data for a single timestep in [C, H, W] format."""
7575
return self.dataset[idx].reshape(self.space.chw)
7676

7777
def get_tchw(self, dates: Sequence[np.datetime64]) -> ArrayTCHW:
78-
"""Return the data for a series of timesteps in [T, C, H, W] format"""
78+
"""Return the data for a series of timesteps in [T, C, H, W] format."""
7979
return np.stack(
8080
[self[self.index_from_date(target_date)] for target_date in dates], axis=0
8181
)

ice_station_zebra/data_processors/cli.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,26 @@
1010
# Create the typer app
1111
datasets_cli = typer.Typer(help="Manage datasets")
1212

13-
log = logging.getLogger(__name__)
13+
logger = logging.getLogger(__name__)
1414

1515

1616
@datasets_cli.command("create")
1717
@hydra_adaptor
1818
def create(config: DictConfig) -> None:
19-
"""Create all datasets"""
19+
"""Create all datasets."""
2020
factory = ZebraDataProcessorFactory(config)
2121
for dataset in factory.datasets:
22-
log.info(f"Working on {dataset.name}")
22+
logger.info("Working on %s.", dataset.name)
2323
dataset.create()
2424

2525

2626
@datasets_cli.command("inspect")
2727
@hydra_adaptor
2828
def inspect(config: DictConfig) -> None:
29-
"""Inspect all datasets"""
29+
"""Inspect all datasets."""
3030
factory = ZebraDataProcessorFactory(config)
3131
for dataset in factory.datasets:
32-
log.info(f"Working on {dataset.name}")
32+
logger.info("Working on %s.", dataset.name)
3333
dataset.inspect()
3434

3535

ice_station_zebra/data_processors/preprocessors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .icenet_sic import IceNetSICPreprocessor
33

44
__all__ = [
5-
"IceNetSICPreprocessor",
65
"IPreprocessor",
6+
"IceNetSICPreprocessor",
77
"NullPreprocessor",
88
]

0 commit comments

Comments
 (0)