Skip to content

Saliency #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Dec 11, 2018
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions chainer_chemistry/link_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
try:
from chainer_chemistry.link_hooks import variable_monitor_link_hook # NOQA

from chainer_chemistry.link_hooks.variable_monitor_link_hook import VariableMonitorLinkHook # NOQA
is_link_hooks_available = True
except ImportError:
import warnings
warnings.warn('link_hooks failed to import, you need to upgrade chainer '
'version to use link_hooks feature')
is_link_hooks_available = False
168 changes: 168 additions & 0 deletions chainer_chemistry/link_hooks/variable_monitor_link_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from collections import OrderedDict
from logging import getLogger

import chainer
from chainer.link_hook import _ForwardPostprocessCallbackArgs, _ForwardPreprocessCallbackArgs # NOQA


def _default_extract_pre(hook, args):
"""Default extract_fn when `timing='pre`

Args:
hook (VariableMonitorLinkHook):
args (_ForwardPreprocessCallbackArgs):

Returns (chainer.Variable): First input variable to the link.
"""
return args.args[0]


def _default_extract_post(hook, args):
"""Default extract_fn when `timing='post`

Args:
hook (VariableMonitorLinkHook):
args (_ForwardPostprocessCallbackArgs):

Returns (chainer.Variable): Output variable to the link.
"""
return args.out


class VariableMonitorLinkHook(chainer.LinkHook):
"""Monitor Variable of specific link input/output

Args:
target_link (chainer.Link): target link to monitor variable.
name (str): name of this link hook
timing (str): timing of this link hook to monitor. 'pre' or 'post'.
If 'pre', the input of `target_link` is monitored.
If 'post', the output of `target_link` is monitored.
extract_fn (callable): Specify custom method to extract target variable
Default behavior is to extract first input when `timing='pre'`,
or extract output when `timing='post'`.
It takes `hook, args` as argument.
logger:

.. admonition:: Example

>>> import numpy
>>> from chainer import cuda, links, functions # NOQA
>>> from chainer_chemistry.link_hooks.variable_monitor_link_hook import VariableMonitorLinkHook # NOQA

>>> class DummyModel(chainer.Chain):
>>> def __init__(self):
>>> super(DummyModel, self).__init__()
>>> with self.init_scope():
>>> self.l1 = links.Linear(None, 1)
>>> self.h = None
>>>
>>> def forward(self, x):
>>> h = self.l1(x)
>>> out = functions.sigmoid(h)
>>> return out

>>> model = DummyModel()
>>> hook = VariableMonitorLinkHook(model.l1, timing='post')
>>> x = numpy.array([1, 2, 3])

>>> # Example 1. `get_variable` of `target_link`.
>>> with hook:
>>> out = model(x)
>>> # You can extract `h`, which is output of `model.l1` as follows.
>>> var_h = hook.get_variable()

>>> # Example 2. `add_process` to override value of target variable.
>>> def _process_zeros(hook, args, target_var):
>>> xp = cuda.get_array_module(target_var.array)
>>> target_var.array = xp.zeros(target_var.array.shape)
>>> hook.add_process('_process_zeros', _process_zeros)
>>> with hook:
>>> # During the forward, `h` is overriden to value 0.
>>> out = model(x)
>>> # Remove _process_zeros method
>>> hook.delete_process('_process_zeros')
"""

def __init__(self, target_link, name='VariableMonitorLinkHook',
timing='post', extract_fn=None, logger=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using input and output, instead of pre and post?
It is clear what we would like to get by this class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe fetch_target is better than timing? How do you feel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, it is actually possible to fetch "input" with timing=post, but this use case is quite limited.
So it is also ok to change name as you suggested.

if not isinstance(target_link, chainer.Link):
raise TypeError('target_link must be instance of chainer.Link!'
'actual {}'.format(type(target_link)))
if timing not in ['pre', 'post']:
raise ValueError(
"Unexpected value timing={}, "
"must be either pre or post"
.format(timing))
super(VariableMonitorLinkHook, self).__init__()
self.target_link = target_link

# This LinkHook maybe instantiated multiple times.
# So it is allowed to change name by argument.
self.name = name
self.logger = logger or getLogger(__name__)

if extract_fn is None:
if timing == 'pre':
extract_fn = _default_extract_pre
elif timing == 'post':
extract_fn = _default_extract_post
else:
raise ValueError("Unexpected value timing={}"
.format(timing))
self.extract_fn = extract_fn
self.process_fns = OrderedDict() # Additional process, if necessary

self.timing = timing
self.result = None

def add_process(self, key, fn):
"""Add additional process for target variable

Args:
key (str): id for this process, you may remove added process by
`delete_process` with this key.
fn (callable): function which takes `hook, args, target_var` as
arguments.
"""
if not isinstance(key, str):
raise TypeError('key must be str, actual {}'.format(type(key)))
if not callable(fn):
raise TypeError('fn must be callable')
self.process_fns[key] = fn

def delete_process(self, key):
"""Delete process added at `add_process`

Args:
key (str): id for the process, named at `add_process`.
"""
if not isinstance(key, str):
raise TypeError('key must be str, actual {}'.format(type(key)))
if key in self.process_fns.keys():
del self.process_fns[key]
else:
# Nothing to delete
self.logger.warning('{} is not in process_fns, skip delete_process'
.format(key))

def get_variable(self):
"""Get target variable, which is input or output of `target_link`.

Returns (chainer.Variable): target variable
"""
return self.result

def forward_preprocess(self, args):
if self.timing == 'pre' and args.link is self.target_link:
self.result = self.extract_fn(self, args)
if self.process_fns is not None:
for key, fn in self.process_fns.items():
fn(self, args, self.result)

def forward_postprocess(self, args):
if self.timing == 'post' and args.link is self.target_link:
self.result = self.extract_fn(self, args)
if self.process_fns is not None:
for key, fn in self.process_fns.items():
fn(self, args, self.result)
69 changes: 69 additions & 0 deletions chainer_chemistry/links/graph_linear_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import chainer
from chainer import links
from chainer import functions


class GraphLinear(chainer.Chain):
"""Graph Linear layer.

This function assumes its input is 3-dimensional.
Differently from :class:`chainer.functions.linear`, it applies an affine
transformation to the third axis of input `x`.

.. seealso:: :class:`chainer.links.Linear`
"""
def __init__(self, in_channels, out_channels, nobias=False, stride=1,
initialW=None, initial_bias=None, pad=0, **kwargs):
super(GraphLinear, self).__init__()
self.out_channels = out_channels
with self.init_scope():
self.conv = links.Convolution2D(
in_channels, out_channels, ksize=1, stride=stride, pad=pad,
nobias=nobias, initialW=initialW, initial_bias=initial_bias,
**kwargs)

def __call__(self, x):
"""Forward propagation.

Args:
x (:class:`chainer.Variable`, or :class:`numpy.ndarray`\
or :class:`cupy.ndarray`):
Input array that should be a float array whose ``ndim`` is 3.

It represents a minibatch of atoms, each of which consists
of a sequence of molecules. Each molecule is represented
by integer IDs. The first axis is an index of atoms
(i.e. minibatch dimension) and the second one an index
of molecules.

Returns:
:class:`chainer.Variable`:
A 3-dimeisional array.

"""
h = x
# (minibatch, atom, ch)
s0, s1, s2 = h.shape
# (minibatch, ch, atom)
h = functions.transpose(h, (0, 2, 1))
# (minibatch, ch, atom, 1)
h = functions.reshape(h, (s0, s2, s1, 1))
# (minibatch, out_ch, atom, 1)
h = self.conv(h)
# (minibatch, atom, out_ch, 1)
h = functions.transpose(h, (0, 2, 1, 3))
# (minibatch, atom, out_ch)
h = functions.reshape(h, (s0, s1, self.out_channels))
return h


if __name__ == '__main__':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete debug codes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry i should remove this file

import numpy as np
bs = 5
ch = 4
out_ch = 7
atom = 3
x = np.random.rand(bs, atom, ch).astype(np.float32)
gl = GraphLinear(ch, out_ch)
y = gl(x)
print('x', x.shape, 'y', y.shape) # x (5, 3, 4) y (5, 3, 7)
2 changes: 2 additions & 0 deletions chainer_chemistry/saliency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from chainer_chemistry.saliency import calculator # NOQA
from chainer_chemistry.saliency import visualizer # NOQA
12 changes: 12 additions & 0 deletions chainer_chemistry/saliency/calculator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from chainer_chemistry.saliency.calculator import base_calculator # NOQA
from chainer_chemistry.saliency.calculator import common # NOQA
from chainer_chemistry.saliency.calculator import gradient_calculator # NOQA
from chainer_chemistry.saliency.calculator import integrated_gradients_calculator # NOQA
from chainer_chemistry.saliency.calculator import occlusion_calculator # NOQA

from chainer_chemistry.saliency.calculator.base_calculator import BaseCalculator # NOQA
from chainer_chemistry.saliency.calculator.gradient_calculator import GradientCalculator # NOQA
from chainer_chemistry.saliency.calculator.integrated_gradients_calculator import IntegratedGradientsCalculator # NOQA
from chainer_chemistry.saliency.calculator.occlusion_calculator import OcclusionCalculator # NOQA

from chainer_chemistry.saliency.calculator.common import GaussianNoiseSampler # NOQA
Loading