Skip to content

Commit 9c213b9

Browse files
authored
Merge pull request #125 from corochann/classifier_example
Classifier example
2 parents 07e57ef + 3841c7c commit 9c213b9

File tree

1 file changed

+18
-36
lines changed

1 file changed

+18
-36
lines changed

examples/tox21/predict_tox21.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,20 @@
99
from chainer.iterators import SerialIterator
1010
from chainer.training.extensions import Evaluator
1111
import chainer.functions as F
12-
from chainer_chemistry.training.extensions.roc_auc_evaluator import \
13-
ROCAUCEvaluator
1412
from rdkit import RDLogger
1513
import six
1614

1715
from chainer_chemistry import datasets as D
16+
from chainer_chemistry.dataset.converters import concat_mols
17+
from chainer_chemistry.models.prediction import Classifier
18+
from chainer_chemistry.training.extensions.roc_auc_evaluator import \
19+
ROCAUCEvaluator
1820

1921
import data
2022
import predictor
2123

2224

2325
# Disable errors by RDKit occurred in preprocessing Tox21 dataset.
24-
from chainer_chemistry.models.prediction import Classifier
2526

2627
lg = RDLogger.logger()
2728
lg.setLevel(RDLogger.CRITICAL)
@@ -49,7 +50,7 @@ def main():
4950
label_names = D.get_tox21_label_names()
5051

5152
parser = argparse.ArgumentParser(
52-
description='Inference with a trained model.')
53+
description='Predict with a trained model.')
5354
parser.add_argument('--in-dir', '-i', type=str, default='result',
5455
help='Path to the result directory of the training '
5556
'script.')
@@ -81,8 +82,6 @@ def main():
8182
class_num = len(label_names)
8283

8384
_, test, _ = data.load_dataset(method, labels)
84-
# test = test.get_datasets()
85-
X_test = D.NumpyTupleDataset(*test.get_datasets()[:-1])
8685
y_test = test.get_datasets()[-1]
8786

8887
# Load pretrained model
@@ -95,30 +94,30 @@ def main():
9594
chainer.serializers.load_npz(snapshot_file,
9695
predictor_, 'updater/model:main/predictor/')
9796

98-
# if args.gpu >= 0:
99-
# chainer.cuda.get_device_from_id(args.gpu).use()
100-
# predictor_.to_gpu()
101-
102-
# inference_loop = predictor.InferenceLoop(predictor_)
103-
# y_pred = inference_loop.inference(X_test)
104-
105-
from chainer_chemistry.dataset.converters import concat_mols
10697
clf = Classifier(predictor=predictor_, device=args.gpu,
10798
lossfun=F.sigmoid_cross_entropy,
108-
metrics_fun={'binary_accuracy': F.binary_accuracy})
99+
metrics_fun=F.binary_accuracy)
109100

110101
# ---- predict ---
111102
print('Predicting...')
112103

104+
def extract_inputs(batch, device=None):
105+
return concat_mols(batch, device=device)[:-1]
106+
113107
def postprocess_pred(x):
114108
x_array = cuda.to_cpu(x.data)
115109
return numpy.where(x_array > 0, 1, 0)
116-
y_pred = clf.predict(X_test, converter=concat_mols,
110+
y_pred = clf.predict(test, converter=extract_inputs,
117111
postprocess_fn=postprocess_pred)
118-
y_proba = clf.predict_proba(X_test, converter=concat_mols,
112+
y_proba = clf.predict_proba(test, converter=extract_inputs,
119113
postprocess_fn=F.sigmoid)
120-
print('y_pred[:5]', y_pred[:5, 0])
121-
print('y_proba[:5]', y_proba[:5, 0])
114+
115+
# `predict` method returns the prediction label (0: non-toxic, 1:toxic)
116+
print('y_pread.shape = {}, y_pred[:5, 0] = {}'
117+
.format(y_pred.shape, y_pred[:5, 0]))
118+
# `predict_proba` method returns the probability to be toxic
119+
print('y_proba.shape = {}, y_proba[:5, 0] = {}'
120+
.format(y_proba.shape, y_proba[:5, 0]))
122121
# --- predict end ---
123122

124123
if y_pred.ndim == 1:
@@ -164,20 +163,3 @@ def postprocess_pred(x):
164163

165164
if __name__ == '__main__':
166165
main()
167-
168-
"""
169-
python train_tox21.py --method nfp --label NR-AR --conv-layers 3 --gpu 0 --epoch 10 --unit-num 32 --out nfp_predict
170-
python inference_tox21.py --in-dir nfp_predict --gpu 0
171-
172-
python predict_tox21.py --in-dir nfp_predict --gpu 0
173-
174-
Predicting...
175-
y_pred[:5] [0 0 0 0 0]
176-
y_proba[:5] [0.01780733 0.18086143 0.05626253 0.02249673 0.01841126]
177-
TaskID Correct Total Accuracy
178-
task 0 284 291 0.9759
179-
Save prediction result to prediction.npz
180-
Evaluating...
181-
Evaluation result: {'main/loss': array(0.12492516, dtype=float32), 'main/binary_accuracy': array(0.9725251, dtype=float32)}
182-
ROCAUC Evaluation result: {'test/main/roc_auc': 0.33796296296296297}
183-
"""

0 commit comments

Comments
 (0)