Skip to content

Scale output labels in the QM9 example and refactor code. #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions examples/qm9/evaluate_models_qm9.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@ gpu=${2:--1}

for method in ${methods[@]}
do
result_dir=${prefix}${method}
python train_qm9.py --method ${method} --gpu ${gpu} --out ${result_dir} --epoch ${epoch}
python predict_qm9.py --in-dir ${result_dir} --method ${method}
result_dir=${prefix}${method}

python train_qm9.py \
--method ${method} \
--gpu ${gpu} \
--out ${result_dir} \
--epoch ${epoch}

python predict_qm9.py \
--in-dir ${result_dir} \
--method ${method}
done

python plot.py --prefix ${prefix} --methods ${methods[@]}
165 changes: 97 additions & 68 deletions examples/qm9/predict_qm9.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#!/usr/bin/env python

from __future__ import print_function

import argparse
import json
import numpy
import os
import pandas
import pickle

from chainer.iterators import SerialIterator
from chainer.training.extensions import Evaluator
import pandas

try:
import matplotlib
Expand All @@ -19,7 +20,6 @@
from chainer import cuda
from chainer.datasets import split_dataset_random
from chainer import Variable
import numpy # NOQA

from chainer_chemistry.dataset.converters import concat_mols
from chainer_chemistry.dataset.preprocessors import preprocess_method_dict
Expand All @@ -33,133 +33,162 @@
from train_qm9 import MeanAbsError, RootMeanSqrError # NOQA


def main():
# Supported preprocessing/network list
class ScaledGraphConvPredictor(GraphConvPredictor):
def __init__(self, *args, **kwargs):
"""Initializes the (scaled) graph convolution predictor. This uses
a standard scaler to rescale the predicted labels.
"""
super(ScaledGraphConvPredictor, self).__init__(*args, **kwargs)

def __call__(self, atoms, adjs):
h = super(ScaledGraphConvPredictor, self).__call__(atoms, adjs)
scaler_available = hasattr(self, 'scaler')
numpy_data = isinstance(h.data, numpy.ndarray)

if scaler_available:
h = self.scaler.inverse_transform(cuda.to_cpu(h.data))
if not numpy_data:
h = cuda.to_gpu(h)
return Variable(h)


def parse_arguments():
# Lists of supported preprocessing methods/models.
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn']
label_names = ['A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2',
'zpve', 'U0', 'U', 'H', 'G', 'Cv']
scale_list = ['standardize', 'none']

parser = argparse.ArgumentParser(
description='Regression with QM9.')
# Set up the argument parser.
parser = argparse.ArgumentParser(description='Regression on QM9.')
parser.add_argument('--method', '-m', type=str, choices=method_list,
default='nfp')
help='method name', default='nfp')
parser.add_argument('--label', '-l', type=str, choices=label_names,
default='', help='target label for regression, '
'empty string means to predict all '
'property at once')
default='',
help='target label for regression; empty string means '
'predicting all properties at once')
parser.add_argument('--scale', type=str, choices=scale_list,
default='standardize', help='Label scaling method')
parser.add_argument('--batchsize', '-b', type=int, default=32)
parser.add_argument('--gpu', '-g', type=int, default=-1)
parser.add_argument('--in-dir', '-i', type=str, default='result')
parser.add_argument('--seed', '-s', type=int, default=777)
parser.add_argument('--train-data-ratio', '-t', type=float, default=0.7)
parser.add_argument('--model-filename', type=str, default='regressor.pkl')
help='label scaling method', default='standardize')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='id of gpu to use; negative value means running'
'the code on cpu')
parser.add_argument('--seed', '-s', type=int, default=777,
help='random seed value')
parser.add_argument('--train-data-ratio', '-r', type=float, default=0.7,
help='ratio of training data w.r.t the dataset')
parser.add_argument('--in-dir', '-i', type=str, default='result',
help='directory to load model data from')
parser.add_argument('--model-filename', type=str, default='regressor.pkl',
help='saved model filename')
parser.add_argument('--num-data', type=int, default=-1,
help='Number of data to be parsed from parser.'
'-1 indicates to parse all data.')
args = parser.parse_args()
help='amount of data to be parsed; -1 indicates '
'parsing all data.')
return parser.parse_args()



seed = args.seed
train_data_ratio = args.train_data_ratio
def main():
# Parse the arguments.
args = parse_arguments()

# Set up some useful variables that will be used later on.
method = args.method
if args.label:
labels = args.label
cache_dir = os.path.join('input', '{}_{}'.format(method, labels))
# class_num = len(labels) if isinstance(labels, list) else 1
else:
labels = D.get_qm9_label_names()
cache_dir = os.path.join('input', '{}_all'.format(method))
# class_num = len(labels)

# Dataset preparation
dataset = None

# Get the filename corresponding to the cached dataset, based on the amount
# of data samples that need to be parsed from the original dataset.
num_data = args.num_data
if num_data >= 0:
dataset_filename = 'data_{}.npz'.format(num_data)
else:
dataset_filename = 'data.npz'

# Load the cached dataset.
dataset_cache_path = os.path.join(cache_dir, dataset_filename)

dataset = None
if os.path.exists(dataset_cache_path):
print('load from cache {}'.format(dataset_cache_path))
print('Loading cached data from {}.'.format(dataset_cache_path))
dataset = NumpyTupleDataset.load(dataset_cache_path)
if dataset is None:
print('preprocessing dataset...')
print('Preprocessing dataset...')
preprocessor = preprocess_method_dict[method]()
dataset = D.get_qm9(preprocessor, labels=labels)

# Cache the newly preprocessed dataset.
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
NumpyTupleDataset.save(dataset_cache_path, dataset)

# Load the standard scaler parameters, if necessary.
if args.scale == 'standardize':
# Standard Scaler for labels
with open(os.path.join(args.in_dir, 'ss.pkl'), mode='rb') as f:
ss = pickle.load(f)
scaler_path = os.path.join(args.in_dir, 'scaler.pkl')
print('Loading scaler parameters from {}.'.format(scaler_path))
with open(scaler_path, mode='rb') as f:
scaler = pickle.load(f)
else:
ss = None
print('No standard scaling was selected.')
scaler = None

# Split the dataset into training and testing.
train_data_size = int(len(dataset) * args.train_data_ratio)
_, test = split_dataset_random(dataset, train_data_size, args.seed)

train_data_size = int(len(dataset) * train_data_ratio)
train, test = split_dataset_random(dataset, train_data_size, seed)
# Use a predictor with scaled output labels.
model_path = os.path.join(args.in_dir, args.model_filename)
regressor = Regressor.load_pickle(model_path, device=args.gpu)

regressor = Regressor.load_pickle(
os.path.join(args.in_dir, args.model_filename), device=args.gpu)
# Replace the default predictor with one that scales the output labels.
scaled_predictor = ScaledGraphConvPredictor(regressor.predictor)
scaled_predictor.scaler = scaler
regressor.predictor = scaled_predictor

# We need to feed only input features `x` to `predict`/`predict_proba`.
# This converter extracts only inputs (x1, x2, ...) from the features which
# consist of input `x` and label `t` (x1, x2, ..., t).
# This callback function extracts only the inputs and discards the labels.
def extract_inputs(batch, device=None):
return concat_mols(batch, device=device)[:-1]

def postprocess_fn(x):
if ss is not None:
# Model's output is scaled by StandardScaler,
# so we need to rescale back.
if isinstance(x, Variable):
x = x.data
scaled_x = ss.inverse_transform(cuda.to_cpu(x))
return scaled_x
else:
return x

# Predict the output labels.
print('Predicting...')
y_pred = regressor.predict(test, converter=extract_inputs,
postprocess_fn=postprocess_fn)

print('y_pred.shape = {}, y_pred[:5, 0] = {}'
.format(y_pred.shape, y_pred[:5, 0]))
y_pred = regressor.predict(test, converter=extract_inputs)

# Extract the ground-truth labels.
t = concat_mols(test, device=-1)[-1]
n_eval = 10

# Construct dataframe
# Construct dataframe.
df_dict = {}
for i, l in enumerate(labels):
df_dict.update({
'y_pred_{}'.format(l): y_pred[:, i],
't_{}'.format(l): t[:, i],
})
df_dict.update({'y_pred_{}'.format(l): y_pred[:, i],
't_{}'.format(l): t[:, i],})
df = pandas.DataFrame(df_dict)

# Show random 5 example's prediction/ground truth table
# Show a prediction/ground truth table with 5 random examples.
print(df.sample(5))

for target_label in range(y_pred.shape[1]):
diff = y_pred[:n_eval, target_label] - t[:n_eval, target_label]
print('target_label = {}, y_pred = {}, t = {}, diff = {}'
.format(target_label, y_pred[:n_eval, target_label],
t[:n_eval, target_label], diff))

# --- evaluate ---
# To calc loss/accuracy, we can use `Evaluator`, `ROCAUCEvaluator`
# Run an evaluator on the test dataset.
print('Evaluating...')
test_iterator = SerialIterator(test, 16, repeat=False, shuffle=False)
eval_result = Evaluator(
test_iterator, regressor, converter=concat_mols, device=args.gpu)()
eval_result = Evaluator(test_iterator, regressor, converter=concat_mols,
device=args.gpu)()

# Prevents the loss function from becoming a cupy.core.core.ndarray object
# when using the GPU. This hack will be removed as soon as the cause of
# the issue is found and properly fixed.
loss = numpy.asscalar(cuda.to_cpu(eval_result['main/loss']))
eval_result['main/loss'] = loss
print('Evaluation result: ', eval_result)

# Save the evaluation results.
with open(os.path.join(args.in_dir, 'eval_result.json'), 'w') as f:
json.dump(eval_result, f)

Expand Down
51 changes: 40 additions & 11 deletions examples/qm9/test_qm9.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,47 @@

set -e

# gpu id given from first argument, default value is -1
gpu=${1:--1}
# List of available graph convolution methods.
methods=(nfp ggnn schnet weavenet rsgcn)
# Number of training epochs (default: 1).
epoch=${1:-1}
# GPU identifier; set it to -1 to train on the CPU (default).
gpu=${2:--1}

for method in nfp ggnn schnet weavenet rsgcn
for method in ${methods[@]}
do
# QM9
if [ ! -f "input" ]; then
rm -rf input
fi
# Remove any previously cached models.
[ -d "input" ] && rm -rf input

python train_qm9.py --method ${method} --label A --conv-layers 1 --gpu ${gpu} --epoch 1 --unit-num 10 --batchsize 32 --num-data 100
python predict_qm9.py --method ${method} --label A --gpu ${gpu} --batchsize 32 --num-data 100
python train_qm9.py --method ${method} --conv-layers 1 --gpu ${gpu} --epoch 1 --unit-num 10 --batchsize 32 --num-data 100
python predict_qm9.py --method ${method} --gpu ${gpu} --batchsize 32 --num-data 100
# Train with the current method (one label).
python train_qm9.py \
--method ${method} \
--label A \
--conv-layers 1 \
--gpu ${gpu} \
--epoch ${epoch} \
--unit-num 10 \
--num-data 100

# Predict with the current method (one label).
python predict_qm9.py \
--method ${method} \
--label A \
--gpu ${gpu} \
--num-data 100

# Train with the current method (all labels).
python train_qm9.py \
--method ${method} \
--conv-layers 1 \
--gpu ${gpu} \
--epoch ${epoch} \
--unit-num 10 \
--num-data 100

# Predict with the current method (all labels).
python predict_qm9.py \
--method ${method} \
--gpu ${gpu} \
--num-data 100
done
Loading