Skip to content

Commit

Permalink
Merge pull request #347 from ReactiveBayes/dev-fix-335
Browse files Browse the repository at this point in the history
Show variable name and suggestions if the resulting functional form is not supported by the inference backend
  • Loading branch information
bvdmitri authored Aug 27, 2024
2 parents 244592f + 08b93ab commit 1dbd503
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ MacroTools = "0.5.6"
Optim = "1.0.0"
ProgressMeter = "1.0.0"
Random = "1.9"
ReactiveMP = "~4.3.0"
ReactiveMP = "~4.4.0"
Reexport = "1.2.0"
Rocket = "1.8.0"
TupleTools = "1.2.0"
Expand Down
1 change: 1 addition & 0 deletions src/RxInfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("model/plugins/reactivemp_free_energy.jl")
include("model/plugins/initialization_plugin.jl")
include("model/graphppl.jl")

include("constraints/form/form_ensure_supported.jl")
include("constraints/form/form_fixed_marginal.jl")
include("constraints/form/form_point_mass.jl")
include("constraints/form/form_sample_list.jl")
Expand Down
38 changes: 38 additions & 0 deletions src/constraints/form/form_ensure_supported.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import ReactiveMP: AbstractFormConstraint

# This is an internal functional form constraint that only checks that the result
# is of a supported form. Displays a user-friendly error message if the form is not supported.
struct EnsureSupportedFunctionalForm <: AbstractFormConstraint
prefix::Symbol
name::Symbol
index::Any
end

ReactiveMP.default_form_check_strategy(::EnsureSupportedFunctionalForm) = FormConstraintCheckLast()

ReactiveMP.default_prod_constraint(::EnsureSupportedFunctionalForm) = GenericProd()

function ReactiveMP.constrain_form(constraint::EnsureSupportedFunctionalForm, something)
if typeof(something) <: ProductOf || typeof(something) <: LinearizedProductOf
expr = string(constraint.prefix, '(', constraint.name, isnothing(constraint.index) ? "" : string('[', constraint.index, ']'), ')')
expr_noindex = string(constraint.prefix, '(', constraint.name, ')')
error(lazy"""
The expression `$expr` has an undefined functional form of type `$(typeof(something))`.
This is likely because the inference backend does not support the product of these distributions.
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `$expr`.
Possible solutions:
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
```julia
using ExponentialFamilyProjection
@constraints begin
$(expr_noindex) :: ProjectedTo(NormalMeanVariance)
end
```
Refer to the documentation for more details on functional form constraints.
""")
end
return something
end
14 changes: 8 additions & 6 deletions src/model/plugins/reactivemp_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ function activate_rmp_variable!(plugin::ReactiveMPInferencePlugin, model::Model,
# By default it is `UnspecifiedFormConstraint` which means that the form of the resulting distribution is not specified in advance
# and follows from the computation, but users may override it with other form constraints, e.g. `PointMassFormConstraint`, which
# constraints the resulting distribution to be of a point mass form
messages_form_constraint = ReactiveMP.preprocess_form_constraints(
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMessagesFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
)
marginal_form_constraint = ReactiveMP.preprocess_form_constraints(
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMarginalFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
)
messages_form_constraint =
ReactiveMP.preprocess_form_constraints(
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMessagesFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
) + EnsureSupportedFunctionalForm(, GraphPPL.getname(nodeproperties), GraphPPL.index(nodeproperties))
marginal_form_constraint =
ReactiveMP.preprocess_form_constraints(
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMarginalFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
) + EnsureSupportedFunctionalForm(:q, GraphPPL.getname(nodeproperties), GraphPPL.index(nodeproperties))
# Fetch "prod-constraint" for messages and marginals. The prod-constraint usually defines the constraints for a single product of messages
# It can for example preserve a specific parametrization of distribution
messages_prod_constraint = getextra(nodedata, :messages_prod_constraint, ReactiveMP.default_prod_constraint(messages_form_constraint))
Expand Down
19 changes: 19 additions & 0 deletions test/constraints/form/form_ensure_supported_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testitem "Tests for `EnsureSupportedFunctionalForm" begin
import RxInfer: EnsureSupportedFunctionalForm
import ReactiveMP: default_form_check_strategy, default_prod_constraint, constrain_form
import BayesBase: PointMass, ProductOf, LinearizedProductOf

# In principle any object is supported except `ProductOf` and `LinearizedProductOf` from `BayesBase`
# Those are supposed to be passed to the functional form constraint

for prefix in (:q, ), index in (nothing, (1,)), name in (:a, :b)
@test default_form_check_strategy(EnsureSupportedFunctionalForm(prefix, name, index)) === FormConstraintCheckLast()
@test default_prod_constraint(EnsureSupportedFunctionalForm(prefix, name, index)) === GenericProd()

@testset let constraint = EnsureSupportedFunctionalForm(prefix, name, index)
@test constrain_form(constraint, PointMass(1)) === PointMass(1)
@test_throws Exception constrain_form(constraint, ProductOf(PointMass(1), PointMass(2)))
@test_throws Exception constrain_form(constraint, LinearizedProductOf([PointMass(1), PointMass(2)], 2))
end
end
end
72 changes: 72 additions & 0 deletions test/inference/inference_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -782,3 +782,75 @@ end
model = beta_bernoulli(), data = (y = 1,), initmessages = (t = Normal(0.0, 1.0))
)
end

@testitem "Unsupported functional forms (e.g. `ProductOf`) should display the name of the variable and suggestions" begin
struct DistributionA
a
end
struct DistributionB
b
end
struct LikelihoodDistribution
input
end

@node DistributionA Stochastic [out, a]
@node DistributionB Stochastic [out, b]
@node LikelihoodDistribution Stochastic [out, input]

@rule DistributionA(:out, Marginalisation) (q_a::Any,) = DistributionA(mean(q_a))
@rule DistributionB(:out, Marginalisation) (q_b::Any,) = DistributionB(mean(q_b))
@rule LikelihoodDistribution(:input, Marginalisation) (q_out::Any,) = LikelihoodDistribution(mean(q_out))

@model function invalid_product_posterior(out)
θ ~ DistributionA(1.0)
out ~ LikelihoodDistribution(θ)
end

# Product of `DistributionA` & `LikelihoodDistribution` in the posterior
P = typeof(prod(GenericProd(), DistributionA(1.0), LikelihoodDistribution(1.0)))
@test_throws """
The expression `q(θ)` has an undefined functional form of type `$(P)`.
This is likely because the inference backend does not support the product of these distributions.
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `q(θ)`.
Possible solutions:
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
```julia
using ExponentialFamilyProjection
@constraints begin
q(θ) :: ProjectedTo(NormalMeanVariance)
end
```
Refer to the documentation for more details on functional form constraints.
""" result = infer(model = invalid_product_posterior(), data = (out = 1.0,))

# Product of `DistributionA` & `DistributionB` in the message
@model function invalid_product_message(out)
input[1] ~ DistributionA(1.0)
input[1] ~ DistributionB(1.0)
θ ~ DistributionA(input[1])
out ~ LikelihoodDistribution(θ)
end

T = typeof(prod(GenericProd(), DistributionA(1.0), DistributionB(1.0)))
@test_throws """
The expression `μ(input[1])` has an undefined functional form of type `$(T)`.
This is likely because the inference backend does not support the product of these distributions.
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `μ(input[1])`.
Possible solutions:
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
```julia
using ExponentialFamilyProjection
@constraints begin
μ(input) :: ProjectedTo(NormalMeanVariance)
end
```
Refer to the documentation for more details on functional form constraints.
""" result = infer(model = invalid_product_message(), data = (out = 1.0,), returnvars == KeepEach(),))
end

0 comments on commit 1dbd503

Please sign in to comment.