-
Notifications
You must be signed in to change notification settings - Fork 10.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AutoDiff] Support differentiation of conditionals. (#25057)
- Support control flow in adjoint generation. - Make adjoint value/buffer mappings be per basic block. - Change `AdjointValue` to not be move-only. Original values from different basic blocks may share the same `AdjointValue`. - Propagate adjoint values from active bb arguments to predecessor terminator operands. - Propagate adjoint values/buffers of dominated active values/buffers to predecessor blocks. - For active values: propagate adjoint values as adjoint bb arguments. - For active buffers: propagate adjoint buffers via `copy_addr`. - Revamp `AdjointEmitter` handling of `begin_access` and `end_access`. - `getAdjointValue` of `begin_access` now returns the adjoint base buffer. Previously, it returned a `begin_access` of the adjoint base buffer without generating a corresponding `end_access`. - `AdjointEmitter::visitBeginAccessInst` now generates no code. - `AdjointEmitter::visitEndAccessInst` now does nothing. - Add various control flow differentiation tests. - Test differentiation of conditionals (nested), recursion, `var` allocations (tuples, structs). - Add negative leakchecking tests. Todos: - Fix adjoint value/buffer propagation memory leaks. - Add more tests (adjoint SIL, leakchecking). - Support differentiation of enum-related instructions and loops.
- Loading branch information
Showing
8 changed files
with
1,115 additions
and
396 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES | ||
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -differentiation-enable-control-flow %s | %FileCheck %s -check-prefix=CHECK-SIL | ||
|
||
// TODO: Add adjoint SIL FileCheck tests. | ||
|
||
// Test conditional: a simple if-diamond. | ||
|
||
@differentiable | ||
@_silgen_name("cond") | ||
func cond(_ x: Float) -> Float { | ||
if x > 0 { | ||
return x + x | ||
} | ||
return x - x | ||
} | ||
|
||
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: } | ||
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb0__PB__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: } | ||
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0) | ||
// CHECK-DATA-STRUCTURES: } | ||
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb1__PB__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb1__Pred__src_0_wrt_0 { get set } | ||
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_0: (Float) -> (Float, Float) { get set } | ||
// CHECK-DATA-STRUCTURES: } | ||
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0) | ||
// CHECK-DATA-STRUCTURES: } | ||
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb2__PB__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb2__Pred__src_0_wrt_0 { get set } | ||
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_1: (Float) -> (Float, Float) { get set } | ||
// CHECK-DATA-STRUCTURES: } | ||
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb3__Pred__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: case bb2(_AD__cond_bb2__PB__src_0_wrt_0) | ||
// CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0) | ||
// CHECK-DATA-STRUCTURES: } | ||
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 { | ||
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set } | ||
// CHECK-DATA-STRUCTURES: } | ||
|
||
// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 | ||
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float): | ||
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 () | ||
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]] | ||
// CHECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]] | ||
// CHECK-SIL: cond_br {{%.*}}, bb1([[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0), bb2([[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_ | ||
|
||
// CHECK-SIL: bb1([[BB1_PRED_ARG:%.*]] : $_AD__cond_bb1__Pred__src_0_wrt_0) | ||
// CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0 | ||
// CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]] | ||
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0) | ||
|
||
// CHECK-SIL: bb2([[BB2_PRED_ARG:%.*]] : $_AD__cond_bb2__Pred__src_0_wrt_0) | ||
// CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0 | ||
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]] | ||
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0) | ||
|
||
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0) | ||
// CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0 | ||
// CHECK-SIL: [[ADJOINT_REF:%.*]] = function_ref @AD__cond__adjoint_src_0_wrt_0 | ||
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[ADJOINT_REF]]([[BB3_PB_STRUCT]]) | ||
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) | ||
// CHECK-SIL: return [[VJP_RESULT]] | ||
|
||
@differentiable | ||
@_silgen_name("nested_cond") | ||
func nested_cond(_ x: Float, _ y: Float) -> Float { | ||
if x > 0 { | ||
if y > 10 { | ||
return x * y | ||
} else { | ||
return x + y | ||
} | ||
} | ||
return y - x | ||
} | ||
|
||
@differentiable | ||
@_silgen_name("nested_cond_generic") | ||
func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) -> T { | ||
if x > 0 { | ||
if y > 10 { | ||
return y | ||
} else { | ||
return x | ||
} | ||
} | ||
return y | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
test/TensorFlowRuntime/tensor_autodiff_control_flow.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// RUN: %target-run-simple-swift-control-flow-differentiation | ||
// REQUIRES: executable_test | ||
// | ||
// FIXME(TF-326): Re-enable `-O` after deserialization failure fix. | ||
// UNSUPPORTED: swift_test_mode_optimize | ||
// | ||
// Tensor control flow AD runtime tests. | ||
// TODO: Move TensorFlow-specific AD tests into test/AutoDiff. | ||
|
||
import TensorFlow | ||
import StdlibUnittest | ||
import TensorFlowUnittest | ||
|
||
var TensorADTests = TestSuite("TensorControlFlowAD") | ||
|
||
TensorADTests.testAllBackends("Conditionals") { | ||
func cond_nestedtuple_var(_ x: Tensor<Float>) -> Tensor<Float> { | ||
// Convoluted function returning `x + x`. | ||
var y: (Tensor<Float>, Tensor<Float>) = (x + x, x - x) | ||
var z: ((Tensor<Float>, Tensor<Float>), Tensor<Float>) = (y, x) | ||
if x > 0 { | ||
var w = (x, x) | ||
y.0 = w.1 | ||
y.1 = w.0 | ||
z.0.0 = z.0.0 - y.0 | ||
z.0.1 = z.0.1 + y.0 | ||
} else { | ||
z = ((y.0 - x, y.1 + x), x) | ||
} | ||
return y.0 + y.1 - z.0.0 + z.0.1 | ||
} | ||
expectEqual((Tensor(8), Tensor(2)), | ||
valueWithGradient(at: Tensor(4), in: cond_nestedtuple_var)) | ||
expectEqual((Tensor(-20), Tensor(2)), | ||
valueWithGradient(at: Tensor(-10), in: cond_nestedtuple_var)) | ||
expectEqual((Tensor(-2674), Tensor(2)), | ||
valueWithGradient(at: Tensor(-1337), in: cond_nestedtuple_var)) | ||
|
||
func guard2_var(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> { | ||
var z = y | ||
guard x > 0 else { | ||
if y > 0 { | ||
z = z * x | ||
} else if x == Tensor(-1337) { | ||
z = x | ||
z = z * z | ||
} else { | ||
z = Tensor(0) | ||
} | ||
return z | ||
} | ||
return z * y | ||
} | ||
expectEqual((Tensor(0), Tensor(10)), | ||
gradient(at: Tensor(4), Tensor(5), in: guard2_var)) | ||
expectEqual((Tensor(5), Tensor(-1337)), | ||
gradient(at: Tensor(-1337), Tensor(5), in: guard2_var)) | ||
expectEqual((Tensor(-2674), Tensor(0)), | ||
gradient(at: Tensor(-1337), Tensor(-5), in: guard2_var)) | ||
expectEqual((Tensor(2), Tensor(-3)), | ||
gradient(at: Tensor(-3), Tensor(2), in: guard2_var)) | ||
} | ||
|
||
TensorADTests.testAllBackends("NestedConditionals") { | ||
// Test tensor-tensor ops. | ||
func cond_nested1(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> { | ||
if x > 0 { | ||
if y > 10 { | ||
let z = x * y | ||
if z > 100 { | ||
return x + z | ||
} else if y == Tensor(20) { | ||
return z + z | ||
} | ||
} else { | ||
return x + y | ||
} | ||
} | ||
return -y | ||
} | ||
|
||
expectEqual((Tensor(40), Tensor(8)), | ||
gradient(at: Tensor(4), Tensor(20), in: cond_nested1)) | ||
expectEqual((Tensor(0), Tensor(-1)), | ||
gradient(at: Tensor(4), Tensor(21), in: cond_nested1)) | ||
expectEqual((Tensor(1), Tensor(1)), | ||
gradient(at: Tensor(4), Tensor(5), in: cond_nested1)) | ||
expectEqual((Tensor(0), Tensor(-1)), | ||
gradient(at: Tensor(-3), Tensor(-2), in: cond_nested1)) | ||
|
||
// Test tensor-scalar ops. | ||
func cond_nested2(_ x: Tensor<Float>, _ y: Float) -> Tensor<Float> { | ||
if x > 0 { | ||
if y > 10 { | ||
let z = x * y | ||
if z > 100 { | ||
return x + z | ||
} else if y == 20 { | ||
return z + z | ||
} | ||
} else { | ||
return x + y | ||
} | ||
} | ||
return Tensor(-y) | ||
} | ||
|
||
expectEqual((Tensor(40), 8), gradient(at: Tensor(4), 20, in: cond_nested2)) | ||
expectEqual((Tensor(0), -1), gradient(at: Tensor(4), 21, in: cond_nested2)) | ||
expectEqual((Tensor(1), 1), gradient(at: Tensor(4), 5, in: cond_nested2)) | ||
expectEqual((Tensor(0), -1), gradient(at: Tensor(-3), -2, in: cond_nested2)) | ||
} | ||
|
||
TensorADTests.testAllBackends("Recursion") { | ||
func factorial(_ x: Tensor<Float>) -> Tensor<Float> { | ||
if x == Tensor(1) { | ||
return Tensor(1) | ||
} | ||
return x * factorial(x - 1) | ||
} | ||
expectEqual(Tensor(0), gradient(at: Tensor(1), in: factorial)) | ||
expectEqual(Tensor(1), gradient(at: Tensor(2), in: factorial)) | ||
expectEqual(Tensor(5), gradient(at: Tensor(3), in: factorial)) | ||
expectEqual(Tensor(26), gradient(at: Tensor(4), in: factorial)) | ||
expectEqual(Tensor(154), gradient(at: Tensor(5), in: factorial)) | ||
|
||
func product(_ x: Tensor<Float>, count: Int) -> Tensor<Float> { | ||
precondition(count > 0) | ||
if count == 1 { | ||
return x | ||
} | ||
return x * product(x, count: count - 1) | ||
} | ||
expectEqual(Tensor(300), | ||
gradient(at: Tensor(10), in: { x in product(x, count: 3) })) | ||
expectEqual(Tensor(-20), | ||
gradient(at: Tensor(-10), in: { x in product(x, count: 2) })) | ||
expectEqual(Tensor(1), | ||
gradient(at: Tensor(100), in: { x in product(x, count: 1) })) | ||
} | ||
|
||
runAllTests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters