Skip to content

Commit 8e1a578

Browse files
authored
Merge pull request #283 from corochann/saliency
Saliency
2 parents 084d0e8 + 06300bf commit 8e1a578

26 files changed

+1927
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
try:
2+
from chainer_chemistry.link_hooks import variable_monitor_link_hook # NOQA
3+
4+
from chainer_chemistry.link_hooks.variable_monitor_link_hook import VariableMonitorLinkHook # NOQA
5+
is_link_hooks_available = True
6+
except ImportError:
7+
import warnings
8+
warnings.warn('link_hooks failed to import, you need to upgrade chainer '
9+
'version to use link_hooks feature')
10+
is_link_hooks_available = False
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from collections import OrderedDict
2+
from logging import getLogger
3+
4+
import chainer
5+
from chainer.link_hook import _ForwardPostprocessCallbackArgs, _ForwardPreprocessCallbackArgs # NOQA
6+
7+
8+
def _default_extract_pre(hook, args):
9+
"""Default extract_fn when `timing='pre`
10+
11+
Args:
12+
hook (VariableMonitorLinkHook):
13+
args (_ForwardPreprocessCallbackArgs):
14+
15+
Returns (chainer.Variable): First input variable to the link.
16+
"""
17+
return args.args[0]
18+
19+
20+
def _default_extract_post(hook, args):
21+
"""Default extract_fn when `timing='post`
22+
23+
Args:
24+
hook (VariableMonitorLinkHook):
25+
args (_ForwardPostprocessCallbackArgs):
26+
27+
Returns (chainer.Variable): Output variable to the link.
28+
"""
29+
return args.out
30+
31+
32+
class VariableMonitorLinkHook(chainer.LinkHook):
33+
"""Monitor Variable of specific link input/output
34+
35+
Args:
36+
target_link (chainer.Link): target link to monitor variable.
37+
name (str): name of this link hook
38+
timing (str): timing of this link hook to monitor. 'pre' or 'post'.
39+
If 'pre', the input of `target_link` is monitored.
40+
If 'post', the output of `target_link` is monitored.
41+
extract_fn (callable): Specify custom method to extract target variable
42+
Default behavior is to extract first input when `timing='pre'`,
43+
or extract output when `timing='post'`.
44+
It takes `hook, args` as argument.
45+
logger:
46+
47+
.. admonition:: Example
48+
49+
>>> import numpy
50+
>>> from chainer import cuda, links, functions # NOQA
51+
>>> from chainer_chemistry.link_hooks.variable_monitor_link_hook import VariableMonitorLinkHook # NOQA
52+
53+
>>> class DummyModel(chainer.Chain):
54+
>>> def __init__(self):
55+
>>> super(DummyModel, self).__init__()
56+
>>> with self.init_scope():
57+
>>> self.l1 = links.Linear(None, 1)
58+
>>> self.h = None
59+
>>>
60+
>>> def forward(self, x):
61+
>>> h = self.l1(x)
62+
>>> out = functions.sigmoid(h)
63+
>>> return out
64+
65+
>>> model = DummyModel()
66+
>>> hook = VariableMonitorLinkHook(model.l1, timing='post')
67+
>>> x = numpy.array([1, 2, 3])
68+
69+
>>> # Example 1. `get_variable` of `target_link`.
70+
>>> with hook:
71+
>>> out = model(x)
72+
>>> # You can extract `h`, which is output of `model.l1` as follows.
73+
>>> var_h = hook.get_variable()
74+
75+
>>> # Example 2. `add_process` to override value of target variable.
76+
>>> def _process_zeros(hook, args, target_var):
77+
>>> xp = cuda.get_array_module(target_var.array)
78+
>>> target_var.array = xp.zeros(target_var.array.shape)
79+
>>> hook.add_process('_process_zeros', _process_zeros)
80+
>>> with hook:
81+
>>> # During the forward, `h` is overriden to value 0.
82+
>>> out = model(x)
83+
>>> # Remove _process_zeros method
84+
>>> hook.delete_process('_process_zeros')
85+
"""
86+
87+
def __init__(self, target_link, name='VariableMonitorLinkHook',
88+
timing='post', extract_fn=None, logger=None):
89+
if not isinstance(target_link, chainer.Link):
90+
raise TypeError('target_link must be instance of chainer.Link!'
91+
'actual {}'.format(type(target_link)))
92+
if timing not in ['pre', 'post']:
93+
raise ValueError(
94+
"Unexpected value timing={}, "
95+
"must be either pre or post"
96+
.format(timing))
97+
super(VariableMonitorLinkHook, self).__init__()
98+
self.target_link = target_link
99+
100+
# This LinkHook maybe instantiated multiple times.
101+
# So it is allowed to change name by argument.
102+
self.name = name
103+
self.logger = logger or getLogger(__name__)
104+
105+
if extract_fn is None:
106+
if timing == 'pre':
107+
extract_fn = _default_extract_pre
108+
elif timing == 'post':
109+
extract_fn = _default_extract_post
110+
else:
111+
raise ValueError("Unexpected value timing={}"
112+
.format(timing))
113+
self.extract_fn = extract_fn
114+
self.process_fns = OrderedDict() # Additional process, if necessary
115+
116+
self.timing = timing
117+
self.result = None
118+
119+
def add_process(self, key, fn):
120+
"""Add additional process for target variable
121+
122+
Args:
123+
key (str): id for this process, you may remove added process by
124+
`delete_process` with this key.
125+
fn (callable): function which takes `hook, args, target_var` as
126+
arguments.
127+
"""
128+
if not isinstance(key, str):
129+
raise TypeError('key must be str, actual {}'.format(type(key)))
130+
if not callable(fn):
131+
raise TypeError('fn must be callable')
132+
self.process_fns[key] = fn
133+
134+
def delete_process(self, key):
135+
"""Delete process added at `add_process`
136+
137+
Args:
138+
key (str): id for the process, named at `add_process`.
139+
"""
140+
if not isinstance(key, str):
141+
raise TypeError('key must be str, actual {}'.format(type(key)))
142+
if key in self.process_fns.keys():
143+
del self.process_fns[key]
144+
else:
145+
# Nothing to delete
146+
self.logger.warning('{} is not in process_fns, skip delete_process'
147+
.format(key))
148+
149+
def get_variable(self):
150+
"""Get target variable, which is input or output of `target_link`.
151+
152+
Returns (chainer.Variable): target variable
153+
"""
154+
return self.result
155+
156+
def forward_preprocess(self, args):
157+
if self.timing == 'pre' and args.link is self.target_link:
158+
self.result = self.extract_fn(self, args)
159+
if self.process_fns is not None:
160+
for key, fn in self.process_fns.items():
161+
fn(self, args, self.result)
162+
163+
def forward_postprocess(self, args):
164+
if self.timing == 'post' and args.link is self.target_link:
165+
self.result = self.extract_fn(self, args)
166+
if self.process_fns is not None:
167+
for key, fn in self.process_fns.items():
168+
fn(self, args, self.result)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from chainer_chemistry.saliency import calculator # NOQA
2+
from chainer_chemistry.saliency import visualizer # NOQA
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from chainer_chemistry.saliency.calculator import base_calculator # NOQA
2+
from chainer_chemistry.saliency.calculator import calculator_utils # NOQA
3+
from chainer_chemistry.saliency.calculator import gradient_calculator # NOQA
4+
from chainer_chemistry.saliency.calculator import integrated_gradients_calculator # NOQA
5+
from chainer_chemistry.saliency.calculator import occlusion_calculator # NOQA
6+
7+
from chainer_chemistry.saliency.calculator.base_calculator import BaseCalculator # NOQA
8+
from chainer_chemistry.saliency.calculator.gradient_calculator import GradientCalculator # NOQA
9+
from chainer_chemistry.saliency.calculator.integrated_gradients_calculator import IntegratedGradientsCalculator # NOQA
10+
from chainer_chemistry.saliency.calculator.occlusion_calculator import OcclusionCalculator # NOQA
11+
12+
from chainer_chemistry.saliency.calculator.calculator_utils import GaussianNoiseSampler # NOQA

0 commit comments

Comments
 (0)