Skip to content

Commit 4466b70

Browse files
authored
Merge pull request #316 from corochann/relgcn_fix
use mlp for relgcn
2 parents e8dadcc + e5a394a commit 4466b70

File tree

4 files changed

+9
-8
lines changed

4 files changed

+9
-8
lines changed

examples/molnet/train_molnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ def set_up_predictor(method, n_unit, conv_layers, class_num):
134134
elif method == 'relgcn':
135135
print('Training an RelGCN predictor...')
136136
num_edge_type = 4
137-
relgcn = RelGCN(out_channels=class_num, num_edge_type=num_edge_type,
137+
relgcn = RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
138138
scale_adj=True)
139-
return GraphConvPredictor(relgcn, None)
139+
return GraphConvPredictor(relgcn, mlp)
140140
elif method == 'relgat':
141141
print('Train Relational GAT model...')
142142
relgat = RelGAT(out_dim=n_unit, hidden_dim=n_unit,

examples/own_dataset/train_own_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def set_up_predictor(method, n_unit, conv_layers, class_num):
169169
elif method == 'relgcn':
170170
print('Training an RelGCN predictor...')
171171
num_edge_type = 4
172-
relgcn = RelGCN(out_channels=class_num, num_edge_type=num_edge_type,
172+
relgcn = RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
173173
scale_adj=True)
174-
predictor = GraphConvPredictor(relgcn, None)
174+
predictor = GraphConvPredictor(relgcn, mlp)
175175
elif method == 'relgat':
176176
print('Training an RelGAT predictor...')
177177
relgat = RelGAT(out_dim=n_unit, hidden_dim=n_unit,

examples/qm9/train_qm9.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def set_up_predictor(method, n_unit, conv_layers, class_num, scaler):
173173
elif method == 'relgcn':
174174
print('Use Relational GCN predictor...')
175175
num_edge_type = 4
176-
relgcn = RelGCN(out_channels=class_num, num_edge_type=num_edge_type,
176+
relgcn = RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
177177
scale_adj=True)
178-
predictor = GraphConvPredictor(relgcn, None, scaler)
178+
predictor = GraphConvPredictor(relgcn, mlp, scaler)
179179
elif method == 'relgat':
180180
print('Train Relational GAT predictor...')
181181
relgat = RelGAT(out_dim=n_unit, hidden_dim=n_unit,

examples/tox21/predictor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ def build_predictor(method, n_unit, conv_layers, class_num):
4646
print('Use Relational GCN predictor...')
4747
num_edge_type = 4
4848
predictor = GraphConvPredictor(
49-
RelGCN(out_channels=class_num, num_edge_type=num_edge_type,
50-
scale_adj=True))
49+
RelGCN(out_channels=n_unit, num_edge_type=num_edge_type,
50+
scale_adj=True),
51+
MLP(out_dim=class_num, hidden_dim=n_unit))
5152
elif method == 'relgat':
5253
print('Use GAT predictor...')
5354
predictor = GraphConvPredictor(

0 commit comments

Comments
 (0)