9
9
from chainer .iterators import SerialIterator
10
10
from chainer .training .extensions import Evaluator
11
11
import chainer .functions as F
12
- from chainer_chemistry .training .extensions .roc_auc_evaluator import \
13
- ROCAUCEvaluator
14
12
from rdkit import RDLogger
15
13
import six
16
14
17
15
from chainer_chemistry import datasets as D
16
+ from chainer_chemistry .dataset .converters import concat_mols
17
+ from chainer_chemistry .models .prediction import Classifier
18
+ from chainer_chemistry .training .extensions .roc_auc_evaluator import \
19
+ ROCAUCEvaluator
18
20
19
21
import data
20
22
import predictor
21
23
22
24
23
25
# Disable errors by RDKit occurred in preprocessing Tox21 dataset.
24
- from chainer_chemistry .models .prediction import Classifier
25
26
26
27
lg = RDLogger .logger ()
27
28
lg .setLevel (RDLogger .CRITICAL )
@@ -49,7 +50,7 @@ def main():
49
50
label_names = D .get_tox21_label_names ()
50
51
51
52
parser = argparse .ArgumentParser (
52
- description = 'Inference with a trained model.' )
53
+ description = 'Predict with a trained model.' )
53
54
parser .add_argument ('--in-dir' , '-i' , type = str , default = 'result' ,
54
55
help = 'Path to the result directory of the training '
55
56
'script.' )
@@ -81,8 +82,6 @@ def main():
81
82
class_num = len (label_names )
82
83
83
84
_ , test , _ = data .load_dataset (method , labels )
84
- # test = test.get_datasets()
85
- X_test = D .NumpyTupleDataset (* test .get_datasets ()[:- 1 ])
86
85
y_test = test .get_datasets ()[- 1 ]
87
86
88
87
# Load pretrained model
@@ -95,30 +94,30 @@ def main():
95
94
chainer .serializers .load_npz (snapshot_file ,
96
95
predictor_ , 'updater/model:main/predictor/' )
97
96
98
- # if args.gpu >= 0:
99
- # chainer.cuda.get_device_from_id(args.gpu).use()
100
- # predictor_.to_gpu()
101
-
102
- # inference_loop = predictor.InferenceLoop(predictor_)
103
- # y_pred = inference_loop.inference(X_test)
104
-
105
- from chainer_chemistry .dataset .converters import concat_mols
106
97
clf = Classifier (predictor = predictor_ , device = args .gpu ,
107
98
lossfun = F .sigmoid_cross_entropy ,
108
- metrics_fun = { 'binary_accuracy' : F .binary_accuracy } )
99
+ metrics_fun = F .binary_accuracy )
109
100
110
101
# ---- predict ---
111
102
print ('Predicting...' )
112
103
104
+ def extract_inputs (batch , device = None ):
105
+ return concat_mols (batch , device = device )[:- 1 ]
106
+
113
107
def postprocess_pred (x ):
114
108
x_array = cuda .to_cpu (x .data )
115
109
return numpy .where (x_array > 0 , 1 , 0 )
116
- y_pred = clf .predict (X_test , converter = concat_mols ,
110
+ y_pred = clf .predict (test , converter = extract_inputs ,
117
111
postprocess_fn = postprocess_pred )
118
- y_proba = clf .predict_proba (X_test , converter = concat_mols ,
112
+ y_proba = clf .predict_proba (test , converter = extract_inputs ,
119
113
postprocess_fn = F .sigmoid )
120
- print ('y_pred[:5]' , y_pred [:5 , 0 ])
121
- print ('y_proba[:5]' , y_proba [:5 , 0 ])
114
+
115
+ # `predict` method returns the prediction label (0: non-toxic, 1:toxic)
116
+ print ('y_pread.shape = {}, y_pred[:5, 0] = {}'
117
+ .format (y_pred .shape , y_pred [:5 , 0 ]))
118
+ # `predict_proba` method returns the probability to be toxic
119
+ print ('y_proba.shape = {}, y_proba[:5, 0] = {}'
120
+ .format (y_proba .shape , y_proba [:5 , 0 ]))
122
121
# --- predict end ---
123
122
124
123
if y_pred .ndim == 1 :
@@ -164,20 +163,3 @@ def postprocess_pred(x):
164
163
165
164
if __name__ == '__main__' :
166
165
main ()
167
-
168
- """
169
- python train_tox21.py --method nfp --label NR-AR --conv-layers 3 --gpu 0 --epoch 10 --unit-num 32 --out nfp_predict
170
- python inference_tox21.py --in-dir nfp_predict --gpu 0
171
-
172
- python predict_tox21.py --in-dir nfp_predict --gpu 0
173
-
174
- Predicting...
175
- y_pred[:5] [0 0 0 0 0]
176
- y_proba[:5] [0.01780733 0.18086143 0.05626253 0.02249673 0.01841126]
177
- TaskID Correct Total Accuracy
178
- task 0 284 291 0.9759
179
- Save prediction result to prediction.npz
180
- Evaluating...
181
- Evaluation result: {'main/loss': array(0.12492516, dtype=float32), 'main/binary_accuracy': array(0.9725251, dtype=float32)}
182
- ROCAUC Evaluation result: {'test/main/roc_auc': 0.33796296296296297}
183
- """
0 commit comments