Skip to content

Commit dcda27f

Browse files
authored
Merge pull request #137 from corochann/base_forward_model
[refactor] Introduce BaseForwardModel
2 parents 8a0ff2a + aa783eb commit dcda27f

File tree

4 files changed

+149
-170
lines changed

4 files changed

+149
-170
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from chainer_chemistry.models.prediction import base # NOQA
12
from chainer_chemistry.models.prediction import classifier # NOQA
23
from chainer_chemistry.models.prediction import regressor # NOQA
34

5+
from chainer_chemistry.models.prediction.base import BaseForwardModel # NOQA
46
from chainer_chemistry.models.prediction.classifier import Classifier # NOQA
57
from chainer_chemistry.models.prediction.regressor import Regressor # NOQA
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import chainer
2+
from chainer.dataset.convert import concat_examples
3+
from chainer import link, cuda
4+
from chainer.iterators import SerialIterator
5+
import numpy
6+
7+
8+
def _to_tuple(x):
9+
if not isinstance(x, tuple):
10+
x = (x,)
11+
return x
12+
13+
14+
def _extract_numpy(x):
15+
if isinstance(x, chainer.Variable):
16+
x = x.data
17+
return cuda.to_cpu(x)
18+
19+
20+
class BaseForwardModel(link.Chain):
21+
22+
"""A base model which supports _forward functionality.
23+
24+
It also supports `device` id management.
25+
26+
Args:
27+
device (int): GPU device id of this model to be used.
28+
-1 indicates to use in CPU.
29+
30+
Attributes:
31+
_device (int): Model's current device id
32+
33+
"""
34+
35+
def __init__(self):
36+
super(BaseForwardModel, self).__init__()
37+
38+
self.inputs = None
39+
self._device = None
40+
41+
def get_device(self):
42+
return self._device
43+
44+
def initialize(self, device=-1):
45+
"""Initialization of the model.
46+
47+
It must be executed **after** the link registration
48+
(often done by `with self.init_scope()` finished.
49+
50+
Args:
51+
device (int): GPU device id of this model to be used.
52+
-1 indicates to use in CPU.
53+
54+
"""
55+
self.update_device(device=device)
56+
57+
def update_device(self, device=-1):
58+
if self._device is None or self._device != device:
59+
# reset current state
60+
self.to_cpu()
61+
62+
# update the model to specified device id
63+
self._device = device
64+
if device >= 0:
65+
chainer.cuda.get_device_from_id(device).use()
66+
self.to_gpu() # Copy the model to the GPU
67+
68+
def _forward(self, data, fn, batchsize=16,
69+
converter=concat_examples, retain_inputs=False,
70+
preprocess_fn=None, postprocess_fn=None):
71+
"""Forward data by iterating with batch
72+
73+
Args:
74+
data: "train_x array" or "chainer dataset"
75+
fn (Callable): Main function to forward. Its input argument is
76+
either Variable, cupy.ndarray or numpy.ndarray, and returns
77+
Variable.
78+
batchsize (int): batch size
79+
converter (Callable): convert from `data` to `inputs`
80+
retain_inputs (bool): If True, this instance keeps inputs in
81+
`self.inputs` or not.
82+
preprocess_fn (Callable): Its input is numpy.ndarray or
83+
cupy.ndarray, it can return either Variable, cupy.ndarray or
84+
numpy.ndarray
85+
postprocess_fn (Callable): Its input argument is Variable,
86+
but this method may return either Variable, cupy.ndarray or
87+
numpy.ndarray.
88+
89+
Returns (tuple or numpy.ndarray): forward result
90+
91+
"""
92+
input_list = None
93+
output_list = None
94+
it = SerialIterator(data, batch_size=batchsize, repeat=False,
95+
shuffle=False)
96+
for batch in it:
97+
inputs = converter(batch, self._device)
98+
inputs = _to_tuple(inputs)
99+
100+
if preprocess_fn:
101+
inputs = preprocess_fn(*inputs)
102+
inputs = _to_tuple(inputs)
103+
104+
outputs = fn(*inputs)
105+
outputs = _to_tuple(outputs)
106+
107+
# Init
108+
if retain_inputs:
109+
if input_list is None:
110+
input_list = [[] for _ in range(len(inputs))]
111+
for j, input in enumerate(inputs):
112+
input_list[j].append(cuda.to_cpu(input))
113+
if output_list is None:
114+
output_list = [[] for _ in range(len(outputs))]
115+
116+
if postprocess_fn:
117+
outputs = postprocess_fn(*outputs)
118+
outputs = _to_tuple(outputs)
119+
for j, output in enumerate(outputs):
120+
output_list[j].append(_extract_numpy(output))
121+
122+
if retain_inputs:
123+
self.inputs = [numpy.concatenate(in_array) for in_array in input_list]
124+
125+
result = [numpy.concatenate(output) for output in output_list]
126+
if len(result) == 1:
127+
return result[0]
128+
else:
129+
return result

chainer_chemistry/models/prediction/classifier.py

Lines changed: 13 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,17 @@
44
from chainer.dataset.convert import concat_examples
55
from chainer.functions.evaluation import accuracy
66
from chainer.functions.loss import softmax_cross_entropy
7-
from chainer import link, cuda
87
from chainer import reporter
9-
from chainer.iterators import SerialIterator
10-
import numpy
118

12-
13-
def _to_tuple(x):
14-
if not isinstance(x, tuple):
15-
x = (x,)
16-
return x
17-
18-
19-
def _extract_numpy(x):
20-
if isinstance(x, chainer.Variable):
21-
x = x.data
22-
return cuda.to_cpu(x)
9+
from chainer_chemistry.models.prediction.base import BaseForwardModel
2310

2411

2512
def _argmax(*args):
2613
x = args[0]
2714
return chainer.functions.argmax(x, axis=1)
2815

2916

30-
class Classifier(link.Chain):
17+
class Classifier(BaseForwardModel):
3118

3219
"""A simple classifier model.
3320
@@ -55,6 +42,15 @@ class Classifier(link.Chain):
5542
compute_metrics (bool): If ``True``, compute metrics on the forward
5643
computation. The default value is ``True``.
5744
45+
.. note::
46+
The differences between original `Classifier` class in chainer and
47+
chainer chemistry are as follows.
48+
1. `predict` and `predict_proba` methods are supported.
49+
2. `device` can be managed internally by the `Classifier`
50+
3. `accfun` is deprecated, `metrics_fun` is used instead.
51+
4. `metrics_fun` can be `dict` which specifies the metrics name as key
52+
and function as value.
53+
5854
.. note::
5955
This link uses :func:`chainer.softmax_cross_entropy` with
6056
default arguments as a loss function (specified by ``lossfun``),
@@ -111,10 +107,8 @@ def __init__(self, predictor,
111107
with self.init_scope():
112108
self.predictor = predictor
113109

114-
self.device = device
115-
if device >= 0:
116-
chainer.cuda.get_device_from_id(device).use()
117-
self.to_gpu() # Copy the model to the GPU
110+
# `initialize` must be called after `init_scope`.
111+
self.initialize(device)
118112

119113
def __call__(self, *args, **kwargs):
120114
"""Computes the loss value for an input and label pair.
@@ -172,69 +166,6 @@ def __call__(self, *args, **kwargs):
172166
reporter.report(self.metrics, self)
173167
return self.loss
174168

175-
def _forward(self, data, fn, batchsize=16,
176-
converter=concat_examples, retain_inputs=False,
177-
preprocess_fn=None, postprocess_fn=None):
178-
"""Forward data by iterating with batch
179-
180-
Args:
181-
data: "train_x array" or "chainer dataset"
182-
fn (Callable): Main function to forward. Its input argument is
183-
either Variable, cupy.ndarray or numpy.ndarray, and returns
184-
Variable.
185-
batchsize (int): batch size
186-
converter (Callable): convert from `data` to `inputs`
187-
retain_inputs (bool): If True, this instance keeps inputs in
188-
`self.inputs` or not.
189-
preprocess_fn (Callable): Its input is numpy.ndarray or
190-
cupy.ndarray, it can return either Variable, cupy.ndarray or
191-
numpy.ndarray
192-
postprocess_fn (Callable): Its input argument is Variable,
193-
but this method may return either Variable, cupy.ndarray or
194-
numpy.ndarray.
195-
196-
Returns (tuple or numpy.ndarray): forward result
197-
198-
"""
199-
input_list = None
200-
output_list = None
201-
it = SerialIterator(data, batch_size=batchsize, repeat=False,
202-
shuffle=False)
203-
for batch in it:
204-
inputs = converter(batch, self.device)
205-
inputs = _to_tuple(inputs)
206-
207-
if preprocess_fn:
208-
inputs = preprocess_fn(*inputs)
209-
inputs = _to_tuple(inputs)
210-
211-
outputs = fn(*inputs)
212-
outputs = _to_tuple(outputs)
213-
214-
# Init
215-
if retain_inputs:
216-
if input_list is None:
217-
input_list = [[] for _ in range(len(inputs))]
218-
for j, input in enumerate(inputs):
219-
input_list[j].append(cuda.to_cpu(input))
220-
if output_list is None:
221-
output_list = [[] for _ in range(len(outputs))]
222-
223-
if postprocess_fn:
224-
outputs = postprocess_fn(*outputs)
225-
outputs = _to_tuple(outputs)
226-
for j, output in enumerate(outputs):
227-
output_list[j].append(_extract_numpy(output))
228-
229-
if retain_inputs:
230-
self.inputs = [numpy.concatenate(in_array) for in_array in input_list]
231-
232-
result = [numpy.concatenate(output) for output in output_list]
233-
if len(result) == 1:
234-
return result[0]
235-
else:
236-
return result
237-
238169
def predict_proba(
239170
self, data, batchsize=16, converter=concat_examples,
240171
retain_inputs=False, preprocess_fn=None,

chainer_chemistry/models/prediction/regressor.py

Lines changed: 5 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,11 @@
11
import chainer
22
from chainer.dataset.convert import concat_examples
3-
from chainer import link, cuda
43
from chainer import reporter
5-
from chainer.iterators import SerialIterator
6-
import numpy
74

5+
from chainer_chemistry.models.prediction.base import BaseForwardModel
86

9-
def _to_tuple(x):
10-
if not isinstance(x, tuple):
11-
x = (x,)
12-
return x
137

14-
15-
def _extract_numpy(x):
16-
if isinstance(x, chainer.Variable):
17-
x = x.data
18-
return cuda.to_cpu(x)
19-
20-
21-
def _argmax(*args):
22-
x = args[0]
23-
return chainer.functions.argmax(x, axis=1)
24-
25-
26-
class Regressor(link.Chain):
8+
class Regressor(BaseForwardModel):
279

2810
"""A simple regressor model.
2911
@@ -37,7 +19,7 @@ class Regressor(link.Chain):
3719
label_key (int or str): Key to specify label variable from arguments.
3820
When it is ``int``, a variable in positional arguments is used.
3921
And when it is ``str``, a variable in keyword arguments is used.
40-
device (int): GPU device id of this Classifier to be used.
22+
device (int): GPU device id of this Regressor to be used.
4123
-1 indicates to use in CPU.
4224
4325
Attributes:
@@ -80,10 +62,8 @@ def __init__(self, predictor,
8062
with self.init_scope():
8163
self.predictor = predictor
8264

83-
self.device = device
84-
if device >= 0:
85-
chainer.cuda.get_device_from_id(device).use()
86-
self.to_gpu() # Copy the model to the GPU
65+
# `initialize` must be called after `init_scope`.
66+
self.initialize(device)
8767

8868
def __call__(self, *args, **kwargs):
8969
"""Computes the loss value for an input and label pair.
@@ -142,69 +122,6 @@ def __call__(self, *args, **kwargs):
142122
reporter.report(self.metrics, self)
143123
return self.loss
144124

145-
def _forward(self, data, fn, batchsize=16,
146-
converter=concat_examples, retain_inputs=False,
147-
preprocess_fn=None, postprocess_fn=None):
148-
"""Forward data by iterating with batch
149-
150-
Args:
151-
data: "train_x array" or "chainer dataset"
152-
fn (Callable): Main function to forward. Its input argument is
153-
either Variable, cupy.ndarray or numpy.ndarray, and returns
154-
Variable.
155-
batchsize (int): batch size
156-
converter (Callable): convert from `data` to `inputs`
157-
retain_inputs (bool): If True, this instance keeps inputs in
158-
`self.inputs` or not.
159-
preprocess_fn (Callable): Its input is numpy.ndarray or
160-
cupy.ndarray, it can return either Variable, cupy.ndarray or
161-
numpy.ndarray
162-
postprocess_fn (Callable): Its input argument is Variable,
163-
but this method may return either Variable, cupy.ndarray or
164-
numpy.ndarray.
165-
166-
Returns (tuple or numpy.ndarray): forward result
167-
168-
"""
169-
input_list = None
170-
output_list = None
171-
it = SerialIterator(data, batch_size=batchsize, repeat=False,
172-
shuffle=False)
173-
for batch in it:
174-
inputs = converter(batch, self.device)
175-
inputs = _to_tuple(inputs)
176-
177-
if preprocess_fn:
178-
inputs = preprocess_fn(*inputs)
179-
inputs = _to_tuple(inputs)
180-
181-
outputs = fn(*inputs)
182-
outputs = _to_tuple(outputs)
183-
184-
# Init
185-
if retain_inputs:
186-
if input_list is None:
187-
input_list = [[] for _ in range(len(inputs))]
188-
for j, input in enumerate(inputs):
189-
input_list[j].append(cuda.to_cpu(input))
190-
if output_list is None:
191-
output_list = [[] for _ in range(len(outputs))]
192-
193-
if postprocess_fn:
194-
outputs = postprocess_fn(*outputs)
195-
outputs = _to_tuple(outputs)
196-
for j, output in enumerate(outputs):
197-
output_list[j].append(_extract_numpy(output))
198-
199-
if retain_inputs:
200-
self.inputs = [numpy.concatenate(in_array) for in_array in input_list]
201-
202-
result = [numpy.concatenate(output) for output in output_list]
203-
if len(result) == 1:
204-
return result[0]
205-
else:
206-
return result
207-
208125
def predict(
209126
self, data, batchsize=16, converter=concat_examples,
210127
retain_inputs=False, preprocess_fn=None, postprocess_fn=None):

0 commit comments

Comments
 (0)