-
Notifications
You must be signed in to change notification settings - Fork 131
Saliency #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Saliency #283
Conversation
TODO:
|
Codecov Report
@@ Coverage Diff @@
## master #283 +/- ##
==========================================
- Coverage 81.73% 75.12% -6.61%
==========================================
Files 122 137 +15
Lines 6116 6654 +538
==========================================
Hits 4999 4999
- Misses 1117 1655 +538 |
Codecov Report
@@ Coverage Diff @@
## master #283 +/- ##
==========================================
+ Coverage 81.73% 83.24% +1.51%
==========================================
Files 122 147 +25
Lines 6116 7092 +976
==========================================
+ Hits 4999 5904 +905
- Misses 1117 1188 +71 |
""" | ||
|
||
def __init__(self, target_link, name='VariableMonitorLinkHook', | ||
timing='post', extract_fn=None, logger=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using input
and output
, instead of pre
and post
?
It is clear what we would like to get by this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So maybe fetch_target
is better than timing
? How do you feel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, it is actually possible to fetch "input" with timing=post, but this use case is quite limited.
So it is also ok to change name as you suggested.
return h | ||
|
||
|
||
if __name__ == '__main__': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please delete debug codes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry i should remove this file
maxv = xp.max(saliency) | ||
minv = xp.min(saliency) | ||
if maxv == minv: | ||
saliency = xp.zeros_like(saliency) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise Warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually this case happened several times with bayesgrad, so may be just debug logging or info logging is enough.
""" | ||
xp = cuda.get_array_module(saliency) | ||
maxv = xp.max(xp.abs(saliency)) | ||
if maxv <= 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise warning?
link_hooks = chainer._get_link_hooks() | ||
name = prefix + linkhook.name | ||
if name in link_hooks: | ||
print('[WARNING] hook {} already exists, overwrite.'.format(name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
name = prefix + linkhook.name | ||
link_hooks = chainer._get_link_hooks() | ||
if name not in link_hooks.keys(): | ||
print('[WARNING] linkhook {} is not registered'.format(name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
""" | ||
xp = cuda.get_array_module(saliency) | ||
vsum = xp.sum(xp.abs(saliency), axis=axis, keepdims=True) | ||
if vsum <= 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
||
|
||
def normalize_scaler(saliency, axis=None): | ||
"""Normalize saliency to be sum=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only for saliency which all value are > 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified to show warning message
num_atoms = mol.GetNumAtoms() | ||
|
||
# --- type check --- | ||
if not saliency.ndim == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
saliency.ndmi != 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
raise ValueError("Unexpected value saliency.shape={}" | ||
.format(saliency.shape)) | ||
|
||
# Cut saliency array for unnecessary tail part |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this expected behavior? I feel when len(saliency) == num_atoms
, raising Warning is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is expected. atom_array is often 0 padded with concat_mols, so we need to truncate this padded length.
in_array) for in_array in input_list] | ||
|
||
result = [_concat(output) for output in output_list] | ||
if len(result) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This specification may be confusing. And I think test is necessary.
|
||
# 2. test with `save_filepath=None` runs without error | ||
image = numpy.random.uniform(0, 1, (ch, h, w)) | ||
visualizer.visualize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Visualilze method stops unit testing.
Can you insert plt.ion()
and plt.close()
? like this.
plt.ion()
visualizer.visualize(
saliency, save_filepath=None, feature_names=['hoge', 'huga', 'piyo'],
num_visualize=2)
plt.close()
visualizer.visualize(saliency, save_filepath=save_filepath) | ||
assert os.path.exists(save_filepath) | ||
# 2. test with `save_filepath=None` runs without error | ||
visualizer.visualize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As same as image_visualizer
Updated based on comment! |
saliency modules
[Calculator]
[Visualizer]
[Addtional utils]