Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lux / No rrule!! available for gemm foreigncall #423

Closed
penelopeysm opened this issue Dec 14, 2024 · 2 comments
Closed

Lux / No rrule!! available for gemm foreigncall #423

penelopeysm opened this issue Dec 14, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@penelopeysm
Copy link
Contributor

This is a simplified example from the Bayesian neural net Turing.jl tutorial:

import Mooncake
using DifferentiationInterface: AutoMooncake, prepare_gradient
using Lux: Lux
using Random: Xoshiro

"""Generates a NamedTuple of parameters in the format that Lux expects"""
function reconstruct_nn_params(ps_new::AbstractVector)
    # Works
    # (layer_1 = (weight = reshape(ps_new[1:2], 1, 2), bias = ps_new[3:3]),)
    # Fails
    (layer_1 = (weight = reshape(view(ps_new, 1:2), 1, 2), bias = ps_new[3:3]),)
end

nn = Lux.Chain(Lux.Dense(2 => 1, Lux.σ))
_, st = Lux.setup(Xoshiro(468), nn)
xs = rand(Xoshiro(468), Float32, 2, 3) # training data
params = randn(Xoshiro(468), 3) # parameters

predict(params) = Lux.apply(nn, xs, Lux.f32(reconstruct_nn_params(params)), st)[1]
predict(params)

prep = prepare_gradient(predict, AutoMooncake(; config=nothing), params)

It's just the view that seems to be problematic, without that it doesn't error.

Bisects to 0.4.57, i.e. #410

penelopeysm added a commit to TuringLang/docs that referenced this issue Dec 14, 2024
penelopeysm added a commit to TuringLang/docs that referenced this issue Dec 15, 2024
* Build with Julia 1.11

* Change pull_request_target -> pull_request

* Bump to Turing 0.35.3

* Remove view() in Bayesian NN doc, see compintell/Mooncake.jl#423
@willtebbutt willtebbutt added the bug Something isn't working label Dec 15, 2024
@willtebbutt
Copy link
Member

Thanks for finding + minimising this -- #424 will resolve. Your bisection makes total sense -- I changed the way that gemm!'s rrule is implemented in that PR, and the result was very slightly less generic than I had hoped for. It looks like I mustn't have had any test cases covering Base.ReshapedArrays being passed into BLAS calls. In any case, #424 permanently adds inputs containing Base.ReshapedArrays to the test suite of all BLAS / LAPACK functions, so there should be no risk of this bug reappearing for gemm! or for any other BLAS / LAPACK functionality.

@willtebbutt
Copy link
Member

0.6.32 is merged, should be available shortly, and should resolve. Please let me know if it doesn't!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants