From b73103c0bcba301e50367ba69daf1bea46ff9399 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 17 Jun 2024 14:35:58 +0200 Subject: [PATCH] fixes for new versions of GraphPPL and ReactiveMP --- Project.toml | 6 +++--- codemeta.json | 4 ++-- src/model/plugins/reactivemp_inference.jl | 16 ++++++++++++++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 817aa72dc..2da610924 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RxInfer" uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d" authors = ["Bagaev Dmitry and contributors"] -version = "3.3.1" +version = "3.4.0" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" @@ -28,13 +28,13 @@ Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" ExponentialFamily = "1.2" FastCholesky = "1.3.0" -GraphPPL = "~4.2.0" +GraphPPL = "~4.3.0" LinearAlgebra = "1.9" MacroTools = "0.5.6" Optim = "1.0.0" ProgressMeter = "1.0.0" Random = "1.9" -ReactiveMP = "~4.1.0" +ReactiveMP = "~4.2.0" Reexport = "1.2.0" Rocket = "1.8.0" TupleTools = "1.2.0" diff --git a/codemeta.json b/codemeta.json index 7323e9bb4..b815a2fe1 100644 --- a/codemeta.json +++ b/codemeta.json @@ -9,12 +9,12 @@ "downloadUrl": "https://github.com/reactivebayes/RxInfer.jl/releases", "issueTracker": "https://github.com/reactivebayes/RxInfer.jl/issues", "name": "RxInfer.jl", - "version": "3.3.1", + "version": "3.4.0", "description": "Julia package for automated, scalable and efficient Bayesian inference on factor graphs with reactive message passing. ", "applicationCategory": "Statistics", "developmentStatus": "active", "readme": "https://reactivebayes.github.io/RxInfer.jl/stable/", - "softwareVersion": "3.3.1", + "softwareVersion": "3.4.0", "keywords": [ "Bayesian inference", "message passing", diff --git a/src/model/plugins/reactivemp_inference.jl b/src/model/plugins/reactivemp_inference.jl index 2981628dd..9e20a3cba 100644 --- a/src/model/plugins/reactivemp_inference.jl +++ b/src/model/plugins/reactivemp_inference.jl @@ -159,8 +159,12 @@ function activate_rmp_variable!(plugin::ReactiveMPInferencePlugin, model::Model, # By default it is `UnspecifiedFormConstraint` which means that the form of the resulting distribution is not specified in advance # and follows from the computation, but users may override it with other form constraints, e.g. `PointMassFormConstraint`, which # constraints the resulting distribution to be of a point mass form - messages_form_constraint = getextra(nodedata, GraphPPL.VariationalConstraintsMessagesFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint()) - marginal_form_constraint = getextra(nodedata, GraphPPL.VariationalConstraintsMarginalFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint()) + messages_form_constraint = ReactiveMP.preprocess_form_constraints( + plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMessagesFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint()) + ) + marginal_form_constraint = ReactiveMP.preprocess_form_constraints( + plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMarginalFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint()) + ) # Fetch "prod-constraint" for messages and marginals. The prod-constraint usually defines the constraints for a single product of messages # It can for example preserve a specific parametrization of distribution messages_prod_constraint = getextra(nodedata, :messages_prod_constraint, ReactiveMP.default_prod_constraint(messages_form_constraint)) @@ -301,3 +305,11 @@ ReactiveMP.setmarginals!(collection::AbstractArray{GraphVariableRef}, marginal) ReactiveMP.setmessage!(ref::GraphVariableRef, marginal) = setmessage!(ref.variable, marginal) ReactiveMP.setmessages!(collection::AbstractArray{GraphVariableRef}, marginal) = ReactiveMP.setmessages!(map(ref -> ref.variable, collection), marginal) + +# Form constraint preprocessing + +function ReactiveMP.preprocess_form_constraints(backend::ReactiveMPInferencePlugin, model::Model, constraints) + # It is a simple pass-through for now, but can be extended in the future to preprocess constraints that + # are defined in other packages, e.g. in `Distributions` and to support constraints, such as `q(x) :: Normal` + return ReactiveMP.preprocess_form_constraints(constraints) +end