Skip to content

Commit ef5ad38

Browse files
Make activation function configurable as argument (#60)
* 🔧 Make activation function configurable as argument * 🔨 Fixing format and lint code * 🔨 Fixing unsorted imports * 🔨 Remove extra whitespace * 🔨 Fixing format and lint code * 📦 Support string-based activations for all blocks * ➖ Remove unused package * 🔨 Remove extra whitespace * 🔨 Replace lambda assignments with defs in blocks to satisfy E731 * 🔨 Fixing ruff format * ♻️ Use class returned from dictionary directly so that we can pass arguments when we instantiate it. This is necessary to reproduce previous behaviour. --------- Co-authored-by: James Robinson <[email protected]>
1 parent 2626f03 commit ef5ad38

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torch import nn
2+
3+
ACTIVATION_FROM_NAME: dict[str, type[nn.Module]] = {
4+
"ReLU": nn.ReLU,
5+
"LeakyReLU": nn.LeakyReLU,
6+
"ELU": nn.ELU,
7+
"GELU": nn.GELU,
8+
"SiLU": nn.SiLU,
9+
"Sigmoid": nn.Sigmoid,
10+
"Tanh": nn.Tanh,
11+
}

ice_station_zebra/models/common/bottleneckblock.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from torch import Tensor, nn
22

3+
from .activations import ACTIVATION_FROM_NAME
4+
35

46
class BottleneckBlock(nn.Module):
57
def __init__(
@@ -8,19 +10,22 @@ def __init__(
810
out_channels: int,
911
*,
1012
filter_size: int,
13+
activation: str = "ReLU",
1114
) -> None:
1215
"""Initialise a BottleneckBlock."""
1316
super().__init__()
1417

18+
activation_layer = ACTIVATION_FROM_NAME[activation]
19+
1520
self.model = nn.Sequential(
1621
nn.Conv2d(
1722
in_channels, out_channels, kernel_size=filter_size, padding="same"
1823
),
19-
nn.ReLU(inplace=True),
24+
activation_layer(inplace=True),
2025
nn.Conv2d(
2126
out_channels, out_channels, kernel_size=filter_size, padding="same"
2227
),
23-
nn.ReLU(inplace=True),
28+
activation_layer(inplace=True),
2429
nn.BatchNorm2d(num_features=out_channels),
2530
)
2631

ice_station_zebra/models/common/convblock.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from torch import Tensor, nn
22

3+
from .activations import ACTIVATION_FROM_NAME
4+
35

46
class ConvBlock(nn.Module):
57
def __init__(
@@ -9,19 +11,22 @@ def __init__(
911
*,
1012
filter_size: int,
1113
final: bool = False,
14+
activation: str = "ReLU",
1215
) -> None:
1316
"""Initialise a ConvBlock."""
1417
super().__init__()
1518

19+
activation_layer = ACTIVATION_FROM_NAME[activation]
20+
1621
layers = [
1722
nn.Conv2d(
1823
in_channels, out_channels, kernel_size=filter_size, padding="same"
1924
),
20-
nn.ReLU(inplace=True),
25+
activation_layer(inplace=True),
2126
nn.Conv2d(
2227
out_channels, out_channels, kernel_size=filter_size, padding="same"
2328
),
24-
nn.ReLU(inplace=True),
29+
activation_layer(inplace=True),
2530
]
2631
if final:
2732
layers += [
@@ -31,7 +36,7 @@ def __init__(
3136
kernel_size=filter_size,
3237
padding="same",
3338
),
34-
nn.ReLU(inplace=True),
39+
activation_layer(inplace=True),
3540
]
3641

3742
else:

ice_station_zebra/models/common/upconvblock.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
from torch import Tensor, nn
22

3+
from .activations import ACTIVATION_FROM_NAME
4+
35

46
class UpconvBlock(nn.Module):
5-
def __init__(self, in_channels: int, out_channels: int) -> None:
7+
def __init__(
8+
self,
9+
in_channels: int,
10+
out_channels: int,
11+
activation: str = "ReLU",
12+
) -> None:
613
"""Initialise an UpconvBlock."""
714
super().__init__()
815

16+
activation_layer = ACTIVATION_FROM_NAME[activation]
17+
918
self.model = nn.Sequential(
1019
nn.Upsample(scale_factor=2, mode="nearest"),
1120
nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"),
12-
nn.ReLU(inplace=True),
21+
activation_layer(inplace=True),
1322
)
1423

1524
def forward(self, x: Tensor) -> Tensor:

0 commit comments

Comments
 (0)