Skip to content

Commit

Permalink
Add adtype field to DynamicPPL.Model
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Feb 23, 2025
1 parent 90c7b26 commit 10fe743
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 16 deletions.
19 changes: 18 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ This release removes the feature of `VarInfo` where it kept track of which varia
**Other changes**
### Models now store AD backend types
In `DynamicPPL.Model`, an extra field `adtype::Union{Nothing,ADTypes.AbstractADType}` has been added. This field is used to store the AD backend which should be used when calculating gradients of the log density.
The field can be set by passing an extra argument to the `Model` constructor, but more realistically, it is likely that users will want to manually set the `adtype` field on an existing model:
```julia
@model f() = ...
model = f()
model_with_adtype = setadtype(model, AutoForwardDiff())
```
As far as `DynamicPPL.Model` is concerned, this field does not actually have any effect.
However, when a `LogDensityFunction` is constructed from said model, it will inherit the `adtype` field from the model.
See below for more information on `LogDensityFunction`.
### `LogDensityProblems` interface
LogDensityProblemsAD is now removed as a dependency.
Expand All @@ -136,7 +152,8 @@ Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now direct
Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).
However, in this version, `LogDensityFunction` now takes an extra AD type argument.
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
By default, this AD type is inherited from the model that the `LogDensityFunction` is constructed from.
If the model does not have an AD type, or if the argument is explicitly set to `nothing`, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
You may thus find that it is easier to instead do this:
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ getargnames
getmissings
```

The context and AD type of a model can be changed with [`contextualize`](@ref) and [`setadtype`](@ref) respectively.

```@docs
contextualize
setadtype
```

## Evaluation

With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ export AbstractVarInfo,
getargnames,
extract_priors,
values_as_in_model,
setadtype,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
2 changes: 1 addition & 1 deletion src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct LogDensityFunction{
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
adtype::Union{ADTypes.AbstractADType,Nothing}=model.adtype,
)
if adtype === nothing
prep = nothing
Expand Down
61 changes: 47 additions & 14 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,TAD<:Union{Nothing,ADTypes.AbstractADType}}
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx=DefaultContext()
adtype::TAD=nothing
end
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
arguments `missings`, and evaluation context of type `Ctx`.
A `Model` struct contains the following fields:
- `f`, a model evaluation function of type `F`
- `args`, arguments of names `argnames` with types `Targs`
- `defaults`, default arguments of names `defaultnames` with types `Tdefaults`
- `context`, an evaluation context of type `Ctx`
- `adtype`, which can be nothing, or an automatic differentiation backend of type `TAD`
Its missing arguments are also stored as a type parameter `missings`.
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
`context` is by default `DefaultContext()`.
`context` is by default `DefaultContext()`, and `adtype` is by default `nothing`.
An argument with a type of `Missing` will be in `missings` by default. However, in
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
Expand All @@ -33,12 +39,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
AbstractProbabilisticProgram
struct Model{
F,
argnames,
defaultnames,
missings,
Targs,
Tdefaults,
Ctx<:AbstractContext,
TAD<:Union{Nothing,ADTypes.AbstractADType},
} <: AbstractProbabilisticProgram
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx
adtype::TAD

@doc """
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
Expand All @@ -51,9 +66,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{defaultnames,Tdefaults},
context::Ctx=DefaultContext(),
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
f, args, defaults, context
adtype::TAD=nothing,
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,TAD}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD}(
f, args, defaults, context, adtype
)
end
end
Expand All @@ -71,22 +87,39 @@ model with different arguments.
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{kwargnames,Tkwargs},
context::AbstractContext=DefaultContext(),
) where {F,argnames,Targs,kwargnames,Tkwargs}
adtype::TAD=nothing,
) where {F,argnames,Targs,kwargnames,Tkwargs,TAD}
missing_args = Tuple(
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
)
missing_kwargs = Tuple(
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
)
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
return :(Model{$(missing_args..., missing_kwargs...)}(
f, args, defaults, context, adtype
))
end

function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
return Model(f, args, NamedTuple(kwargs), context)
return Model(f, args, NamedTuple(kwargs), context, nothing)
end

"""
contextualize(model::Model, context::AbstractContext)
Set the context of `model` to `context`.
"""
function contextualize(model::Model, context::AbstractContext)
return Model(model.f, model.args, model.defaults, context)
return Model(model.f, model.args, model.defaults, context, model.adtype)
end

"""
setadtype(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})
Set the automatic differentiation backend of `model` to `adtype`.
"""
function setadtype(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})
return Model(model.f, model.args, model.defaults, model.context, adtype)
end

"""
Expand Down
14 changes: 14 additions & 0 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff
end
end

@testset "AD type forwarding from model" begin
@model demo_simple() = x ~ Normal()
adtype = ForwardDiff()
model = setadtype(demo_simple(), adtype)
ldf = DynamicPPL.LogDensityFunction(model)
# Check that the model's AD type is forwarded to the LDF
@test ldf.adtype == adtype
# Check that the gradient can be evaluated on the resulting LDF
@test LogDensityProblems.capabilities(typeof(ldf)) ==
LogDensityProblems.LogDensityOrder{1}()
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any
@test LogDensityProblems.logdensity_and_gradient(ldf, [1.0])
end

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
Expand Down
10 changes: 10 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end

@testset "model adtype" begin
# Check that adtype can be set and unset
@model demo_adtype() = x ~ Normal()
adtype = AutoForwardDiff()
model = setadtype(demo_adtype(), adtype)
@test model.adtype == adtype
model = setadtype(model, nothing)
@test model.adtype === nothing
end

@testset "model de/conditioning" begin
@model function demo_condition()
x ~ Normal()
Expand Down

0 comments on commit 10fe743

Please sign in to comment.