Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type class syntax #420

Merged
merged 7 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/chol.dx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ https://en.wikipedia.org/wiki/Cholesky_decomposition

' ## Cholesky Algorithm

def chol (_:Eq n) ?=> (x:n=>n=>Float) : (n=>n=>Float) =
def chol [Eq n] (x:n=>n=>Float) : (n=>n=>Float) =
yieldState zero \buf.
for_ i. for j':(..i).
j = %inject(j')
Expand Down Expand Up @@ -31,7 +31,7 @@ def trisolveU (mat:n=>n=>Float) (b:n=>Float) : n=>Float =
xPrev = for j:(i..). get (buf!%inject j)
buf!i := (b.i - vdot row xPrev) / mat.i.i

def psdsolve (_:Eq n) ?=> (mat:n=>n=>Float) (b:n=>Float) : n=>Float =
def psdsolve [Eq n] (mat:n=>n=>Float) (b:n=>Float) : n=>Float =
l = chol mat
trisolveU (transpose l) $ trisolveL l b

Expand Down
7 changes: 5 additions & 2 deletions examples/ctc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ def logaddexp (x:Float) (y:Float) : Float =
m = max x y
m + ( log ( (exp (x - m) + exp (y - m))))

def ctc (dict: Eq vocab) ?=> (dict2: Eq position) ?=> (dict3: Eq time) ?=> (blank: vocab)
(logits: time=>vocab=>Float) (labels: position=>vocab) : Float =
def ctc [Eq vocab, Eq position, Eq time]
(blank: vocab)
(logits: time=>vocab=>Float)
(labels: position=>vocab)
: Float =
-- Computes log p(labels | logits), marginalizing over possible alignments.
-- Todo: remove unnecessary implicit type annotations once
-- Dex starts putting implicit types in scope.
Expand Down
20 changes: 10 additions & 10 deletions examples/fluidsim.dx
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def incwrap (i:n) : n = -- Increment index, wrapping around at ends.
def decwrap (i:n) : n = -- Decrement index, wrapping around at ends.
asidx $ mod ((ordinal i) - 1) $ size n

def finite_difference_neighbours (_:Add a) ?=> (x:n=>a) : n=>a =
def finite_difference_neighbours [Add a] (x:n=>a) : n=>a =
for i. x.(incwrap i) - x.(decwrap i)

def add_neighbours (_:Add a) ?=> (x:n=>a) : n=>a =
def add_neighbours [Add a] (x:n=>a) : n=>a =
for i. x.(incwrap i) + x.(decwrap i)

def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a =
Expand All @@ -26,21 +26,21 @@ def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a =
def apply_along_axis2 (f:c=>a -> c=>a) (x:b=>c=>a) : b=>c=>a =
for i. f x.i

def fdx (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) =
def fdx [Add a] (x:n=>m=>a) : (n=>m=>a) =
apply_along_axis1 finite_difference_neighbours x

def fdy (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) =
def fdy [Add a] (x:n=>m=>a) : (n=>m=>a) =
apply_along_axis2 finite_difference_neighbours x

def divergence (_:Add a) ?=> (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) =
def divergence [Add a] (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) =
fdx vx + fdy vy

def add_neighbours_2d (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) =
def add_neighbours_2d [Add a] (x:n=>m=>a) : (n=>m=>a) =
ax1 = apply_along_axis1 add_neighbours x
ax2 = apply_along_axis2 add_neighbours x
ax1 + ax2

def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a =
def project [VSpace a] (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a =
-- Project the velocity field to be approximately mass-conserving,
-- using a few iterations of Gauss-Seidel.
h = 1.0 / IToF (size n)
Expand All @@ -60,13 +60,13 @@ def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a =

for i j. [vx.i.j, vy.i.j] -- pack back into a table.

def bilinear_interp (_:VSpace a) ?=> (right_weight:Float) (bottom_weight:Float)
def bilinear_interp [VSpace a] (right_weight:Float) (bottom_weight:Float)
(topleft: a) (bottomleft: a) (topright: a) (bottomright: a) : a =
left = (1.0 - right_weight) .* ((1.0 - bottom_weight) .* topleft + bottom_weight .* bottomleft)
right = right_weight .* ((1.0 - bottom_weight) .* topright + bottom_weight .* bottomright)
left + right

def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a =
def advect [VSpace a] (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a =
-- Move field f according to x and y velocities (u and v)
-- using an implicit Euler integrator.

Expand Down Expand Up @@ -95,7 +95,7 @@ def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a =
-- A convex weighting of the 4 surrounding cells.
bilinear_interp right_weight bottom_weight f.l.t f.l.b f.r.t f.r.b

def fluidsim (_: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a)
def fluidsim [ VSpace a] (num_steps: Int) (color_init: n=>m=>a)
(v: n=>m=>(Fin 2)=>Float) : (Fin num_steps)=>n=>m=>a =
withState (color_init, v) \state.
for i:(Fin num_steps).
Expand Down
26 changes: 13 additions & 13 deletions examples/linear_algebra.dx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'## LU Decomposition and Matrix Inversion

def identity_matrix (_:Eq n) ?=> (_:Add a) ?=> (_:Mul a) ?=> : n=>n=>a =
def identity_matrix [Eq n, Add a, Mul a] : n=>n=>a =
for i j. select (i == j) one zero

'### Triangular matrices
Expand All @@ -11,15 +11,15 @@ def UpperTriMat (n:Type) (v:Type) : Type = i:n=>(i..)=>v
def upperTriDiag (u:UpperTriMat n v) : n=>v = for i. u.i.(0@_)
def lowerTriDiag (l:LowerTriMat n v) : n=>v = for i. l.i.((ordinal i)@_)

def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n Float) (b:n=>v) : n=>v =
def forward_substitute [VSpace v] (a:LowerTriMat n Float) (b:n=>v) : n=>v =
-- Solves lower triangular linear system (inverse a) **. b
yieldState zero \sRef.
for i:n.
s = sum for k:(..<i). -- dot product
a.i.((ordinal k)@_) .* get sRef!(%inject k)
sRef!i := (b.i - s) / a.i.((ordinal i)@_)

def backward_substitute (_:VSpace v) ?=> (a:UpperTriMat n Float) (b:n=>v) : n=>v =
def backward_substitute [VSpace v] (a:UpperTriMat n Float) (b:n=>v) : n=>v =
-- Solves upper triangular linear system (inverse a) **. b
yieldState zero \sRef.
rof i:n.
Expand Down Expand Up @@ -61,7 +61,7 @@ def permSign ((_, sign):Permutation n) : PermutationSign = sign

'### LU decomposition functions

def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n =
def pivotize [Eq n] (a:n=>n=>Float) : Permutation n =
-- Gives a row permutation that makes Gaussian elimination more stable.
yieldState identity_permutation \permRef.
for j:n.
Expand All @@ -71,7 +71,7 @@ def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n =
True -> ()
False -> swapInPlace permRef j row_with_largest

def lu (_:Eq n) ?=> (a: n=>n=>Float) :
def lu [Eq n] (a: n=>n=>Float) :
(LowerTriMat n Float & UpperTriMat n Float & Permutation n) =
-- Computes lower, upper, and permuntation matrices from a square matrix,
-- such that apply_permutation permutation a == lower ** upper.
Expand Down Expand Up @@ -113,10 +113,10 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) :
ukj = get (upperTriIndex uRef k')!(((ordinal j) - (ordinal k))@_)
lik = get (lowerTriIndex lRef i')!((ordinal k)@_)
ukj * lik

uijRef = (upperTriIndex uRef i')!(((ordinal j) - (ordinal i))@_)
uijRef := a.(%inject i).j - s

for i:(j<..).
i' = %inject i
s = sum for k:(..j).
Expand All @@ -125,7 +125,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) :
ukj = get (upperTriIndex uRef k')!i''
lik = get (lowerTriIndex lRef i')!((ordinal k)@_)
ukj * lik

i'' = ((ordinal i) + (ordinal j) + 1)@_
ujj = get (upperTriIndex uRef j)!(0@_)
lijRef = (lowerTriIndex lRef i'')!((ordinal j)@_)
Expand All @@ -135,7 +135,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) :

'### General linear algebra functions.

def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v =
def solve [Eq n, VSpace v] (a:n=>n=>Float) (b:n=>v) : n=>v =
-- There's a small speedup possible by exploiting the fact
-- that l always has ones on the diagonal. It would just require a
-- custom forward_substitute routine that doesn't divide
Expand All @@ -145,18 +145,18 @@ def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v =
y = forward_substitute l b'
backward_substitute u y

def invert (_:Eq n) ?=> (a:n=>n=>Float) : n=>n=>Float =
def invert [Eq n] (a:n=>n=>Float) : n=>n=>Float =
solve a identity_matrix

def determinant (_:Eq n) ?=> (a:n=>n=>Float) : Float =
def determinant [Eq n] (a:n=>n=>Float) : Float =
(l, u, perm) = lu a
prod (for i. (upperTriDiag u).i * (lowerTriDiag l).i) * permSign perm

def sign_and_log_determinant (_:Eq n) ?=> (a:n=>n=>Float) : (Float & Float) =
def sign_and_log_determinant [Eq n] (a:n=>n=>Float) : (Float & Float) =
(l, u, perm) = lu a
diags = for i. (upperTriDiag u).i * (lowerTriDiag l).i
sign = (permSign perm) * prod for i. sign diags.i
sum_of_log_abs = sum for i. log (abs diags.i)
sum_of_log_abs = sum for i. log (abs diags.i)
(sign, sum_of_log_abs)


Expand Down
2 changes: 1 addition & 1 deletion examples/mcmc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def mhStep
HMCParams : Type = (Int & Float) -- leapfrog steps, step size

def leapfrogIntegrate
(_:VSpace a) ?=>
[VSpace a]
((nsteps, dt): HMCParams)
(logProb: a -> LogProb)
((x, p): (a & a))
Expand Down
6 changes: 3 additions & 3 deletions examples/ode-integrator.dx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Time = Float
def length (x: d=>Float) : Float = sqrt $ sum for i. sq x.i
def (./) (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i

def fit_4th_order_polynomial (_:VSpace v) ?=>
def fit_4th_order_polynomial [VSpace v]
(z0:v) (z1:v) (z_mid:v) (dz0:v) (dz1:v) (dt:Time) : (Fin 5)=>v =
-- dz0 and dz1 are gradient evaluations.
a = -2. * dt .* dz0 + 2. * dt .* dz1 - 8. .* z0 - 8. .* z1 + 16. .* z_mid
Expand All @@ -26,7 +26,7 @@ dps_c_mid = [6025192743. /30085553152. /2., 0., 51252292925. /65400821598. /2.,
-2691868925. /45128329728. /2., 187940372067. /1594534317056. /2.,
-1776094331. /19743644256. /2., 11237099. /235043384. /2.]

def interp_fit_dopri (_:VSpace v) ?=>
def interp_fit_dopri [VSpace v]
(z0:v) (z1:v) (k:(Fin 7)=>v) (dt:Time) : (Fin 5)=>v =
-- Fit a polynomial to the results of a Runge-Kutta step.
z_mid = z0 + dt .* (dot dps_c_mid k)
Expand Down Expand Up @@ -64,7 +64,7 @@ c_error = [35. / 384. - 1951. / 21600., 0., 500. / 1113. - 22642. / 50085.,
125. / 192. - 451. / 720., -2187. / 6784. + 12231. / 42400.,
11. / 84. - 649. / 6300., -1. / 60.]

def runge_kutta_step (_:VSpace v) ?=> (func:v->Time->v)
def runge_kutta_step [VSpace v] (func:v->Time->v)
(z0:v) (f0:v) (t0:Time) (dt:Time) : (v & v & v & (Fin 7)=>v) =

evals_init = yieldState zero \r.
Expand Down
2 changes: 1 addition & 1 deletion examples/raytrace.dx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def directionAndLength (x: d=>Float) : (d=>Float & Float) =
def randuniform (lower:Float) (upper:Float) (k:Key) : Float =
lower + (rand k) * (upper - lower)

def sampleAveraged (_:VSpace a) ?=> (sample:Key -> a) (n:Int) (k:Key) : a =
def sampleAveraged [VSpace a] (sample:Key -> a) (n:Int) (k:Key) : a =
yieldState zero \total.
for i:(Fin n).
total := get total + sample (ixkey k i) / IToF n
Expand Down
4 changes: 2 additions & 2 deletions examples/sgd.dx
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@

'## Stochastic Gradient Descent with Momentum

def sgd_step (dict: VSpace a) ?=> (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) =
def sgd_step [VSpace a] (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) =
g = gradfunc x iter
new_m = decay .* m + g
new_x = x - step_size .* new_m
(new_x, new_m)

-- In-place optimization loop.
def sgd (dict: VSpace a) ?=> (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a =
def sgd [VSpace a] (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a =
m0 = zero
(x_final, m_final) = yieldState (x0, m0) \state.
for i:(Fin num_steps).
Expand Down
4 changes: 2 additions & 2 deletions lib/diagram.dx
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def quote (s:String) : String = "\"" <.> s <.> "\""
def strSpaceCatUncurried ((s1,s2):(String & String)) : String =
s1 <.> " " <.> s2

def (<+>) (_:Show a) ?=> (_:Show b) ?=> (s1:a) (s2:b) : String =
def (<+>) [Show a, Show b] (s1:a) (s2:b) : String =
strSpaceCatUncurried ((show s1), (show s2))

def selfClosingBrackets (s:String) : String = "<" <.> s <.> "/>"
Expand All @@ -127,7 +127,7 @@ def tagBracketsAttrUncurried ((tag, attr, s):(String & String & String)) : Strin
def tagBracketsAttr (tag:String) (attr:String) (s:String) : String =
tagBracketsAttrUncurried (tag, attr, s)

def (<=>) (_:Show b) ?=> (attr:String) (val:b) : String =
def (<=>) [Show b] (attr:String) (val:b) : String =
attr <.> "=" <.> quote (show val)

def htmlColor(cs:HtmlColor) : String =
Expand Down
2 changes: 1 addition & 1 deletion lib/plot.dx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def getScaled (sd:ScaledData n a) (i:n) : Maybe Float =
lowColor = [1.0, 0.5, 0.0]
highColor = [0.0, 0.5, 1.0]

def interpolate (_:VSpace a) ?=> (low:a) (high:a) (x:Float) : a =
def interpolate [VSpace a] (low:a) (high:a) (x:Float) : a =
x' = clip (0.0, 1.0) x
(x' .* low) + ((1.0 - x') .* high)

Expand Down
2 changes: 1 addition & 1 deletion lib/png.dx
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def decodeChunk (chunk : Fin 4 => Char) : Maybe (Fin 3 => Char) =
Just base64s -> Just $ base64sToBytes base64s

-- TODO: put this in prelude?
def replace (_:Eq a) ?=> ((old,new):(a&a)) (x:a) : a =
def replace [Eq a] ((old,new):(a&a)) (x:a) : a =
case x == old of
True -> new
False -> x
Expand Down
Loading