Skip to content

Commit

Permalink
[AutoDiff] Support differentiation of conditionals. (#25057)
Browse files Browse the repository at this point in the history
- Support control flow in adjoint generation.
  - Make adjoint value/buffer mappings be per basic block.
  - Change `AdjointValue` to not be move-only. Original values from
    different basic blocks may share the same `AdjointValue`.
  - Propagate adjoint values from active bb arguments to predecessor
    terminator operands.
  - Propagate adjoint values/buffers of dominated active values/buffers
    to predecessor blocks.
    - For active values: propagate adjoint values as adjoint bb arguments.
    - For active buffers: propagate adjoint buffers via `copy_addr`.
- Revamp `AdjointEmitter` handling of `begin_access` and `end_access`.
  - `getAdjointValue` of `begin_access` now returns the adjoint base buffer.
    Previously, it returned a `begin_access` of the adjoint base buffer
    without generating a corresponding `end_access`.
  - `AdjointEmitter::visitBeginAccessInst` now generates no code.
  - `AdjointEmitter::visitEndAccessInst` now does nothing.
- Add various control flow differentiation tests.
  - Test differentiation of conditionals (nested), recursion,
    `var` allocations (tuples, structs).
  - Add negative leakchecking tests.

Todos:
- Fix adjoint value/buffer propagation memory leaks.
- Add more tests (adjoint SIL, leakchecking).
- Support differentiation of enum-related instructions and loops.
  • Loading branch information
dan-zheng authored and rxwei committed Jun 4, 2019
1 parent fc896ca commit 045e192
Show file tree
Hide file tree
Showing 8 changed files with 1,115 additions and 396 deletions.
774 changes: 451 additions & 323 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp

Large diffs are not rendered by default.

424 changes: 357 additions & 67 deletions test/AutoDiff/control_flow.swift

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion test/AutoDiff/control_flow_diagnostics.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s

// Test supported `br` and `cond_br` terminators.

@differentiable
func branch(_ x: Float) -> Float {
if x > 0 {
return x
} else if x < 10 {
return x
}
return x
}

// Test currently unsupported `switch_enum` terminator.

Expand Down
91 changes: 91 additions & 0 deletions test/AutoDiff/control_flow_sil.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s | %FileCheck %s -check-prefix=CHECK-SIL

// TODO: Add adjoint SIL FileCheck tests.

// Test conditional: a simple if-diamond.

@differentiable
@_silgen_name("cond")
func cond(_ x: Float) -> Float {
if x > 0 {
return x + x
}
return x - x
}

// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb0__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb1__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb1__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_0: (Float) -> (Float, Float) { get set }
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb2__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb2__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_1: (Float) -> (Float, Float) { get set }
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb3__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb2(_AD__cond_bb2__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: }

// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
// CHECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
// CHECK-SIL: cond_br {{%.*}}, bb1([[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0), bb2([[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_

// CHECK-SIL: bb1([[BB1_PRED_ARG:%.*]] : $_AD__cond_bb1__Pred__src_0_wrt_0)
// CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0
// CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]]
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// CHECK-SIL: bb2([[BB2_PRED_ARG:%.*]] : $_AD__cond_bb2__Pred__src_0_wrt_0)
// CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]]
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0
// CHECK-SIL: [[ADJOINT_REF:%.*]] = function_ref @AD__cond__adjoint_src_0_wrt_0
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[ADJOINT_REF]]([[BB3_PB_STRUCT]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL: return [[VJP_RESULT]]

@differentiable
@_silgen_name("nested_cond")
func nested_cond(_ x: Float, _ y: Float) -> Float {
if x > 0 {
if y > 10 {
return x * y
} else {
return x + y
}
}
return y - x
}

@differentiable
@_silgen_name("nested_cond_generic")
func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) -> T {
if x > 0 {
if y > 10 {
return y
} else {
return x
}
}
return y
}
49 changes: 47 additions & 2 deletions test/AutoDiff/leakchecking.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-run-simple-swift
// RUN: %target-run-simple-swift-control-flow-differentiation
// REQUIRES: executable_test

// A test that we can properly differentiate types that require refcounting.
Expand All @@ -8,6 +8,13 @@ import DifferentiationUnittest

var LeakCheckingTests = TestSuite("LeakChecking")

/// Execute body, check expected leak count, and reset global leak count.
func testWithLeakChecking(expectedLeakCount: Int = 0, _ body: () -> Void) {
body()
expectEqual(expectedLeakCount, _GlobalLeakCount.count, "Leak detected.")
_GlobalLeakCount.count = 0
}

struct ExampleLeakModel : Differentiable {
var bias: Tracked<Float> = 2.0
func applied(to input: Tracked<Float>) -> Tracked<Float> {
Expand All @@ -22,7 +29,45 @@ LeakCheckingTests.test("BasicVarLeakChecking") {
let x: Tracked<Float> = 1.0
let _ = model.gradient(at: x) { m, x in m.applied(to: x) }
}
expectEqual(0, _GlobalLeakCount.count, "Leak Detected.")
expectEqual(0, _GlobalLeakCount.count, "Leak detected.")
}

LeakCheckingTests.test("ControlFlow") {
// TODO: Add more `var` + control flow tests.
// Porting tests from test/AutoDiff/control_flow.swift requires more support
// for `Tracked<Float>`.

// FIXME: Fix control flow AD memory leaks.
// See related FIXME comments in adjoint value/buffer propagation in
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
testWithLeakChecking(expectedLeakCount: 9) {
var model = ExampleLeakModel()
let x: Tracked<Float> = 1.0
let _ = model.gradient(at: x) { m, x in
let result: Tracked<Float>
if x > 0 {
result = m.applied(to: x)
} else {
result = x
}
return result
}
}

// FIXME: Fix control flow AD memory leaks.
// See related FIXME comments in adjoint value/buffer propagation in
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
testWithLeakChecking(expectedLeakCount: 14) {
var model = ExampleLeakModel()
let x: Tracked<Float> = 1.0
let _ = model.gradient(at: x) { m, x in
var result: Tracked<Float> = x
if x > 0 {
result = result + m.applied(to: x)
}
return result
}
}
}

runAllTests()
6 changes: 3 additions & 3 deletions test/AutoDiff/refcounting.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ _ = pullback(at: Vector.zero, in: testOwnedVector)
// CHECK: return [[NEEDED_COTAN1]] : $Vector

// CHECK-LABEL: sil hidden @{{.*}}side_effect_release_zero{{.*}}__adjoint_src_0_wrt_0
// CHECK: bb0([[X:%.*]] : $Vector, %1 : ${{.*}}side_effect_release_zero{{.*}}_bb0__PB__src_0_wrt_0):
// CHECK: retain_value [[SEED:%.*]] : $Vector
// CHECK: bb0([[SEED:%.*]] : $Vector, %1 : ${{.*}}side_effect_release_zero{{.*}}_bb0__PB__src_0_wrt_0):
// CHECK: [[BUF:%.*]] = alloc_stack $Vector
// CHECK: [[BUF_ACCESS:%.*]] = begin_access [init] [static] [no_nested_conflict] [[BUF]] : $*Vector
// CHECK: [[ZERO_GETTER:%.*]] = function_ref @$s11refcounting6VectorV4zeroACvgZ
// CHECK: [[ZERO:%.*]] = apply [[ZERO_GETTER]]({{%.*}}) : $@convention(method) (@thin Vector.Type) -> @owned Vector
// CHECK: store [[ZERO]] to [[BUF_ACCESS]] : $*Vector
// CHECK: retain_value [[SEED:%.*]] : $Vector
// CHECK: release_value [[SEED:%.*]] : $Vector
// CHECK: destroy_addr [[BUF]] : $*Vector
// CHECK: dealloc_stack [[BUF]] : $*Vector
// CHECK: release_value [[SEED:%.*]] : $Vector
// CHECK: }

// The vjp should not release pullback values.
Expand Down
142 changes: 142 additions & 0 deletions test/TensorFlowRuntime/tensor_autodiff_control_flow.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// RUN: %target-run-simple-swift-control-flow-differentiation
// REQUIRES: executable_test
//
// FIXME(TF-326): Re-enable `-O` after deserialization failure fix.
// UNSUPPORTED: swift_test_mode_optimize
//
// Tensor control flow AD runtime tests.
// TODO: Move TensorFlow-specific AD tests into test/AutoDiff.

import TensorFlow
import StdlibUnittest
import TensorFlowUnittest

var TensorADTests = TestSuite("TensorControlFlowAD")

TensorADTests.testAllBackends("Conditionals") {
func cond_nestedtuple_var(_ x: Tensor<Float>) -> Tensor<Float> {
// Convoluted function returning `x + x`.
var y: (Tensor<Float>, Tensor<Float>) = (x + x, x - x)
var z: ((Tensor<Float>, Tensor<Float>), Tensor<Float>) = (y, x)
if x > 0 {
var w = (x, x)
y.0 = w.1
y.1 = w.0
z.0.0 = z.0.0 - y.0
z.0.1 = z.0.1 + y.0
} else {
z = ((y.0 - x, y.1 + x), x)
}
return y.0 + y.1 - z.0.0 + z.0.1
}
expectEqual((Tensor(8), Tensor(2)),
valueWithGradient(at: Tensor(4), in: cond_nestedtuple_var))
expectEqual((Tensor(-20), Tensor(2)),
valueWithGradient(at: Tensor(-10), in: cond_nestedtuple_var))
expectEqual((Tensor(-2674), Tensor(2)),
valueWithGradient(at: Tensor(-1337), in: cond_nestedtuple_var))

func guard2_var(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
var z = y
guard x > 0 else {
if y > 0 {
z = z * x
} else if x == Tensor(-1337) {
z = x
z = z * z
} else {
z = Tensor(0)
}
return z
}
return z * y
}
expectEqual((Tensor(0), Tensor(10)),
gradient(at: Tensor(4), Tensor(5), in: guard2_var))
expectEqual((Tensor(5), Tensor(-1337)),
gradient(at: Tensor(-1337), Tensor(5), in: guard2_var))
expectEqual((Tensor(-2674), Tensor(0)),
gradient(at: Tensor(-1337), Tensor(-5), in: guard2_var))
expectEqual((Tensor(2), Tensor(-3)),
gradient(at: Tensor(-3), Tensor(2), in: guard2_var))
}

TensorADTests.testAllBackends("NestedConditionals") {
// Test tensor-tensor ops.
func cond_nested1(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
if x > 0 {
if y > 10 {
let z = x * y
if z > 100 {
return x + z
} else if y == Tensor(20) {
return z + z
}
} else {
return x + y
}
}
return -y
}

expectEqual((Tensor(40), Tensor(8)),
gradient(at: Tensor(4), Tensor(20), in: cond_nested1))
expectEqual((Tensor(0), Tensor(-1)),
gradient(at: Tensor(4), Tensor(21), in: cond_nested1))
expectEqual((Tensor(1), Tensor(1)),
gradient(at: Tensor(4), Tensor(5), in: cond_nested1))
expectEqual((Tensor(0), Tensor(-1)),
gradient(at: Tensor(-3), Tensor(-2), in: cond_nested1))

// Test tensor-scalar ops.
func cond_nested2(_ x: Tensor<Float>, _ y: Float) -> Tensor<Float> {
if x > 0 {
if y > 10 {
let z = x * y
if z > 100 {
return x + z
} else if y == 20 {
return z + z
}
} else {
return x + y
}
}
return Tensor(-y)
}

expectEqual((Tensor(40), 8), gradient(at: Tensor(4), 20, in: cond_nested2))
expectEqual((Tensor(0), -1), gradient(at: Tensor(4), 21, in: cond_nested2))
expectEqual((Tensor(1), 1), gradient(at: Tensor(4), 5, in: cond_nested2))
expectEqual((Tensor(0), -1), gradient(at: Tensor(-3), -2, in: cond_nested2))
}

TensorADTests.testAllBackends("Recursion") {
func factorial(_ x: Tensor<Float>) -> Tensor<Float> {
if x == Tensor(1) {
return Tensor(1)
}
return x * factorial(x - 1)
}
expectEqual(Tensor(0), gradient(at: Tensor(1), in: factorial))
expectEqual(Tensor(1), gradient(at: Tensor(2), in: factorial))
expectEqual(Tensor(5), gradient(at: Tensor(3), in: factorial))
expectEqual(Tensor(26), gradient(at: Tensor(4), in: factorial))
expectEqual(Tensor(154), gradient(at: Tensor(5), in: factorial))

func product(_ x: Tensor<Float>, count: Int) -> Tensor<Float> {
precondition(count > 0)
if count == 1 {
return x
}
return x * product(x, count: count - 1)
}
expectEqual(Tensor(300),
gradient(at: Tensor(10), in: { x in product(x, count: 3) }))
expectEqual(Tensor(-20),
gradient(at: Tensor(-10), in: { x in product(x, count: 2) }))
expectEqual(Tensor(1),
gradient(at: Tensor(100), in: { x in product(x, count: 1) }))
}

runAllTests()
11 changes: 11 additions & 0 deletions test/lit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,14 @@ if not getattr(config, 'target_run_simple_swift', None):
'%s %%t/a.out' % (config.target_build_swift,
mcp_opt, config.target_codesign,
config.target_run)))
# SWIFT_ENABLE_TENSORFLOW
# TODO: Remove when differentiation control flow support is robust.
config.target_run_simple_swift_control_flow_differentiation = (
'%%empty-directory(%%t) && '
'%s %s %%s -Xllvm -differentiation-enable-control-flow -o %%t/a.out %s -module-name main && '
'%s %%t/a.out &&'
'%s %%t/a.out'
% (config.target_build_swift, mcp_opt, swift_tensorflow_extra_options, config.target_codesign, config.target_run))
config.target_run_simple_swift = (
'%%empty-directory(%%t) && '
'%s %s %%s -o %%t/a.out %s -module-name main && '
Expand Down Expand Up @@ -1476,6 +1484,9 @@ config.substitutions.append(('%target-swift-frontend', config.target_swift_front


config.substitutions.append(('%target-run-simple-swiftgyb', config.target_run_simple_swiftgyb))
# SWIFT_ENABLE_TENSORFLOW
# TODO: Remove when differentiation control flow support is robust.
config.substitutions.append(('%target-run-simple-swift-control-flow-differentiation', config.target_run_simple_swift_control_flow_differentiation))
config.substitutions.append(('%target-run-simple-swift\(([^)]+)\)', config.target_run_simple_swift_parameterized))
config.substitutions.append(('%target-run-simple-swift', config.target_run_simple_swift))
config.substitutions.append(('%target-run-stdlib-swiftgyb', config.target_run_stdlib_swiftgyb))
Expand Down

0 comments on commit 045e192

Please sign in to comment.