-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #347 from ReactiveBayes/dev-fix-335
Show variable name and suggestions if the resulting functional form is not supported by the inference backend
- Loading branch information
Showing
6 changed files
with
139 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import ReactiveMP: AbstractFormConstraint | ||
|
||
# This is an internal functional form constraint that only checks that the result | ||
# is of a supported form. Displays a user-friendly error message if the form is not supported. | ||
struct EnsureSupportedFunctionalForm <: AbstractFormConstraint | ||
prefix::Symbol | ||
name::Symbol | ||
index::Any | ||
end | ||
|
||
ReactiveMP.default_form_check_strategy(::EnsureSupportedFunctionalForm) = FormConstraintCheckLast() | ||
|
||
ReactiveMP.default_prod_constraint(::EnsureSupportedFunctionalForm) = GenericProd() | ||
|
||
function ReactiveMP.constrain_form(constraint::EnsureSupportedFunctionalForm, something) | ||
if typeof(something) <: ProductOf || typeof(something) <: LinearizedProductOf | ||
expr = string(constraint.prefix, '(', constraint.name, isnothing(constraint.index) ? "" : string('[', constraint.index, ']'), ')') | ||
expr_noindex = string(constraint.prefix, '(', constraint.name, ')') | ||
error(lazy""" | ||
The expression `$expr` has an undefined functional form of type `$(typeof(something))`. | ||
This is likely because the inference backend does not support the product of these distributions. | ||
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `$expr`. | ||
Possible solutions: | ||
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance). | ||
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example: | ||
```julia | ||
using ExponentialFamilyProjection | ||
@constraints begin | ||
$(expr_noindex) :: ProjectedTo(NormalMeanVariance) | ||
end | ||
``` | ||
Refer to the documentation for more details on functional form constraints. | ||
""") | ||
end | ||
return something | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
@testitem "Tests for `EnsureSupportedFunctionalForm" begin | ||
import RxInfer: EnsureSupportedFunctionalForm | ||
import ReactiveMP: default_form_check_strategy, default_prod_constraint, constrain_form | ||
import BayesBase: PointMass, ProductOf, LinearizedProductOf | ||
|
||
# In principle any object is supported except `ProductOf` and `LinearizedProductOf` from `BayesBase` | ||
# Those are supposed to be passed to the functional form constraint | ||
|
||
for prefix in (:q, :μ), index in (nothing, (1,)), name in (:a, :b) | ||
@test default_form_check_strategy(EnsureSupportedFunctionalForm(prefix, name, index)) === FormConstraintCheckLast() | ||
@test default_prod_constraint(EnsureSupportedFunctionalForm(prefix, name, index)) === GenericProd() | ||
|
||
@testset let constraint = EnsureSupportedFunctionalForm(prefix, name, index) | ||
@test constrain_form(constraint, PointMass(1)) === PointMass(1) | ||
@test_throws Exception constrain_form(constraint, ProductOf(PointMass(1), PointMass(2))) | ||
@test_throws Exception constrain_form(constraint, LinearizedProductOf([PointMass(1), PointMass(2)], 2)) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters