-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A Flux model #10
A Flux model #10
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did not review everything in details but here are some comments.
Additionally, should we have some internal train
and predict
functions? We could add a dependency on Flux via Requires.jl
?
return m, A | ||
end | ||
|
||
function _init_variational_params( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a keyword argument to enforce the eltype and type of \mu and \Sigma?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an eltype parameter, but I can't think of a reason why you wouldn't want to use Vector
/ Matrix
, so maybe a type parameter isn't needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well in the crazy case where someone would like to use GPUs, one would need pass CuArrays but I guess they would just have to make the effort of passing their own variational parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's a functor, isn't it easiest to just use Flux.gpu
for that after it's initialised?
Co-authored-by: Théo Galy-Fajou <theo.galyfajou@gmail.com>
Training works with Flux's default training loop, so I'm not sure an internal |
Pretty much everything useful in this PR is in the regression/classification examples already, so closing |
Adds a high level API for a model which can easily be used with Flux (although doesn't add a dependency on Flux itself).
Examples have been updated to use the model.
It should also be straightforward to adapt the model to support something like ParameterHandling.jl for use with general optimisers - this will probably be left to a future PR though.