Skip to content

Commit dc99ce7

Browse files
authored
Merge pull request #246 from corochann/features_indexer_with_length0
support features indexer with length 0
2 parents 98e2089 + 886528c commit dc99ce7

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

chainer_chemistry/dataset/indexer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,24 @@ def _extract_feature(self, data_index, j):
159159
res = [self.extract_feature(i, j) for i in
160160
six.moves.range(current, stop, step)]
161161
elif isinstance(data_index, (list, numpy.ndarray)):
162-
if isinstance(data_index[0], (bool, numpy.bool, numpy.bool_)):
163-
# Access by bool flag list
164-
if len(data_index) != self.dataset_length:
165-
raise ValueError('Feature index wrong length {} instead of'
166-
' {}'.format(len(data_index),
167-
self.dataset_length))
168-
data_index = numpy.argwhere(data_index).ravel()
169-
170-
res = [self.extract_feature(i, j) for i in data_index]
162+
if len(data_index) == 0:
163+
try:
164+
# HACKING
165+
return self.extract_feature_by_slice(slice(0, 0, 1), j)
166+
except ExtractBySliceNotSupportedError:
167+
res = []
168+
else:
169+
if isinstance(data_index[0], (bool, numpy.bool, numpy.bool_)):
170+
# Access by bool flag list
171+
if len(data_index) != self.dataset_length:
172+
raise ValueError('Feature index wrong length {} instead of'
173+
' {}'.format(len(data_index),
174+
self.dataset_length))
175+
data_index = numpy.argwhere(data_index).ravel()
176+
177+
res = [self.extract_feature(i, j) for i in data_index]
171178
else:
179+
# `data_index` is expected to be `int`
172180
return self.extract_feature(data_index, j)
173181
try:
174182
feature = numpy.asarray(res)

tests/dataset_tests/test_numpy_tuple_feature_indexer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,32 @@ class TestNumpyTupleDatasetFeatureIndexer(object):
2626
def test_feature_length(self, indexer):
2727
assert indexer.features_length() == 3
2828

29-
@pytest.mark.parametrize('slice_index', [0, 1, slice(0, 2, None)])
29+
@pytest.mark.parametrize('slice_index', [
30+
0, 1, slice(0, 2, None), slice(0, 0, None)])
3031
@pytest.mark.parametrize('j', [0, 1])
3132
def test_extract_feature_by_slice(self, indexer, data, slice_index, j):
3233
numpy.testing.assert_array_equal(
3334
indexer.extract_feature_by_slice(slice_index, j),
3435
data[j][slice_index])
36+
# indexer's __getitem__ should call `extract_feature_by_slice` method,
37+
# result should be same with above.
38+
numpy.testing.assert_array_equal(
39+
indexer[slice_index, j],
40+
data[j][slice_index])
3541

36-
@pytest.mark.parametrize('ndarray_index', [numpy.asarray([0, 1]),
37-
numpy.asarray([1])])
42+
@pytest.mark.parametrize('ndarray_index', [
43+
numpy.asarray([0, 1]), numpy.asarray([1]),
44+
numpy.asarray([], dtype=numpy.int32)])
3845
@pytest.mark.parametrize('j', [0, 1])
3946
def test_extract_feature_by_ndarray(self, indexer, data, ndarray_index, j):
4047
numpy.testing.assert_array_equal(
4148
indexer.extract_feature_by_slice(ndarray_index, j),
4249
data[j][ndarray_index])
50+
# indexer's __getitem__ should call `extract_feature_by_slice` method,
51+
# result should be same with above.
52+
numpy.testing.assert_array_equal(
53+
indexer[ndarray_index, j],
54+
data[j][ndarray_index])
4355

4456

4557
if __name__ == '__main__':

tests/datasets_tests/test_numpy_tuple_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ def test_get_item_slice_index(self, data, index):
5555
for a, e in six.moves.zip(tuple_a, tuple_e):
5656
numpy.testing.assert_array_equal(a, e)
5757

58-
@pytest.mark.parametrize('index', [numpy.asarray([2, 0]),
59-
numpy.asarray([1])])
58+
@pytest.mark.parametrize('index', [
59+
numpy.asarray([2, 0]), numpy.asarray([1]),
60+
numpy.asarray([], dtype=numpy.int32)])
6061
def test_get_item_ndarray_index(self, long_data, index):
6162
dataset = NumpyTupleDataset(*long_data)
6263
actual = dataset[index]

0 commit comments

Comments
 (0)