Skip to content

Commit

Permalink
Merge pull request #356 from ReactiveBayes/add_unfactorizeddata
Browse files Browse the repository at this point in the history
Add UnfactorizedData
  • Loading branch information
wouterwln authored Sep 30, 2024
2 parents fc028f2 + 05f2b1a commit c8c97f0
Show file tree
Hide file tree
Showing 7 changed files with 1,725 additions and 1,613 deletions.
22 changes: 22 additions & 0 deletions docs/src/manuals/constraints-specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,28 @@ end
```
More information can be found in the [GraphPPL documentation](https://reactivebayes.github.io/GraphPPL.jl/stable/plugins/constraint_specification/#Default-constraints).

## Constraints on the data

By default, `RxInfer` assumes that, since the data comes into the model as observed, the posterior marginal distribution of the data is independent from other marginals and is a Dirac-delta distribution. However, this assumption breaks when we pass missing data into our model. When the data is missing, we might have a joint dependency between the data and latent variables, as the missing data essentially behaves as a latent variable. In such cases, we can wrap the data in a `UnfactorizedData`. This will notify the inference engine that the data should not be factorized out and we can specify a custom factorization constraint on these variables using the `@constraints` macro.

```@docs
UnfactorizedData
```

```@example constraints-specification
unfactorized_example_constraints = @constraints begin
q(y[1:1000], μ, τ) = q(y[1:1000])q(μ)q(τ)
q(y[1001:1100], μ, τ) = q(y[1001:1100], μ)q(τ)
end
result = infer(
model = iid_normal(),
data = (y = UnfactorizedData(vcat(rand(NormalMeanPrecision(3.1415, 2.7182), 1000), [missing for _ in 1:100])),),
constraints = unfactorized_example_constraints,
initialization = init,
iterations = 25
)
```

## Prespecified constraints
`GraphPPL` exports some [prespecified constraints](https://reactivebayes.github.io/GraphPPL.jl/stable/plugins/constraint_specification/#Prespecified-constraints) that can be used in the `@constraints` macro, but these constraints can also be passed as top-level constraints in the `infer` function. For example, to specify a mean-field assumption on all variables in the model, we can use the `MeanField` constraint:

Expand Down
3,224 changes: 1,623 additions & 1,601 deletions examples/basic_examples/Predicting Bike Rental Demand.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/inference/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ function batch_inference(;
# But only if the data has missing values in it
elseif isnothing(predictvars) && !isnothing(data)
predictoption = iterations isa Number ? KeepEach() : KeepLast()
predictvars = Dict(variable => predictoption for (variable, value) in pairs(data) if inference_check_dataismissing(value))
predictvars = Dict(variable => predictoption for (variable, value) in pairs(data) if inference_check_dataismissing(get_data(value)))
# If both `predictvar` and `data` are specified we double check if there are some entries in the `predictvars`
# which are not specified in the `data` and inject them
# We do the same the other way around for the `data` entries which are not specified in the `predictvars`
Expand Down Expand Up @@ -301,7 +301,7 @@ function batch_inference(;
end
inference_invoke_callback(callbacks, :before_data_update, fmodel, data)
for (key, value) in fdata
update!(cacheddatavars[key], value)
update!(cacheddatavars[key], get_data(value))
end
inference_invoke_callback(callbacks, :after_data_update, fmodel, data)

Expand Down
2 changes: 1 addition & 1 deletion src/inference/streaming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ function streaming_inference(;
error(lazy"`$(_autoupdate_data_handler_key)` is present both in the `data` and in the `autoupdates`.")
end
end
_condition_on = merge_data_handlers(create_deffered_data_handlers(datavarnames), autoupdates_data_handlers(autoupdates))
_condition_on = merge_data_handlers(create_deferred_data_handlers(datavarnames), autoupdates_data_handlers(autoupdates))

inference_invoke_callback(callbacks, :before_model_creation)
fmodel = create_model(_model | _condition_on)
Expand Down
29 changes: 26 additions & 3 deletions src/model/model.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@

export ProbabilisticModel
export ProbabilisticModel, UnfactorizedData
export getmodel, getreturnval, getvardict, getrandomvars, getconstantvars, getdatavars, getfactornodes

import Base: push!, show, getindex, haskey, firstindex, lastindex
import ReactiveMP: getaddons, AbstractFactorNode
import GraphPPL: ModelGenerator, getmodel, getkwargs, create_model
import Rocket: getscheduler

"""
UnfactorizedData{D}
A wrapper struct to wrap data that should not be factorized out by default during inference.
When performing Bayesian Inference with message passing, every factor node contains a local
factorization constraint on the variational posterior distribution. For data, we usually regarding
data as an independent component in the variational posterior distribution. However, in some cases,
for example when we are predicting data, we do not want to factorize out the data. In such cases,
we can wrap the data with `UnfactorizedData` struct to prevent the factorization and craft a custom
node-local factorization with the `@constraints` macro.
"""
struct UnfactorizedData{D}
data::D
end

get_data(x) = x
get_data(x::UnfactorizedData) = x.data

"A structure that holds the factor graph representation of a probabilistic model."
struct ProbabilisticModel{M}
model::M
Expand Down Expand Up @@ -145,6 +163,11 @@ function __infer_create_data_interface(model, context, key::Symbol, ::DeferredDa
return GraphPPL.datalabel(model, context, GraphPPL.NodeCreationOptions(kind = :data, factorized = true), key, GraphPPL.MissingCollection())
end

# In all other cases we use the `datalabel` to instantiate the data interface for the model and the data is known at the time of the model creation
function __infer_create_data_interface(model, context, key::Symbol, data::UnfactorizedData{D}) where {D}
return GraphPPL.datalabel(model, context, GraphPPL.NodeCreationOptions(kind = :data, factorized = false), key, get_data(data))
end

# In all other cases we use the `datalabel` to instantiate the data interface for the model and the data is known at the time of the model creation
function __infer_create_data_interface(model, context, key::Symbol, data)
return GraphPPL.datalabel(model, context, GraphPPL.NodeCreationOptions(kind = :data, factorized = true), key, data)
Expand All @@ -156,11 +179,11 @@ merge_data_handlers(data::NamedTuple, newdata::Dict) = merge(convert(Dict, data)
merge_data_handlers(data::NamedTuple, newdata::NamedTuple) = merge(data, newdata)

# This function creates a named tuple of `DeferredDataHandler` objects from a tuple of symbols
function create_deffered_data_handlers(symbols::NTuple{N, Symbol}) where {N}
function create_deferred_data_handlers(symbols::NTuple{N, Symbol}) where {N}
return NamedTuple{symbols}(map(_ -> DeferredDataHandler(), symbols))
end

# This function creates a dictionary of `DeferredDataHandler` objects from an array of symbols
function create_deffered_data_handlers(symbols::AbstractVector{Symbol})
function create_deferred_data_handlers(symbols::AbstractVector{Symbol})
return Dict{Symbol, DeferredDataHandler}(map(s -> s => DeferredDataHandler(), symbols))
end
45 changes: 45 additions & 0 deletions test/inference/inference_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -972,3 +972,48 @@ end
Refer to the documentation for more details on functional form constraints.
""" result = infer(model = invalid_product_message(), data = (out = 1.0,), returnvars == KeepEach(),))
end

@testitem "`infer` with UnfactorizedData" begin
using RxInfer

@model function pred_model(p_s_t, y, goal, p_B, A)
s[1] ~ p_s_t
B ~ p_B
y[1] ~ Transition(s[1], A)
for i in 2:3
s[i] ~ Transition(s[i - 1], B)
y[i] ~ Transition(s[i], A)
end
s[3] ~ Categorical(goal)
end

pred_model_constraints = @constraints begin
q(s, B) = q(s)q(B)
end

@initialization function pred_model_init(q_B)
q(B) = q_B
end

result = infer(
model = pred_model(A = diageye(4), goal = [0, 1, 0, 0], p_B = MatrixDirichlet(ones(4, 4)), p_s_t = Categorical([0.7, 0.3, 0, 0])),
data = (y = [[1, 0, 0, 0], missing, missing],),
initialization = pred_model_init(MatrixDirichlet(ones(4, 4))),
constraints = pred_model_constraints,
iterations = 10
)
@test last(result.predictions[:y])[1] == Categorical([0.25, 0.25, 0.25, 0.25])

pred_model_constraints = @constraints begin
q(s, B) = q(s)q(B)
q(y[1], s) = q(y[1])q(s)
end
result = infer(
model = pred_model(A = diageye(4), goal = [0, 0, 1, 0], p_B = MatrixDirichlet(ones(4, 4)), p_s_t = Categorical([0.7, 0.3, 0, 0])),
data = (y = UnfactorizedData([[1, 0, 0, 0], missing, missing]),),
initialization = pred_model_init(MatrixDirichlet(ones(4, 4))),
constraints = pred_model_constraints,
iterations = 10
)
@test probvec(last(last(result.predictions[:y]))) [0, 0, 1, 0]
end
12 changes: 6 additions & 6 deletions test/model/model_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ end
@test create_model(conditioned) isa ProbabilisticModel
end

@testitem "create_deffered_data_handlers" begin
import RxInfer: create_deffered_data_handlers, DeferredDataHandler
@testitem "create_deferred_data_handlers" begin
import RxInfer: create_deferred_data_handlers, DeferredDataHandler

@testset "Creating deffered labels from tuple of symbols" begin
@test create_deffered_data_handlers((:x, :y)) === (x = DeferredDataHandler(), y = DeferredDataHandler())
@testset "Creating deferred labels from tuple of symbols" begin
@test create_deferred_data_handlers((:x, :y)) === (x = DeferredDataHandler(), y = DeferredDataHandler())
end

@testset "Creating deffered labels from array of symbols" begin
@test create_deffered_data_handlers([:x, :y]) == Dict(:x => DeferredDataHandler(), :y => DeferredDataHandler())
@testset "Creating deferred labels from array of symbols" begin
@test create_deferred_data_handlers([:x, :y]) == Dict(:x => DeferredDataHandler(), :y => DeferredDataHandler())
end
end

0 comments on commit c8c97f0

Please sign in to comment.