-
Notifications
You must be signed in to change notification settings - Fork 131
Saliency #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Saliency #283
Changes from 17 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
03ab394
copy from chainer-saliency
024c91a
rename
7ffa8ec
__init__ import
6f86870
refactoring
cc2f427
test: visualizer.common
7160031
test visualizer & docstring
b38efd0
tests for calculator & bug fix.
9a781fc
test variable monitor link hook
3674f6b
Always set target_extractor
13f63d0
add docstring
f44faaa
skip test when linkhook is not avaiable
0dd27e9
fix for tests
af116e2
update
3460170
skip ipython import with showing error message
ee9e93e
str(tmpdir)
1e06437
citations
c3d7c73
cairosvg error skip with message
1edc955
save as svg
9a9b91d
update
5073975
update
121b130
update
77d443b
skip tests with matplotlib dependency when python version=2
e81c0fb
update for comments
170ef17
show_progress option, show warning when target_var is None.
99cc954
plt.show() to not stop during test
4567d84
rename common to xxx_utils
b9e0a13
Show logger log for saliency_array == 0
7a6ee77
revert to support inputs as target_var
06300bf
show error msg
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
try: | ||
from chainer_chemistry.link_hooks import variable_monitor_link_hook # NOQA | ||
|
||
from chainer_chemistry.link_hooks.variable_monitor_link_hook import VariableMonitorLinkHook # NOQA | ||
is_link_hooks_available = True | ||
except ImportError: | ||
import warnings | ||
warnings.warn('link_hooks failed to import, you need to upgrade chainer ' | ||
'version to use link_hooks feature') | ||
is_link_hooks_available = False |
168 changes: 168 additions & 0 deletions
168
chainer_chemistry/link_hooks/variable_monitor_link_hook.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from collections import OrderedDict | ||
from logging import getLogger | ||
|
||
import chainer | ||
from chainer.link_hook import _ForwardPostprocessCallbackArgs, _ForwardPreprocessCallbackArgs # NOQA | ||
|
||
|
||
def _default_extract_pre(hook, args): | ||
"""Default extract_fn when `timing='pre` | ||
|
||
Args: | ||
hook (VariableMonitorLinkHook): | ||
args (_ForwardPreprocessCallbackArgs): | ||
|
||
Returns (chainer.Variable): First input variable to the link. | ||
""" | ||
return args.args[0] | ||
|
||
|
||
def _default_extract_post(hook, args): | ||
"""Default extract_fn when `timing='post` | ||
|
||
Args: | ||
hook (VariableMonitorLinkHook): | ||
args (_ForwardPostprocessCallbackArgs): | ||
|
||
Returns (chainer.Variable): Output variable to the link. | ||
""" | ||
return args.out | ||
|
||
|
||
class VariableMonitorLinkHook(chainer.LinkHook): | ||
"""Monitor Variable of specific link input/output | ||
|
||
Args: | ||
target_link (chainer.Link): target link to monitor variable. | ||
name (str): name of this link hook | ||
timing (str): timing of this link hook to monitor. 'pre' or 'post'. | ||
If 'pre', the input of `target_link` is monitored. | ||
If 'post', the output of `target_link` is monitored. | ||
extract_fn (callable): Specify custom method to extract target variable | ||
Default behavior is to extract first input when `timing='pre'`, | ||
or extract output when `timing='post'`. | ||
It takes `hook, args` as argument. | ||
logger: | ||
|
||
.. admonition:: Example | ||
|
||
>>> import numpy | ||
>>> from chainer import cuda, links, functions # NOQA | ||
>>> from chainer_chemistry.link_hooks.variable_monitor_link_hook import VariableMonitorLinkHook # NOQA | ||
|
||
>>> class DummyModel(chainer.Chain): | ||
>>> def __init__(self): | ||
>>> super(DummyModel, self).__init__() | ||
>>> with self.init_scope(): | ||
>>> self.l1 = links.Linear(None, 1) | ||
>>> self.h = None | ||
>>> | ||
>>> def forward(self, x): | ||
>>> h = self.l1(x) | ||
>>> out = functions.sigmoid(h) | ||
>>> return out | ||
|
||
>>> model = DummyModel() | ||
>>> hook = VariableMonitorLinkHook(model.l1, timing='post') | ||
>>> x = numpy.array([1, 2, 3]) | ||
|
||
>>> # Example 1. `get_variable` of `target_link`. | ||
>>> with hook: | ||
>>> out = model(x) | ||
>>> # You can extract `h`, which is output of `model.l1` as follows. | ||
>>> var_h = hook.get_variable() | ||
|
||
>>> # Example 2. `add_process` to override value of target variable. | ||
>>> def _process_zeros(hook, args, target_var): | ||
>>> xp = cuda.get_array_module(target_var.array) | ||
>>> target_var.array = xp.zeros(target_var.array.shape) | ||
>>> hook.add_process('_process_zeros', _process_zeros) | ||
>>> with hook: | ||
>>> # During the forward, `h` is overriden to value 0. | ||
>>> out = model(x) | ||
>>> # Remove _process_zeros method | ||
>>> hook.delete_process('_process_zeros') | ||
""" | ||
|
||
def __init__(self, target_link, name='VariableMonitorLinkHook', | ||
timing='post', extract_fn=None, logger=None): | ||
if not isinstance(target_link, chainer.Link): | ||
raise TypeError('target_link must be instance of chainer.Link!' | ||
'actual {}'.format(type(target_link))) | ||
if timing not in ['pre', 'post']: | ||
raise ValueError( | ||
"Unexpected value timing={}, " | ||
"must be either pre or post" | ||
.format(timing)) | ||
super(VariableMonitorLinkHook, self).__init__() | ||
self.target_link = target_link | ||
|
||
# This LinkHook maybe instantiated multiple times. | ||
# So it is allowed to change name by argument. | ||
self.name = name | ||
self.logger = logger or getLogger(__name__) | ||
|
||
if extract_fn is None: | ||
if timing == 'pre': | ||
extract_fn = _default_extract_pre | ||
elif timing == 'post': | ||
extract_fn = _default_extract_post | ||
else: | ||
raise ValueError("Unexpected value timing={}" | ||
.format(timing)) | ||
self.extract_fn = extract_fn | ||
self.process_fns = OrderedDict() # Additional process, if necessary | ||
|
||
self.timing = timing | ||
self.result = None | ||
|
||
def add_process(self, key, fn): | ||
"""Add additional process for target variable | ||
|
||
Args: | ||
key (str): id for this process, you may remove added process by | ||
`delete_process` with this key. | ||
fn (callable): function which takes `hook, args, target_var` as | ||
arguments. | ||
""" | ||
if not isinstance(key, str): | ||
raise TypeError('key must be str, actual {}'.format(type(key))) | ||
if not callable(fn): | ||
raise TypeError('fn must be callable') | ||
self.process_fns[key] = fn | ||
|
||
def delete_process(self, key): | ||
"""Delete process added at `add_process` | ||
|
||
Args: | ||
key (str): id for the process, named at `add_process`. | ||
""" | ||
if not isinstance(key, str): | ||
raise TypeError('key must be str, actual {}'.format(type(key))) | ||
if key in self.process_fns.keys(): | ||
del self.process_fns[key] | ||
else: | ||
# Nothing to delete | ||
self.logger.warning('{} is not in process_fns, skip delete_process' | ||
.format(key)) | ||
|
||
def get_variable(self): | ||
"""Get target variable, which is input or output of `target_link`. | ||
|
||
Returns (chainer.Variable): target variable | ||
""" | ||
return self.result | ||
|
||
def forward_preprocess(self, args): | ||
if self.timing == 'pre' and args.link is self.target_link: | ||
self.result = self.extract_fn(self, args) | ||
if self.process_fns is not None: | ||
for key, fn in self.process_fns.items(): | ||
fn(self, args, self.result) | ||
|
||
def forward_postprocess(self, args): | ||
if self.timing == 'post' and args.link is self.target_link: | ||
self.result = self.extract_fn(self, args) | ||
if self.process_fns is not None: | ||
for key, fn in self.process_fns.items(): | ||
fn(self, args, self.result) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import chainer | ||
from chainer import links | ||
from chainer import functions | ||
|
||
|
||
class GraphLinear(chainer.Chain): | ||
"""Graph Linear layer. | ||
|
||
This function assumes its input is 3-dimensional. | ||
Differently from :class:`chainer.functions.linear`, it applies an affine | ||
transformation to the third axis of input `x`. | ||
|
||
.. seealso:: :class:`chainer.links.Linear` | ||
""" | ||
def __init__(self, in_channels, out_channels, nobias=False, stride=1, | ||
initialW=None, initial_bias=None, pad=0, **kwargs): | ||
super(GraphLinear, self).__init__() | ||
self.out_channels = out_channels | ||
with self.init_scope(): | ||
self.conv = links.Convolution2D( | ||
in_channels, out_channels, ksize=1, stride=stride, pad=pad, | ||
nobias=nobias, initialW=initialW, initial_bias=initial_bias, | ||
**kwargs) | ||
|
||
def __call__(self, x): | ||
"""Forward propagation. | ||
|
||
Args: | ||
x (:class:`chainer.Variable`, or :class:`numpy.ndarray`\ | ||
or :class:`cupy.ndarray`): | ||
Input array that should be a float array whose ``ndim`` is 3. | ||
|
||
It represents a minibatch of atoms, each of which consists | ||
of a sequence of molecules. Each molecule is represented | ||
by integer IDs. The first axis is an index of atoms | ||
(i.e. minibatch dimension) and the second one an index | ||
of molecules. | ||
|
||
Returns: | ||
:class:`chainer.Variable`: | ||
A 3-dimeisional array. | ||
|
||
""" | ||
h = x | ||
# (minibatch, atom, ch) | ||
s0, s1, s2 = h.shape | ||
# (minibatch, ch, atom) | ||
h = functions.transpose(h, (0, 2, 1)) | ||
# (minibatch, ch, atom, 1) | ||
h = functions.reshape(h, (s0, s2, s1, 1)) | ||
# (minibatch, out_ch, atom, 1) | ||
h = self.conv(h) | ||
# (minibatch, atom, out_ch, 1) | ||
h = functions.transpose(h, (0, 2, 1, 3)) | ||
# (minibatch, atom, out_ch) | ||
h = functions.reshape(h, (s0, s1, self.out_channels)) | ||
return h | ||
|
||
|
||
if __name__ == '__main__': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please delete debug codes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry i should remove this file |
||
import numpy as np | ||
bs = 5 | ||
ch = 4 | ||
out_ch = 7 | ||
atom = 3 | ||
x = np.random.rand(bs, atom, ch).astype(np.float32) | ||
gl = GraphLinear(ch, out_ch) | ||
y = gl(x) | ||
print('x', x.shape, 'y', y.shape) # x (5, 3, 4) y (5, 3, 7) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from chainer_chemistry.saliency import calculator # NOQA | ||
from chainer_chemistry.saliency import visualizer # NOQA |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from chainer_chemistry.saliency.calculator import base_calculator # NOQA | ||
from chainer_chemistry.saliency.calculator import common # NOQA | ||
from chainer_chemistry.saliency.calculator import gradient_calculator # NOQA | ||
from chainer_chemistry.saliency.calculator import integrated_gradients_calculator # NOQA | ||
from chainer_chemistry.saliency.calculator import occlusion_calculator # NOQA | ||
|
||
from chainer_chemistry.saliency.calculator.base_calculator import BaseCalculator # NOQA | ||
from chainer_chemistry.saliency.calculator.gradient_calculator import GradientCalculator # NOQA | ||
from chainer_chemistry.saliency.calculator.integrated_gradients_calculator import IntegratedGradientsCalculator # NOQA | ||
from chainer_chemistry.saliency.calculator.occlusion_calculator import OcclusionCalculator # NOQA | ||
|
||
from chainer_chemistry.saliency.calculator.common import GaussianNoiseSampler # NOQA |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using
input
andoutput
, instead ofpre
andpost
?It is clear what we would like to get by this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So maybe
fetch_target
is better thantiming
? How do you feel?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, it is actually possible to fetch "input" with timing=post, but this use case is quite limited.
So it is also ok to change name as you suggested.