-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: indexing with broadcasting #1473
Changes from 2 commits
105bd64
23b4fe0
726ba5d
df7011f
f9232cb
17b6465
33c51d3
03a336f
d5af395
866de91
08e7444
84afc98
7b33269
50ea56e
bac0089
1206c28
c2747be
0671f39
ffccff1
becf539
1ae4b4c
1967bf5
c2eeff3
df12c04
5ba367d
d25c1f1
0115994
1b4e854
36d052f
9bd53ca
563cafa
1712060
884423a
c2e6f42
002eafa
eedfb3f
bb2e515
5983a67
7a5ff79
0b559bc
bad828e
a821a2b
6550880
464e711
7dd171d
e8f006b
a8ec82b
32749d4
31401d4
f42ddfd
8d96ad3
19f7204
a8f60ba
69f8570
5eb00b7
d133766
631f6e9
72587de
3231445
dd325c5
434a004
d518f7a
ba3cc88
f63f3d5
aa10635
fd73e82
6202aff
f9746fd
f78c932
1c027cd
d11829f
0777128
f580c99
4ebe852
20f5cb9
ab08af8
92dded6
a624424
a4cd724
f66c9b6
24309c4
fd698de
9cbaff9
a5c7766
f242166
bff18f0
21c11c4
1975f66
73ad94e
b49f813
1fd6b3a
46dd7c7
11f3e4f
1d3eddc
bcb25f1
7104964
173968b
addb91a
7ad7d36
91dd833
118a5d8
dc9f8a6
5726c89
523ecaa
a3a83db
765ae45
3deaf5c
c8c8a12
a16a04b
1b34cd4
24599a7
d5d967b
d0d6a6f
4f08e2e
fbbe35c
db23c93
031be9a
969f9cf
8608451
b4e5b36
9523039
dc60348
8a62ad9
6b96960
cb84154
9726531
caa79fe
170abc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
pass | ||
|
||
|
||
|
||
def as_variable(obj, name=None): | ||
"""Convert an object into a Variable. | ||
|
||
|
@@ -406,6 +407,96 @@ def __getitem__(self, key): | |
return type(self)(dims, values, self._attrs, self._encoding, | ||
fastpath=True) | ||
|
||
def _broadcast_indexes(self, key): | ||
""" | ||
Parameters | ||
----------- | ||
key: One of | ||
array | ||
a mapping of dimension names to index. | ||
|
||
Returns | ||
------- | ||
dims: Tuple of strings. | ||
Dimension of the resultant variable. | ||
indexers: list of integer, array-like, or slice. This is aligned | ||
along self.dims. | ||
""" | ||
if not utils.is_dict_like(key): | ||
key = {self.dims[0]: key} | ||
example_v = None | ||
indexes = OrderedDict() | ||
for k, v in key.items(): | ||
if not isinstance(v, (integer_types, slice, Variable)): | ||
if not hasattr(key, 'ndim'): # convert list or tuple | ||
v = np.array(v) | ||
if example_v is None and isinstance(v, Variable): | ||
example_v = v | ||
indexes[k] = v | ||
|
||
# When all the keys are array or integer, slice | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we actually need two totally different code paths for basic vs. advanced indexing:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's an alternative version (basically what I described above) that I think will work, at least if the variable's data is a NumPy array (we may need to jump through a few more hoops for dask arrays): def _broadcast_indexes(self, key):
key = self._item_key_to_tuple(key) # key is a tuple
key = indexing.expanded_indexer(key, self.ndim) # key is a tuple of full size
basic_indexing_types = integer_types + (slice,)
if all(isinstance(k, basic_indexing_types) for k in key):
return self._broadcast_indexes_basic(key)
else:
return self._broadcast_indexes_advanced(key)
def _broadcast_indexes_basic(self, key):
dims = tuple(dim for k, dim in zip(key, self.dims)
if not isinstance(k, integer_types))
return dims, key
def _broadcast_indexes_advanced(self, key):
variables = []
for dim, value in zip(self.dims, key):
if isinstance(value, slice):
value = np.arange(*value.indices(self.sizes[dim])
# NOTE: this is close to but not quite correct, since we want to
# handle tuples differently than as_variable and want a different
# error message (not referencing tuples)
variable = as_variable(value, name=dim)
variables.append(variable)
variables = _broadcast_compat_variables(*variables)
dims = variables[0].dims # all variables have the same dims
key = tuple(variable.data for variable in variables)
return dims, key There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Please let me make sure one thing. What would be expected by the following? v = Variable(['x', 'y'], [[0, 1, 2], [3, 4, 5]])
ind_x = Variable(['a'], [0, 1])
v.getitem2(dict(x=ind_x, y=[1, 0]))) It should be understood as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think the result here would look like
|
||
if example_v is None: | ||
# more than one (unlabelled) arrays | ||
if len([v for k, v in indexes.items() | ||
if not isinstance(v, (integer_types, slice))]) > 1: | ||
raise IndexError("Indexing with multiple unlabelled arrays " | ||
"is not allowed.") | ||
# multi-dimensional unlabelled array | ||
if any([v.ndim > 1 for k, v in indexes.items() | ||
if not isinstance(v, integer_types)]): | ||
raise IndexError("Indexing with a multi-dimensional unlabelled" | ||
"array is not allowed.") | ||
# convert the array into Variable | ||
for k, v in indexes.items(): | ||
if not hasattr(v, 'dims'): | ||
indexes[k] = type(self)([k], v) | ||
example_v = v | ||
|
||
for k, v in indexes.items(): | ||
# Found unlabelled array | ||
if not isinstance(v, (Variable, integer_types, slice)): | ||
if (v.ndim > example_v.ndim or | ||
any([example_v.ndim != v.ndim for k, v | ||
in indexes.items() if isinstance(v, Variable)])): | ||
raise IndexError("Broadcasting failed because dimensions " | ||
"does not match.") | ||
else: | ||
_, indexes[k], _ = _broadcast_compat_data(example_v, v) | ||
|
||
index_tuple = tuple(indexes.get(d, slice(None)) for d in self.dims) | ||
index_tuple = indexing.expanded_indexer(index_tuple, self.ndim) | ||
|
||
# comput dims | ||
dims = [] | ||
for i, d in enumerate(self.dims): | ||
if d in indexes.keys(): | ||
if isinstance(v, Variable): | ||
for d in v.dims: | ||
if d not in dims: | ||
dims.append(d) | ||
else: | ||
dims.append(d) | ||
|
||
return dims, index_tuple | ||
|
||
def getitem2(self, key): | ||
"""Return a new Array object whose contents are consistent with | ||
getting the provided key from the underlying data. | ||
|
||
NB. __getitem__ and __setitem__ implement "diagonal indexing" like | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I like the name "diagonal indexing". |
||
np.ndarray. | ||
|
||
This method will replace __getitem__ after we make sure its stability. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete |
||
""" | ||
dims, index_tuple = self._broadcast_indexes(key) | ||
values = self._data[index_tuple] | ||
if hasattr(values, 'ndim'): | ||
assert values.ndim == len(dims), (values.ndim, len(dims)) | ||
else: | ||
assert len(dims) == 0, len(dims) | ||
return type(self)(dims, values, self._attrs, self._encoding, | ||
fastpath=True) | ||
|
||
def __setitem__(self, key, value): | ||
"""__setitem__ is overloaded to access the underlying numpy values with | ||
orthogonal indexing. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that if
key
is a tuple, it should be paired with multiple dimensions.