Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Gradient of the einsum not working with CuArrays, when it’s output is assigned to the field of a mutable struct #586

Closed
AzamatB opened this issue Feb 10, 2020 · 5 comments
Labels

Comments

@AzamatB
Copy link

AzamatB commented Feb 10, 2020

I'm getting a strange error when using OMEinsum, CuArrays and Zygote together. Here is the MRE:

using CuArrays
using Zygote
using OMEinsum

mutable struct S{T}
   c::T
end

function f(α, s, H)
   @ein s.c[d,b] := α[t,b] * H[d,t,b]
   sum(s.c)
end

α = rand(3, 2)
H = rand(5, 3, 2)
s = S(rand(5, 2))

julia> gradient(x -> f(x, s, H), α) # works
([1.8856141333787124 3.177915606920301; 1.7898891125983922 3.2721594554280564; 1.959869250310523 2.247906876515321],)

Now if I call f with cu-ified inputs

julia> gradient(x -> f(x, S(cu(s.c)), cu(H)), cu(α)) # errors
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::CuArray{Float32,2,Nothing})

Yet, the forward pass works

julia> f(cu(α), S(cu(s.c)), cu(H))
8.397266f0

Any explanations for what is happening? What would be the workaround here?

@AzamatB AzamatB added the bug label Feb 10, 2020
@maleadt
Copy link
Member

maleadt commented Feb 10, 2020

Please don't double post. You should reduce this before painting it as a CuArrays issue. I don't have experience with those packages, so can't help with that.

@maleadt maleadt closed this as completed Feb 10, 2020
@AzamatB
Copy link
Author

AzamatB commented Feb 10, 2020

Here is the reduction without OMEinsum dependency

using CuArrays
using Zygote

mutable struct S{T}
   c::T
end

function f(α, s, H)
   s.c = dropdims(sum(reshape(α, 1, :, size(α,2)) .* H; dims=2); dims=2)
   sum(s.c)
end

α = rand(3, 2)
H = rand(5, 3, 2)
s = S(rand(5, 2))
julia> gradient(x -> f(x, s, H), α) # works
([0.9977189000464359 2.6728663891887647; 2.3515342113507853 2.0810362850597954; 3.4727942154123843 2.0741737657385504],)

julia> gradient(x -> f(x, S(cu(s.c)), cu(H)), cu(α)) # errors
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::CuArray{Float32,2,Nothing})

@maleadt
Copy link
Member

maleadt commented Feb 10, 2020

Executing the function without gradient does work, so this still doesn't show it to be a CuArrays issue. Also note that cu(s) doesn't put your S{Array} on the GPU, you need to have Adapt rules in place to do that automatically or reconstruct your wrapper yourself.

@AzamatB
Copy link
Author

AzamatB commented Feb 11, 2020

note that cu(s) doesn't put your S{Array} on the GPU

But I'm not doing cu(s), I'm doing S(cu(s.c)) instead. That would put S{Array} on the GPU, no?

@maleadt
Copy link
Member

maleadt commented Feb 11, 2020

Oh sorry, I misread your code. Yes, that's fine wrt. uploading to the GPU.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

2 participants