-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add adtype to DynamicPPL.Model #818
base: release-0.35
Are you sure you want to change the base?
Conversation
10fe743
to
fec68da
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## release-0.35 #818 +/- ##
=============================================
Coverage 84.60% 84.61%
=============================================
Files 34 34
Lines 3832 3834 +2
=============================================
+ Hits 3242 3244 +2
Misses 590 590 ☔ View full report in Codecov by Sentry. |
Pull Request Test Coverage Report for Build 13478408388Details
💛 - Coveralls |
0034d58
to
9e998f8
Compare
9e998f8
to
9ca718b
Compare
With this PR, is it still necessary for |
Hmmmmmmm I didn't think about that at first but that's a really good point. I guess at first I thought that the approach above was a natural extension of how we duplicate the context, but also now that makes me wonder whether that's needed?! Are you thinking of dropping the adtype in the LDF so that we have something like struct LogDensityFunction{...,TAD<:Union{Nothing,AbstractADType}}
model::Model{...,TAD} # Then the adtype is contained here
varinfo::AbstractVarInfo
...
end And instead of storing the leaf context of the model in the LDF and then using it here DynamicPPL.jl/src/logdensityfunction.jl Lines 167 to 185 in 90c7b26
we could just do function logdensity_at(
x::AbstractVector, model::Model, varinfo::AbstractVarInfo
)
varinfo_new = unflatten(varinfo, x)
return getlogp(last(evaluate!!(model, varinfo_new, leafcontext(model.context))))
end ? |
yeah, but my thought is really first-order -- model is a field of LDF, if model stores adtype, then LDF doesn't have to. sorry if this is super obvious, may I know how does |
I guess the annoying thing is that from a technical point of view the adtype really should just be on the LDF, not the model ... but the LDF isn't exposed to the user so it's annoying for them to set the adtype. I do think you are right, I prefer not having it be duplicated - even though it might be a bit ugly - but more correct from a code point of view!
|
Previously we discussed the idea of associating the AD type with the model, such that we could write something like
rather than*
(Note I'm not proposing that the second example be removed, I'm just saying that the first example should be made possible & preferred going forwards.)
Most of the hard work for this has been accomplished by #806, which gives us the ability to set an
adtype
on aLogDensityFunction
.However, as a typical user doesn't construct a
LogDensityFunction
themselves, it makes sense to also have anadtype
field onDynamicPPL.Model
(which users do construct) that is just automatically forwarded to theLogDensityFunction
upon construction.This PR does this, and also adds
Model(::Model, ::Union{Nothing,AbstractADType})
which allows a user to set the AD type. Together with some upstream changes in Turing, this will enable the following