Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when running infer due to splatting #240

Open
albertpod opened this issue Jun 13, 2024 · 3 comments
Open

Error when running infer due to splatting #240

albertpod opened this issue Jun 13, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@albertpod
Copy link
Member

I know the issue with splatting was addressed here, but I am having a trouble with it, when running the infer function with the provided PCA model (see below). The following error is encountered:

ERROR: MethodError: no method matching iterate(::GraphPPL.VariableRef{…})

Closest candidates are:
  iterate(::Revise.LineSkippingIterator)
   @ Revise ~/.julia/packages/Revise/bAgL0/src/relocatable_exprs.jl:70
  iterate(::Revise.LineSkippingIterator, ::Any)
   @ Revise ~/.julia/packages/Revise/bAgL0/src/relocatable_exprs.jl:70
  iterate(::Base.MethodSpecializations)
   @ Base reflection.jl:1148
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/dev/GraphPPL/src/model_macro.jl:543 [inlined]
  [2] macro expansion

The error occurs at the line:

generate_lhs_proxylabel(var, index::Nothing) = quote

The minimum code example to reproduce the error:

using RxInfer

PCA_block(x, w...) = hcat(w...)*x
# PCA_block(x, w1, w2) = [w1 w2]*x

@model function pca_mode(y, components, obs_dim, lat_dim)
    local w
    for j in 1:components
        w[j] ~ MvNormal=zeros(obs_dim), Λ=diageye(obs_dim))
    end

    for i in eachindex(y)
        x[i] ~ MvNormalMeanPrecision(ones(lat_dim), diageye(lat_dim))
        y[i] ~ MvNormal=PCA_block(x[i], w...), Λ=diageye(obs_dim))
        # y[i] ~ MvNormal(μ=PCA_block(x[i], w[1], w[2]), Λ=diageye(obs_dim))
    end
end

components = 2
n_samples = 100
obs_dim = 4
lat_dim = 2

w1 = [2.0, -1.0, 0.5, -0.2] 
w2 = [0.8, 1.5, -0.3, 0.1]

latent_x = [rand(MvNormal(zeros(lat_dim), diageye(lat_dim))) for i in 1:n_samples]

y = [PCA_block(latent_x[i], w1, w2) + rand(MvNormal(zeros(obs_dim), 0.1diageye(obs_dim))) for i in 1:n_samples]

pca_meta = @meta begin
    PCA_block() -> Linearization()
end

initialization = @initialization begin
    μ(w) = MvNormalMeanPrecision(zeros(obs_dim), diageye(obs_dim))
end

result = infer(model=pca_mode(components=components, obs_dim=obs_dim, lat_dim=lat_dim), initialization=initialization, data=(y=y, ), meta=pca_meta, free_energy=true, iterations=5, showprogress=true, returnvars=KeepLast())

The issue appears to be related to the splatting of the w array in the PCA_block function.

@albertpod albertpod added the bug Something isn't working label Jun 13, 2024
@bvdmitri
Copy link
Member

bvdmitri commented Jun 13, 2024

Can you try and see if it works as a workaround?

μ=PCA_block(in = [ x[i], w... ])

EDIT: ah, sorry, it won't work probably either

@bvdmitri
Copy link
Member

bvdmitri commented Jun 13, 2024

I think its a real issue, but should be fixable. We need to define iterate, which probably should reuse some code from Base.broadcastable(ref::VariableRef) (it may even call broadcastable?). We may had some justification for not including iterate on VariableRef, but I cannot recall it.

@wouterwln
Copy link
Member

@bvdmitri the problem when I try to define iterate for VariableRef is that under the hood proxylabel(name, v..., nothing, False()) will be called, which splats out the variable we want to splat inside the proxylabel function. Do you have an idea how to fix this? I was thinking we maybe fix the argument order of proxylabel to include a dispatch option on Vararg, but it is an ugly solution I do not like. What do you think? An alternative viable strategy?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

3 participants