Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF] Added support for advanced indexing and slicing #23684

Merged
merged 31 commits into from
Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
96a894a
Initial attempt at advanced indexing and slicing.
eaplatanios Mar 30, 2019
0afbbea
Added '@inlinable'.
eaplatanios Mar 30, 2019
5610318
Removed the commented out implementations of 'Tensor.subscript'.
eaplatanios Mar 30, 2019
40f6021
Addressed Richard's comments.
eaplatanios Mar 30, 2019
e2cdf31
Minor edits.
eaplatanios Apr 9, 2019
2ae2c2c
Resolved merge conflict.
eaplatanios Apr 9, 2019
d0f43b8
Updated some of the dependencies.
eaplatanios Apr 9, 2019
4237986
Fixed the build script to work with the latest TensorFlow updates.
eaplatanios Apr 10, 2019
0f37daa
Additional fix to the build script.
eaplatanios Apr 10, 2019
6d55664
Added support for tensor advanced indexing subscript setters.
eaplatanios Apr 10, 2019
7e6d040
Additional fix to the build script.
eaplatanios Apr 10, 2019
c22a0e3
Updated dependencies.
eaplatanios Apr 11, 2019
d739d4e
Merged upstream changes.
eaplatanios Apr 12, 2019
44d22a8
Merge remote-tracking branch 'upstream/tensorflow' into advanced-inde…
eaplatanios Apr 12, 2019
b6c9156
Merge remote-tracking branch 'upstream/tensorflow' into advanced-inde…
eaplatanios Apr 13, 2019
89f2801
Added a couple tests for the tensor subscript setter.
eaplatanios Apr 14, 2019
7c57019
Added subscript getter VJP.
eaplatanios Apr 14, 2019
216aa79
Made the tensor subscript infinitely differentiable.
eaplatanios Apr 14, 2019
7626d13
Added VJP for slice.
eaplatanios Apr 14, 2019
9a3e1a7
Merged upstream changes.
eaplatanios Apr 14, 2019
548f98c
Updated the TensorFlow dependency.
eaplatanios Apr 15, 2019
f745756
Addressed some of Richard's comments.
eaplatanios Apr 15, 2019
e291758
Did some refactoring.
eaplatanios Apr 15, 2019
08a2950
Added some convenient helpers.
eaplatanios Apr 15, 2019
f4cbf03
Addressed Richard's comments.
eaplatanios Apr 15, 2019
ae2f4a0
Addressed some of Richard's comments.
eaplatanios Apr 16, 2019
72abd58
Made some modifications to the tensor indexing helpers and added a fe…
eaplatanios Apr 16, 2019
f26fad8
Added a stride operator and addressed Richard's comments.
eaplatanios Apr 16, 2019
d42ed49
Merged upstream changes.
eaplatanios Apr 17, 2019
00af1a4
Merged upstream changes.
eaplatanios Apr 17, 2019
01b0f2f
Switched from Int32 to Int for the tensor advanced indexing ops.
eaplatanios Apr 17, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 280 additions & 39 deletions stdlib/public/TensorFlow/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1719,56 +1719,297 @@ public extension Tensor where Scalar : Numeric {
// Indexing and slicing
//===----------------------------------------------------------------------===//

// TODO: Negative indexing and strides syntax.

public extension Tensor {
/// Access the element tensor specified by an index in the leading dimension.
/// - Parameter index: Index of the element tensor.
/// Extracts a slice from the tensor defined by lower and upper bounds for
/// each dimension.
///
/// - Parameter lowerBounds: The lower bounds at each dimension.
/// - Parameter upperBounds: The upper bounds at each dimension.
@inlinable
@differentiable(wrt: self)
func slice(lowerBounds: [Int], upperBounds: [Int]) -> Tensor {
// TODO: Precondition `lowerBounds.count == upperBounds.count`,
// preferably in graph.
// TODO: Differentiating control flow is not supported yet, thus the thunks.
let lowerBoundsTensor = Tensor<Int32>({lowerBounds.map(Int32.init)}())
let upperBoundsTensor = Tensor<Int32>({upperBounds.map(Int32.init)}())
return slice(
lowerBounds: lowerBoundsTensor,
sizes: upperBoundsTensor - lowerBoundsTensor)
}

@inlinable
subscript(index: Int) -> Tensor {
@differentiable(wrt: self, vjp: _vjpSlice)
func slice(lowerBounds: Tensor<Int32>, sizes: Tensor<Int32>) -> Tensor {
return Raw.slice(self, begin: lowerBounds, size: sizes)
}

@inlinable
internal func _vjpSlice(
lowerBounds: Tensor<Int32>,
sizes: Tensor<Int32>
) -> (Tensor, (Tensor) -> Tensor) {
let value = slice(lowerBounds: lowerBounds, sizes: sizes)
let afterPaddings = shapeTensor - value.shapeTensor - lowerBounds
return (value, { [after = afterPaddings] v in
let beforePaddings = lowerBounds.expandingShape(at: 1)
let afterPaddings = after.expandingShape(at: 1)
let paddings = Tensor<Int32>(
concatenating: [beforePaddings, afterPaddings], alongAxis: 1)
return Raw.pad(v, paddings: paddings)
})
}
}

public enum TensorRange : TensorRangeExpression {
case ellipsis
case newAxis
case squeezeAxis
case index(Int)
case range(Range<Int>, stride: Int)
case closedRange(ClosedRange<Int>, stride: Int)
case partialRangeFrom(PartialRangeFrom<Int>, stride: Int)
case partialRangeUpTo(PartialRangeUpTo<Int>, stride: Int)
case partialRangeThrough(PartialRangeThrough<Int>, stride: Int)

public var tensorRange: TensorRange { return self }
}

extension TensorRange : Equatable {
public static func == (lhs: TensorRange, rhs: TensorRange) -> Bool {
switch (lhs, rhs) {
case (.ellipsis, .ellipsis),
(.newAxis, .newAxis),
(.squeezeAxis, .squeezeAxis):
return true
case (let .index(i1), let .index(i2)): return i1 == i2
case (let .range(r1, s1), let .range(r2, s2)): return r1 == r2 && s1 == s2
case (let .closedRange(r1, s1), let .closedRange(r2, s2)):
return r1 == r2 && s1 == s2
case (let .partialRangeFrom(r1, s1), let .partialRangeFrom(r2, s2)):
return r1.lowerBound == r2.lowerBound && s1 == s2
case (let .partialRangeUpTo(r1, s1), let .partialRangeUpTo(r2, s2)):
return r1.upperBound == r2.upperBound && s1 == s2
case (let .partialRangeThrough(r1, s1), let .partialRangeThrough(r2, s2)):
return r1.upperBound == r2.upperBound && s1 == s2
default: return false
}
}
}

public protocol TensorRangeExpression {
var tensorRange: TensorRange { get }
}

// TODO: Cannot extend non-nominal type 'UnboundedRange'.
// extension UnboundedRange : TensorRangeExpression {
eaplatanios marked this conversation as resolved.
Show resolved Hide resolved
// public var tensorRange: TensorRange { return .ellipsis }
// }

extension Int : TensorRangeExpression {
public var tensorRange: TensorRange { return .index(self) }
}

extension Range : TensorRangeExpression where Bound == Int {
public var tensorRange: TensorRange {
return .range(self, stride: 1)
}
}

extension ClosedRange : TensorRangeExpression where Bound == Int {
public var tensorRange: TensorRange {
return .closedRange(self, stride: 1)
}
}

extension PartialRangeFrom : TensorRangeExpression where Bound == Int {
public var tensorRange: TensorRange {
return .partialRangeFrom(self, stride: 1)
}
}

extension PartialRangeUpTo : TensorRangeExpression where Bound == Int {
public var tensorRange: TensorRange {
return .partialRangeUpTo(self, stride: 1)
}
}

extension PartialRangeThrough : TensorRangeExpression where Bound == Int {
public var tensorRange: TensorRange {
return .partialRangeThrough(self, stride: 1)
}
}

infix operator .. : StridedRangeFormationPrecedence
precedencegroup StridedRangeFormationPrecedence {
associativity: left
higherThan: CastingPrecedence
lowerThan: RangeFormationPrecedence
}

public extension Range where Bound == Int {
static func .. (range: Range, stride: Int) -> TensorRange {
return .range(range, stride: stride)
}
}

public extension ClosedRange where Bound == Int {
static func .. (range: ClosedRange, stride: Int) -> TensorRange {
return .closedRange(range, stride: stride)
}
}

public extension PartialRangeFrom where Bound == Int {
static func .. (range: PartialRangeFrom, stride: Int) -> TensorRange {
return .partialRangeFrom(range, stride: stride)
}
}

public extension PartialRangeUpTo where Bound == Int {
static func .. (range: PartialRangeUpTo, stride: Int) -> TensorRange {
return .partialRangeUpTo(range, stride: stride)
}
}

public extension PartialRangeThrough where Bound == Int {
static func .. (range: PartialRangeThrough, stride: Int) -> TensorRange {
return .partialRangeThrough(range, stride: stride)
}
}

public extension Tensor {
@_fixed_layout @usableFromInline
internal struct IndexPath {
@usableFromInline
let begin, end, strides: Tensor<Int32>

@usableFromInline
let beginMask, endMask, ellipsisMask, newAxisMask, squeezeAxisMask: Int64

@inlinable
public init(
begin: Tensor<Int32>, end: Tensor<Int32>, strides: Tensor<Int32>,
beginMask: Int64, endMask: Int64, ellipsisMask: Int64, newAxisMask: Int64,
squeezeAxisMask: Int64
) {
self.begin = begin
self.end = end
self.strides = strides
self.beginMask = beginMask
self.endMask = endMask
self.ellipsisMask = ellipsisMask
self.newAxisMask = newAxisMask
self.squeezeAxisMask = squeezeAxisMask
}
}

@inlinable
@differentiable(wrt: self, vjp: _vjpSubscript)
internal subscript(_ indexPath: IndexPath) -> Tensor {
get {
let index = Int32(index)
let slice = Raw.stridedSlice(
self, begin: Tensor<Int32>([index]), end: Tensor<Int32>([index + 1]),
strides: Tensor<Int32>([1]))
return slice.squeezingShape(at: 0)
return Raw.stridedSlice(
self, begin: indexPath.begin, end: indexPath.end,
strides: indexPath.strides, beginMask: indexPath.beginMask,
endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask,
newAxisMask: indexPath.newAxisMask,
shrinkAxisMask: indexPath.squeezeAxisMask)
}
set {
let leftElements = self[0..<index]
let rightElements = self[index+1..<shape[0]]
self = Raw.concatV2([leftElements, newValue.rankLifted(), rightElements],
axis: Tensor<Int32>(0))
self = Raw.tensorStridedSliceUpdate(
self, begin: indexPath.begin, end: indexPath.end,
strides: indexPath.strides, value: newValue,
beginMask: indexPath.beginMask, endMask: indexPath.endMask,
ellipsisMask: indexPath.ellipsisMask,
newAxisMask: indexPath.newAxisMask,
shrinkAxisMask: indexPath.squeezeAxisMask)
}
}

/// Access the subtensor specified by a contiguous range of indices.
/// - Parameter bounds: Contiguous range of indices.
@inlinable
subscript(bounds: Range<Int>) -> Tensor {
return Raw.stridedSlice(
self,
begin: Tensor<Int32>([Int32(bounds.lowerBound)]),
end: Tensor<Int32>([Int32(bounds.upperBound)]),
strides: Tensor<Int32>([1]))
// TODO: @differentiable(wrt: self)
subscript(_ ranges: TensorRangeExpression...) -> Tensor {
get {
return self[IndexPath(ranges.map { $0.tensorRange })]
}
set {
self[IndexPath(ranges.map { $0.tensorRange })] = newValue
}
}

@usableFromInline
eaplatanios marked this conversation as resolved.
Show resolved Hide resolved
internal func _vjpSubscript(
_ indexPath: IndexPath
) -> (Tensor, (Tensor) -> Tensor) {
return (self[indexPath], { [shape = shapeTensor] v in
Raw.stridedSliceGrad(
shape: shape, begin: indexPath.begin, end: indexPath.end,
strides: indexPath.strides, dy: v, beginMask: indexPath.beginMask,
endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask,
newAxisMask: indexPath.newAxisMask,
shrinkAxisMask: indexPath.squeezeAxisMask)
})
}
}

// TODO(danielzheng): Add strided slices? (increment by something different
// than 1)
// Ideas for strided slice API: it could be another subscript method, or it
// be a top level `stride` function like Swift's `stride(from:to:by:)`.
internal extension Tensor.IndexPath {
@inlinable
init(_ ranges: [TensorRange]) {
precondition(!ranges.isEmpty, "The tensor range collection cannot be empty.")
precondition(ranges.count { $0 == TensorRange.ellipsis } < 2,
"Only one ellipsis is allowed per tensor range collection.")

var begin = [Int32](repeating: 0, count: ranges.count)
var end = [Int32](repeating: 0, count: ranges.count)
var strides = [Int32](repeating: 1, count: ranges.count)
var beginMask: Int64 = 0
var endMask: Int64 = 0
var ellipsisMask: Int64 = 0
var newAxisMask: Int64 = 0
var squeezeAxisMask: Int64 = 0
for (i, index) in ranges.enumerated() {
switch index {
case .ellipsis: ellipsisMask |= 1 << i
case .newAxis: newAxisMask |= 1 << i
case .squeezeAxis: squeezeAxisMask |= 1 << i
case .index(let index):
begin[i] = Int32(index)
end[i] = Int32(index) + 1
squeezeAxisMask |= 1 << i
case .range(let range, let stride):
begin[i] = Int32(range.lowerBound)
end[i] = Int32(range.upperBound)
strides[i] = Int32(stride)
case .closedRange(let range, let stride):
begin[i] = Int32(range.lowerBound)
switch Int32(range.upperBound) {
case -1: endMask |= 1 << i
case let u: end[i] = u + 1
}
strides[i] = Int32(stride)
case .partialRangeFrom(let range, let stride):
begin[i] = Int32(range.lowerBound)
strides[i] = Int32(stride)
endMask |= 1 << i
case .partialRangeUpTo(let range, let stride):
end[i] = Int32(range.upperBound)
strides[i] = Int32(stride)
beginMask |= 1 << i
case .partialRangeThrough(let range, let stride):
end[i] = Int32(range.upperBound) + 1
strides[i] = Int32(stride)
beginMask |= 1 << i
}
}

/// Extracts a slice from the tensor defined by lower and upper bounds for
/// each dimension.
///
/// - Parameter lowerBounds: The lower bounds at each dimension.
/// - Parameter upperBounds: The upper bounds at each dimension.
@inlinable @inline(__always)
func slice(lowerBounds: [Int], upperBounds: [Int]) -> Tensor {
/// TODO: Precondition `lowerBounds.count == upperBounds.count`,
/// preferably in graph.
let lowerBoundsTensor = Tensor<Int32>(lowerBounds.map(Int32.init))
let upperBoundsTensor = Tensor<Int32>(upperBounds.map(Int32.init))
return Raw.slice(
self,
begin: lowerBoundsTensor,
size: upperBoundsTensor - lowerBoundsTensor)
self.begin = Tensor<Int32>(begin)
self.end = Tensor<Int32>(end)
self.strides = Tensor<Int32>(strides)
self.beginMask = beginMask
self.endMask = endMask
self.ellipsisMask = ellipsisMask
self.newAxisMask = newAxisMask
self.squeezeAxisMask = squeezeAxisMask
}
}
Loading