-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add chainrules for r2r
, dct
#272
Comments
how is the gradient computed for using LinearAlgebra, FFTW, Zygote
x = rand(4)
C = plan_dct(x)
f(x) = C \ (C * x) |> norm
g(x) = x |> dct |> idct |> norm
h(x) = plan_dct(x) \ (plan_dct(x) * x) |> norm
@show Zygote.gradient(f, x) # ([0.7499995699183157, 0.5170775887690442, 0.3522881598130941, 0.2145331321046639],)
@show Zygote.gradient(g, x) # errors
@show Zygote.gradient(h, x) # errors error message: julia> Zygote.gradient(f, x)
ERROR: Compiling Tuple{Type{FFTW.r2rFFTWPlan{Float64, Any, false, 1}}, Vector{Float64}, FFTW.FakeArray{Float64, 1}, UnitRange{Int64}, Int64, UInt32, Float64}: try/catch is n
ot supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations |
Cosigning this, having |
Here's an extremely rudimentary implementation of using AbstractFFTs
using FFTW
struct R2RFFTAdjointStyle <: AbstractFFTs.AdjointStyle end
AbstractFFTs.AdjointStyle(::FFTW.r2rFFTWPlan) = R2RFFTAdjointStyle()
function AbstractFFTs.adjoint_mul(
p::FFTW.r2rFFTWPlan{T}, x::AbstractVector{T}, ::R2RFFTAdjointStyle
) where {T}
(length(p.kinds) == 1) || throw(ArgumentError("Multidimensional r2r transforms not yet supported"))
(only(p.kinds) == 5) || throw(ArgumentError("r2r kinds other than REDFT10 not yet supported"))
pinv = inv(p)
unscaled_pinv = (pinv isa AbstractFFTs.ScaledPlan) ? pinv.p : pinv
y = unscaled_pinv * x
# REDFT10 is unitary except for the first row, so the unscaled inverse is its adjoint
# except for the first column. To obtain the true adjoint, add more DC.
y .+= first(x)
return y
end |
No description provided.
The text was updated successfully, but these errors were encountered: