Skip to content

Commit 2a45d68

Browse files
authored
Merge pull request #386 from corochann/graph-film-v2
Graph film v2
2 parents e912e2e + f7e4e0f commit 2a45d68

File tree

19 files changed

+455
-20
lines changed

19 files changed

+455
-20
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ The following graph convolutional neural networks are currently supported:
111111
- GIN: Graph Isomorphism Networks [17]
112112
- MPNN: Message Passing Neural Networks [3]
113113
- Set2Set [19]
114+
- GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation [20]
114115

115116
We test supporting the brand-new Graph Warp Module (GWM) [18]-attached models for:
116117
- NFP ('nfp_gwm')
@@ -202,3 +203,6 @@ papers. Use the library at your own risk.
202203
[18] K. Ishiguro, S. Maeda, and M. Koyama, ``Graph Warp Module: an Auxiliary Module for Boosting the Power of Graph Neural Networks'', arXiv:1902.01020 [cs.LG], 2019.
203204

204205
[19] Oriol Vinyals, Samy Bengio, Manjunath Kudlur. Order Matters: Sequence to sequence for sets. *arXiv preprint arXiv:1511.06391*, 2015.
206+
.
207+
208+
[20] Marc Brockschmidt, ``GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation'', arXiv:1906.12192 [cs.ML], 2019.

chainer_chemistry/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from chainer_chemistry import dataset # NOQA
44
try:
55
from chainer_chemistry import datasets # NOQA
6-
except ImportError:
7-
warnings.warn(
8-
'A module chainer_chemistry.datasets was not imported, '
9-
'probably because RDKit is not installed. '
10-
'To install RDKit, please follow instruction in '
11-
'https://github.com/pfnet-research/chainer-chemistry#installation.',
12-
UserWarning)
6+
except ImportError as e:
7+
if 'rdkit' in e.msg:
8+
warnings.warn(
9+
'A module chainer_chemistry.datasets was not imported, '
10+
'probably because RDKit is not installed. '
11+
'To install RDKit, please follow instruction in '
12+
'https://github.com/pfnet-research/chainer-chemistry#installation.',
13+
UserWarning)
14+
else:
15+
raise(e)
1316
from chainer_chemistry import functions # NOQA
1417
from chainer_chemistry import links # NOQA
1518
from chainer_chemistry import models # NOQA

chainer_chemistry/dataset/preprocessors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from chainer_chemistry.dataset.preprocessors.ecfp_preprocessor import ECFPPreprocessor # NOQA
1010
from chainer_chemistry.dataset.preprocessors.relgat_preprocessor import RelGATPreprocessor # NOQA
1111
from chainer_chemistry.dataset.preprocessors.ggnn_preprocessor import GGNNPreprocessor # NOQA
12+
from chainer_chemistry.dataset.preprocessors.gnnfilm_preprocessor import GNNFiLMPreprocessor # NOQA
1213
from chainer_chemistry.dataset.preprocessors.gin_preprocessor import GINPreprocessor # NOQA
1314
from chainer_chemistry.dataset.preprocessors.gwm_preprocessor import GGNNGWMPreprocessor # NOQA
1415
from chainer_chemistry.dataset.preprocessors.gwm_preprocessor import GINGWMPreprocessor # NOQA
@@ -35,4 +36,5 @@
3536
'rsgcn': RSGCNPreprocessor,
3637
'rsgcn_gwm': RSGCNGWMPreprocessor,
3738
'relgat': RelGATPreprocessor,
39+
'gnnfilm': GNNFiLMPreprocessor,
3840
}

chainer_chemistry/dataset/preprocessors/common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def construct_adj_matrix(mol, out_size=-1, self_connection=True):
118118
return adj_array
119119

120120

121-
def construct_discrete_edge_matrix(mol, out_size=-1):
121+
def construct_discrete_edge_matrix(mol, out_size=-1,
122+
add_self_connection_channel=False):
122123
"""Returns the edge-type dependent adjacency matrix of the given molecule.
123124
124125
Args:
@@ -129,6 +130,9 @@ def construct_discrete_edge_matrix(mol, out_size=-1):
129130
in the input molecules. In that case, the adjacent
130131
matrix is expanded and zeros are padded to right
131132
columns and bottom rows.
133+
add_self_connection_channel (bool): Add self connection or not.
134+
If True, adjacency matrix whose diagonal element filled with 1
135+
is added to last channel.
132136
133137
Returns:
134138
adj_array (numpy.ndarray): The adjacent matrix of the input molecule.
@@ -150,7 +154,10 @@ def construct_discrete_edge_matrix(mol, out_size=-1):
150154
raise ValueError(
151155
'out_size {} is smaller than number of atoms in mol {}'
152156
.format(out_size, N))
153-
adjs = numpy.zeros((4, size, size), dtype=numpy.float32)
157+
if add_self_connection_channel:
158+
adjs = numpy.zeros((5, size, size), dtype=numpy.float32)
159+
else:
160+
adjs = numpy.zeros((4, size, size), dtype=numpy.float32)
154161

155162
bond_type_to_channel = {
156163
Chem.BondType.SINGLE: 0,
@@ -165,6 +172,8 @@ def construct_discrete_edge_matrix(mol, out_size=-1):
165172
j = bond.GetEndAtomIdx()
166173
adjs[ch, i, j] = 1.0
167174
adjs[ch, j, i] = 1.0
175+
if add_self_connection_channel:
176+
adjs[-1] = numpy.eye(N)
168177
return adjs
169178

170179

chainer_chemistry/dataset/preprocessors/gin_preprocessor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False):
3232
self.max_atoms = max_atoms
3333
self.out_size = out_size
3434

35-
3635
def get_input_features(self, mol):
3736
"""get input features
3837
@@ -45,4 +44,4 @@ def get_input_features(self, mol):
4544
type_check_num_atoms(mol, self.max_atoms)
4645
atom_array = construct_atomic_number_array(mol, out_size=self.out_size)
4746
adj_array = construct_adj_matrix(mol, out_size=self.out_size)
48-
return atom_array, adj_array
47+
return atom_array, adj_array
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from chainer_chemistry.dataset.preprocessors.common \
2+
import construct_atomic_number_array, construct_discrete_edge_matrix
3+
from chainer_chemistry.dataset.preprocessors.common import type_check_num_atoms
4+
from chainer_chemistry.dataset.preprocessors.mol_preprocessor \
5+
import MolPreprocessor
6+
7+
8+
class GNNFiLMPreprocessor(MolPreprocessor):
9+
"""GNNFiLM Preprocessor
10+
11+
Args:
12+
max_atoms (int): Max number of atoms for each molecule, if the
13+
number of atoms is more than this value, this data is simply
14+
ignored.
15+
Setting negative value indicates no limit for max atoms.
16+
out_size (int): It specifies the size of array returned by
17+
`get_input_features`.
18+
If the number of atoms in the molecule is less than this value,
19+
the returned arrays is padded to have fixed size.
20+
Setting negative value indicates do not pad returned array.
21+
add_Hs (bool): If True, implicit Hs are added.
22+
kekulize (bool): If True, Kekulizes the molecule.
23+
24+
"""
25+
26+
def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False,
27+
kekulize=False):
28+
super(GNNFiLMPreprocessor, self).__init__(
29+
add_Hs=add_Hs, kekulize=kekulize)
30+
if max_atoms >= 0 and out_size >= 0 and max_atoms > out_size:
31+
raise ValueError('max_atoms {} must be less or equal to '
32+
'out_size {}'.format(max_atoms, out_size))
33+
self.max_atoms = max_atoms
34+
self.out_size = out_size
35+
36+
def get_input_features(self, mol):
37+
"""get input features
38+
39+
Args:
40+
mol (Mol): Molecule input
41+
42+
Returns:
43+
44+
"""
45+
type_check_num_atoms(mol, self.max_atoms)
46+
atom_array = construct_atomic_number_array(mol, out_size=self.out_size)
47+
adj_array = construct_discrete_edge_matrix(mol, out_size=self.out_size, add_self_connection_channel=True)
48+
return atom_array, adj_array

chainer_chemistry/dataset/preprocessors/mol_preprocessor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def prepare_smiles_and_mol(self, mol):
3939
Chem.Kekulize(mol)
4040
return canonical_smiles, mol
4141

42-
4342
def get_label(self, mol, label_names=None):
4443
"""Extracts label information from a molecule.
4544
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import chainer
2+
from chainer import functions
3+
from chainer import links
4+
5+
from chainer_chemistry.links.connection.graph_linear import GraphLinear
6+
7+
8+
class GNNFiLMUpdate(chainer.Chain):
9+
"""GNNFiLM submodule for update part.
10+
11+
Args:
12+
hidden_channels (int): dimension of feature vector associated to
13+
each atom
14+
n_edge_types (int): number of types of edge
15+
"""
16+
17+
def __init__(self, hidden_channels=16, n_edge_types=5, activation=functions.relu):
18+
super(GNNFiLMUpdate, self).__init__()
19+
self.n_edge_types = n_edge_types
20+
self.activation = activation
21+
with self.init_scope():
22+
self.W_linear = GraphLinear(
23+
in_size=None, out_size=self.n_edge_types * hidden_channels, nobias=True) # W_l in eq. (6)
24+
self.W_g = GraphLinear(
25+
in_size=None, out_size=self.n_edge_types * hidden_channels * 2, nobias=True) # g in eq. (6)
26+
self.norm_layer = links.LayerNormalization() # l in eq. (6)
27+
28+
def forward(self, h, adj):
29+
# --- Message part ---
30+
31+
xp = self.xp
32+
mb, atom, ch = h.shape
33+
newshape = adj.shape + (ch, )
34+
adj = functions.broadcast_to(adj[:, :, :, :, xp.newaxis], newshape)
35+
messages = functions.reshape(self.W_linear(h),
36+
(mb, atom, ch, self.n_edge_types))
37+
messages = functions.transpose(messages, (3, 0, 1, 2))
38+
film_weights = functions.reshape(self.W_g(h),
39+
(mb, atom, 2 * ch, self.n_edge_types))
40+
film_weights = functions.transpose(film_weights, (3, 0, 1, 2))
41+
# (n_edge_types, minibatch, atom, out_ch)
42+
gamma = film_weights[:, :, :, :ch]
43+
# (n_edge_types, minibatch, atom, out_ch)
44+
beta = film_weights[:, :, :, ch:]
45+
46+
# --- Update part ---
47+
48+
messages = functions.expand_dims(gamma, axis=3) * functions.expand_dims(
49+
messages, axis=2) + functions.expand_dims(beta, axis=3)
50+
messages = self.activation(messages)
51+
# (minibatch, n_edge_types, atom, atom, out_ch)
52+
messages = functions.transpose(messages, (1, 0, 2, 3, 4))
53+
messages = adj * messages
54+
messages = functions.sum(messages, axis=3) # sum across atoms
55+
messages = functions.sum(messages, axis=1) # sum across n_edge_types
56+
messages = functions.reshape(messages, (mb * atom, ch))
57+
messages = self.norm_layer(messages)
58+
messages = functions.reshape(messages, (mb, atom, ch))
59+
return messages

chainer_chemistry/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from chainer_chemistry.models.rsgcn import RSGCN # NOQA
2323
from chainer_chemistry.models.schnet import SchNet # NOQA
2424
from chainer_chemistry.models.weavenet import WeaveNet # NOQA
25+
from chainer_chemistry.models.gnn_film import GNNFiLM # NOQA
2526

2627
from chainer_chemistry.models.gwm.gwm_net import GGNN_GWM # NOQA
2728
from chainer_chemistry.models.gwm.gwm_net import GIN_GWM # NOQA

chainer_chemistry/models/gnn_film.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import chainer
2+
from chainer import cuda
3+
from chainer import functions
4+
5+
from chainer_chemistry.config import MAX_ATOMIC_NUM
6+
from chainer_chemistry.links.connection.embed_atom_id import EmbedAtomID
7+
from chainer_chemistry.links.readout.ggnn_readout import GGNNReadout
8+
from chainer_chemistry.links.update.gnn_film_update import GNNFiLMUpdate
9+
10+
11+
class GNNFiLM(chainer.Chain):
12+
"""Graph Neural Networks with Feature-wise Linear Modulation (GNN_FiLM)
13+
14+
Marc Brockschmidt (2019).\
15+
GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation \
16+
`arXiv:1906.12192 <https://arxiv.org/abs/1906.12192>`_
17+
18+
Args:
19+
out_dim (int): dimension of output feature vector
20+
hidden_channels (int): dimension of feature vector
21+
associated to each atom
22+
n_update_layers (int): number of layers
23+
n_atom_types (int): number of types of atoms
24+
concat_hidden (bool): If set to True, readout is executed in each layer
25+
and the result is concatenated
26+
weight_tying (bool): enable weight_tying or not
27+
activation (~chainer.Function or ~chainer.FunctionNode):
28+
activate function
29+
n_edge_types (int): number of edge type.
30+
Defaults to 5 for single, double, triple, aromatic bond
31+
and self-connection.
32+
"""
33+
34+
def __init__(self, out_dim, hidden_channels=16, n_update_layers=4,
35+
n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False,
36+
weight_tying=True, activation=functions.identity,
37+
n_edge_types=5):
38+
super(GNNFiLM, self).__init__()
39+
n_readout_layer = n_update_layers if concat_hidden else 1
40+
n_message_layer = 1 if weight_tying else n_update_layers
41+
with self.init_scope():
42+
# Update
43+
self.embed = EmbedAtomID(out_size=hidden_channels,
44+
in_size=n_atom_types)
45+
self.update_layers = chainer.ChainList(*[GNNFiLMUpdate(
46+
hidden_channels=hidden_channels, n_edge_types=n_edge_types)
47+
for _ in range(n_message_layer)])
48+
# Readout
49+
# self.readout_layers = chainer.ChainList(*[GeneralReadout(
50+
# out_dim=out_dim, hidden_channels=hidden_channels,
51+
# activation=activation, activation_agg=activation)
52+
# for _ in range(n_readout_layer)])
53+
self.readout_layers = chainer.ChainList(*[GGNNReadout(
54+
out_dim=out_dim, in_channels=hidden_channels * 2,
55+
activation=activation, activation_agg=activation)
56+
for _ in range(n_readout_layer)])
57+
self.out_dim = out_dim
58+
self.hidden_channels = hidden_channels
59+
self.n_update_layers = n_update_layers
60+
self.n_edge_types = n_edge_types
61+
self.activation = activation
62+
self.concat_hidden = concat_hidden
63+
self.weight_tying = weight_tying
64+
65+
def __call__(self, atom_array, adj, is_real_node=None):
66+
"""Forward propagation
67+
68+
Args:
69+
atom_array (numpy.ndarray): minibatch of molecular which is
70+
represented with atom IDs (representing C, O, S, ...)
71+
`atom_array[mol_index, atom_index]` represents `mol_index`-th
72+
molecule's `atom_index`-th atomic number
73+
adj (numpy.ndarray): minibatch of adjancency matrix with edge-type
74+
information
75+
is_real_node (numpy.ndarray): 2-dim array (minibatch, num_nodes).
76+
1 for real node, 0 for virtual node.
77+
If `None`, all node is considered as real node.
78+
79+
Returns:
80+
~chainer.Variable: minibatch of fingerprint
81+
"""
82+
# reset state
83+
# self.reset_state()
84+
if atom_array.dtype == self.xp.int32:
85+
h = self.embed(atom_array) # (minibatch, max_num_atoms)
86+
else:
87+
h = atom_array
88+
h0 = functions.copy(h, cuda.get_device_from_array(h.data).id)
89+
g_list = []
90+
for step in range(self.n_update_layers):
91+
message_layer_index = 0 if self.weight_tying else step
92+
h = self.update_layers[message_layer_index](h, adj)
93+
if self.concat_hidden:
94+
g = self.readout_layers[step](h, h0, is_real_node)
95+
g_list.append(g)
96+
97+
if self.concat_hidden:
98+
return functions.concat(g_list, axis=1)
99+
else:
100+
g = self.readout_layers[0](h, h0, is_real_node)
101+
return g
102+
103+
def reset_state(self):
104+
[update_layer.reset_state() for update_layer in self.update_layers]

0 commit comments

Comments
 (0)