Skip to content

Commit 34a78d2

Browse files
author
corochann
committed
test for set_up_predictor
1 parent 85468da commit 34a78d2

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

chainer_chemistry/models/prediction/set_up_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def set_up_predictor(
135135
print('Training a GNN_FiLM predictor...')
136136
conv = GNNFiLM(
137137
out_dim=n_unit,
138-
hidden_dim=n_unit,
138+
hidden_channels=n_unit,
139139
n_update_layers=conv_layers,
140140
n_edge_types=5,
141141
**conv_kwargs)

tests/functions_tests/loss/test_mean_absolute_error.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def check_backward_ignore_nan(inputs):
118118
def func(x0, x1):
119119
return chainer_chemistry.functions.mean_absolute_error(x0, x1,
120120
ignore_nan=True)
121-
gradient_check.check_backward(func, (x0_data, x2_data), None, eps=1e-2)
121+
gradient_check.check_backward(func, (x0_data, x2_data), None, eps=1e-2,
122+
atol=1e-3, rtol=1e-3)
122123

123124

124125
def check_backward_ignore_nan_with_nonnan_value(inputs):

tests/models_tests/prediction_tests/test_set_up_predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from chainer_chemistry.models.ggnn import GGNN
77
from chainer_chemistry.models.gin import GIN
8+
from chainer_chemistry.models.gnn_film import GNNFiLM
89
from chainer_chemistry.models.nfp import NFP
910
from chainer_chemistry.models.prediction.graph_conv_predictor import GraphConvPredictor # NOQA
1011
from chainer_chemistry.models.prediction.set_up_predictor import set_up_predictor # NOQA
@@ -39,7 +40,8 @@ def models_dict():
3940
'nfp_gwm': NFP_GWM,
4041
'ggnn_gwm': GGNN_GWM,
4142
'rsgcn_gwm': RSGCN_GWM,
42-
'gin_gwm': GIN_GWM
43+
'gin_gwm': GIN_GWM,
44+
'gnnfilm': GNNFiLM
4345
}
4446

4547

0 commit comments

Comments
 (0)