Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffractor + Tapir for computing hessian #157

Open
yebai opened this issue May 15, 2024 · 15 comments
Open

Diffractor + Tapir for computing hessian #157

yebai opened this issue May 15, 2024 · 15 comments
Labels
enhancement New feature or request low priority

Comments

@yebai
Copy link
Contributor

yebai commented May 15, 2024

This example currently fails:

julia> import ForwardDiff, Tapir

julia> using DifferentiationInterface

julia> b = SecondOrder(AutoForwardDiff(), AutoTapir());

julia> hessian(sum, b, [1.0, 2.0])
[ Info: Compiling rule for Tuple{typeof(sum), Vector{ForwardDiff.Dual{ForwardDiff.Tag{DifferentiationInterface.var"#inner_gradient_closure#27"{typeof(sum), SecondOrder{AutoForwardDiff{nothing, Nothing}, AutoTapir}}, Float64}, Float64, 1}}} in safe mode. Disable for best performance.
ERROR: MethodError: Cannot `convert` an object of type ForwardDiff.Dual{ForwardDiff.Tag{DifferentiationInterface.var"#inner_gradient_closure#27"{typeof(sum), SecondOrder{AutoForwardDiff{nothing, Nothing}, AutoTapir}}, Float64}, Float64, 1} to an object of type Tapir.Tangent{@NamedTuple{value::Float64, partials::Tapir.Tangent{@NamedTuple{values::Tuple{Float64}}}}}

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:84
  (::Type{Tapir.Tangent{Tfields}} where Tfields<:NamedTuple)(::Any)
   @ Tapir ~/.julia/packages/Tapir/O4V78/src/tangents.jl:53

It's a pity since Tapir would provide a great tool for computing second-order derivatives in conjunction with ForwarDiff. Could this be improved?

Package environment:

(@v1.10) pkg> st Tapir DifferentiationInterface ForwardDiff
Status `~/.julia/environments/v1.10/Project.toml`
  [a0c0ee7d] DifferentiationInterface v0.4.0
  [f6369f11] ForwardDiff v0.10.36
  [07d77754] Tapir v0.2.12
@yebai yebai changed the title ForwarDiff + Tapir for hession ForwardDiff + Tapir for hession May 15, 2024
@willtebbutt willtebbutt changed the title ForwardDiff + Tapir for hession ForwardDiff + Tapir for hessian May 15, 2024
@willtebbutt
Copy link
Member

Could you please provide the entire stack trace, or is this it?

In general, I would say that the chances are slim that we'll be able to do ForwardDiff-over-Tapir without a really substantial effort -- to be honest I'm not sure that it is feasible at all because Tapir does a lot of concrete typing, meaning that Duals cannot propagate. It might be possible to make something like Diffractor work, but it would require a large time investment to figure out how to make this all work given that we're using OpaqueClosures for everything (the standard code lookup that Diffractor has to do almost certainly won't work out-of-the-box).

Tapir-over-ForwardDiff (which I think is what you're doing here) might be able to work, as Tapir ought to be able to differentiate things that ForwardDiff does. I'm not sure that this is really the interesting way round to do things though (I think you normall want to do forward-over-reverse). In any case, I'd have to see the stack trace to know more.

@yebai
Copy link
Contributor Author

yebai commented May 16, 2024

given that we're using OpaqueClosures for everything

Out of curiosity, what mechanism is Diffractor using, and what are the differences between Diffractor and Tapir?

Tapir-over-ForwardDiff (which I think is what you're doing here) might be able to work

It is ForwardDiff-over-Tapir IIUC, see here for more details.

@willtebbutt
Copy link
Member

willtebbutt commented May 16, 2024

Out of curiosity, what mechanism is Diffractor using, and what are the differences between Diffractor and Tapir?

Ah, sorry, it's more that I mean it seems more likely to me that Diffractor-over-Tapir can be made to work, rather than ForwardDiff-over-Tapir. Diffractor also makes use of OpaqueClosures, but it doesn't make use of Dual numbers.

To be clear, I think it would be a substantial piece of work to make Diffractor-over-Tapir work, because you'd have to figure out how to get Diffractor to differentiate through OpaqueClosures (we'd have to find a way to give the IR used to generate the OpaqueClosure to Diffractor). Maybe @oxinabox has thoughts on this? In principle it ought to work nicely though, because both frameworks (if I've understood what Diffractor is doing properly) place few restrictions on the Julia IR that they can work with.

It is ForwardDiff-over-Tapir IIUC, see here for more details.

Cool, thanks -- I'll take a look at this at some point.

@yebai
Copy link
Contributor Author

yebai commented May 16, 2024

To be clear, I think it would be a substantial piece of work to make Diffractor-over-Tapir work,

It sounds like a good opportunity for collaboration and a use case for Diffractor. Although Diffractor might implement its own reverse mode in the future, in its current form, making the Diffractor interoperable with Tapir would benefit both packages.

@oxinabox
Copy link

oxinabox commented May 16, 2024

In general I have spent a fair amount of time in the last year making sure that Diffractor-over-Diffractor and Diffractor-over-ForwardDiff works and is fast.
As such we should need be able to get Diffractor-over-Tapir to work (or the reverse).
Diffractor is good at compiling itself out of existance so you hopefully in many cases can't even tell the code was run through Diffractor.

It should be just a matter of stashing the IR somewhere and teaching the other package what to do with it.
I suggest something like replace all OpaqueClosures with a struct

struct MistyClosure{OC}<:Function
    oc::OC
    ir::IRCode
end
MistyClosure(ir) = MistyClosure(OpaqueClosure(ir), ir)
(this::MistyClosure)(args...; kwargs...) = this.oc(args...; kwargs...)

Then with something more or less with something like a frule or rrule!! it
instruct it that what to do to AD a call to (this::MistyClosure)(args...; kwargs...) is

  1. go get the IR from the this.ir
  2. run the IR transform based AD pass

@willtebbutt
Copy link
Member

willtebbutt commented May 16, 2024

Ahh excellent -- I like this idea.

Maybe the way forward would be to:

  1. create a package for misty closures,
  2. replace OpaqueClosures in Diffractor and Tapir with MistyClosures,
  3. ensure that Diffractor and Tapir know how to differentiate MistyClosures, which I agree ought to be straightforward, and
  4. apply Diffractor over Tapir, and get second order stuff.

Is this something you would be interested in collaborating on to make happen?

@oxinabox
Copy link

Diffractor's public API doesn't use OpaqueClosures.
IIRC.
That's only for our forward&demand stuff though that is what I would get using to AD tapir if we used MistyClosures.
So I think as long as Tapir emits MistyClosures then for now we are good.

@willtebbutt
Copy link
Member

So I think as long as Tapir emits MistyClosures then for now we are good.

Cool. I propose we do the following sequence of things:

  1. I create a package called MistyClosures.jl (you can do this if you like @oxinabox , but I'm happy to take care of it)
  2. @oxinabox you add a small test to Diffractor.jl to make sure that it can differentiate e.g. a MistyClosure which is equivalent to the identity function
  3. I'll change over the OpaqueClosures in this package to use MistyClosures
  4. We work through whatever issues remain in getting Diffractor to work over Tapir (🤞 they are only minor)

Do you think this plan makes sense @oxinabox ?

@oxinabox
Copy link

I think this makes sense.

@yebai yebai changed the title ForwardDiff + Tapir for hessian Diffractor + Tapir for computing hessian May 22, 2024
@gdalle
Copy link

gdalle commented May 28, 2024

What about forward Enzyme over Tapir?

@willtebbutt
Copy link
Member

Conceptually the same idea, but I don't know whether Enzyme is able to deal with OpaqueClosures / whether we can use the same MistyClosure trick, because Enzyme needs the LLVM code that Julia code gets lowered into.

@gdalle
Copy link

gdalle commented May 28, 2024

Well if you ever need second order and are okay with approximations, something that will surely work is FiniteDifferences over Tapir

@willtebbutt
Copy link
Member

That's true. I do think we'll be able to make Diffractor-over-Tapir work nicely though.

@gdalle
Copy link

gdalle commented May 28, 2024

Beware that if you want to test with DifferentiationInterface, you'll be stuck with an old version of Diffractor until the following issue is resolved:

@willtebbutt
Copy link
Member

Update: MistyClosures.jl now exists.

@willtebbutt willtebbutt added enhancement New feature or request low priority labels Jul 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request low priority
Projects
None yet
Development

No branches or pull requests

4 participants