Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adtype to DynamicPPL.Model #818

Open
wants to merge 1 commit into
base: release-0.35
Choose a base branch
from
Open

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Feb 23, 2025

Previously we discussed the idea of associating the AD type with the model, such that we could write something like

sample(model_with_adtype, NUTS(), 1000)

rather than*

sample(model, NUTS(; adtype=adtype), 1000)

(Note I'm not proposing that the second example be removed, I'm just saying that the first example should be made possible & preferred going forwards.)


Most of the hard work for this has been accomplished by #806, which gives us the ability to set an adtype on a LogDensityFunction.

However, as a typical user doesn't construct a LogDensityFunction themselves, it makes sense to also have an adtype field on DynamicPPL.Model (which users do construct) that is just automatically forwarded to the LogDensityFunction upon construction.

This PR does this, and also adds Model(::Model, ::Union{Nothing,AbstractADType}) which allows a user to set the AD type. Together with some upstream changes in Turing, this will enable the following

@model f() = x ~ Normal()

sample(Model(f(), adtype), NUTS(), 1000)

@penelopeysm penelopeysm changed the base branch from master to release-0.35 February 23, 2025 00:36
@penelopeysm penelopeysm force-pushed the py/model-adtype branch 3 times, most recently from 10fe743 to fec68da Compare February 23, 2025 00:45
Copy link

codecov bot commented Feb 23, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.61%. Comparing base (90c7b26) to head (9ca718b).

Additional details and impacted files
@@              Coverage Diff              @@
##           release-0.35     #818   +/-   ##
=============================================
  Coverage         84.60%   84.61%           
=============================================
  Files                34       34           
  Lines              3832     3834    +2     
=============================================
+ Hits               3242     3244    +2     
  Misses              590      590           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Feb 23, 2025

Pull Request Test Coverage Report for Build 13478408388

Details

  • 6 of 6 (100.0%) changed or added relevant lines in 1 file are covered.
  • 22 unchanged lines in 2 files lost coverage.
  • Overall coverage increased (+0.008%) to 84.7%

Files with Coverage Reduction New Missed Lines %
src/model.jl 5 80.33%
src/threadsafe.jl 17 51.43%
Totals Coverage Status
Change from base Build 13419888128: 0.008%
Covered Lines: 3244
Relevant Lines: 3830

💛 - Coveralls

@penelopeysm penelopeysm force-pushed the py/model-adtype branch 3 times, most recently from 0034d58 to 9e998f8 Compare February 23, 2025 01:29
@penelopeysm penelopeysm marked this pull request as ready for review February 23, 2025 08:07
@sunxd3
Copy link
Member

sunxd3 commented Feb 23, 2025

With this PR, is it still necessary for LogDensityFunction to keep adtype?
I am thinking it might creates a "adtype hierarchy" that can be confusing.

@penelopeysm
Copy link
Member Author

Hmmmmmmm I didn't think about that at first but that's a really good point. I guess at first I thought that the approach above was a natural extension of how we duplicate the context, but also now that makes me wonder whether that's needed?!

Are you thinking of dropping the adtype in the LDF so that we have something like

struct LogDensityFunction{...,TAD<:Union{Nothing,AbstractADType}}
   model::Model{...,TAD} # Then the adtype is contained here
   varinfo::AbstractVarInfo
   ...
end

And instead of storing the leaf context of the model in the LDF and then using it here

"""
logdensity_at(
x::AbstractVector,
model::Model,
varinfo::AbstractVarInfo,
context::AbstractContext
)
Evaluate the log density of the given `model` at the given parameter values `x`,
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
only for its structure, in the sense that the parameters from the vector `x` are inserted into
it, and its own parameters are discarded.
"""
function logdensity_at(
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
)
varinfo_new = unflatten(varinfo, x)
return getlogp(last(evaluate!!(model, varinfo_new, context)))
end

we could just do

function logdensity_at(
    x::AbstractVector, model::Model, varinfo::AbstractVarInfo
)
    varinfo_new = unflatten(varinfo, x)
    return getlogp(last(evaluate!!(model, varinfo_new, leafcontext(model.context))))
end

?

@sunxd3
Copy link
Member

sunxd3 commented Feb 24, 2025

dropping the adtype in the LDF

yeah, but my thought is really first-order -- model is a field of LDF, if model stores adtype, then LDF doesn't have to.

sorry if this is super obvious, may I know how does logdensity_at connect to adtype?

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 25, 2025

I guess the annoying thing is that from a technical point of view the adtype really should just be on the LDF, not the model ... but the LDF isn't exposed to the user so it's annoying for them to set the adtype. I do think you are right, I prefer not having it be duplicated - even though it might be a bit ugly - but more correct from a code point of view!

logdensity_at is just the function that evaluates logp for the given LDF. And if you want the gradient of logp then the code takes logdensity_at and feeds it into the AD package.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants