-
Notifications
You must be signed in to change notification settings - Fork 14
/
probit.jl
45 lines (30 loc) · 1.17 KB
/
probit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
using StatsFuns: normcdf
export Probit, ProbitMeta
struct Probit end
struct ProbitMeta
p :: Int
end
getp(meta::ProbitMeta) = meta.p
ProbitMeta(; p = 32) = ProbitMeta(p)
@node Probit Stochastic [ out, in ]
default_meta(::Type{ Probit }) = ProbitMeta(32)
default_functional_dependencies_pipeline(::Type{ <: Probit }) = RequireInboundFunctionalDependencies((2, ), (vague(NormalMeanPrecision), ))
default_interface_local_constraint(::Type{ <: Probit }, edge::Val{ :in }) = MomentMatching()
default_interface_local_constraint(::Type{ <: Probit }, edge::Val{ :out }) = Marginalisation()
@average_energy Probit (q_out::Union{PointMass, Bernoulli}, q_in::UnivariateNormalDistributionsFamily, meta::ProbitMeta) = begin
# extract parameters
p = mean(q_out)
m, v = mean_var(q_in)
# specify function=
h = (x) -> -p*log(normcdf(x)) - (1-p)*log(normcdf(-x))
# calculate average average energy (default of 32 points)
gh_cubature = GaussHermiteCubature(getp(meta))
U = 0.0
tmp = sqrt(2*v)
for k = 1:getp(meta)
U += gh_cubature.witer[k] * h( gh_cubature.piter[k] * tmp + m)
end
U /= sqrt(pi)
# return average energy
return U
end