Skip to content

Commit

Permalink
Implement Numpy fancy indexing (#434)
Browse files Browse the repository at this point in the history
* index_select should use SomeInteger not SOmeNumber

* Overload index_select for arrays and sequences

* Masked Selector overload for openarrays

* Add masked overload for regular arrays and sequences

* Initial support of Numpy fancy indexing: index select

* Fix broadcast operators from #429 using deprecated syntax

* Stash dispatcher, working with types in macros is a minefield nim-lang/Nim#14021

* Masked indexing: closes #400, workaround nim-lang/Nim#14021

* Test for full masked fancy indexing

* Add index_fill

* Tensor mutation via fancy indexing

* Add tests for index mutation via fancy indexing

* Fancy indexing: supports broadcasting a value to a masked assignation

* Detect wrong mask or tensor axis length

* masked axis assign value test

* Add masked assign of broadcastable tensor

* Tag for changelog [skip ci]
  • Loading branch information
mratsim authored Apr 19, 2020
1 parent bd05448 commit b2620e6
Show file tree
Hide file tree
Showing 16 changed files with 993 additions and 69 deletions.
3 changes: 3 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Arraymancer v0.x.x
=====================================================

Changes (TODO):
- Fancy Indexing (#434)

Arraymancer v0.6.0 Jan. 09 2020
=====================================================

Expand Down
2 changes: 1 addition & 1 deletion src/arraymancer.nim
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ export tensor,
einsum

when not defined(no_lapack):
# THe ml module also does not export everything is LAPACK is not available
# The ml module also does not export everything is LAPACK is not available
import ./linear_algebra/linear_algebra
export linear_algebra
2 changes: 1 addition & 1 deletion src/nn/layers/embedding.nim
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ proc embedding_cache[TT, Idx](
)


proc embedding*[TT; Idx: byte or char or SomeNumber](
proc embedding*[TT; Idx: byte or char or SomeInteger](
input_vocab_id: Tensor[Idx],
weight: Variable[TT],
padding_idx: Idx = -1,
Expand Down
4 changes: 2 additions & 2 deletions src/nn_primitives/nnp_embedding.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import
proc flatten_idx(t: Tensor): Tensor {.inline.}=
t.reshape(t.size)

proc embedding*[T; Idx: byte or char or SomeNumber](
proc embedding*[T; Idx: byte or char or SomeInteger](
vocab_id: Tensor[Idx],
weight: Tensor[T]
): Tensor[T] =
Expand Down Expand Up @@ -54,7 +54,7 @@ proc embedding*[T; Idx: byte or char or SomeNumber](
let shape = vocab_id.shape & weight.shape[1]
result = weight.index_select(0, vocab_id.flatten_idx).reshape(shape)

proc embedding_backward*[T; Idx: byte or char or SomeNumber](
proc embedding_backward*[T; Idx: byte or char or SomeInteger](
dWeight: var Tensor[T],
vocab_id: Tensor[Idx],
dOutput: Tensor[T],
Expand Down
10 changes: 9 additions & 1 deletion src/private/ast_utils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import macros


proc hasType*(x: NimNode, t: static[string]): bool {. compileTime .} =
## Compile-time type checking
sameType(x, bindSym(t))
Expand All @@ -25,6 +24,15 @@ proc isInt*(x: NimNode): bool {. compileTime .} =
## Compile-time type checking
hasType(x, "int")

proc isBool*(x: NimNode): bool {. compileTime .} =
## Compile-time type checking
hasType(x, "bool")

proc isOpenarray*(x: NimNode): bool {. compileTime .} =
## Compile-time type checking
doAssert false, "This is broken for generics https://github.com/nim-lang/Nim/issues/14021"
hasType(x, "array") or hasType(x, "seq") or hasType(x, "openArray")

proc isAllInt*(slice_args: NimNode): bool {. compileTime .} =
## Compile-time type checking
result = true
Expand Down
8 changes: 4 additions & 4 deletions src/tensor/accessors_macros_write.nim
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ macro `[]=`*[T](t: var Tensor[T], args: varargs[untyped]): untyped =
slice_typed_dispatch_mut(`t`, `new_args`,`val`)


# # Linked to: https://github.com/mratsim/Arraymancer/issues/52
# Linked to: https://github.com/mratsim/Arraymancer/issues/52
# Unfortunately enabling this breaksthe test suite
# "Setting a slice from a view of the same Tensor"

#
# macro `[]`*[T](t: var AnyTensor[T], args: varargs[untyped]): untyped =
# ## Slice a Tensor or a CudaTensor
# ## Input:
Expand All @@ -69,6 +69,6 @@ macro `[]=`*[T](t: var Tensor[T], args: varargs[untyped]): untyped =
# ## For CudaTensor only, this is a no-copy operation, data is shared with the input.
# ## This proc does not guarantee that a ``let`` value is immutable.
# let new_args = getAST(desugar(args))

#
# result = quote do:
# slice_typed_dispatch_var(`t`, `new_args`)
# slice_typed_dispatch_var(`t`, `new_args`)
2 changes: 1 addition & 1 deletion src/tensor/init_cpu.nim
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ proc newTensorWith*[T](shape: MetadataArray, value: T): Tensor[T] {.noInit, noSi
{.unroll: 8.}
tval = value

proc toTensor*(s:openarray, dummy_bugfix: static[int] = 0 ): auto {.noSideEffect.} =
proc toTensor*(s:openarray, dummy_bugfix: static[int] = 0): auto {.noSideEffect.} =
## Convert an openarray to a Tensor
## Input:
## - An array or a seq (can be nested)
Expand Down
12 changes: 6 additions & 6 deletions src/tensor/operators_comparison.nim
Original file line number Diff line number Diff line change
Expand Up @@ -98,38 +98,38 @@ template gen_broadcasted_scalar_comparison(op: untyped): untyped {.dirty.} =
result = map_inline(t):
op(x, value)

proc `.==`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
proc `==.`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
## Tensor element-wise equality with scalar
## Returns:
## - A tensor of boolean
gen_broadcasted_scalar_comparison(`==`)


proc `.!=`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
proc `!=.`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
## Tensor element-wise inequality with scalar
## Returns:
## - A tensor of boolean
gen_broadcasted_scalar_comparison(`!=`)

proc `.<=`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
proc `<=.`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
## Tensor element-wise lesser or equal with scalar
## Returns:
## - A tensor of boolean
gen_broadcasted_scalar_comparison(`<=`)

proc `.<`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
proc `<.`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
## Tensor element-wise lesser than a scalar
## Returns:
## - A tensor of boolean
gen_broadcasted_scalar_comparison(`<`)

proc `.>=`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
proc `>=.`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
## Tensor element-wise greater or equal than a scalar
## Returns:
## - A tensor of boolean
gen_broadcasted_scalar_comparison(`>=`)

proc `.>`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
proc `>.`*[T](t: Tensor[T], value : T): Tensor[bool] {.noInit.} =
## Tensor element-wise greater than a scalar
## Returns:
## - A tensor of boolean
Expand Down
135 changes: 132 additions & 3 deletions src/tensor/private/p_accessors_macros_read.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import ../../private/ast_utils,
./p_checks, ./p_accessors,
sequtils, macros

from ../init_cpu import toTensor

template slicerImpl*[T](result: AnyTensor[T]|var AnyTensor[T], slices: ArrayOfSlices): untyped =
## Slicing routine

Expand Down Expand Up @@ -117,21 +119,148 @@ proc slicer*[T](t: Tensor[T], slices: ArrayOfSlices): Tensor[T] {.noInit,noSideE
# #########################################################################
# Dispatching logic

type FancySelectorKind* = enum
FancyNone
FancyIndex
FancyMaskFull
FancyMaskAxis
# Workaround needed for https://github.com/nim-lang/Nim/issues/14021
FancyUnknownFull
FancyUnknownAxis

proc getFancySelector*(ast: NimNode, axis: var int, selector: var NimNode): FancySelectorKind =
## Detect indexing in the form
## - "tensor[_, _, [0, 1, 4], _, _]
## - "tensor[_, _, [0, 1, 4], `...`]
## or with the index selector being a tensor
result = FancyNone
var foundNonSpanOrEllipsis = false
var ellipsisAtStart = false

template checkNonSpan(): untyped {.dirty.} =
doAssert not foundNonSpanOrEllipsis,
"Fancy indexing is only compatible with full spans `_` on non-indexed dimensions" &
" and/or ellipsis `...`"

var i = 0
while i < ast.len:
let cur = ast[i]

# Important: sameType doesn't work for generic type like Array, Seq or Tensors ...
# https://github.com/nim-lang/Nim/issues/14021
if cur.sameType(bindSym"SteppedSlice") or cur.isInt():
if cur.eqIdent"Span":
discard
else:
doAssert result == FancyNone
foundNonSpanOrEllipsis = true
elif cur.sameType(bindSym"Ellipsis"):
if i == ast.len - 1: # t[t.sum(axis = 1) >. 0.5, `...`]
doAssert not ellipsisAtStart, "Cannot deduce the indexed/sliced dimensions due to ellipsis at the start and end of indexing."
ellipsisAtStart = false
elif i == 0: # t[`...`, t.sum(axis = 0) >. 0.5]
ellipsisAtStart = true
else:
# t[0 ..< 10, `...`, t.sum(axis = 0) >. 0.5] is unsupported
# so we tag as "foundNonSpanOrEllipsis"
foundNonSpanOrEllipsis = true
elif cur.kind == nnkBracket:
checkNonSpan()
axis = i
if cur[0].kind == nnkIntLit:
result = FancyIndex
selector = cur
elif cur[0].isBool():
let full = i == 0 and ast.len == 1
result = if full: FancyMaskFull else: FancyMaskAxis
selector = cur
else:
# byte, char, enums are all represented by integers in the VM
error "Fancy indexing is only possible with integers or booleans"
else:
checkNonSpan()
axis = i
let full = i == 0 and ast.len == 1
result = if full: FancyUnknownFull else: FancyUnknownAxis
selector = cur
inc i

# Handle ellipsis at the start
if result != FancyNone and ellipsisAtStart:
axis = ast.len - axis

macro slice_typed_dispatch*(t: typed, args: varargs[typed]): untyped =
## Typed macro so that isAllInt has typed context and we can dispatch.
## If args are all int, we dispatch to atIndex and return T
## Else, all ints are converted to SteppedSlices and we return a Tensor.
## Note, normal slices and `_` were already converted in the `[]` macro
## TODO in total we do 3 passes over the list of arguments :/. It is done only at compile time though

# Point indexing
# -----------------------------------------------------------------
if isAllInt(args):
result = newCall(bindSym("atIndex"), t)
result = newCall(bindSym"atIndex", t)
for slice in args:
result.add(slice)
else:
result = newCall(bindSym("slicer"), t)
return

# Fancy indexing
# -----------------------------------------------------------------
# Cannot depend/bindSym the "selectors.nim" proc
# Due to recursive module dependencies
var selector: NimNode
var axis: int
let fancy = args.getFancySelector(axis, selector)
if fancy == FancyIndex:
return newCall(
ident"index_select",
t, newLit axis, selector
)
if fancy == FancyMaskFull:
return newCall(
ident"masked_select",
t, selector
)
elif fancy == FancyMaskAxis:
return newCall(
ident"masked_axis_select",
t, selector, newLit axis
)

# Slice indexing
# -----------------------------------------------------------------
if fancy == FancyNone:
result = newCall(bindSym"slicer", t)
for slice in args:
if isInt(slice):
## Convert [10, 1..10|1] to [10..10|1, 1..10|1]
result.add(infix(slice, "..", infix(slice, "|", newIntLitNode(1))))
else:
result.add(slice)
return

# Fancy bug in Nim compiler
# -----------------------------------------------------------------
# We need to drop down to "when a is T" to infer what selector to call
# as `getType`/`getTypeInst`/`getTypeImpl`/`sameType`
# are buggy with generics
# due to https://github.com/nim-lang/Nim/issues/14021
let lateBind_masked_select = ident"masked_select"
let lateBind_masked_axis_select = ident"masked_axis_select"
let lateBind_index_select = ident"index_select"

result = quote do:
type FancyType = typeof(`selector`)
when FancyType is (array or seq):
type FancyTensorType = typeof(toTensor(`selector`))
else:
type FancyTensorType = FancyType
when FancyTensorType is Tensor[bool]:
when FancySelectorKind(`fancy`) == FancyUnknownFull:
`lateBind_masked_select`(`t`, `selector`)
elif FancySelectorKind(`fancy`) == FancyUnknownAxis:
`lateBind_masked_axis_select`(`t`, `selector`, `axis`)
else:
{.error: "Unreachable".}
else:
`lateBind_index_select`(`t`, `axis`, `selector`)
68 changes: 65 additions & 3 deletions src/tensor/private/p_accessors_macros_write.nim
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,86 @@ proc slicerMut*[T](t: var Tensor[T],

macro slice_typed_dispatch_mut*(t: typed, args: varargs[typed], val: typed): untyped =
## Assign `val` to Tensor T at slice/position `args`

# Point indexing
# -----------------------------------------------------------------
if isAllInt(args):
result = newCall(bindSym("atIndexMut"), t)
result = newCall(bindSym"atIndexMut", t)
for slice in args:
result.add(slice)
result.add(val)
else:
result = newCall(bindSym("slicerMut"), t)
return

# Fancy indexing
# -----------------------------------------------------------------
# Cannot depend/bindSym the "selectors.nim" proc
# Due to recursive module dependencies
var selector: NimNode
var axis: int
let fancy = args.getFancySelector(axis, selector)
if fancy == FancyIndex:
return newCall(
ident"index_fill",
t, newLit axis, selector,
val
)
if fancy == FancyMaskFull:
return newCall(
ident"masked_fill",
t, selector,
val
)
elif fancy == FancyMaskAxis:
return newCall(
ident"masked_axis_fill",
t, selector, newLit axis,
val
)

# Slice indexing
# -----------------------------------------------------------------
if fancy == FancyNone:
result = newCall(bindSym"slicerMut", t)
for slice in args:
if isInt(slice):
## Convert [10, 1..10|1] to [10..10|1, 1..10|1]
result.add(infix(slice, "..", infix(slice, "|", newIntLitNode(1))))
else:
result.add(slice)
result.add(val)
return

# Fancy bug in Nim compiler
# -----------------------------------------------------------------
# We need to drop down to "when a is T" to infer what selector to call
# as `getType`/`getTypeInst`/`getTypeImpl`/`sameType`
# are buggy with generics
# due to https://github.com/nim-lang/Nim/issues/14021
let lateBind_masked_fill = ident"masked_fill"
let lateBind_masked_axis_fill = ident"masked_axis_fill"
let lateBind_index_fill = ident"index_fill"

result = quote do:
type FancyType = typeof(`selector`)
when FancyType is (array or seq):
type FancyTensorType = typeof(toTensor(`selector`))
else:
type FancyTensorType = FancyType
when FancyTensorType is Tensor[bool]:
when FancySelectorKind(`fancy`) == FancyUnknownFull:
`lateBind_masked_fill`(`t`, `selector`, `val`)
elif FancySelectorKind(`fancy`) == FancyUnknownAxis:
`lateBind_masked_axis_fill`(`t`, `selector`, `axis`, `val`)
else:
{.error: "Unreachable".}
else:
`lateBind_index_fill`(`t`, `axis`, `selector`, `val`)

# ############################################################################
# Slicing a var returns a var (for Result[_] += support)
# And apply2(result[_], foo) support
#
# Unused: Nim support for var return types is problematic

proc slicer_var[T](t: var AnyTensor[T], slices: varargs[SteppedSlice]): var AnyTensor[T] {.noInit,noSideEffect.}=
## Take a Tensor and SteppedSlices
Expand Down
Loading

0 comments on commit b2620e6

Please sign in to comment.