Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_approx() when one of the arguments is a broadcasted multidimensional array #230

Merged
merged 2 commits into from
Jun 9, 2022

Conversation

dfdx
Copy link
Contributor

@dfdx dfdx commented Jan 17, 2022

collect() on broadcasted arrays doesn't preserve its shape, but instead creates a flat array. This breaks tests e.g. for most activation functions in NNlib:

test_rrule(Broadcast.broadcasted, NNlib.σ, rand(3, 4))

which produces:

Test threw exception
  Expression: isapprox(actual, expected; kwargs...)
  DimensionMismatch("dimensions must match: a has dims (Base.OneTo(2), Base.OneTo(2)), b has dims (Base.OneTo(4),), mismatch at 1")

@codecov-commenter
Copy link

Codecov Report

Merging #230 (2c080a7) into main (63bbd48) will not change coverage.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #230   +/-   ##
=======================================
  Coverage   90.84%   90.84%           
=======================================
  Files          11       11           
  Lines         295      295           
=======================================
  Hits          268      268           
  Misses         27       27           
Impacted Files Coverage Δ
src/check_result.jl 89.70% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 63bbd48...2c080a7. Read the comment docs.

@oxinabox
Copy link
Member

Should this be fixed up stream in JuliaLang/Julia
JuliaLang/julia#43847 ?

@dfdx
Copy link
Contributor Author

dfdx commented Jan 17, 2022

Amazing analysis! Let's see how it turns out in the Base and whether the fix gets to all popular Julia versions. If it's not applied to any of the versions supported in ChainRules, we can still apply this PR to be on the safe side.

@dfdx dfdx force-pushed the dfdx/fix-test-broadcasted branch from 2c080a7 to 350555d Compare January 24, 2022 22:31
@oxinabox
Copy link
Member

oxinabox commented Jun 6, 2022

Cool so this has now been addressed in julia v1.9.0
Unclear as to if this will be backported,
but we can add a comment to the source saying:
#Works around https://github.com/JuliaLang/julia/issues/43847 on pre-Julia v1.9
to this PR and then merge this if we still think this is needed?

dfdx and others added 2 commits June 8, 2022 01:04
…ensional array.

`collect()` on broadcasted arrays doesn't preserve its shape, but instead creates a flat array. This breaks tests e.g. for most activation functions in NNlib:

```
test_rrule(Broadcast.broadcasted, NNlib.σ, rand(3, 4))
```
@dfdx dfdx force-pushed the dfdx/fix-test-broadcasted branch from 350555d to 5e5acf8 Compare June 7, 2022 22:08
@dfdx
Copy link
Contributor Author

dfdx commented Jun 7, 2022

Yeah, since we support older versions of Julia, I believe it's still worth to have this workaround.
I added the suggested comment to the PR.

@oxinabox oxinabox merged commit 8cbf82d into JuliaDiff:main Jun 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants