Skip to content

Commit

Permalink
Fix test_approx() when one of the arguments is a broadcasted multidim…
Browse files Browse the repository at this point in the history
…ensional array.

`collect()` on broadcasted arrays doesn't preserve its shape, but instead creates a flat array. This breaks tests e.g. for most activation functions in NNlib:

```
test_rrule(Broadcast.broadcasted, NNlib.σ, rand(3, 4))
```
  • Loading branch information
dfdx committed Jan 17, 2022
1 parent 63bbd48 commit 2c080a7
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "1.3.1"
version = "1.3.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
4 changes: 2 additions & 2 deletions src/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ function test_approx(actual::A, expected::E, msg=""; kwargs...) where {A,E}
if _can_pass_early(actual, expected)
@test true
else
c_actual = collect(actual)
c_expected = collect(expected)
c_actual = collect(Broadcast.materialize(actual))
c_expected = collect(Broadcast.materialize(expected))
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
throw(MethodError, test_approx, (actual, expected))
end
Expand Down
1 change: 1 addition & 0 deletions test/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ end

test_approx([1.0, 2.0], [1.0, 2.0])
test_approx([[1.0], [2.0]], [[1.0], [2.0]])
test_approx(Broadcast.broadcasted(identity, [1.0 2.0; 3.0 4.0]), [1.0 2.0; 3.0 4.0])

test_approx(@thunk(10 * 0.1 * [[1.0], [2.0]]), [[1.0], [2.0]])

Expand Down

0 comments on commit 2c080a7

Please sign in to comment.