From 1f9b2ef54a451cc29b9c665250d3bcbfa47a633a Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 17:28:10 -0500 Subject: [PATCH 1/6] Add new syntax for type class constraints on `def` decls. Fixes #370. --- examples/chol.dx | 4 +- examples/ctc.dx | 7 +- examples/fluidsim.dx | 20 +++--- examples/linear_algebra.dx | 26 ++++---- examples/mcmc.dx | 2 +- examples/ode-integrator.dx | 6 +- examples/raytrace.dx | 2 +- examples/sgd.dx | 4 +- lib/diagram.dx | 4 +- lib/plot.dx | 2 +- lib/png.dx | 2 +- lib/prelude.dx | 131 ++++++++++++++++++------------------- src/lib/Parser.hs | 30 ++++++--- 13 files changed, 126 insertions(+), 114 deletions(-) diff --git a/examples/chol.dx b/examples/chol.dx index 1d8c53d11..63473ba91 100644 --- a/examples/chol.dx +++ b/examples/chol.dx @@ -3,7 +3,7 @@ https://en.wikipedia.org/wiki/Cholesky_decomposition ' ## Cholesky Algorithm -def chol (_:Eq n) ?=> (x:n=>n=>Float) : (n=>n=>Float) = +def chol [Eq n] (x:n=>n=>Float) : (n=>n=>Float) = yieldState zero \buf. for_ i. for j':(..i). j = %inject(j') @@ -31,7 +31,7 @@ def trisolveU (mat:n=>n=>Float) (b:n=>Float) : n=>Float = xPrev = for j:(i..). get (buf!%inject j) buf!i := (b.i - vdot row xPrev) / mat.i.i -def psdsolve (_:Eq n) ?=> (mat:n=>n=>Float) (b:n=>Float) : n=>Float = +def psdsolve [Eq n] (mat:n=>n=>Float) (b:n=>Float) : n=>Float = l = chol mat trisolveU (transpose l) $ trisolveL l b diff --git a/examples/ctc.dx b/examples/ctc.dx index d0dc7979a..aa7b5fc77 100644 --- a/examples/ctc.dx +++ b/examples/ctc.dx @@ -48,8 +48,11 @@ def logaddexp (x:Float) (y:Float) : Float = m = max x y m + ( log ( (exp (x - m) + exp (y - m)))) -def ctc (dict: Eq vocab) ?=> (dict2: Eq position) ?=> (dict3: Eq time) ?=> (blank: vocab) - (logits: time=>vocab=>Float) (labels: position=>vocab) : Float = +def ctc [Eq vocab, Eq position, Eq time] + (blank: vocab) + (logits: time=>vocab=>Float) + (labels: position=>vocab) + : Float = -- Computes log p(labels | logits), marginalizing over possible alignments. -- Todo: remove unnecessary implicit type annotations once -- Dex starts putting implicit types in scope. diff --git a/examples/fluidsim.dx b/examples/fluidsim.dx index 227817517..e1ea7e9ac 100644 --- a/examples/fluidsim.dx +++ b/examples/fluidsim.dx @@ -14,10 +14,10 @@ def incwrap (i:n) : n = -- Increment index, wrapping around at ends. def decwrap (i:n) : n = -- Decrement index, wrapping around at ends. asidx $ mod ((ordinal i) - 1) $ size n -def finite_difference_neighbours (_:Add a) ?=> (x:n=>a) : n=>a = +def finite_difference_neighbours [Add a] (x:n=>a) : n=>a = for i. x.(incwrap i) - x.(decwrap i) -def add_neighbours (_:Add a) ?=> (x:n=>a) : n=>a = +def add_neighbours [Add a] (x:n=>a) : n=>a = for i. x.(incwrap i) + x.(decwrap i) def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a = @@ -26,21 +26,21 @@ def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a = def apply_along_axis2 (f:c=>a -> c=>a) (x:b=>c=>a) : b=>c=>a = for i. f x.i -def fdx (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = +def fdx [Add a] (x:n=>m=>a) : (n=>m=>a) = apply_along_axis1 finite_difference_neighbours x -def fdy (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = +def fdy [Add a] (x:n=>m=>a) : (n=>m=>a) = apply_along_axis2 finite_difference_neighbours x -def divergence (_:Add a) ?=> (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) = +def divergence [Add a] (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) = fdx vx + fdy vy -def add_neighbours_2d (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = +def add_neighbours_2d [Add a] (x:n=>m=>a) : (n=>m=>a) = ax1 = apply_along_axis1 add_neighbours x ax2 = apply_along_axis2 add_neighbours x ax1 + ax2 -def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = +def project [VSpace a] (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = -- Project the velocity field to be approximately mass-conserving, -- using a few iterations of Gauss-Seidel. h = 1.0 / IToF (size n) @@ -60,13 +60,13 @@ def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = for i j. [vx.i.j, vy.i.j] -- pack back into a table. -def bilinear_interp (_:VSpace a) ?=> (right_weight:Float) (bottom_weight:Float) +def bilinear_interp [VSpace a] (right_weight:Float) (bottom_weight:Float) (topleft: a) (bottomleft: a) (topright: a) (bottomright: a) : a = left = (1.0 - right_weight) .* ((1.0 - bottom_weight) .* topleft + bottom_weight .* bottomleft) right = right_weight .* ((1.0 - bottom_weight) .* topright + bottom_weight .* bottomright) left + right -def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = +def advect [VSpace a] (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = -- Move field f according to x and y velocities (u and v) -- using an implicit Euler integrator. @@ -95,7 +95,7 @@ def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = -- A convex weighting of the 4 surrounding cells. bilinear_interp right_weight bottom_weight f.l.t f.l.b f.r.t f.r.b -def fluidsim (_: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) +def fluidsim [ VSpace a] (num_steps: Int) (color_init: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : (Fin num_steps)=>n=>m=>a = withState (color_init, v) \state. for i:(Fin num_steps). diff --git a/examples/linear_algebra.dx b/examples/linear_algebra.dx index ec6a6e5e9..2d0cffd14 100644 --- a/examples/linear_algebra.dx +++ b/examples/linear_algebra.dx @@ -1,6 +1,6 @@ '## LU Decomposition and Matrix Inversion -def identity_matrix (_:Eq n) ?=> (_:Add a) ?=> (_:Mul a) ?=> : n=>n=>a = +def identity_matrix [Eq n, Add a, Mul a] : n=>n=>a = for i j. select (i == j) one zero '### Triangular matrices @@ -11,7 +11,7 @@ def UpperTriMat (n:Type) (v:Type) : Type = i:n=>(i..)=>v def upperTriDiag (u:UpperTriMat n v) : n=>v = for i. u.i.(0@_) def lowerTriDiag (l:LowerTriMat n v) : n=>v = for i. l.i.((ordinal i)@_) -def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n Float) (b:n=>v) : n=>v = +def forward_substitute [VSpace v] (a:LowerTriMat n Float) (b:n=>v) : n=>v = -- Solves lower triangular linear system (inverse a) **. b yieldState zero \sRef. for i:n. @@ -19,7 +19,7 @@ def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n Float) (b:n=>v) : n=>v a.i.((ordinal k)@_) .* get sRef!(%inject k) sRef!i := (b.i - s) / a.i.((ordinal i)@_) -def backward_substitute (_:VSpace v) ?=> (a:UpperTriMat n Float) (b:n=>v) : n=>v = +def backward_substitute [VSpace v] (a:UpperTriMat n Float) (b:n=>v) : n=>v = -- Solves upper triangular linear system (inverse a) **. b yieldState zero \sRef. rof i:n. @@ -61,7 +61,7 @@ def permSign ((_, sign):Permutation n) : PermutationSign = sign '### LU decomposition functions -def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n = +def pivotize [Eq n] (a:n=>n=>Float) : Permutation n = -- Gives a row permutation that makes Gaussian elimination more stable. yieldState identity_permutation \permRef. for j:n. @@ -71,7 +71,7 @@ def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n = True -> () False -> swapInPlace permRef j row_with_largest -def lu (_:Eq n) ?=> (a: n=>n=>Float) : +def lu [Eq n] (a: n=>n=>Float) : (LowerTriMat n Float & UpperTriMat n Float & Permutation n) = -- Computes lower, upper, and permuntation matrices from a square matrix, -- such that apply_permutation permutation a == lower ** upper. @@ -113,10 +113,10 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : ukj = get (upperTriIndex uRef k')!(((ordinal j) - (ordinal k))@_) lik = get (lowerTriIndex lRef i')!((ordinal k)@_) ukj * lik - + uijRef = (upperTriIndex uRef i')!(((ordinal j) - (ordinal i))@_) uijRef := a.(%inject i).j - s - + for i:(j<..). i' = %inject i s = sum for k:(..j). @@ -125,7 +125,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : ukj = get (upperTriIndex uRef k')!i'' lik = get (lowerTriIndex lRef i')!((ordinal k)@_) ukj * lik - + i'' = ((ordinal i) + (ordinal j) + 1)@_ ujj = get (upperTriIndex uRef j)!(0@_) lijRef = (lowerTriIndex lRef i'')!((ordinal j)@_) @@ -135,7 +135,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : '### General linear algebra functions. -def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v = +def solve [Eq n, VSpace v] (a:n=>n=>Float) (b:n=>v) : n=>v = -- There's a small speedup possible by exploiting the fact -- that l always has ones on the diagonal. It would just require a -- custom forward_substitute routine that doesn't divide @@ -145,18 +145,18 @@ def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v = y = forward_substitute l b' backward_substitute u y -def invert (_:Eq n) ?=> (a:n=>n=>Float) : n=>n=>Float = +def invert [Eq n] (a:n=>n=>Float) : n=>n=>Float = solve a identity_matrix -def determinant (_:Eq n) ?=> (a:n=>n=>Float) : Float = +def determinant [Eq n] (a:n=>n=>Float) : Float = (l, u, perm) = lu a prod (for i. (upperTriDiag u).i * (lowerTriDiag l).i) * permSign perm -def sign_and_log_determinant (_:Eq n) ?=> (a:n=>n=>Float) : (Float & Float) = +def sign_and_log_determinant [Eq n] (a:n=>n=>Float) : (Float & Float) = (l, u, perm) = lu a diags = for i. (upperTriDiag u).i * (lowerTriDiag l).i sign = (permSign perm) * prod for i. sign diags.i - sum_of_log_abs = sum for i. log (abs diags.i) + sum_of_log_abs = sum for i. log (abs diags.i) (sign, sum_of_log_abs) diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 1ba161c85..a3bcbd314 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -55,7 +55,7 @@ def mhStep HMCParams : Type = (Int & Float) -- leapfrog steps, step size def leapfrogIntegrate - (_:VSpace a) ?=> + [VSpace a] ((nsteps, dt): HMCParams) (logProb: a -> LogProb) ((x, p): (a & a)) diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index f553fca91..53e568d5a 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -12,7 +12,7 @@ Time = Float def length (x: d=>Float) : Float = sqrt $ sum for i. sq x.i def (./) (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i -def fit_4th_order_polynomial (_:VSpace v) ?=> +def fit_4th_order_polynomial [VSpace v] (z0:v) (z1:v) (z_mid:v) (dz0:v) (dz1:v) (dt:Time) : (Fin 5)=>v = -- dz0 and dz1 are gradient evaluations. a = -2. * dt .* dz0 + 2. * dt .* dz1 - 8. .* z0 - 8. .* z1 + 16. .* z_mid @@ -26,7 +26,7 @@ dps_c_mid = [6025192743. /30085553152. /2., 0., 51252292925. /65400821598. /2., -2691868925. /45128329728. /2., 187940372067. /1594534317056. /2., -1776094331. /19743644256. /2., 11237099. /235043384. /2.] -def interp_fit_dopri (_:VSpace v) ?=> +def interp_fit_dopri [VSpace v] (z0:v) (z1:v) (k:(Fin 7)=>v) (dt:Time) : (Fin 5)=>v = -- Fit a polynomial to the results of a Runge-Kutta step. z_mid = z0 + dt .* (dot dps_c_mid k) @@ -64,7 +64,7 @@ c_error = [35. / 384. - 1951. / 21600., 0., 500. / 1113. - 22642. / 50085., 125. / 192. - 451. / 720., -2187. / 6784. + 12231. / 42400., 11. / 84. - 649. / 6300., -1. / 60.] -def runge_kutta_step (_:VSpace v) ?=> (func:v->Time->v) +def runge_kutta_step [VSpace v] (func:v->Time->v) (z0:v) (f0:v) (t0:Time) (dt:Time) : (v & v & v & (Fin 7)=>v) = evals_init = yieldState zero \r. diff --git a/examples/raytrace.dx b/examples/raytrace.dx index fc44b7054..051722f56 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -24,7 +24,7 @@ def directionAndLength (x: d=>Float) : (d=>Float & Float) = def randuniform (lower:Float) (upper:Float) (k:Key) : Float = lower + (rand k) * (upper - lower) -def sampleAveraged (_:VSpace a) ?=> (sample:Key -> a) (n:Int) (k:Key) : a = +def sampleAveraged [VSpace a] (sample:Key -> a) (n:Int) (k:Key) : a = yieldState zero \total. for i:(Fin n). total := get total + sample (ixkey k i) / IToF n diff --git a/examples/sgd.dx b/examples/sgd.dx index 3e5a5575a..bc1a0cb29 100644 --- a/examples/sgd.dx +++ b/examples/sgd.dx @@ -1,14 +1,14 @@ '## Stochastic Gradient Descent with Momentum -def sgd_step (dict: VSpace a) ?=> (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) = +def sgd_step [VSpace a] (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) = g = gradfunc x iter new_m = decay .* m + g new_x = x - step_size .* new_m (new_x, new_m) -- In-place optimization loop. -def sgd (dict: VSpace a) ?=> (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a = +def sgd [VSpace a] (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a = m0 = zero (x_final, m_final) = yieldState (x0, m0) \state. for i:(Fin num_steps). diff --git a/lib/diagram.dx b/lib/diagram.dx index 4e91ddb9d..98ae4e60e 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -112,7 +112,7 @@ def quote (s:String) : String = "\"" <.> s <.> "\"" def strSpaceCatUncurried ((s1,s2):(String & String)) : String = s1 <.> " " <.> s2 -def (<+>) (_:Show a) ?=> (_:Show b) ?=> (s1:a) (s2:b) : String = +def (<+>) [Show a, Show b] (s1:a) (s2:b) : String = strSpaceCatUncurried ((show s1), (show s2)) def selfClosingBrackets (s:String) : String = "<" <.> s <.> "/>" @@ -127,7 +127,7 @@ def tagBracketsAttrUncurried ((tag, attr, s):(String & String & String)) : Strin def tagBracketsAttr (tag:String) (attr:String) (s:String) : String = tagBracketsAttrUncurried (tag, attr, s) -def (<=>) (_:Show b) ?=> (attr:String) (val:b) : String = +def (<=>) [Show b] (attr:String) (val:b) : String = attr <.> "=" <.> quote (show val) def htmlColor(cs:HtmlColor) : String = diff --git a/lib/plot.dx b/lib/plot.dx index 0212ad537..4529435fb 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -49,7 +49,7 @@ def getScaled (sd:ScaledData n a) (i:n) : Maybe Float = lowColor = [1.0, 0.5, 0.0] highColor = [0.0, 0.5, 1.0] -def interpolate (_:VSpace a) ?=> (low:a) (high:a) (x:Float) : a = +def interpolate [VSpace a] (low:a) (high:a) (x:Float) : a = x' = clip (0.0, 1.0) x (x' .* low) + ((1.0 - x') .* high) diff --git a/lib/png.dx b/lib/png.dx index 131f7c609..b542449d5 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -72,7 +72,7 @@ def decodeChunk (chunk : Fin 4 => Char) : Maybe (Fin 3 => Char) = Just base64s -> Just $ base64sToBytes base64s -- TODO: put this in prelude? -def replace (_:Eq a) ?=> ((old,new):(a&a)) (x:a) : a = +def replace [Eq a] ((old,new):(a&a)) (x:a) : a = case x == old of True -> new False -> x diff --git a/lib/prelude.dx b/lib/prelude.dx index 3e00ac2e0..54f7aefef 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -45,8 +45,8 @@ interface Add a:Type where sub : a -> a -> a zero : a -def (+) (d:Add a) ?=> : a -> a -> a = add -def (-) (d:Add a) ?=> : a -> a -> a = sub +def (+) [Add a] : a -> a -> a = add +def (-) [Add a] : a -> a -> a = sub instance float64Add : Add Float64 where add = \x:Float64 y:Float64. %fadd x y @@ -87,7 +87,7 @@ interface Mul a:Type where mul : a -> a -> a one : a -def (*) (d:Mul a) ?=> : a -> a -> a = mul +def (*) [Mul a] : a -> a -> a = mul instance float64Mul : Mul Float64 where mul = \x:Float64 y:Float64. %fmul x y @@ -162,10 +162,10 @@ data VSpace a:Type = MkVSpace (Add a) (Float -> a -> a) @superclass def addFromVSpace (d:VSpace a) : Add a = case d of MkVSpace addDict _ -> addDict -def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale -(*.) : VSpace a ?=> a -> Float -> a = flip (.*) -def (/) (_:VSpace a) ?=> (v:a) (s:Float) : a = (divide 1.0 s) .* v -def neg (_:VSpace a) ?=> (v:a) : a = (-1.0) .* v +def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale +def (*.) [VSpace a] : a -> Float -> a = flip (.*) +def (/) [VSpace a] (v:a) (s:Float) : a = divide 1.0 s .* v +def neg [VSpace a] (v:a) : a = (-1.0) .* v @instance floatVS : VSpace Float = MkVSpace float32Add (*) @instance tabVS : VSpace a ?=> VSpace (n=>a) = MkVSpace tabAdd \s xs. for i. s .* xs.i @@ -292,12 +292,12 @@ data Ord a:Type = MkOrd (Eq a) (a -> a -> Bool) (a -> a -> Bool) -- eq, gt, lt def eqFromOrd (d:Ord a) : Eq a = case d of MkOrd eq _ _ -> eq def (==) (d:Eq a) ?=> (x:a) (y:a) : Bool = case d of MkEq eq -> eq x y -def (/=) (d:Eq a) ?=> (x:a) (y:a) : Bool = not $ x == y +def (/=) [Eq a] (x:a) (y:a) : Bool = not $ x == y def (>) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ gt _ -> gt x y def (<) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ _ lt -> lt x y -def (<=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x>y || x==y +def (<=) [Ord a] (x:a) (y:a) : Bool = x=) [Ord a] (x:a) (y:a) : Bool = x>y || x==y @instance float64Eq : Eq Float64 = MkEq \x:Float64 y:Float64. W8ToB $ %feq x y @instance float32Eq : Eq Float32 = MkEq \x:Float32 y:Float32. W8ToB $ %feq x y @@ -321,19 +321,18 @@ def (>=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x>y || x==y @instance unitOrd : Ord Unit = (MkOrd unitEq (\x y. False) (\x y. False)) @instance -def pairEq (eqA: Eq a)?=> (eqB: Eq b)?=> : Eq (a & b) = MkEq $ +def pairEq [Eq a, Eq b] : Eq (a & b) = MkEq $ \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 @instance -def pairOrd (ordA: Ord a)?=> (ordB: Ord b)?=> : Ord (a & b) = +def pairOrd [Ord a, Ord b] : Ord (a & b) = pairGt = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) pairLt = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) MkOrd pairEq pairGt pairLt - -- TODO: accumulate using the True/&& monoid @instance -def tabEq (n:Type) ?-> (eqA: Eq a) ?=> : Eq (n=>a) = MkEq $ +def tabEq [Eq a] : Eq (n=>a) = MkEq $ \xs ys. numDifferent : Float = yieldAccum \ref. for i. @@ -362,7 +361,7 @@ interface Floating a:Type where pow : a -> a -> a lgamma : a -> a -def lbeta (_ : Add a) ?=> (_ : Floating a) ?=> : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y) +def lbeta [Add a, Floating a] : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y) -- Todo: better numerics for very large and small values. -- Using %exp here to avoid circular definition problems. @@ -468,28 +467,28 @@ instance int32Storable : Storable Int32 where load = int32Load storageSize = const 4 -def unpackPairPtr (_:Storable a) ?=> (_:Storable b) ?=> +def unpackPairPtr [Storable a, Storable b] (pairPtr: Ptr (a & b)) : (Ptr a & Ptr b) = (MkPtr rawPtrX) = pairPtr rawPtrY = %ptrOffset rawPtrX (storageSize (typeVehicle a)) (MkPtr rawPtrX, MkPtr rawPtrY) -def pairStore (_:Storable a) ?=> (_:Storable b) ?=> +def pairStore [Storable a, Storable b] (pairPtr:Ptr (a & b)) ((x, y):(a & b)) : {State World} Unit = (xPtr, yPtr) = unpackPairPtr pairPtr store xPtr x store yPtr y -def pairLoad (_:Storable a) ?=> (_:Storable b) ?=> +def pairLoad [Storable a, Storable b] (pairPtr:Ptr (a & b)) : {State World} (a & b) = (xPtr, yPtr) = unpackPairPtr pairPtr (load xPtr, load yPtr) -def pairStorageSize (_:Storable a) ?=> (_:Storable b) ?=> +def pairStorageSize [Storable a, Storable b] (_:TypeVehicle (a & b)) : Int = storageSize (typeVehicle a) + storageSize (typeVehicle b) -instance pairStorable : Storable a ?=> Storable b ?=> Storable (a & b) where +instance pairStorable : (Storable a) ?=> (Storable b) ?=> Storable (a & b) where store = pairStore load = pairLoad storageSize = pairStorageSize @@ -508,7 +507,7 @@ instance ptrStorable : Storable (Ptr a) where -- TODO: Storable instances for other types -def malloc (_:Storable a) ?=> (n:Int) : {State World} (Ptr a) = +def malloc [Storable a] (n:Int) : {State World} (Ptr a) = numBytes = storageSize (typeVehicle a) * n MkPtr $ %alloc numBytes @@ -516,7 +515,7 @@ def free (ptr:Ptr a) : {State World} Unit = (MkPtr ptr') = ptr %free ptr' -def (+>>) (_:Storable a) ?=> (ptr:Ptr a) (i:Int) : Ptr a = +def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = (MkPtr ptr') = ptr i' = i * storageSize (typeVehicle a) MkPtr $ %ptrOffset ptr' i' @@ -524,28 +523,28 @@ def (+>>) (_:Storable a) ?=> (ptr:Ptr a) (i:Int) : Ptr a = -- TODO: generalize these brackets to allow other effects -- TODO: consider making a Storable instance for tables instead -def storeTab (_:Storable a) ?=> (ptr: Ptr a) (tab:n=>a) : {State World} Unit = +def storeTab [Storable a] (ptr: Ptr a) (tab:n=>a) : {State World} Unit = for_ i. store (ptr +>> ordinal i) tab.i -def memcpy (_:Storable a) ?=> (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = +def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = for_ i:(Fin n). i' = ordinal i store (dest +>> i') (load $ src +>> i') -def withAlloc (_:Storable a) ?=> +def withAlloc [Storable a] (n:Int) (action: Ptr a -> {State World} b) : {State World} b = ptr = malloc n result = action ptr free ptr result -def withTabPtr (_:Storable a) ?=> +def withTabPtr [Storable a] (xs:n=>a) (action : Ptr a -> {State World} b) : {State World} b = withAlloc (size n) \ptr. for i. store (ptr +>> ordinal i) xs.i action ptr -def tabFromPtr (_:Storable a) ?=> (n:Type) -> (ptr:Ptr a) : {State World} n=>a = +def tabFromPtr [Storable a] (n:Type) -> (ptr:Ptr a) : {State World} n=>a = for i. load $ ptr +>> ordinal i '## Miscellaneous common utilities @@ -558,8 +557,8 @@ def map (f:a->{|eff} b) (xs: n=>a) : {|eff} (n=>b) = for i. f xs.i def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = view i. (xs.i, ys.i) def unzip (xys:n=>(a&b)) : (n=>a & n=>b) = (map fst xys, map snd xys) def fanout (n:Type) (x:a) : n=>a = view i. x -def sq (d:Mul a) ?=> (x:a) : a = x * x -def abs (_:Add a) ?=> (_:Ord a) ?=> (x:a) : a = select (x > zero) x (zero - x) +def sq [Mul a] (x:a) : a = x * x +def abs [Add a, Ord a] (x:a) : a = select (x > zero) x (zero - x) def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y def reindex (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i) @@ -582,9 +581,9 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i -def sum (_: Add v) ?=> (xs:n=>v) : v = reduce zero (+) xs -def prod (_: Mul v) ?=> (xs:n=>v) : v = reduce one (*) xs -def mean (n:Type) ?-> (xs:n=>Float) : Float = sum xs / IToF (size n) +def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs +def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs +def mean (xs:n=>Float) : Float = sum xs / IToF (size n) def std (xs:n=>Float) : Float = sqrt $ mean (map sq xs) - sq (mean xs) def any (xs:n=>Bool) : Bool = reduce False (||) xs def all (xs:n=>Bool) : Bool = reduce True (&&) xs @@ -599,7 +598,7 @@ def linspace (n:Type) (low:Float) (high:Float) : n=>Float = def transpose (x:n=>m=>a) : m=>n=>a = view i j. x.j.i def vdot (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i -def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j +def dot [VSpace v] (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j -- matmul. Better symbol to use? `@`? (**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. @@ -611,7 +610,7 @@ def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = fsum view (i,j). x.i * mat.i.j * y.j -def eye (_:Eq n) ?=> : n=>n=>Float = +def eye [Eq n] : n=>n=>Float = for i j. select (i == j) 1.0 0.0 '## Pseudorandom number generator utilities @@ -645,7 +644,7 @@ def randInt (k:Key) : Int = (I64ToI k) `mod` 2147483647 def bern (p:Float) (k:Key) : Bool = rand k < p -def randnVec (n:Type) ?-> (k:Key) : n=>Float = +def randnVec (k:Key) : n=>Float = for i. randn (ixkey k i) def cumSum (xs: n=>Float) : n=>Float = @@ -679,7 +678,7 @@ interface HasDefaultTolerance a:Type where atol : a rtol : a -def (~~) (_:HasAllClose a) ?=> (d:HasDefaultTolerance a) ?=> : a -> a -> Bool = allclose atol rtol +def (~~) [HasAllClose a, HasDefaultTolerance a] : a -> a -> Bool = allclose atol rtol instance allCloseF32 : HasAllClose Float32 where allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) @@ -758,14 +757,12 @@ def Tile (n : Type) (m : Type) : Type = %IndexSlice n m -- elements of n. In this view (+>) is just function application, while ++> -- is currying followed by function application. We cannot represent currying -- in isolation, because `Tile n (Tile u v)` does not make sense, unlike `Tile n (u & v)`. -def (+>) (l : Type) ?-> (t:Tile n l) (i : l) : n = %sliceOffset t i +def (+>) (t:Tile n l) (i : l) : n = %sliceOffset t i def (++>) (t : Tile n (u & v)) (i : u) : Tile n v = %sliceCurry t i -def tile (l : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} l=>a)) +def tile (fTile : (t:(Tile n l) -> {|eff} l=>a)) (fScalar : n -> {|eff} a) : {|eff} n=>a = %tiled fTile fScalar -def tile1 (n : Type) ?-> (l : Type) ?-> (m : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) +def tile1 (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) (fScalar : n -> {|eff} m=>a) : {|eff} m=>n=>a = %tiledd fTile fScalar -- TODO: This should become just `loadVector $ for i. arr.(t +> i)` @@ -783,7 +780,7 @@ interface Monoid a:Type where mempty : a mcombine : a -> a -> a -- can't use `<>` just for parser reasons? -(<>) : Monoid a ?=> a -> a -> a = mcombine +def (<>) [Monoid a] : a -> a -> a = mcombine '## Length-erased lists @@ -793,7 +790,7 @@ data List a:Type = def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = for i. xs.(unsafeFromOrdinal _ (ordinal i)) -def toList (n:Type) ?-> (xs:n=>a) : List a = +def toList (xs:n=>a) : List a = n' = size n AsList _ $ unsafeCastTable (Fin n') xs @@ -895,7 +892,7 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = -- TODO: would be nice to be able to use records here data DynBuffer a:Type = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr -def withDynamicBuffer (_:Storable a) ?=> +def withDynamicBuffer [Storable a] (action: DynBuffer a -> {State World} b) : {State World} b = initMaxSize = 256 withAlloc 1 \dbPtr. @@ -906,7 +903,7 @@ def withDynamicBuffer (_:Storable a) ?=> free bufPtr' result -def maybeIncreaseBufferSize (_:Storable a) ?=> +def maybeIncreaseBufferSize [Storable a] (buf: DynBuffer a) (sizeDelta:Int) : {State World} Unit = (MkDynBuffer dbPtr) = buf (size, maxSize, bufPtr) = load dbPtr @@ -918,7 +915,7 @@ def maybeIncreaseBufferSize (_:Storable a) ?=> memcpy newBufPtr bufPtr size store dbPtr (size, newMaxSize, newBufPtr) -def extendDynBuffer (_:Storable a) ?=> +def extendDynBuffer [Storable a] (buf: DynBuffer a) (new:List a) : {State World} Unit = (AsList n xs) = new maybeIncreaseBufferSize buf n @@ -928,13 +925,13 @@ def extendDynBuffer (_:Storable a) ?=> storeTab (bufPtr +>> size) xs store dbPtr (newSize, maxSize, bufPtr) -def loadDynBuffer (_:Storable a) ?=> +def loadDynBuffer [Storable a] (buf: DynBuffer a) : {State World} (List a) = (MkDynBuffer dbPtr) = buf (size, _, bufPtr) = load dbPtr AsList size $ tabFromPtr _ bufPtr -def pushDynBuffer (_:Storable a) ?=> +def pushDynBuffer [Storable a] (buf: DynBuffer a) (x:a) : {State World} Unit = extendDynBuffer buf $ AsList _ [x] @@ -1194,7 +1191,7 @@ def error (s:String) : a = unsafeIO do print s %throwError a -def todo (a:Type) ?-> : a = error "TODO: implement it!" +def todo : a = error "TODO: implement it!" def fromOrdinal (n:Type) (i:Int) : n = case (0 <= i) && (i < size n) of @@ -1210,7 +1207,7 @@ def castTable (m:Type) (xs:n=>a) : m=>a = False -> error $ "Table size mismatch in cast: " <> show (size m) <> " vs " <> show (size n) -def asidx (n:Type) ?-> (i:Int) : n = fromOrdinal n i +def asidx (i:Int) : n = fromOrdinal n i def (@) (i:Int) (n:Type) : n = fromOrdinal n i def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = @@ -1218,11 +1215,11 @@ def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = def head (xs:n=>a) : a = xs.(0@_) -def tail (n:Type) ?-> (xs:n=>a) (start:Int) : List a = +def tail (xs:n=>a) (start:Int) : List a = numElts = size n - start toList $ slice xs start (Fin numElts) -def randIdx (n:Type) ?-> (k:Key) : n = +def randIdx (k:Key) : n = unif = rand k fromOrdinal n $ FToI $ floor $ unif * IToF (size n) @@ -1246,7 +1243,7 @@ instance finArb : n:Int ?-> Arbitrary (Fin n) where 'Control flow -- returns the highest index `i` such that `xs.i <= x` -def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = +def searchSorted [Ord a] (xs:n=>a) (x:a) : Maybe n = if size n == 0 then Nothing else if x < xs.(fromOrdinal _ 0) @@ -1264,28 +1261,28 @@ def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = 'min / max etc -def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y -def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y +def minBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y +def maxBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y -def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 -def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 +def min [Ord o] (x1: o) -> (x2: o) : o = minBy id x1 x2 +def max [Ord o] (x1: o) -> (x2: o) : o = maxBy id x1 x2 -def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = +def minimumBy [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (minBy f) xs -def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = +def maximumBy [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (maxBy f) xs -def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs -def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs +def minimum [Ord o] (xs:n=>o) : o = minimumBy id xs +def maximum [Ord o] (xs:n=>o) : o = maximumBy id xs -def argmin (_:Ord o) ?=> (xs:n=>o) : n = +def argmin [Ord o] (xs:n=>o) : n = zeroth = (0@_, xs.(0@_)) compare = \(idx1, x1) (idx2, x2). select (x1 < x2) (idx1, x1) (idx2, x2) zipped = for i. (i, xs.i) fst $ reduce zeroth compare zipped -def clip (_:Ord a) ?=> ((low,high):(a&a)) (x:a) : a = +def clip [Ord a] ((low,high):(a&a)) (x:a) : a = min high $ max low x '## Trigonometric functions @@ -1307,7 +1304,7 @@ def atan_inner (x:Float) : Float = r = r * s r * x + x -def min_and_max (_: Ord a) ?=> (x:a) (y:a) : (a & a) = +def min_and_max [Ord a] (x:a) (y:a) : (a & a) = select (x < y) (x, y) (y, x) -- get both with one comparison. def atan2 (y:Float) (x:Float) : Float = @@ -1461,7 +1458,7 @@ def reverse (x:n=>a) : n=>a = s = size n for i. x.((s - 1 - ordinal i)@_) -def padTo (n:Type) ?-> (m:Type) (x:a) (xs:n=>a) : (m=>a) = +def padTo (m:Type) (x:a) (xs:n=>a) : (m=>a) = n' = size n for i. i' = ordinal i @@ -1483,7 +1480,7 @@ def seqMaybes (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = True -> Nothing False -> Just $ map fromJust xs -def linearSearch (_:Eq a) ?=> (xs:n=>a) (query:a) : Maybe n = +def linearSearch [Eq a] (xs:n=>a) (query:a) : Maybe n = yieldState Nothing \ref. for i. case xs.i == query of True -> ref := Just i @@ -1555,7 +1552,7 @@ def softmax (x: n=>Float) : n=>Float = s = sum e for i. e.i / s -def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = +def evalpoly [VSpace v] (coefficients:n=>v) (x:Float) : v = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients.i + x .* c diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 3129b4c0e..6eab3f678 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -258,7 +258,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UFor _ _ _ -> error "Unexpected for in type annotation" UHole -> mempty UTypeAnn v ty -> findVarsInAppLHS v <> findVarsInAppLHS ty - UTabCon _ -> error "Unexpected table in type annotation" + UTabCon _ -> mempty UIndexRange low high -> foldMap findVarsInAppLHS low <> foldMap findVarsInAppLHS high UPrimExpr prim -> foldMap findVarsInAppLHS prim @@ -329,7 +329,7 @@ interfaceDef = do ns $ UApp (PlainArrow ()) func (var typeVarName) recordStr = "recordVar" recordPat = ns $ UPatRecord $ Ext (labeledSingleton fLabel (patb - fLabel)) $ Just (ns (UPatBinder (Ignore ()))) + fLabel)) $ Just underscorePat conPat = ns $ UPatCon (mkInterfaceConsName interfaceName) $ toNest [patb recordStr] @@ -430,20 +430,31 @@ funDefLet :: Parser (UExpr -> UDecl) funDefLet = label "function definition" $ mayBreak $ do keyWord DefKW v <- letPat - bs <- many arg + cs <- defClassConstraints + argBinders <- many arg (eff, ty) <- label "result type annotation" $ annot effectiveType - when (null bs && eff /= Pure) $ fail "Nullary def can't have effects" + when (null argBinders && eff /= Pure) $ fail "Nullary def can't have effects" + let bs = map classAsBinder cs ++ argBinders let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) let lamBinders = flip map bs $ \(p,_, arr) -> ((p,Nothing), arr) return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) where + classAsBinder :: UType -> (UPat, UType, UArrow) + classAsBinder ty = (underscorePat, ty, ClassArrow) + arg :: Parser (UPat, UType, UArrow) arg = label "def arg" $ do (p, ty) <-parens ((,) <$> pat <*> annot uType) arr <- arrow (return ()) <|> return (PlainArrow ()) return (p, ty, arr) +defClassConstraints :: Parser [UType] +defClassConstraints = + (brackets $ mayNotPair $ uType `sepBy` sym ",") + <|> return [] + "class constraints" + nameAsPat :: Parser Name -> Parser UPat nameAsPat p = withSrc $ (UPatBinder . Bind . (:>())) <$> p @@ -520,11 +531,12 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ noSrc $ UDecl (ULet PlainLet underscorePat e) $ noSrc unitExpr + then return $ noSrc $ UDecl (ULet PlainLet (underscorePat, Nothing) e) $ + noSrc unitExpr else return e - where - underscorePat :: UPatAnn - underscorePat = (noSrc $ UPatBinder $ Ignore (), Nothing) + +underscorePat :: UPat +underscorePat = noSrc $ UPatBinder $ Ignore () unitExpr :: UExpr' unitExpr = UPrimExpr $ ConExpr UnitCon @@ -558,7 +570,7 @@ wrapUStatements statements = case statements of (s, pos):rest -> WithSrc (Just pos) $ case s of Left d -> UDecl d $ wrapUStatements rest Right e -> UDecl d $ wrapUStatements rest - where d = ULet PlainLet (WithSrc (Just pos) (UPatBinder (Ignore ())), Nothing) e + where d = ULet PlainLet (underscorePat, Nothing) e [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement From 46222ea934d35a6692c2fc863e1265ba544df725 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 4 Jan 2021 14:00:37 -0500 Subject: [PATCH 2/6] Overhaul new-style interface/instance decls. * Handle superclasses * Remove the need to name instances explicitly * Push down types from interface definitions into instance methods * Improve error messages for missing/duplicated methods As par of this change, I moved the lowering (turning interface/instance decls into data defs and method/super class getters) from the parser to type inference where we have much more context about existing definitions. Fixes #370. --- lib/diagram.dx | 2 +- lib/prelude.dx | 476 +++++++++++++++++++-------------------- src/lib/Embed.hs | 28 ++- src/lib/Env.hs | 2 + src/lib/Imp.hs | 1 + src/lib/Inference.hs | 157 +++++++++++-- src/lib/PPrint.hs | 47 +--- src/lib/Parser.hs | 286 ++++++++++------------- src/lib/Syntax.hs | 57 ++++- src/lib/Type.hs | 9 +- tests/adt-tests.dx | 14 +- tests/io-tests.dx | 2 +- tests/typeclass-tests.dx | 61 ++--- 13 files changed, 619 insertions(+), 523 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index 98ae4e60e..a05fc1cb3 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -36,7 +36,7 @@ defaultGeomStyle : GeomStyle = -- TODO: consider sharing attributes among a set of objects for efficiency data Diagram = MkDiagram (List (GeomStyle & Point & Geom)) -instance monoidDiagram : Monoid Diagram where +instance Monoid Diagram mempty = MkDiagram mempty mcombine = \(MkDiagram d1) (MkDiagram d2). MkDiagram $ d1 <> d2 diff --git a/lib/prelude.dx b/lib/prelude.dx index 54f7aefef..1831ef83e 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -40,7 +40,7 @@ def FToI (x:Float) : Int = internalCast _ x def I64ToRawPtr (x:Int64 ) : RawPtr = internalCast _ x def RawPtrToI64 (x:RawPtr) : Int64 = internalCast _ x -interface Add a:Type where +interface Add a add : a -> a -> a sub : a -> a -> a zero : a @@ -48,97 +48,97 @@ interface Add a:Type where def (+) [Add a] : a -> a -> a = add def (-) [Add a] : a -> a -> a = sub -instance float64Add : Add Float64 where - add = \x:Float64 y:Float64. %fadd x y - sub = \x:Float64 y:Float64. %fsub x y +instance Add Float64 + add = \x y. %fadd x y + sub = \x y. %fsub x y zero = FToF64 0.0 -instance float32Add : Add Float32 where - add = \x:Float32 y:Float32. %fadd x y - sub = \x:Float32 y:Float32. %fsub x y +instance Add Float32 + add = \x y. %fadd x y + sub = \x y. %fsub x y zero = FToF32 0.0 -instance int64Add : Add Int64 where - add = \x:Int64 y:Int64. %iadd x y - sub = \x:Int64 y:Int64. %isub x y +instance Add Int64 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToI64 0 -instance int32Add : Add Int32 where - add = \x:Int32 y:Int32. %iadd x y - sub = \x:Int32 y:Int32. %isub x y +instance Add Int32 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToI32 0 -instance word8Add : Add Word8 where - add = \x:Word8 y:Word8. %iadd x y - sub = \x:Word8 y:Word8. %isub x y +instance Add Word8 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToW8 0 -instance unitAdd : Add Unit where +instance Add Unit add = \x y. () sub = \x y. () zero = () -instance tabAdd : Add a ?=> Add (n=>a) where +instance [Add a] Add (n=>a) add = \xs ys. for i. xs.i + ys.i sub = \xs ys. for i. xs.i - ys.i zero = for _. zero -interface Mul a:Type where +interface Mul a mul : a -> a -> a one : a def (*) [Mul a] : a -> a -> a = mul -instance float64Mul : Mul Float64 where - mul = \x:Float64 y:Float64. %fmul x y +instance Mul Float64 + mul = \x y. %fmul x y one = FToF64 1.0 -instance float32Mul : Mul Float32 where - mul = \x:Float32 y:Float32. %fmul x y +instance Mul Float32 + mul = \x y. %fmul x y one = FToF32 1.0 -instance int64Mul : Mul Int64 where - mul = \x:Int64 y:Int64. %imul x y +instance Mul Int64 + mul = \x y. %imul x y one = IToI64 1 -instance int32Mul : Mul Int32 where - mul = \x:Int32 y:Int32. %imul x y +instance Mul Int32 + mul = \x y. %imul x y one = IToI32 1 -instance word8Mul : Mul Word8 where - mul = \x:Word8 y:Word8. %imul x y +instance Mul Word8 + mul = \x y. %imul x y one = IToW8 1 -instance unitMul : Mul Unit where +instance Mul Unit mul = \x y. () one = () -interface Integral a:Type where - idiv: a->a->a - rem: a->a->a +interface Integral a + idiv : a->a->a + rem : a->a->a -instance int64Integral : Integral Int64 where - idiv = \x:Int64 y:Int64. %idiv x y - rem = \x:Int64 y:Int64. %irem x y +instance Integral Int64 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -instance int32Integral : Integral Int32 where - idiv = \x:Int32 y:Int32. %idiv x y - rem = \x:Int32 y:Int32. %irem x y +instance Integral Int32 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -instance word8Integral : Integral Word8 where - idiv = \x:Word8 y:Word8. %idiv x y - rem = \x:Word8 y:Word8. %irem x y +instance Integral Word8 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -interface Fractional a:Type where +interface Fractional a divide : a -> a -> a -instance float64Fractional : Fractional Float64 where - divide = \x:Float64 y:Float64. %fdiv x y +instance Fractional Float64 + divide = \x y. %fdiv x y -instance float32Fractional : Fractional Float32 where - divide = \x:Float32 y:Float32. %fdiv x y +instance Fractional Float32 + divide = \x y. %fdiv x y '## Basic polymorphic functions and types @@ -157,19 +157,22 @@ const : a -> b -> a = \x _. x '## Vector spaces -data VSpace a:Type = MkVSpace (Add a) (Float -> a -> a) +interface [Add a] VSpace a + scaleVec : Float -> a -> a -@superclass -def addFromVSpace (d:VSpace a) : Add a = case d of MkVSpace addDict _ -> addDict - -def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale -def (*.) [VSpace a] : a -> Float -> a = flip (.*) +def (.*) [VSpace a] : Float -> a -> a = scaleVec +def (*.) [VSpace a] : a -> Float -> a = flip scaleVec def (/) [VSpace a] (v:a) (s:Float) : a = divide 1.0 s .* v def neg [VSpace a] (v:a) : a = (-1.0) .* v -@instance floatVS : VSpace Float = MkVSpace float32Add (*) -@instance tabVS : VSpace a ?=> VSpace (n=>a) = MkVSpace tabAdd \s xs. for i. s .* xs.i -@instance unitVS : VSpace Unit = MkVSpace unitAdd \s u. () +instance VSpace Float + scaleVec = \x y. x * y + +instance [VSpace a] VSpace (n=>a) + scaleVec = \s xs. for i. s .* xs.i + +instance VSpace Unit + scaleVec = \_ _. () '## Boolean type @@ -197,7 +200,7 @@ def not (x:Bool) : Bool = '## Sum types -data Maybe a:Type = +data Maybe a = Nothing Just a @@ -207,7 +210,7 @@ def isNothing (x:Maybe a) : Bool = case x of def isJust (x:Maybe a) : Bool = not $ isNothing x -data (|) a:Type b:Type = +data (|) a b = Left a Right b @@ -285,55 +288,76 @@ def unreachable (():Unit) : a = unsafeIO do '## Type classes -data Eq a:Type = MkEq (a -> a -> Bool) -data Ord a:Type = MkOrd (Eq a) (a -> a -> Bool) (a -> a -> Bool) -- eq, gt, lt - -@superclass -def eqFromOrd (d:Ord a) : Eq a = case d of MkOrd eq _ _ -> eq +interface Eq a + (==) : a -> a -> Bool -def (==) (d:Eq a) ?=> (x:a) (y:a) : Bool = case d of MkEq eq -> eq x y def (/=) [Eq a] (x:a) (y:a) : Bool = not $ x == y -def (>) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ gt _ -> gt x y -def (<) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ _ lt -> lt x y +interface [Eq a] Ord a + (>) : a -> a -> Bool + (<) : a -> a -> Bool + def (<=) [Ord a] (x:a) (y:a) : Bool = x=) [Ord a] (x:a) (y:a) : Bool = x>y || x==y -@instance float64Eq : Eq Float64 = MkEq \x:Float64 y:Float64. W8ToB $ %feq x y -@instance float32Eq : Eq Float32 = MkEq \x:Float32 y:Float32. W8ToB $ %feq x y -@instance int64Eq : Eq Int64 = MkEq \x:Int64 y:Int64. W8ToB $ %ieq x y -@instance int32Eq : Eq Int32 = MkEq \x:Int32 y:Int32. W8ToB $ %ieq x y -@instance word8Eq : Eq Word8 = MkEq \x:Word8 y:Word8. W8ToB $ %ieq x y -@instance boolEq : Eq Bool = MkEq \x y. BToW8 x == BToW8 y -@instance unitEq : Eq Unit = MkEq \x y. True -@instance rawPtrEq : Eq RawPtr = MkEq \x y. RawPtrToI64 x == RawPtrToI64 y - -@instance float64Ord : Ord Float64 = (MkOrd float64Eq (\x y. W8ToB $ %fgt x y) - (\x y. W8ToB $ %flt x y)) -@instance float32Ord : Ord Float32 = (MkOrd float32Eq (\x y. W8ToB $ %fgt x y) - (\x y. W8ToB $ %flt x y)) -@instance int64Ord : Ord Int64 = (MkOrd int64Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance int32Ord : Ord Int32 = (MkOrd int32Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance word8Ord : Ord Word8 = (MkOrd word8Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance unitOrd : Ord Unit = (MkOrd unitEq (\x y. False) (\x y. False)) - -@instance -def pairEq [Eq a, Eq b] : Eq (a & b) = MkEq $ - \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 - -@instance -def pairOrd [Ord a, Ord b] : Ord (a & b) = - pairGt = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) - pairLt = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) - MkOrd pairEq pairGt pairLt +instance Eq Float64 + (==) = \x y. W8ToB $ %feq x y + +instance Eq Float32 + (==) = \x y. W8ToB $ %feq x y + +instance Eq Int64 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Int32 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Word8 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Bool + (==) = \x y. BToW8 x == BToW8 y + +instance Eq Unit + (==) = \x y. True + +instance Eq RawPtr + (==) = \x y. RawPtrToI64 x == RawPtrToI64 y + +instance Ord Float64 + (>) = \x y. W8ToB $ %fgt x y + (<) = \x y. W8ToB $ %flt x y + +instance Ord Float32 + (>) = \x y. W8ToB $ %fgt x y + (<) = \x y. W8ToB $ %flt x y + +instance Ord Int64 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Int32 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Word8 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Unit + (>) = \x y. False + (<) = \x y. False + +instance [Eq a, Eq b] Eq (a & b) + (==) = \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 + +instance [Ord a, Ord b] Ord (a & b) + (>) = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) + (<) = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) -- TODO: accumulate using the True/&& monoid -@instance -def tabEq [Eq a] : Eq (n=>a) = MkEq $ - \xs ys. +instance [Eq a] Eq (n=>a) + (==) = \xs ys. numDifferent : Float = yieldAccum \ref. for i. ref += (IToF (BToI (xs.i /= ys.i))) @@ -341,7 +365,7 @@ def tabEq [Eq a] : Eq (n=>a) = MkEq $ '## Transcencendental functions -interface Floating a:Type where +interface Floating a exp : a -> a exp2 : a -> a log : a -> a @@ -375,45 +399,45 @@ def float64_cosh (x:Float64) : Float64 = %fdiv ((%exp x) + (%exp (%fsub (FToF64 def float64_tanh (x:Float64) : Float64 = %fdiv (%fsub (%exp x) (%exp (%fsub (FToF64 0.0) x))) ((%exp x) + (%exp (%fsub (FToF64 0.0) x))) -instance float64Floating : Floating Float64 where - exp = \x:Float64. %exp x - exp2 = \x:Float64. %exp2 x - log = \x:Float64. %log x - log2 = \x:Float64. %log2 x - log10 = \x:Float64. %log10 x - log1p = \x:Float64. %log1p x - sin = \x:Float64. %sin x - cos = \x:Float64. %cos x - tan = \x:Float64. %tan x +instance Floating Float64 + exp = \x. %exp x + exp2 = \x. %exp2 x + log = \x. %log x + log2 = \x. %log2 x + log10 = \x. %log10 x + log1p = \x. %log1p x + sin = \x. %sin x + cos = \x. %cos x + tan = \x. %tan x sinh = float64_sinh cosh = float64_cosh tanh = float64_tanh - floor = \x:Float64. %floor x - ceil = \x:Float64. %ceil x - round = \x:Float64. %round x - sqrt = \x:Float64. %sqrt x - pow = \x:Float64 y:Float64. %fpow x y - lgamma = \x:Float64. %lgamma x - -instance float32Floating : Floating Float32 where - exp = \x:Float32. %exp x - exp2 = \x:Float32. %exp2 x - log = \x:Float32. %log x - log2 = \x:Float32. %log2 x - log10 = \x:Float32. %log10 x - log1p = \x:Float32. %log1p x - sin = \x:Float32. %sin x - cos = \x:Float32. %cos x - tan = \x:Float32. %tan x + floor = \x. %floor x + ceil = \x. %ceil x + round = \x. %round x + sqrt = \x. %sqrt x + pow = \x y. %fpow x y + lgamma = \x. %lgamma x + +instance Floating Float32 + exp = \x. %exp x + exp2 = \x. %exp2 x + log = \x. %log x + log2 = \x. %log2 x + log10 = \x. %log10 x + log1p = \x. %log1p x + sin = \x. %sin x + cos = \x. %cos x + tan = \x. %tan x sinh = float32_sinh cosh = float32_cosh tanh = float32_tanh - floor = \x:Float32. %floor x - ceil = \x:Float32. %ceil x - round = \x:Float32. %round x - sqrt = \x:Float32. %sqrt x - pow = \x:Float32 y:Float32. %fpow x y - lgamma = \x:Float32. %lgamma x + floor = \x. %floor x + ceil = \x. %ceil x + round = \x. %round x + sqrt = \x. %sqrt x + pow = \x y. %fpow x y + lgamma = \x. %lgamma x '## Index set utilities @@ -425,90 +449,66 @@ def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i def iota (n:Type) : n=>Int = view i. ordinal i -- TODO: we want Eq and Ord for all index sets, not just `Fin n` -@instance -def finEq (n:Int) ?-> : Eq (Fin n) = MkEq \x y. ordinal x == ordinal y +instance (n:Int) ?-> Eq (Fin n) + (==) = \x y. ordinal x == ordinal y -@instance -def finOrd (n:Int) ?-> : Ord (Fin n) = - MkOrd finEq (\x y. ordinal x > ordinal y) (\x y. ordinal x < ordinal y) +instance (n:Int) ?-> Ord (Fin n) + (>) = \x y. ordinal x > ordinal y + (<) = \x y. ordinal x < ordinal y '## Raw pointer operations -data Ptr a:Type = MkPtr RawPtr +data Ptr a = MkPtr RawPtr -- Is there a better way to select the right instance for `storageSize`?? -data TypeVehicle a:Type = MkTypeVehicle +data TypeVehicle a = MkTypeVehicle def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle -interface Storable a:Type where +interface Storable a store : Ptr a -> a -> {State World} Unit load : Ptr a -> {State World} a - storageSize : TypeVehicle a -> Int - --- TODO: we can't inline these into the instance definitions until we change --- type inference to push types down into record constructors or allow `def` in --- instance definitions. -def word8Store ((MkPtr ptr): Ptr Word8) (x:Word8) : {State World} Unit = %ptrStore ptr x -def word8Load ((MkPtr ptr): Ptr Word8) : {State World} Word8 = %ptrLoad ptr - -instance word8Storable : Storable Word8 where - store = word8Store - load = word8Load - storageSize = const 1 - --- TODO: there's a bug preventing us inlining these definitions into the instance -def int32Store ((MkPtr ptr): Ptr Int32) (x:Int32) : {State World} Unit = - %ptrStore (internalCast %Int32Ptr ptr) x -def int32Load ((MkPtr ptr): Ptr Int32) : {State World} Int32 = - %ptrLoad (internalCast %Int32Ptr ptr) - -instance int32Storable : Storable Int32 where - store = int32Store - load = int32Load - storageSize = const 4 + storageSize_ : TypeVehicle a -> Int + +def storageSize (a:Type) -> (d:Storable a) ?=> : Int = + tv : TypeVehicle a = MkTypeVehicle + storageSize_ tv + +instance Storable Word8 + store = \(MkPtr ptr) x. %ptrStore ptr x + load = \(MkPtr ptr) . %ptrLoad ptr + storageSize_ = const 1 + +instance Storable Int32 + store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x + load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr) + storageSize_ = const 4 def unpackPairPtr [Storable a, Storable b] (pairPtr: Ptr (a & b)) : (Ptr a & Ptr b) = (MkPtr rawPtrX) = pairPtr - rawPtrY = %ptrOffset rawPtrX (storageSize (typeVehicle a)) + rawPtrY = %ptrOffset rawPtrX (storageSize a) (MkPtr rawPtrX, MkPtr rawPtrY) -def pairStore [Storable a, Storable b] - (pairPtr:Ptr (a & b)) ((x, y):(a & b)) : {State World} Unit = - (xPtr, yPtr) = unpackPairPtr pairPtr - store xPtr x - store yPtr y - -def pairLoad [Storable a, Storable b] - (pairPtr:Ptr (a & b)) : {State World} (a & b) = - (xPtr, yPtr) = unpackPairPtr pairPtr - (load xPtr, load yPtr) - -def pairStorageSize [Storable a, Storable b] - (_:TypeVehicle (a & b)) : Int = - storageSize (typeVehicle a) + storageSize (typeVehicle b) - -instance pairStorable : (Storable a) ?=> (Storable b) ?=> Storable (a & b) where - store = pairStore - load = pairLoad - storageSize = pairStorageSize - -def ptrPtrStore ((MkPtr ptr): Ptr (Ptr a)) (x:(Ptr a)) : {State World} Unit = - (MkPtr x') = x - %ptrStore (internalCast %PtrPtr ptr) x' - -def ptrPtrLoad ((MkPtr ptr): Ptr (Ptr a)) : {State World} (Ptr a) = - MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) - -instance ptrStorable : Storable (Ptr a) where - store = ptrPtrStore - load = ptrPtrLoad - storageSize = const 8 -- TODO: something more portable? +instance [Storable a, Storable b] Storable (a & b) + store = \pairPtr (x, y). + (xPtr, yPtr) = unpackPairPtr pairPtr + store xPtr x + store yPtr y + load = \pairPtr. + (xPtr, yPtr) = unpackPairPtr pairPtr + (load xPtr, load yPtr) + storageSize_ = \_. + storageSize a + storageSize b + +instance Storable (Ptr a) + store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x + load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) + storageSize_ = const 8 -- TODO: something more portable? -- TODO: Storable instances for other types def malloc [Storable a] (n:Int) : {State World} (Ptr a) = - numBytes = storageSize (typeVehicle a) * n + numBytes = storageSize a * n MkPtr $ %alloc numBytes def free (ptr:Ptr a) : {State World} Unit = @@ -517,7 +517,7 @@ def free (ptr:Ptr a) : {State World} Unit = def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = (MkPtr ptr') = ptr - i' = i * storageSize (typeVehicle a) + i' = i * storageSize a MkPtr $ %ptrOffset ptr' i' -- TODO: generalize these brackets to allow other effects @@ -601,7 +601,7 @@ def vdot (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i def dot [VSpace v] (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j -- matmul. Better symbol to use? `@`? -(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. +(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. for i k. fsum view j. x.i.j * y.j.k (**.) : (n=>m=>Float) -> (m=>Float) -> (n=>Float) = \mat v. for i. vdot mat.i v @@ -671,33 +671,33 @@ def deriv (f:Float->Float) (x:Float) : Float = jvp f x 1.0 def derivRev (f:Float->Float) (x:Float) : Float = snd (vjp f x) 1.0 -interface HasAllClose a:Type where +interface HasAllClose a allclose : a -> a -> a -> a -> Bool -interface HasDefaultTolerance a:Type where +interface HasDefaultTolerance a atol : a rtol : a def (~~) [HasAllClose a, HasDefaultTolerance a] : a -> a -> Bool = allclose atol rtol -instance allCloseF32 : HasAllClose Float32 where +instance HasAllClose Float32 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) -instance allCloseF64 : HasAllClose Float64 where +instance HasAllClose Float64 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) -instance defaultToleranceF32 : HasDefaultTolerance Float32 where +instance HasDefaultTolerance Float32 atol = FToF32 0.00001 rtol = FToF32 0.0001 -instance defaultToleranceF64 : HasDefaultTolerance Float64 where +instance HasDefaultTolerance Float64 atol = FToF64 0.00000001 rtol = FToF64 0.00001 -instance allCloseTable : HasAllClose t ?=> HasDefaultTolerance t ?=> HasAllClose (n=>t) where +instance [HasAllClose t, HasDefaultTolerance t] HasAllClose (n=>t) allclose = \atol rtol a b. all for i:n. (a.i ~~ b.i) -instance defaultToleranceTable : (HasDefaultTolerance t) ?=> HasDefaultTolerance (n=>t) where +instance [HasDefaultTolerance t] HasDefaultTolerance (n=>t) atol = for i. atol rtol = for i. rtol @@ -776,7 +776,7 @@ def tile1 (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) '## Monoid typeclass -interface Monoid a:Type where +interface Monoid a mempty : a mcombine : a -> a -> a -- can't use `<>` just for parser reasons? @@ -784,7 +784,7 @@ def (<>) [Monoid a] : a -> a -> a = mcombine '## Length-erased lists -data List a:Type = +data List a = AsList n:Int foo:(Fin n => a) def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = @@ -794,7 +794,7 @@ def toList (xs:n=>a) : List a = n' = size n AsList _ $ unsafeCastTable (Fin n') xs -instance monoidList : Monoid (List a) where +instance Monoid (List a) mempty = AsList _ [] mcombine = \x y. (AsList nx xs) = x @@ -808,7 +808,7 @@ instance monoidList : Monoid (List a) where '## Isomorphisms -data Iso a:Type b:Type = MkIso { fwd: a -> b & bwd: b -> a } +data Iso a b = MkIso { fwd: a -> b & bwd: b -> a } def appIso (iso: Iso a b) (x:a) : b = (MkIso {fwd, bwd}) = iso @@ -890,7 +890,7 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = -- TODO: should we be able to use `Ref World Int` instead of `Ptr Int`? -- TODO: would be nice to be able to use records here -data DynBuffer a:Type = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr +data DynBuffer a = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr def withDynamicBuffer [Storable a] (action: DynBuffer a -> {State World} b) : {State World} b = @@ -945,29 +945,29 @@ def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {State World} String = -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int = W8ToI c -interface Show a:Type where +interface Show a show : a -> String -instance showString : Show String where +instance Show String show = id -instance showInt32 : Show Int32 where - show = \x: Int32. unsafeIO do +instance Show Int32 + show = \x. unsafeIO do (n, ptr) = %ffi showInt32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr -instance showInt64 : Show Int64 where - show = \x: Int64. unsafeIO do +instance Show Int64 + show = \x. unsafeIO do (n, ptr) = %ffi showInt64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr -instance showFloat32 : Show Float32 where - show = \x: Float32.unsafeIO do +instance Show Float32 + show = \x. unsafeIO do (n, ptr) = %ffi showFloat32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr -instance showFloat64 : Show Float64 where - show = \x: Float64.unsafeIO do +instance Show Float64 + show = \x. unsafeIO do (n, ptr) = %ffi showFloat64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr @@ -1053,7 +1053,7 @@ def while (eff:Effects) ?-> (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () %while body' -data IterResult a:Type = +data IterResult a = Continue Done a @@ -1225,19 +1225,19 @@ def randIdx (k:Key) : n = 'Type class for generating example values -interface Arbitrary a:Type where +interface Arbitrary a arb : Key -> a -instance float32Arb : Arbitrary Float32 where +instance Arbitrary Float32 arb = randn -instance in32Arb : Arbitrary Int32 where +instance Arbitrary Int32 arb = \key. FToI $ randn key * 5.0 -instance tabArb : Arbitrary a ?=> Arbitrary (n=>a) where +instance [Arbitrary a] Arbitrary (n=>a) arb = \key. for i. arb $ ixkey key i -instance finArb : n:Int ?-> Arbitrary (Fin n) where +instance (n:Int) ?-> Arbitrary (Fin n) arb = randIdx 'Control flow @@ -1331,28 +1331,28 @@ def atan (x:Float) : Float = atan2 x 1.0 data Complex = MkComplex Float Float -- real, imaginary -instance allCloseComplex : HasAllClose Complex where +instance HasAllClose Complex allclose = \atol rtol (MkComplex a b) (MkComplex c d). (a ~~ c) && (b ~~ d) -instance defaultToleranceComplex : HasDefaultTolerance Complex where +instance HasDefaultTolerance Complex atol = MkComplex atol atol rtol = MkComplex rtol rtol -@instance ComplexEq : Eq Complex = - MkEq \(MkComplex a b) (MkComplex c d). (a == c) && (b == d) +instance Eq Complex + (==) = \(MkComplex a b) (MkComplex c d). (a == c) && (b == d) -instance ComplexAdd : Add Complex where +instance Add Complex add = \(MkComplex a b) (MkComplex c d). MkComplex (a + c) (b + d) sub = \(MkComplex a b) (MkComplex c d). MkComplex (a - c) (b - d) zero = MkComplex 0.0 0.0 -instance ComplexMul : Mul Complex where +instance Mul Complex mul = \(MkComplex a b) (MkComplex c d). MkComplex (a * c - b * d) (a * d + b * c) one = MkComplex 1.0 0.0 -@instance complexVS : VSpace Complex = - MkVSpace ComplexAdd \a:Float (MkComplex c d):Complex. MkComplex (a * c) (a * d) +instance VSpace Complex + scaleVec = \a:Float (MkComplex c d):Complex. MkComplex (a * c) (a * d) -- Todo: Hook up to (/) operator. Might require two-parameter VSpace. def complex_division (MkComplex a b:Complex) (MkComplex c d:Complex): Complex = @@ -1391,7 +1391,7 @@ def complex_tanh (MkComplex a b:Complex) : Complex = den = MkComplex (cosh a * cos b) (sinh a * sin b) complex_division num den -instance ComplexFractional : Fractional Complex where +instance Fractional Complex divide = complex_division def complex_floor (MkComplex re im:Complex) : Complex = @@ -1424,7 +1424,7 @@ def complex_log1p (x:Complex) : Complex = True -> complex_log u False -> divide ((complex_log u) * x) x -instance complexFloating : Floating Complex where +instance Floating Complex exp = complex_exp exp2 = complex_exp2 log = complex_log diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 731cfce66..705d1c50a 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -17,7 +17,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, + fpow, flog, fLitLike, recGet, buildImplicitNaryLam, select, substEmbed, substEmbedR, emitUnpack, getUnpacked, fromPair, getFst, getSnd, getFstRef, getSndRef, naryApp, appReduce, appTryReduce, buildAbs, @@ -25,7 +25,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, - emitRunState, emitMaybeCase, emitWhile, + emitRunState, emitMaybeCase, emitWhile, buildDataDef, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, traverseAtom, ptrOffset, ptrLoad, unsafePtrLoad, evalBlockE, substTraversalDef, @@ -42,6 +42,8 @@ import Control.Monad.Writer hiding (Alt) import Control.Monad.Identity import Control.Monad.State.Strict import Data.Foldable (toList) +import Data.List (elemIndex) +import Data.Maybe (fromJust) import Data.String (fromString) import Data.Tuple (swap) import GHC.Stack @@ -188,6 +190,28 @@ buildNAbsAux bs body = do return (fmap Bind vs, result) return (Abs bs' $ wrapDecls decls ans, aux) +buildDataDef :: MonadEmbed m + => Name -> Nest Binder -> ([Atom] -> m [DataConDef]) -> m DataDef +buildDataDef tyConName paramBinders body = do + ((paramBinders', dataDefs), _) <- scopedDecls $ do + vs <- freshNestedBinders paramBinders + result <- body $ map Var $ toList vs + return (fmap Bind vs, result) + return $ DataDef tyConName paramBinders' dataDefs + +buildImplicitNaryLam :: MonadEmbed m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom +buildImplicitNaryLam Empty body = body [] +buildImplicitNaryLam (Nest b bs) body = + buildLam b ImplicitArrow $ \x -> do + bs' <- substEmbed (b@>x) bs + buildImplicitNaryLam bs' $ \xs -> body $ x:xs + +recGet :: Label -> Atom -> Atom +recGet l x = do + let (RecordTy (Ext r _)) = getType x + let i = fromJust $ elemIndex l $ map fst $ toList $ reflectLabels r + getProjection [i] x + buildScoped :: MonadEmbed m => m Atom -> m Block buildScoped m = do (ans, decls) <- scopedDecls m diff --git a/src/lib/Env.hs b/src/lib/Env.hs index bfb2dd93e..456c613ab 100644 --- a/src/lib/Env.hs +++ b/src/lib/Env.hs @@ -39,6 +39,7 @@ data NameSpace = | InferenceName | SumName | FFIName + | TypeClassGenName -- names generated for type class dictionaries | AbstractedPtrName -- used in `abstractPtrLiterals` in Imp lowering | TopFunctionName -- top-level Imp functions | AllocPtrName -- used for constructing dests in Imp lowering @@ -163,6 +164,7 @@ env ! v = case envLookup env v of isGlobal :: VarP ann -> Bool isGlobal (GlobalName _ :> _) = True isGlobal (GlobalArrayName _ :> _) = True +isGlobal (Name TypeClassGenName _ _ :> _) = True isGlobal _ = False isGlobalBinder :: BinderP ann -> Bool diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 4ad9aa0f8..aa3c94663 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -517,6 +517,7 @@ toImpHof env (maybeDest, hof) = do translateBlock env (maybeDest, body) Linearize _ -> error "Unexpected Linearize" Transpose _ -> error "Unexpected Transpose" + CatchException _ -> error "Unexpected CatchException" data LaunchInfo = LaunchInfo { numWorkgroups :: IExpr, workgroupSize :: IExpr } data ThreadInfo = ThreadInfo { tid :: IExpr, wid :: IExpr, threadRange :: Type } diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index ecfd2e962..5a74edbfd 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -321,8 +321,7 @@ unpackTopPat letAnn pat expr = do bindings <- bindPat pat atom void $ flip traverseNames bindings $ \name val -> do let name' = asGlobal name - scope <- getScope - when (name' `isin` scope) $ throw RepeatedVarErr $ pprint $ name' + checkNotInScope name' emitTo name' letAnn $ Atom val inferUDecl :: Bool -> UDecl -> UInferM SubstEnv @@ -343,29 +342,116 @@ inferUDecl topLevel (ULet letAnn (p, ann) rhs) = do else bindPat p val inferUDecl True (UData tc dcs) = do (tc', paramBs) <- inferUConDef tc - scope <- getScope - when (tc' `isin` scope) $ throw RepeatedVarErr $ pprint $ getName tc' - let paramVars = map (\(Bind v) -> v) $ toList paramBs -- TODO: refresh things properly - (dcs', _) <- embedScoped $ - extendR (newEnv paramBs (map Var paramVars)) $ do - extendScope (foldMap boundVars paramBs) - mapM inferUConDef dcs - let dataDef = DataDef tc' paramBs $ map (uncurry DataConDef) dcs' - let tyConTy = getType $ TypeCon dataDef [] - extendScope $ tc' @> (tyConTy, DataBoundTypeCon dataDef) - forM_ (zip [0..] dcs') $ \(i, (dc,_)) -> do - -- Retrieving scope at every step to avoid duplicate constructor names - scope' <- getScope - when (dc `isin` scope') $ throw RepeatedVarErr $ pprint $ getName dc - let ty = getType $ DataCon dataDef [] i [] - extendScope $ dc @> (ty, DataBoundDataCon dataDef i) + dataDef <- buildDataDef tc' paramBs $ \params -> do + extendR (newEnv paramBs params) $ forM dcs $ \dc -> + uncurry DataConDef <$> inferUConDef dc + checkDataDefShadows dataDef + emitConstructors dataDef + return mempty +inferUDecl True (UInterface superclasses tc methods) = do + (tc', paramBs) <- inferUConDef tc + dataDef <- buildDataDef tc' paramBs $ \params -> do + extendR (newEnv paramBs params) $ do + conName <- freshClassGenName + superclasses' <- mkLabeledItems <$> mapM mkSuperclass superclasses + methods' <- mkLabeledItems <$> mapM mkMethod methods + return $ ClassDictDef conName superclasses' methods' + checkDataDefShadows dataDef + emitConstructors dataDef + emitSuperclassGetters dataDef + emitMethodGetters dataDef return mempty -inferUDecl False (UData _ _) = error "data definitions should be top-level" +inferUDecl True (UInstance instanceTy methods) = do + ty <- checkUType instanceTy + instanceDict <- checkInstance ty methods + let instanceName = Name TypeClassGenName "instance" 0 + void $ emitTo instanceName InstanceLet $ Atom instanceDict + return mempty +inferUDecl False (UData _ _ ) = error "data definitions should be top-level" +inferUDecl False (UInterface _ _ _) = error "interface definitions should be top-level" +inferUDecl False (UInstance _ _ ) = error "instance definitions should be top-level" + +freshClassGenName :: MonadEmbed m => m Name +freshClassGenName = do + scope <- getScope + let v' = genFresh (Name TypeClassGenName "classgen" 0) scope + embedExtend $ asFst $ v' @> (UnitTy, UnknownBinder) + return v' + +mkMethod :: UAnnBinder -> UInferM (Label, Type) +mkMethod (Ignore _) = error "Methods must have names" +mkMethod (Bind (v:>ty)) = do + ty' <- checkUType ty + return (nameToLabel v, ty') + +mkSuperclass :: UType -> UInferM (Label, Type) +mkSuperclass ty = do + ty' <- checkUType ty + -- TODO: think about the scope of these names + l <- freshClassGenName + return (nameToLabel l, ty') + +-- TODO: just make Name and Label the same thing +nameToLabel :: Name -> Label +nameToLabel = pprint + +mkLabeledItems :: [(Label, a)] -> LabeledItems a +mkLabeledItems items = foldMap (uncurry labeledSingleton) items + +emitConstructors :: DataDef -> UInferM () +emitConstructors def@(DataDef tyConName _ dataConDefs) = do + let tyConTy = getType $ TypeCon def [] + checkNotInScope tyConName + extendScope $ tyConName @> (tyConTy, DataBoundTypeCon def) + forM_ (zip [0..] dataConDefs) $ \(i, DataConDef dataConName _) -> do + let dataConTy = getType $ DataCon def [] i [] + checkNotInScope dataConName + extendScope $ dataConName @> (dataConTy, DataBoundDataCon def i) + +emitMethodGetters :: DataDef -> UInferM () +emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do + forM_ (getLabels methodTys) $ \l -> do + f <- buildImplicitNaryLam paramBs $ \params -> do + buildLam (Bind ("d":> TypeCon def params)) ClassArrow $ \dict -> do + return $ recGet l $ getProjection [1] dict + let methodName = GlobalName $ fromString l + checkNotInScope methodName + emitTo methodName PlainLet $ Atom f +emitMethodGetters (DataDef _ _ _) = error "Not a class dictionary" + +emitSuperclassGetters :: MonadEmbed m => DataDef -> m () +emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = do + forM_ (getLabels superclassTys) $ \l -> do + f <- buildImplicitNaryLam paramBs $ \params -> do + buildLam (Bind ("d":> TypeCon def params)) PureArrow $ \dict -> do + return $ recGet l $ getProjection [0] dict + getterName <- freshClassGenName + emitTo getterName SuperclassLet $ Atom f +emitSuperclassGetter (DataDef _ _ _) = error "Not a class dictionary" + +checkNotInScope :: Name -> UInferM () +checkNotInScope v = do + scope <- getScope + when (v `isin` scope) $ throw RepeatedVarErr $ pprint v + +checkDataDefShadows :: DataDef -> UInferM () +checkDataDefShadows (DataDef tc _ dataCons) = do + checkShadows $ tc:dcs + where dcs = [dc | DataConDef dc _ <- dataCons] + +checkShadows :: [Name] -> UInferM () +checkShadows vs = do + mapM_ checkNotInScope vs + case repeated vs of + [] -> return () + (v:_) -> throw RepeatedVarErr $ pprint v inferUConDef :: UConDef -> UInferM (Name, Nest Binder) inferUConDef (UConDef v bs) = do (bs', _) <- embedScoped $ checkNestedBinders bs - return (asGlobal v, bs') + let v' = asGlobal v + checkNotInScope v' + return (v', bs') checkNestedBinders :: Nest UAnnBinder -> UInferM (Nest Binder) checkNestedBinders Empty = return Empty @@ -393,6 +479,37 @@ checkULam (p, ann) body piTy = do $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x +checkInstance :: Type -> [(UVar, UExpr)] -> UInferM Atom +checkInstance ty methods = case ty of + TypeCon def@(DataDef className _ _) params -> do + case applyDataDefParams def params of + ClassDictDef _ superclassTys methodTys -> do + methods' <- liftM mkLabeledItems $ forM methods $ \((v:>()), rhs) -> do + let v' = nameToLabel v + case lookupLabel methodTys v' of + Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) + Just methodTy -> do + rhs' <- checkSigma rhs Suggest methodTy + return (v', rhs') + let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys + forM_ (reflectLabels methods') $ \(l,i) -> + when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l + forM_ (reflectLabels methodTys) $ \(l,_) -> + case lookupLabel methods' l of + Nothing -> throw TypeErr $ "Missing method: " ++ pprint l + Just _ -> return () + return $ ClassDictCon def params superclassHoles methods' + _ -> throw TypeErr $ "Not a valid instance: " ++ pprint ty + Pi (Abs b (arrow, bodyTy)) -> do + case arrow of + ImplicitArrow -> return () + ClassArrow -> return () + _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow + buildLam b arrow $ \x@(Var v) -> do + bodyTy' <- substEmbed (b@>x) bodyTy + checkLeaks [v] $ extendR (b@>x) $ checkInstance bodyTy' methods + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty + checkUEffRow :: EffectRow -> UInferM EffectRow checkUEffRow (EffectRow effs t) = do effs' <- liftM S.fromList $ mapM checkUEff $ toList effs diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 976d86dba..9757b9fea 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -21,7 +21,6 @@ import Data.Foldable (toList) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import qualified Data.ByteString.Lazy.Char8 as B -import Data.Maybe (fromMaybe) import Data.String (fromString) import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc @@ -33,7 +32,6 @@ import Numeric import Env import Syntax -import Util (enumerate) -- Specifies what kinds of operations are allowed to be printed at this point. -- Printing at AppPrec level means that applications can be printed @@ -364,7 +362,7 @@ instance PrettyPrec Atom where "DataConRef" <+> p params <+> p args BoxedRef b ptr size body -> atPrec AppPrec $ "Box" <+> p b <+> "<-" <+> p ptr <+> "[" <> p size <> "]" <+> hardline <> "in" <+> p body - ProjectElt idxs x -> prettyProjection idxs x + ProjectElt idxs x -> atPrec LowestPrec $ "project" <+> p idxs <+> p x instance Pretty DataConRefBinding where pretty = prettyFromPrettyPrec instance PrettyPrec DataConRefBinding where @@ -376,45 +374,6 @@ fromInfix t = do (t'', ')') <- unsnoc t' return t'' -prettyProjection :: NE.NonEmpty Int -> Var -> DocPrec ann -prettyProjection idxs (name :> ty) = prettyPrec uproj where - -- Builds a source expression that performs the given projection. - uproj = UApp (PlainArrow ()) (nosrc ulam) (nosrc uvar) - ulam = ULam (upat, Nothing) (PlainArrow ()) (nosrc $ UVar $ target :> ()) - uvar = UVar $ name :> () - (_, upat, target) = buildProj idxs - - buildProj :: NE.NonEmpty Int -> (Type, UPat, Name) - buildProj (i NE.:| is) = let - -- Lazy Haskell trick: refer to `target` even though this function is - -- responsible for setting it! - (ty', pat', eltName) = case NE.nonEmpty is of - Just is' -> let (x, y, z) = buildProj is' in (x, y, Just z) - Nothing -> (ty, nosrc $ UPatBinder $ Bind $ target :> (), Nothing) - in case ty' of - TypeCon def params -> let - [DataConDef conName bs] = applyDataDefParams def params - b = toList bs !! i - pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate bs - hint = case b of - Bind (n :> _) -> n - Ignore _ -> Name SourceName "elt" 0 - in ( binderAnn b, nosrc $ UPatCon conName pats, fromMaybe hint eltName) - RecordTy (NoExt types) -> let - ty'' = toList types !! i - pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate types - (fieldName, _) = toList (reflectLabels types) !! i - hint = Name SourceName (fromString fieldName) 0 - in (ty'', nosrc $ UPatRecord $ NoExt pats, fromMaybe hint eltName) - PairTy x _ | i == 0 -> - (x, nosrc $ UPatPair pat' uignore, fromMaybe "a" eltName) - PairTy _ y | i == 1 -> - (y, nosrc $ UPatPair uignore pat', fromMaybe "b" eltName) - _ -> error "Bad projection" - - nosrc = WithSrc Nothing - uignore = nosrc $ UPatBinder $ Ignore () - prettyExtLabeledItems :: (PrettyPrec a, PrettyPrec b) => ExtLabeledItems a b -> Doc ann -> Doc ann -> DocPrec ann prettyExtLabeledItems (Ext (LabeledItems row) rest) separator bindwith = @@ -629,6 +588,10 @@ instance Pretty UDecl where align $ prettyUBinder b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) pretty (UData tyCon dataCons) = "data" <+> p tyCon <+> "where" <> nest 2 (hardline <> prettyLines dataCons) + pretty (UInterface cs def methods) = + "interface" <+> p cs <+> p def <> hardline <> prettyLines methods + pretty (UInstance ty methods) = + "instance" <+> p ty <> hardline <> prettyLines methods instance Pretty UConDef where pretty (UConDef con bs) = p con <+> spaced bs diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 6eab3f678..62dc532f7 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -12,11 +12,11 @@ import Control.Monad import Control.Monad.Combinators.Expr import Control.Monad.Reader import Text.Megaparsec hiding (Label, State) -import Text.Megaparsec.Char hiding (space) +import Text.Megaparsec.Char hiding (space, eol) +import qualified Text.Megaparsec.Char as MC import Data.Char (isLower) import Data.Functor import Data.List.NonEmpty (NonEmpty (..)) -import qualified Data.Map.Strict as M import Data.Void import qualified Data.Set as S import Data.String (fromString) @@ -88,7 +88,7 @@ logLevel :: Parser LogLevel logLevel = do void $ try $ lexeme $ char '%' >> string "passes" passes <- many passName - void eol + eol case passes of [] -> return $ LogAll _ -> return $ LogPasses passes @@ -96,14 +96,14 @@ logLevel = do logTime :: Parser LogLevel logTime = do void $ try $ lexeme $ char '%' >> string "time" - void eol + eol return PrintEvalTime logBench :: Parser LogLevel logBench = do void $ try $ lexeme $ char '%' >> string "bench" benchName <- stringLiteral - void eol + eol return $ PrintBench benchName passName :: Parser PassName @@ -116,13 +116,15 @@ sourceBlock' :: Parser SourceBlock' sourceBlock' = proseBlock <|> topLevelCommand - <|> fmap (declsToModule . (:[])) (topDecl <* eolf) - <|> fmap (declsToModule . (:[])) (interfaceInstance <* eolf) - <|> fmap declsToModule (interfaceDef <* eolf) - <|> fmap (Command (EvalExpr Printed) . exprAsModule) (expr <* eol) + <|> liftM declToModule (topDecl <* eolf) + <|> liftM declToModule (instanceDef <* eolf) + <|> liftM declToModule (interfaceDef <* eolf) + <|> liftM (Command (EvalExpr Printed) . exprAsModule) (expr <* eol) <|> hidden (some eol >> return EmptyLines) <|> hidden (sc >> eol >> return CommentLine) - where declsToModule = RunModule . UModule . toNest + where + declsToModule = RunModule . UModule . toNest + declToModule = declsToModule . (:[]) proseBlock :: Parser SourceBlock' proseBlock = label "prose block" $ char '\'' >> fmap (ProseBlock . fst) (withSource consumeTillBreak) @@ -151,7 +153,7 @@ exprAsModule :: UExpr -> (Name, UModule) exprAsModule e = (asGlobal v, UModule (toNest [d])) where v = mkName "_ans_" - d = ULet PlainLet (WithSrc (srcPos e) (UPatBinder (Bind (v:>()))), Nothing) e + d = ULet PlainLet (WithSrc (srcPos e) (nameToPat v), Nothing) e -- === uexpr === @@ -206,8 +208,7 @@ charExpr :: Char -> UExpr' charExpr c = UPrimExpr $ ConExpr $ Lit $ Word8Lit $ fromIntegral $ fromEnum c uVarOcc :: Parser UExpr -uVarOcc = withSrc $ try $ (UVar . (:>())) <$> (occName <* notFollowedBy (sym ":")) - where occName = upperName <|> lowerName <|> symName +uVarOcc = withSrc $ try $ (UVar . (:>())) <$> (anyName <* notFollowedBy (sym ":")) uHole :: Parser UExpr uHole = withSrc $ underscore $> UHole @@ -222,9 +223,9 @@ topDecl = dataDef <|> topLet topLet :: Parser UDecl topLet = do - lAnn <- (char '@' >> letAnnStr <* (void eol <|> sc)) <|> return PlainLet - ~(ULet _ (p, ann) rhs, pos) <- withPos decl - let (ann', rhs') = addImplicitImplicitArgs pos ann rhs + lAnn <- (char '@' >> letAnnStr <* (eol <|> sc)) <|> return PlainLet + ~(ULet _ (p, ann) rhs) <- decl + let (ann', rhs') = addImplicitImplicitArgs ann rhs return $ ULet lAnn (p, ann') rhs' -- Given a type signature, find all "implicit implicit args": lower-case @@ -273,77 +274,36 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UIntLit _ -> mempty UFloatLit _ -> mempty -addImplicitImplicitArgs :: SrcPos -> Maybe UType -> UExpr -> (Maybe UType, UExpr) -addImplicitImplicitArgs _ Nothing e = (Nothing, e) -addImplicitImplicitArgs sourcePos (Just typ) ex = - let (ty', e') = foldr (addImplicitArg sourcePos) (typ, ex) implicitVars +addImplicitImplicitArgs :: Maybe UType -> UExpr -> (Maybe UType, UExpr) +addImplicitImplicitArgs Nothing e = (Nothing, e) +addImplicitImplicitArgs (Just typ) ex = + let (ty', e') = foldr addImplicitArg (typ, ex) implicitVars in (Just ty', e') where implicitVars = findImplicitImplicitArgNames typ - addImplicitArg :: SrcPos -> Name -> (UType, UExpr) -> (UType, UExpr) - addImplicitArg pos v (ty, e) = - ( WithSrc (Just pos) $ UPi (Just uPat, uTyKind) ImplicitArrow ty - , WithSrc (Just pos) $ ULam (uPat, Just uTyKind) ImplicitArrow e) + addImplicitArg :: Name -> (UType, UExpr) -> (UType, UExpr) + addImplicitArg v (ty, e) = + ( ns $ UPi (Just uPat, uTyKind) ImplicitArrow ty + , ns $ ULam (uPat, Just uTyKind) ImplicitArrow e) where - uPat = WithSrc (Just pos) $ UPatBinder $ Bind $ v:>() + uPat = ns $ nameToPat v k = if v == mkName "eff" then EffectRowKind else TypeKind - uTyKind = WithSrc (Just pos) $ UPrimExpr $ TCExpr k + uTyKind = ns $ UPrimExpr $ TCExpr k + +superclassConstraints :: Parser [UType] +superclassConstraints = optionalMonoid $ brackets $ uType `sepBy` sym "," -interfaceDef :: Parser [UDecl] +interfaceDef :: Parser UDecl interfaceDef = do keyWord InterfaceKW - (tyCon, pos) <- withPos tyConDef - keyWord WhereKW - recordFieldsWithSrc <- withSrc $ interfaceRecordFields ":" - let (UConDef interfaceName uAnnBinderNest) = tyCon - record = URecordTy . NoExt <$> recordFieldsWithSrc - consName = mkInterfaceConsName interfaceName - varNames = fmap (\(Bind v) -> varName v) uAnnBinderNest - (WithSrc _ recordFields) = recordFieldsWithSrc - funDefs = mkFunDefs (pos, varNames, interfaceName) recordFields - return $ UData tyCon [UConDef consName (toNest [Ignore record])] : funDefs - where - -- From an interface - -- interface I a:Type b:Type where - -- f : a -> b - -- mkFunDefs generates the equivalent of the following function definition: - -- def f (instance# : I a b) ?=> : a -> b = - -- (I# {f=f,...}) = instance# - -- f - -- where I# is an automatically generated constructor of I. - mkFunDefs - :: (SrcPos, Nest Name, Name) -> LabeledItems UExpr -> [UDecl] - mkFunDefs meta (LabeledItems items) = - fmap (\(name, ty :| []) -> mkOneFunDef meta (name, ty)) $ M.toList items - mkOneFunDef :: (SrcPos, Nest Name, Name) -> (Label, UExpr) -> UDecl - mkOneFunDef (pos, typeVarNames, interfaceName) (fLabel, fType) = - ULet PlainLet (p, ann') rhs' - where - uAnnPat = ( Just $ WithSrc (Just pos) $ UPatBinder $ Bind $ instanceName :> () - , foldl mkUApp (var interfaceName) typeVarNames) - p = patb fLabel - ann = Just $ ns $ UPi uAnnPat ClassArrow fType - - mkUApp func typeVarName = - ns $ UApp (PlainArrow ()) func (var typeVarName) - recordStr = "recordVar" - recordPat = ns $ UPatRecord $ Ext (labeledSingleton fLabel (patb - fLabel)) $ Just underscorePat - conPat = ns $ UPatCon (mkInterfaceConsName interfaceName) - $ toNest [patb recordStr] - - let1 = ULet PlainLet (conPat, Nothing) $ var instanceName - let2 = ULet PlainLet (recordPat, Nothing) $ var $ mkName recordStr - body = ns $ UDecl let1 (ns $ UDecl let2 (var (mkName fLabel))) - rhs = ns $ ULam (patb instanceStr, Nothing) ClassArrow body - (ann', rhs') = addImplicitImplicitArgs pos ann rhs - - ns = WithSrc Nothing - patb s = ns $ UPatBinder $ Bind $ mkName s :> () - instanceStr = mkNoShadowingStr "instance" - instanceName = mkName instanceStr - var name = ns $ UVar $ name :> () + superclasses <- superclassConstraints + tyCon <- tyConDef + methods <- onePerLine $ do + v <- anyName + ty <- annot uType + return $ Bind $ v:>ty + return $ UInterface superclasses tyCon methods dataDef :: Parser UDecl dataDef = do @@ -353,9 +313,15 @@ dataDef = do dataCons <- onePerLine dataConDef return $ UData tyCon dataCons --- TODO: default to `Type` if unannoted tyConDef :: Parser UConDef -tyConDef = UConDef <$> (upperName <|> symName) <*> manyNested namedBinder +tyConDef = do + con <- upperName <|> symName + bs <- manyNested $ label "type constructor parameter" $ do + v <- lowerName + ty <- annot containedExpr <|> return tyKind + return $ Bind $ v :> ty + return $ UConDef con bs + where tyKind = ns $ UPrimExpr $ TCExpr TypeKind -- TODO: dependent types dataConDef :: Parser UConDef @@ -370,52 +336,32 @@ decl = do rhs <- sym "=" >> blockOrExpr return $ lhs rhs -interfaceInstance :: Parser UDecl -interfaceInstance = do +instanceDef :: Parser UDecl +instanceDef = do keyWord InstanceKW - (p, pos) <- withPos letPat - ann <- annot uType - case mkConstructorNameVar ann of - Left err -> fail err - Right constructorNameVar -> do - keyWord WhereKW - record <- withSrc $ (URecord . NoExt) <$> interfaceRecordFields "=" - let constructorCall = constructorNameVar `mkApp` record - (ann', rhs') = addImplicitImplicitArgs pos (Just ann) constructorCall - return $ ULet InstanceLet (p, ann') rhs' + explicitArgs <- many defArg + constraints <- classConstraints + classTy <- uType + let ty = buildPiType explicitArgs Pure $ + foldr addClassConstraint classTy constraints + let ty' = foldr addImplicitArg ty $ findImplicitImplicitArgNames ty + methods <- onePerLine instanceMethod + return $ UInstance ty' methods where - -- Here, we are traversing the type annotation to retrieve the name of - -- the interface and generate its corresponding constructor. A valid type - -- annotation for an instance is composed of: - -- 1) implicit/class arguments - -- 2) a function whose name is the name of the interface applied to 0 or - -- more arguments - mkConstructorNameVar ann = - stripArrows ann >>= stripAppliedArgs >>= buildConstructor - - stripArrows (WithSrc _ (UPi _ arr typ)) - | arr `elem` [ClassArrow, ImplicitArrow] = stripArrows typ - | otherwise = Left ("Met invalid arrow '" ++ pprint arr ++ "' in type " ++ - "annotation of instance. Only class arrows and " ++ - "implicit arrows are allowed.") - stripArrows ann = Right ann - - stripAppliedArgs ann - | (WithSrc _ (UApp _ func _)) <- ann = stripAppliedArgs func - | otherwise = Right ann - - buildConstructor (WithSrc _ (UVar v)) = - Right $ (var . nameToStr . mkInterfaceConsName . varName) v - buildConstructor _ = Left ("Could not extract interface name from type " ++ - "annotation.") - var s = noSrc $ UVar $ mkName s :> () - -interfaceRecordFields :: String -> Parser (LabeledItems UExpr) -interfaceRecordFields bindwith = - fuse <$> onePerLine (do l <- fieldLabel - e <- symbol bindwith *> expr - return $ labeledSingleton l e) - where fuse = foldr (<>) NoLabeledItems + addClassConstraint :: UType -> UType -> UType + addClassConstraint c ty = ns $ UPi (Nothing, c) ClassArrow ty + + addImplicitArg :: Name -> UType -> UType + addImplicitArg v ty = + ns $ UPi (Just (ns $ nameToPat v), uTyKind) ImplicitArrow ty + where uTyKind = ns $ UPrimExpr $ TCExpr TypeKind + +instanceMethod :: Parser (UVar, UExpr) +instanceMethod = do + v <- anyName + sym "=" + rhs <- blockOrExpr + return (v:>(), rhs) simpleLet :: Parser (UExpr -> UDecl) simpleLet = label "let binding" $ do @@ -424,14 +370,14 @@ simpleLet = label "let binding" $ do return $ ULet PlainLet (p, ann) letPat :: Parser UPat -letPat = nameAsPat $ upperName <|> lowerName <|> symName +letPat = withSrc $ nameToPat <$> anyName funDefLet :: Parser (UExpr -> UDecl) funDefLet = label "function definition" $ mayBreak $ do keyWord DefKW v <- letPat - cs <- defClassConstraints - argBinders <- many arg + cs <- classConstraints + argBinders <- many defArg (eff, ty) <- label "result type annotation" $ annot effectiveType when (null argBinders && eff /= Pure) $ fail "Nullary def can't have effects" let bs = map classAsBinder cs ++ argBinders @@ -441,22 +387,17 @@ funDefLet = label "function definition" $ mayBreak $ do return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) where classAsBinder :: UType -> (UPat, UType, UArrow) - classAsBinder ty = (underscorePat, ty, ClassArrow) + classAsBinder ty = (ns underscorePat, ty, ClassArrow) - arg :: Parser (UPat, UType, UArrow) - arg = label "def arg" $ do - (p, ty) <-parens ((,) <$> pat <*> annot uType) - arr <- arrow (return ()) <|> return (PlainArrow ()) - return (p, ty, arr) +defArg :: Parser (UPat, UType, UArrow) +defArg = label "def arg" $ do + (p, ty) <-parens ((,) <$> pat <*> annot uType) + arr <- arrow (return ()) <|> return (PlainArrow ()) + return (p, ty, arr) -defClassConstraints :: Parser [UType] -defClassConstraints = - (brackets $ mayNotPair $ uType `sepBy` sym ",") - <|> return [] - "class constraints" - -nameAsPat :: Parser Name -> Parser UPat -nameAsPat p = withSrc $ (UPatBinder . Bind . (:>())) <$> p +classConstraints :: Parser [UType] +classConstraints = label "class constraints" $ + optionalMonoid $ brackets $ mayNotPair $ uType `sepBy` sym "," buildPiType :: [(UPat, UType, UArrow)] -> EffectRow -> UType -> UType buildPiType [] Pure ty = ty @@ -531,18 +472,21 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ noSrc $ UDecl (ULet PlainLet (underscorePat, Nothing) e) $ - noSrc unitExpr + then return $ ns $ UDecl (ULet PlainLet (ns underscorePat, Nothing) e) $ + ns unitExpr else return e -underscorePat :: UPat -underscorePat = noSrc $ UPatBinder $ Ignore () +underscorePat :: UPat' +underscorePat = UPatBinder $ Ignore () + +nameToPat :: Name -> UPat' +nameToPat v = UPatBinder (Bind (v:>())) unitExpr :: UExpr' unitExpr = UPrimExpr $ ConExpr UnitCon -noSrc :: a -> WithSrc a -noSrc = WithSrc Nothing +ns :: a -> WithSrc a +ns = WithSrc Nothing blockOrExpr :: Parser UExpr blockOrExpr = block <|> expr @@ -570,7 +514,7 @@ wrapUStatements statements = case statements of (s, pos):rest -> WithSrc (Just pos) $ case s of Left d -> UDecl d $ wrapUStatements rest Right e -> UDecl d $ wrapUStatements rest - where d = ULet PlainLet (underscorePat, Nothing) e + where d = ULet PlainLet (ns underscorePat, Nothing) e [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement @@ -584,16 +528,17 @@ uPiType = withSrc $ UPi <$> piBinderPat <*> arrow effects <*> uType b <- annBinder return $ case b of Bind (n:>a@(WithSrc pos _)) -> - (Just $ WithSrc pos $ UPatBinder $ Bind $ n:>(), a) + (Just $ WithSrc pos $ nameToPat n, a) Ignore a -> (Nothing, a) annBinder :: Parser UAnnBinder annBinder = try $ namedBinder <|> anonBinder namedBinder :: Parser UAnnBinder -namedBinder = label "named annoted binder" $ lowerName - >>= \v -> sym ":" >> containedExpr - >>= \ty -> return $ Bind (v:>ty) +namedBinder = label "named annoted binder" $ do + v <- lowerName + ty <- annot containedExpr + return $ Bind (v:>ty) anonBinder :: Parser UAnnBinder anonBinder = @@ -622,7 +567,7 @@ ifExpr = withSrc $ do e <- expr (alt1, maybeAlt2) <- oneLineThenElse <|> blockThenElse let alt2 = case maybeAlt2 of - Nothing -> noSrc unitExpr + Nothing -> ns unitExpr Just alt -> alt return $ UCase e [ UAlt (globalEnumPat "True" ) alt1 @@ -647,7 +592,7 @@ blockThenElse = withIndent $ mayNotBreak $ do return (alt1, alt2) globalEnumPat :: Tag -> UPat -globalEnumPat s = noSrc $ UPatCon (GlobalName s) Empty +globalEnumPat s = ns $ UPatCon (GlobalName s) Empty onePerLine :: Parser a -> Parser [a] onePerLine p = liftM (:[]) p @@ -667,8 +612,8 @@ leafPat = <|> (variantPat `fallBackTo` recordPat) <|> brackets (UPatTable <$> leafPat `sepBy` sym ",") ) - where pun pos l = WithSrc (Just pos) $ UPatBinder $ Bind (mkName l:>()) - def pos = WithSrc (Just pos) $ UPatBinder $ Ignore () + where pun pos l = WithSrc (Just pos) $ nameToPat $ mkName l + def pos = WithSrc (Just pos) $ underscorePat variantPat = parseVariant leafPat UPatVariant UPatVariantLift recordPat = UPatRecord <$> parseLabeledItems "," "=" leafPat (Just pun) (Just def) @@ -741,9 +686,8 @@ uIsoSugar = withSrc (char '#' *> options) where <|> char '?' *> (variantFieldIso <$> fieldLabel) <|> char '&' *> (recordZipIso <$> fieldLabel) <|> char '|' *> (variantZipIso <$> fieldLabel) - ns = WithSrc Nothing var s = ns $ UVar $ mkName s :> () - patb s = ns $ UPatBinder $ Bind $ mkName s :> () + patb s = ns $ nameToPat $ mkName s plain = PlainArrow () lam p b = ns $ ULam (p, Nothing) plain b recordFieldIso field = @@ -1020,19 +964,6 @@ inpostfix' p op = Postfix $ do mkName :: String -> Name mkName s = Name SourceName (fromString s) 0 -nameToStr :: Name -> String -nameToStr = tagToStr . nameTag - --- This function is used to generate a string that is guaranteed to never shadow --- any user-defined name, as "#" is an invalid identifier character in normal --- source code. -mkNoShadowingStr :: String -> String -mkNoShadowingStr = (++ "#") - -mkInterfaceConsName :: Name -> Name -mkInterfaceConsName = - GlobalName . fromString . mkNoShadowingStr . nameToStr - -- === lexemes === -- These `Lexer` actions must be non-overlapping and never consume input on failure @@ -1054,6 +985,9 @@ lowerName = liftM mkName $ label "lower-case name" $ lexeme $ anyCaseName :: Lexer Name anyCaseName = lowerName <|> upperName +anyName :: Lexer Name +anyName = lowerName <|> upperName <|> symName + checkNotKeyword :: Parser String -> Parser String checkNotKeyword p = try $ do s <- p @@ -1202,6 +1136,9 @@ mayPair p = local (\ctx -> ctx { canPair = True }) p mayNotPair :: Parser a -> Parser a mayNotPair p = local (\ctx -> ctx { canPair = False }) p +optionalMonoid :: Monoid a => Parser a -> Parser a +optionalMonoid p = p <|> return mempty + nameString :: Parser String nameString = lexeme . try $ (:) <$> lowerChar <*> many alphaNumChar @@ -1244,7 +1181,7 @@ withPos p = do nextLine :: Parser () nextLine = do - void eol + eol n <- asks curIndent void $ mayNotBreak $ many $ try (sc >> eol) void $ replicateM n (char ' ') @@ -1261,8 +1198,11 @@ withIndent p = do indent <- liftM length $ some (char ' ') local (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ p +eol :: Parser () +eol = void MC.eol + eolf :: Parser () -eolf = void eol <|> eof +eolf = eol <|> eof failIf :: Bool -> String -> Parser () failIf True s = fail s diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 277d94693..98a303617 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -13,6 +13,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE Rank2Types #-} +{-# LANGUAGE LambdaCase #-} module Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), @@ -29,8 +30,9 @@ module Syntax ( IExpr (..), IVal, ImpInstr (..), Backend (..), Device (..), IPrimOp, IVar, IBinder, IType, SetVal (..), MonMap (..), LitProg, IFunType (..), IFunVar, CallingConvention (..), IsCUDARequired (..), - UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, - reflectLabels, withLabels, ExtLabeledItems (..), prefixExtLabeledItems, + UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, lookupLabel, + reflectLabels, withLabels, ExtLabeledItems (..), + prefixExtLabeledItems, getLabels, IScope, BinderInfo (..), Bindings, CUDAKernel (..), BenchStats, SrcCtx, Result (..), Output (..), OutFormat (..), Err (..), ErrType (..), Except, throw, throwIf, modifyErr, addContext, @@ -61,7 +63,9 @@ module Syntax ( pattern TabTy, pattern TabTyAbs, pattern TabVal, pattern TabValA, pattern Pure, pattern BinaryFunTy, pattern BinaryFunVal, pattern Unlabeled, pattern NoExt, pattern LabeledRowKind, - pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind) + pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind, + pattern NestOne, pattern NewTypeCon, pattern BinderAnn, + pattern ClassDictDef, pattern ClassDictCon) where import qualified Data.Map.Strict as M @@ -187,10 +191,18 @@ reflectLabels :: LabeledItems a -> LabeledItems (Label, Int) reflectLabels (LabeledItems items) = LabeledItems $ flip M.mapWithKey items $ \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) +getLabels :: LabeledItems a -> [Label] +getLabels labeledItems = map fst $ toList $ reflectLabels labeledItems + withLabels :: LabeledItems a -> LabeledItems (Label, Int, a) withLabels (LabeledItems items) = LabeledItems $ flip M.mapWithKey items $ \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) +lookupLabel :: LabeledItems a -> Label -> Maybe a +lookupLabel (LabeledItems items) l = case M.lookup l items of + Nothing -> Nothing + Just (x NE.:| _) -> Just x + instance Semigroup (LabeledItems a) where LabeledItems items <> LabeledItems items' = LabeledItems $ M.unionWith (<>) items items' @@ -237,6 +249,8 @@ data UExpr' = UVar UVar data UConDef = UConDef Name (Nest UAnnBinder) deriving (Show, Generic) data UDecl = ULet LetAnn UPatAnn UExpr | UData UConDef [UConDef] + | UInterface [UType] UConDef [UAnnBinder] + | UInstance UType [(UVar, UExpr)] deriving (Show, Generic) type UType = UExpr @@ -784,11 +798,15 @@ instance BindsUVars UPat' where instance HasUVars UDecl where freeUVars (ULet _ p expr) = freeUVars p <> freeUVars expr freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons + freeUVars (UInterface _ _ _) = mempty -- TODO + freeUVars (UInstance _ _) = mempty -- TODO instance BindsUVars UDecl where boundUVars decl = case decl of - ULet _ (p,_) _ -> boundUVars p - UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons + ULet _ (p,_) _ -> boundUVars p + UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons + UInterface _ _ _ -> mempty + UInstance _ _ -> mempty instance HasUVars UModule where freeUVars (UModule decls) = freeUVars decls @@ -1005,8 +1023,9 @@ applyNaryAbs (Abs (Nest b bs) body) (x:xs) = applyNaryAbs ab xs applyNaryAbs _ _ = error "wrong number of arguments" applyDataDefParams :: DataDef -> [Type] -> [DataConDef] -applyDataDefParams (DataDef _ paramBs cons) params = - applyNaryAbs (Abs paramBs cons) params +applyDataDefParams (DataDef _ bs cons) params + | length params == length (toList bs) = applyNaryAbs (Abs bs cons) params + | otherwise = error $ "Wrong number of parameters: " ++ show (length params) makeAbs :: HasVars a => Binder -> a -> Abs Binder a makeAbs b body | b `isin` freeVars body = Abs b body @@ -1510,6 +1529,30 @@ pattern NothingAtom ty = DataCon MaybeDataDef [ty] 0 [] pattern JustAtom :: Type -> Atom -> Atom pattern JustAtom ty x = DataCon MaybeDataDef [ty] 1 [x] +pattern NestOne :: a -> Nest a +pattern NestOne x = Nest x Empty + +pattern BinderAnn :: a -> BinderP a +pattern BinderAnn x <- ((\case Ignore ann -> ann + Bind (_:>ann) -> ann) -> x) + where BinderAnn x = Ignore x + +pattern NewTypeCon :: Name -> Type -> [DataConDef] +pattern NewTypeCon con ty <- [DataConDef con (NestOne (BinderAnn ty))] + where NewTypeCon con ty = [DataConDef con (NestOne (Ignore ty))] + +pattern ClassDictDef :: Name + -> LabeledItems Type -> LabeledItems Type -> [DataConDef] +pattern ClassDictDef conName superclasses methods = + [DataConDef conName + (Nest (Ignore (RecordTy (NoExt superclasses))) + (Nest (Ignore (RecordTy (NoExt methods))) Empty))] + +pattern ClassDictCon :: DataDef -> [Type] + -> LabeledItems Atom -> LabeledItems Atom -> Atom +pattern ClassDictCon def params superclasses methods = + DataCon def params 0 [Record superclasses, Record methods] + -- TODO: Enable once https://gitlab.haskell.org//ghc/ghc/issues/13363 is fixed... -- {-# COMPLETE TypeVar, ArrowType, TabTy, Forall, TypeAlias, Effect, NoAnn, TC #-} diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 4dd84a817..39af39287 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -92,7 +92,7 @@ checkBindings env ir bs = void $ runTypeCheck (CheckWith (env <> bs, Pure)) $ mapM_ (checkBinding ir) $ envPairs bs checkBinding :: IRVariant -> (Name, (Type, BinderInfo)) -> TypeM () -checkBinding ir (GlobalName v, b@(ty, info)) = +checkBinding ir (v, b@(ty, info)) | isGlobal (v:>()) = addContext ("binding: " ++ pprint (v, b)) $ do ty |: TyKind when (ir >= Evaluated && not (all isGlobal (envAsVars $ freeVars b))) $ @@ -165,8 +165,8 @@ instance HasType Atom where withBinder b $ typeCheck body ProjectElt (i NE.:| is) v -> do ty <- typeCheck $ case NE.nonEmpty is of - Nothing -> Var v - Just is' -> ProjectElt is' v + Nothing -> Var v + Just is' -> ProjectElt is' v case ty of TypeCon def params -> do [DataConDef _ bs'] <- return $ applyDataDefParams def params @@ -184,7 +184,8 @@ instance HasType Atom where PairTy x _ | i == 0 -> return x PairTy _ y | i == 1 -> return y Var _ -> throw CompilerErr $ "Tried to project value of unreduced type " <> pprint ty - _ -> throw TypeErr $ "Only single-member ADTs and record types can be projected. Got " <> pprint ty + _ -> throw TypeErr $ + "Only single-member ADTs and record types can be projected. Got " <> pprint ty <> " " <> pprint v checkDataConRefBindings :: Nest Binder -> Nest DataConRefBinding -> TypeM () diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 1d2d2306e..97ad29dc7 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -216,7 +216,7 @@ def catLists (xs:List a) (ys:List a) : List a = def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs :t listToTable -> ((a:Type) ?-> (pat:(List a)) -> (Fin ((\((AsList n _)). n) pat)) => a) +> ((a:Type) ?-> (pat:(List a)) -> (Fin (project [0] pat:(List a))) => a) :p l = AsList _ [1, 2, 3] @@ -228,7 +228,7 @@ def listToTable2 (l: List a) : (Fin (listLength l))=>a = xs :t listToTable2 -> ((a:Type) ?-> (l:(List a)) -> (Fin ((\((AsList n _)). n) l)) => a) +> ((a:Type) ?-> (l:(List a)) -> (Fin (project [0] l:(List a))) => a) :p l = AsList _ [1, 2, 3] @@ -258,7 +258,7 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = :t graphToAdjacencyMatrix > ((a:Type) > ?-> (pat:(Graph a)) -> -> ((\((MkGraph n _ _ _)). n) pat) => ((\((MkGraph n _ _ _)). n) pat) => Bool) +> -> (project [0] pat:(Graph a)) => (project [0] pat:(Graph a)) => Bool) :p g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)] @@ -269,15 +269,15 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = def pairUnpack ((v, _):(Int & Float)) : Int = v :p pairUnpack -> \pat:(Int32 & Float32). (\(a, _). a) pat +> \pat:(Int32 & Float32). project [0] pat:(Int32 & Float32) def adtUnpack ((MkMyPair v _):MyPair Int Float) : Int = v :p adtUnpack -> \pat:(MyPair Int32 Float32). (\((MkMyPair elt _)). elt) pat +> \pat:(MyPair Int32 Float32). project [0] pat:(MyPair Int32 Float32) def recordUnpack ({a=v, b=_}:{a:Int & b:Float}) : Int = v :p recordUnpack -> \pat:{a: Int32 & b: Float32}. (\{a = a, b = _}. a) pat +> \pat:{a: Int32 & b: Float32}. project [0] pat:{a: Int32 & b: Float32} def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = (MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x @@ -285,7 +285,7 @@ def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = :p nestedUnpack > \x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)). -> (\((MkIntish (((MkMyPair ((MkMyPair _ elt)) _)), _))). elt) x +> project [0, 0, 0, 1] x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)) :p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) > 4 diff --git a/tests/io-tests.dx b/tests/io-tests.dx index 736c07ff4..853ee0027 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -38,7 +38,7 @@ unsafeIO \(). > 9 is odd > [(), (), (), (), (), (), (), (), (), ()] -:p storageSize (typeVehicle Int) +:p storageSize Int > 4 :p unsafeIO \(). diff --git a/tests/typeclass-tests.dx b/tests/typeclass-tests.dx index 5061d6ece..7970fbfcd 100644 --- a/tests/typeclass-tests.dx +++ b/tests/typeclass-tests.dx @@ -1,38 +1,43 @@ -interface InterfaceTest1 a:Type where + + +interface InterfaceTest1 a InterfaceTest1 : a > Error: variable already defined: InterfaceTest1 -interface InterfaceTest2 typeName:Type where - typeName : typeName -> typeName - -interface InterfaceTest3 _:Type where - foo : Int32 +interface InterfaceTest3 a + foo : a -> Int + foo : a -> Int +> Error: variable already defined: foo -> Parse error:8:26: -> | -> 8 | interface InterfaceTest3 _:Type where -> | ^^^^^ -> unexpected "_:Typ" -> expecting "where" or named annoted binder -interface InterfaceTest4 where +interface InterfaceTest4 a foo : Int + bar : a -> Int + +instance InterfaceTest4 Float + foo = 1 + bar = \_. 1 + foo = 1 +> Type error:Duplicate method: foo + +instance InterfaceTest4 Float + foo = 1 +> Type error:Missing method: bar + +instance InterfaceTest4 Float + baz = 1 +> Type error:baz is not a method of InterfaceTest4 -instance instanceTest4 : InterfaceTest4 where +instance InterfaceTest4 Float foo = 1 + bar = \_. 'x' +> Type error: +> Expected: Int32 +> Actual: Word8 +> +> bar = \_. 'x' +> ^^^ -instance instanceTest4 : InterfaceTest4 x -> InterfaceTest4 (n=>a) where +instance InterfaceTest4 Float foo = 1 + bar = \_. 1 -> Parse error:23:68: -> | -> 23 | instance instanceTest4 : InterfaceTest4 x -> InterfaceTest4 (n=>a) where -> | ^ -> Met invalid arrow '->' in type annotation of instance. Only class arrows and implicit arrows are allowed. -instance instanceTest5 : (..i) where - bar = bar - -> Parse error:31:32: -> | -> 31 | instance instanceTest5 : (..i) where -> | ^ -> Could not extract interface name from type annotation. From b1a65969cebe11ba7a2dc0cad6474fb9e646e9c2 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 4 Jan 2021 14:17:43 -0500 Subject: [PATCH 3/6] Enable BlockArguments by default (we use the bracketing pattern a lot) --- dex.cabal | 8 +++-- src/dex.hs | 4 +-- src/lib/Autodiff.hs | 50 ++++++++++++++--------------- src/lib/Cat.hs | 6 ++-- src/lib/Embed.hs | 34 ++++++++++---------- src/lib/Imp.hs | 60 +++++++++++++++++------------------ src/lib/Inference.hs | 56 ++++++++++++++++---------------- src/lib/JIT.hs | 16 +++++----- src/lib/LLVM/JIT.hs | 6 ++-- src/lib/LLVM/Shims.hs | 10 +++--- src/lib/LLVMExec.hs | 72 +++++++++++++++++++++--------------------- src/lib/Logging.hs | 6 ++-- src/lib/Parallelize.hs | 18 +++++------ src/lib/Parser.hs | 10 +++--- src/lib/Serialize.hs | 4 +-- src/lib/Simplify.hs | 46 +++++++++++++-------------- src/lib/Syntax.hs | 16 ++++------ src/lib/TopLevel.hs | 8 ++--- src/lib/Type.hs | 44 +++++++++++++------------- 19 files changed, 237 insertions(+), 237 deletions(-) diff --git a/dex.cabal b/dex.cabal index a4452d2cb..926c5314e 100644 --- a/dex.cabal +++ b/dex.cabal @@ -61,7 +61,8 @@ library cxx-sources: src/lib/dexrt.cpp cxx-options: -std=c++11 -fPIC default-extensions: CPP, DeriveTraversable, TypeApplications, OverloadedStrings, - TupleSections, ScopedTypeVariables, LambdaCase, PatternSynonyms + TupleSections, ScopedTypeVariables, LambdaCase, PatternSynonyms, + BlockArguments pkgconfig-depends: libpng if flag(cuda) include-dirs: /usr/local/cuda/include @@ -82,7 +83,7 @@ executable dex build-depends: dex-resources default-language: Haskell2010 hs-source-dirs: src - default-extensions: CPP, LambdaCase + default-extensions: CPP, LambdaCase, BlockArguments ghc-options: -threaded if flag(optimized) ghc-options: -O3 @@ -101,7 +102,8 @@ foreign-library Dex cc-options: -std=c11 -fPIC ghc-options: -Wall -fPIC -optP-Wno-nonportable-include-path default-language: Haskell2010 - default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase + default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase, + BlockArguments if flag(optimized) ghc-options: -O3 else diff --git a/src/dex.hs b/src/dex.hs index f08f56c25..50b84a98f 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -109,7 +109,7 @@ printLitProg TextDoc prog = do isatty <- queryTerminal stdOutput putStr $ foldMap (uncurry (printLitBlock isatty)) prog printLitProg JSONDoc prog = - forM_ prog $ \(_, result) -> case toJSONStr result of + forM_ prog \(_, result) -> case toJSONStr result of "{}" -> return () s -> putStrLn s @@ -146,7 +146,7 @@ parseMode = subparser $ objectFileInfo = argument str (metavar "OBJFILE" <> help "Output path (.o file)") optionList :: [(String, a)] -> ReadM a -optionList opts = eitherReader $ \s -> case lookup s opts of +optionList opts = eitherReader \s -> case lookup s opts of Just x -> Right x Nothing -> Left $ "Bad option. Expected one of: " ++ show (map fst opts) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index d48837dfe..8a6dbd964 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -44,10 +44,10 @@ newtype LinA a = LinA { runLinA :: PrimalM (a, TangentM a) } linearize :: Scope -> Atom -> Atom linearize scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do - buildLam b PureArrow $ \x@(Var v) -> do + buildLam b PureArrow \x@(Var v) -> do (y, yt) <- flip runReaderT (DerivWrt (varAsEnv v) [] mempty) $ runLinA $ linearizeBlock (b@>x) block -- TODO: check linearity - fLin <- buildLam (fmap tangentType b) LinArrow $ \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty + fLin <- buildLam (fmap tangentType b) LinArrow \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty fLinChecked <- checkEmbed fLin return $ PairVal y fLinChecked @@ -109,7 +109,7 @@ linearizeExpr env expr = case expr of return (ans, applyLinToTangents linLam) where linearizeInactiveAlt (Abs bs body) = do - buildNAbs bs $ \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body + buildNAbs bs \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body _ -> LinA $ do expr' <- substEmbed env expr runLinA $ case expr' of @@ -255,10 +255,10 @@ linearizeHof :: SubstEnv -> Hof -> LinA Atom linearizeHof env hof = case hof of For ~(RegularFor d) ~(LamVal i body) -> LinA $ do i' <- mapM (substEmbed env) i - (ansWithLinTab, vi'') <- buildForAux d i' $ \i''@(Var vi'') -> + (ansWithLinTab, vi'') <- buildForAux d i' \i''@(Var vi'') -> (,vi'') <$> (willRemat vi'' $ tangentFunAsLambda $ linearizeBlock (env <> i@>i'') body) (ans, linTab) <- unzipTab ansWithLinTab - return (ans, buildFor d i' $ \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) + return (ans, buildFor d i' \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) Tile _ _ _ -> notImplemented RunWriter lam -> linearizeEff Nothing lam True (const RunWriter) (emitRunWriter "r") Writer RunReader val lam -> linearizeEff (Just val) lam False RunReader (emitRunReader "r") Reader @@ -266,7 +266,7 @@ linearizeHof env hof = case hof of RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do arrow' <- substEmbed env arrow -- TODO: consider the possibility of other effects here besides IO - lam <- buildLam (Ignore UnitTy) arrow' $ \_ -> + lam <- buildLam (Ignore UnitTy) arrow' \_ -> tangentFunAsLambda $ linearizeBlock env body result <- emit $ Hof $ RunIO lam (ans, linLam) <- fromPair result @@ -299,18 +299,18 @@ linearizeHof env hof = case hof of let (BinaryFunTy _ b _ _) = getType lam' let RefTy _ wTy = binderType b return $ emitter $ tangentType wTy - valEmitter $ \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do + valEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody return (ans, lin) linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var) linearizeEffectFun rws ~(BinaryFunVal h ref eff body) = do h' <- mapM (substEmbed env) h - buildLamAux h' (const $ return PureArrow) $ \h''@(Var hVar) -> do + buildLamAux h' (const $ return PureArrow) \h''@(Var hVar) -> do let env' = env <> h@>h'' eff' <- substEmbed env' eff ref' <- mapM (substEmbed env') ref - buildLamAux ref' (const $ return $ PlainArrow eff') $ \ref''@(Var refVar) -> + buildLamAux ref' (const $ return $ PlainArrow eff') \ref''@(Var refVar) -> extendWrt [refVar] [RWSEffect rws (varName hVar)] $ (,refVar) <$> (tangentFunAsLambda $ linearizeBlock (env' <> ref@>ref'') body) @@ -341,7 +341,7 @@ linearizeAtom atom = case atom of Con con -> linearizePrimCon con Lam (Abs i (TabArrow, body)) -> LinA $ do wrt <- ask - return (atom, buildLam i TabArrow $ \i' -> + return (atom, buildLam i TabArrow \i' -> rematPrimal wrt $ linearizeBlock (i@>i') body) DataCon _ _ _ _ -> notImplemented -- Need to synthesize or look up a tangent ADT Record elems -> Record <$> traverse linearizeAtom elems @@ -394,7 +394,7 @@ addTangent x y = case getType x of RecordTy (NoExt tys) -> do elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) return $ Record $ restructure elems tys - TabTy b _ -> buildFor Fwd b $ \i -> bindM2 addTangent (tabGet x i) (tabGet y i) + TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) TC con -> case con of BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y @@ -422,8 +422,8 @@ tangentFunAsLambda m = do let hs = map (Bind . (:>TyKind) . effectRegion) effs let rematList = envAsVars remats liftM (PairVal ans) $ lift $ do - tanLam <- makeLambdas rematList $ \rematArgs -> - buildNestedLam PureArrow hs $ \hVals -> do + tanLam <- makeLambdas rematList \rematArgs -> + buildNestedLam PureArrow hs \hVals -> do let hVarNames = map (\(Var (v:>_)) -> v) hVals -- TODO: handle exception effect too let effs' = zipWith (\(RWSEffect rws _) v -> RWSEffect rws v) effs hVarNames @@ -431,8 +431,8 @@ tangentFunAsLambda m = do let regionMap = newEnv (map ((:>()) . effectRegion) effs) hVals -- TODO: Only bind tangents for free variables? let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars - buildNestedLam PureArrow activeVarBinders $ \activeVarArgs -> - buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) $ \_ -> + buildNestedLam PureArrow activeVarBinders \activeVarArgs -> + buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) \_ -> runReaderT tanFun $ TangentEnv (newEnv (envNames activeVars) activeVarArgs) hVarNames (newEnv rematList $ fmap Var rematArgs) @@ -448,7 +448,7 @@ tangentFunAsLambda m = do return $ Lam $ makeAbs (Bind v) (PureArrow, block) makeLambdas [] f = f [] - makeLambdas (v:vs) f = makeLambda v $ \x -> makeLambdas vs $ \xs -> f (x:xs) + makeLambdas (v:vs) f = makeLambda v \x -> makeLambdas vs \xs -> f (x:xs) -- This doesn't work if we have references inside pairs, tables etc. -- TODO: something more general and cleaner. @@ -544,7 +544,7 @@ type TransposeM a = ReaderT TransposeEnv Embed a transpose :: Scope -> Atom -> Atom transpose scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do - buildLam (Bind $ "ct" :> getType block) LinArrow $ \ct -> do + buildLam (Bind $ "ct" :> getType block) LinArrow \ct -> do snd <$> (flip runReaderT mempty $ withLinVar b $ transposeBlock block ct) transposeBlock :: Block -> Atom -> TransposeM () @@ -590,7 +590,7 @@ transposeExpr expr ct = case expr of void $ emit $ Case e' alts' UnitTy where transposeNonlinAlt (Abs bs body) = - buildNAbs bs $ \xs -> do + buildNAbs bs \xs -> do localNonlinSubst (newEnv bs xs) $ transposeBlock body ct return UnitVal @@ -619,7 +619,7 @@ transposeOp op ct = case op of MPut x -> do transposeAtom x =<< emitEff MGet void $ emitEff $ MPut $ zeroAt $ getType x - TabCon ~(TabTy b _) es -> forM_ (enumerate es) $ \(i, e) -> do + TabCon ~(TabTy b _) es -> forM_ (enumerate es) \(i, e) -> do transposeAtom e =<< tabGet ct =<< intToIndexE (binderType b) (IdxRepVal $ fromIntegral i) IndexRef _ _ -> notImplemented FstRef _ -> notImplemented @@ -675,24 +675,24 @@ linAtomRef a = error $ "Not a linear var: " ++ pprint a transposeHof :: Hof -> Atom -> TransposeM () transposeHof hof ct = case hof of For ~(RegularFor d) ~(Lam (Abs b (_, body))) -> - void $ buildFor (flipDir d) b $ \i -> do + void $ buildFor (flipDir d) b \i -> do ct' <- tabGet ct i localNonlinSubst (b@>i) $ transposeBlock body ct' return UnitVal where flipDir dir = case dir of Fwd -> Rev; Rev -> Fwd RunReader r ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do - (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) $ \ref -> do + (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ct return UnitVal transposeAtom r ctr RunWriter ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do (ctBody, ctEff) <- fromPair ct - void $ emitRunReader "r" ctEff $ \ref -> do + void $ emitRunReader "r" ctEff \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal RunState s ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do (ctBody, ctState) <- fromPair ct - (_, cts) <- (fromPair =<<) $ emitRunState "s" ctState $ \ref -> do + (_, cts) <- (fromPair =<<) $ emitRunState "s" ctState \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal transposeAtom s cts @@ -715,7 +715,7 @@ transposeAtom atom ct = case atom of DataCon _ _ _ e -> void $ zipWithT transposeAtom e =<< getUnpacked ct Variant _ _ _ _ -> notImplemented TabVal b body -> - void $ buildFor Fwd b $ \i -> do + void $ buildFor Fwd b \i -> do ct' <- tabGet ct i localNonlinSubst (b@>i) $ transposeBlock body ct' return UnitVal @@ -787,7 +787,7 @@ withLinVar :: Binder -> TransposeM a -> TransposeM (a, Atom) withLinVar b body = case singletonTypeVal (binderType b) of Nothing -> flip evalStateT Nothing $ do - ans <- emitRunWriter "ref" (binderType b) $ \ref -> do + ans <- emitRunWriter "ref" (binderType b) \ref -> do lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal (,) <$> (fromJust <$> get) <*> (getSnd ans) Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs index aa6d703fa..f120df661 100644 --- a/src/lib/Cat.hs +++ b/src/lib/Cat.hs @@ -50,7 +50,7 @@ instance (Monoid env, Monad m) => MonadCat env (CatT env m) where instance MonadCat env m => MonadCat env (StateT s m) where look = lift look extend x = lift $ extend x - scoped m = StateT $ \s -> do + scoped m = StateT \s -> do ((ans, s'), env) <- scoped $ runStateT m s return $ ((ans, env), s') @@ -145,7 +145,7 @@ catTraverse f inj xs env = runCatT (traverse (asCat f inj) xs) env catFoldM :: (Monoid env, Traversable t, Monad m) => (env -> a -> m env) -> env -> t a -> m env -catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs $ \x -> do +catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs \x -> do cur <- look new <- lift $ f cur x extend new @@ -156,7 +156,7 @@ catFold f env xs = runIdentity $ catFoldM (\e x -> Identity $ f e x) env xs catMapM :: (Monoid env, Traversable t, Monad m) => (env -> a -> m (b, env)) -> env -> t a -> m (t b, env) -catMapM f env xs = flip runCatT env $ forM xs $ \x -> do +catMapM f env xs = flip runCatT env $ forM xs \x -> do cur <- look (y, new) <- lift $ f cur x extend new diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 705d1c50a..c46397d64 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -164,7 +164,7 @@ buildLam b arr body = buildDepEffLam b (const (return arr)) body buildDepEffLam :: MonadEmbed m => Binder -> (Atom -> m Arrow) -> (Atom -> m Atom) -> m Atom -buildDepEffLam b fArr fBody = liftM fst $ buildLamAux b fArr $ \x -> (,()) <$> fBody x +buildDepEffLam b fArr fBody = liftM fst $ buildLamAux b fArr \x -> (,()) <$> fBody x buildLamAux :: MonadEmbed m => Binder -> (Atom -> m Arrow) -> (Atom -> m (Atom, a)) -> m (Atom, a) @@ -180,7 +180,7 @@ buildLamAux b fArr fBody = do return (Lam $ makeAbs b' (arr, wrapDecls decls ans), aux) buildNAbs :: MonadEmbed m => Nest Binder -> ([Atom] -> m Atom) -> m Alt -buildNAbs bs body = liftM fst $ buildNAbsAux bs $ \xs -> (,()) <$> body xs +buildNAbs bs body = liftM fst $ buildNAbsAux bs \xs -> (,()) <$> body xs buildNAbsAux :: MonadEmbed m => Nest Binder -> ([Atom] -> m (Atom, a)) -> m (Alt, a) buildNAbsAux bs body = do @@ -202,9 +202,9 @@ buildDataDef tyConName paramBinders body = do buildImplicitNaryLam :: MonadEmbed m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom buildImplicitNaryLam Empty body = body [] buildImplicitNaryLam (Nest b bs) body = - buildLam b ImplicitArrow $ \x -> do + buildLam b ImplicitArrow \x -> do bs' <- substEmbed (b@>x) bs - buildImplicitNaryLam bs' $ \xs -> body $ x:xs + buildImplicitNaryLam bs' \xs -> body $ x:xs recGet :: Label -> Atom -> Atom recGet l x = do @@ -383,14 +383,14 @@ unpackConsList xs = case getType xs of emitWhile :: MonadEmbed m => m Atom -> m () emitWhile body = do eff <- getAllowedEffects - lam <- buildLam (Ignore UnitTy) (PlainArrow eff) $ \_ -> body + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) \_ -> body void $ emit $ Hof $ While lam emitMaybeCase :: MonadEmbed m => Atom -> m Atom -> (Atom -> m Atom) -> m Atom emitMaybeCase scrut nothingCase justCase = do let (MaybeTy a) = getType scrut - nothingAlt <- buildNAbs Empty $ \[] -> nothingCase - justAlt <- buildNAbs (Nest (Bind ("x":>a)) Empty) $ \[x] -> justCase x + nothingAlt <- buildNAbs Empty \[] -> nothingCase + justAlt <- buildNAbs (Nest (Bind ("x":>a)) Empty) \[x] -> justCase x let (Abs _ nothingBody) = nothingAlt let resultTy = getType nothingBody emit $ Case scrut [nothingAlt, justAlt] resultTy @@ -410,7 +410,7 @@ emitRunState v x0 body = do mkBinaryEffFun :: MonadEmbed m => RWS -> Name -> Type -> (Atom -> m Atom) -> m Atom mkBinaryEffFun rws v ty body = do eff <- getAllowedEffects - buildLam (Bind ("h":>TyKind)) PureArrow $ \r@(Var (rName:>_)) -> do + buildLam (Bind ("h":>TyKind)) PureArrow \r@(Var (rName:>_)) -> do let arr = PlainArrow $ extendEffect (RWSEffect rws rName) eff buildLam (Bind (v:> RefTy r ty)) arr body @@ -434,16 +434,16 @@ buildFor = buildForAnn . RegularFor buildNestedLam :: MonadEmbed m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom buildNestedLam _ [] f = f [] buildNestedLam arr (b:bs) f = - buildLam b arr $ \x -> buildNestedLam arr bs $ \xs -> f (x:xs) + buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) tabGet :: MonadEmbed m => Atom -> Atom -> m Atom tabGet x i = emit $ App x i unzipTab :: MonadEmbed m => Atom -> m (Atom, Atom) unzipTab tab = do - fsts <- buildLam (Bind ("i":>binderType v)) TabArrow $ \i -> + fsts <- buildLam (Bind ("i":>binderType v)) TabArrow \i -> liftM fst $ app tab i >>= fromPair - snds <- buildLam (Bind ("i":>binderType v)) TabArrow $ \i -> + snds <- buildLam (Bind ("i":>binderType v)) TabArrow \i -> liftM snd $ app tab i >>= fromPair return (fsts, snds) where TabTy v _ = getType tab @@ -509,9 +509,9 @@ instance Monad m => MonadEmbed (EmbedT m) where instance MonadEmbed m => MonadEmbed (ReaderT r m) where embedLook = lift embedLook embedExtend x = lift $ embedExtend x - embedScoped m = ReaderT $ \r -> embedScoped $ runReaderT m r + embedScoped m = ReaderT \r -> embedScoped $ runReaderT m r embedAsk = lift embedAsk - embedLocal v m = ReaderT $ \r -> embedLocal v $ runReaderT m r + embedLocal v m = ReaderT \r -> embedLocal v $ runReaderT m r instance MonadEmbed m => MonadEmbed (StateT s m) where embedLook = lift embedLook @@ -710,7 +710,7 @@ traverseExpr def@(_, _, fAtom) expr = case expr of where traverseAlt (Abs bs body) = do bs' <- mapM (mapM fAtom) bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ evalBlockE def body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ evalBlockE def body traverseAtom :: forall m . (MonadEmbed m, MonadReader SubstEnv m) => TraversalDef m -> Atom -> m Atom @@ -747,7 +747,7 @@ traverseAtom def@(_, _, fAtom) atom = case atom of BoxedRef b ptr size body -> do ptr' <- fAtom ptr size' <- buildScoped $ evalBlockE def size - Abs b' (decls, body') <- buildAbs b $ \x -> + Abs b' (decls, body') <- buildAbs b \x -> extendR (b@>x) $ evalBlockE def (Block Empty $ Atom body) case decls of Empty -> return $ BoxedRef b' ptr' size' body' @@ -765,7 +765,7 @@ traverseAtom def@(_, _, fAtom) atom = case atom of traverseAAlt (Abs bs a) = do bs' <- mapM (mapM fAtom) bs - (Abs bs'' b) <- buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ fAtom a + (Abs bs'' b) <- buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ fAtom a case b of Block Empty (Atom r) -> return $ Abs bs'' r _ -> error "ACase alternative traversal has emitted decls or exprs!" @@ -842,7 +842,7 @@ indexToIntE idx = case getType idx of (offsets, _) <- scanM (\sz prev -> (prev,) <$> iadd sz prev) sizes (IdxRepVal 0) -- Build and apply a case expression alts <- flip mapM (zip (toList offsets) (toList types)) $ - \(offset, subty) -> buildNAbs (toNest [Ignore subty]) $ \[subix] -> do + \(offset, subty) -> buildNAbs (toNest [Ignore subty]) \[subix] -> do i <- indexToIntE subix iadd offset i emit $ Case idx alts IdxRepTy diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index aa3c94663..deae7ab50 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -80,7 +80,7 @@ toImpModule :: TopEnv -> Backend -> CallingConvention -> Name -> (ImpFunction, ImpModule, AtomRecon) toImpModule env backend cc entryName argBinders maybeDest block = do let standaloneFunctions = - for (requiredFunctions env block) $ \(v, f) -> + for (requiredFunctions env block) \(v, f) -> runImpM initCtx inVarScope $ toImpStandalone v f runImpM initCtx inVarScope $ do (reconAtom, impBlock) <- scopedBlock $ translateTopLevel env (maybeDest, block) @@ -98,7 +98,7 @@ toImpModule env backend cc entryName argBinders maybeDest block = do requiredFunctions :: HasVars a => Scope -> a -> [(Name, Atom)] requiredFunctions scope expr = - flip foldMap (transitiveClosure getParents immediateParents) $ \fname -> + flip foldMap (transitiveClosure getParents immediateParents) \fname -> case scope ! fname of (_, LetBound _ (Atom f)) -> [(fname, f)] (_, LamBound _) -> [] @@ -142,7 +142,7 @@ toImpStandalone fname ~(LamVal b body) = do impBlock <- scopedErrBlock $ do arg <- destToAtom argDest void $ translateBlock (b@>arg) (Just outDest, body) - let bs = for ptrSizes $ \(Bind (v:>BaseTy ty), _) -> Bind $ v:>ty + let bs = for ptrSizes \(Bind (v:>BaseTy ty), _) -> Bind $ v:>ty let fTy = IFunType CEntryFun (map binderAnn bs) (impBlockType impBlock) return $ ImpFunction (fname:>fTy) bs impBlock @@ -150,7 +150,7 @@ translateBlock :: SubstEnv -> WithDest Block -> ImpM Atom translateBlock env destBlock = do let (decls, result, copies) = splitDest destBlock env' <- (env<>) <$> catFoldM translateDecl env decls - forM_ copies $ \(dest, atom) -> copyAtom dest =<< impSubst env' atom + forM_ copies \(dest, atom) -> copyAtom dest =<< impSubst env' atom translateExpr env' result translateDecl :: SubstEnv -> WithDest Decl -> ImpM SubstEnv @@ -239,7 +239,7 @@ toImpOp :: WithDest (PrimOp Atom) -> ImpM Atom toImpOp (maybeDest, op) = case op of TabCon (TabTy b _) rows -> do dest <- allocDest maybeDest resultTy - forM_ (zip [0..] rows) $ \(i, row) -> do + forM_ (zip [0..] rows) \(i, row) -> do ithDest <- destGet dest =<< intToIndex (binderType b) (IIdxRepVal i) copyAtom ithDest row destToAtom dest @@ -358,7 +358,7 @@ toImpHof env (maybeDest, hof) = do Select (toScalarAtom isLast) (toScalarAtom elemsUntilEnd) (toScalarAtom usualChunkSize)) - emitLoop "li" Fwd (fromScalarAtom threadChunkSize) $ \li -> do + emitLoop "li" Fwd (fromScalarAtom threadChunkSize) \li -> do i <- li `iaddI` chunkStart let idx = Con $ ParIndexCon idxTy $ toScalarAtom i ithDest <- destGet dest idx @@ -381,7 +381,7 @@ toImpHof env (maybeDest, hof) = do _ -> do n <- indexSetSize idxTy dest <- allocDest maybeDest resultTy - emitLoop (binderNameHint b) d n $ \i -> do + emitLoop (binderNameHint b) d n \i -> do idx <- intToIndex idxTy i ithDest <- destGet dest idx void $ translateBlock (env <> b @> idx) (Just ithDest, body) @@ -389,13 +389,13 @@ toImpHof env (maybeDest, hof) = do For ParallelFor ~fbody@(LamVal b _) -> do idxTy <- impSubst env $ binderType b dest <- allocDest maybeDest resultTy - buildKernel idxTy $ \LaunchInfo{..} buildBody -> do - liftM (,()) $ buildBody $ \ThreadInfo{..} -> do + buildKernel idxTy \LaunchInfo{..} buildBody -> do + liftM (,()) $ buildBody \ThreadInfo{..} -> do let threadBody = fst $ flip runSubstEmbed (freeVars fbody) $ - buildLam (Bind $ "hwidx" :> threadRange) PureArrow $ \hwidx -> + buildLam (Bind $ "hwidx" :> threadRange) PureArrow \hwidx -> appReduce fbody =<< (emitOp $ Inject hwidx) let threadDest = Con $ TabRef $ fst $ flip runSubstEmbed (freeVars dest) $ - buildLam (Bind $ "hwidx" :> threadRange) TabArrow $ \hwidx -> + buildLam (Bind $ "hwidx" :> threadRange) TabArrow \hwidx -> indexDest dest =<< (emitOp $ Inject hwidx) void $ toImpHof env (Just threadDest, For (RegularFor Fwd) threadBody) destToAtom dest @@ -407,12 +407,12 @@ toImpHof env (maybeDest, hof) = do nTiles <- n `idivI` tileLen epilogueOff <- nTiles `imulI` tileLen nEpilogue <- n `isubI` epilogueOff - emitLoop (binderNameHint tb) Fwd nTiles $ \iTile -> do + emitLoop (binderNameHint tb) Fwd nTiles \iTile -> do tileOffset <- toScalarAtom <$> iTile `imulI` tileLen let tileAtom = Con $ IndexSliceVal idxTy tileIdxTy tileOffset tileDest <- fromEmbed $ sliceDestDim d dest tileOffset tileIdxTy void $ translateBlock (env <> tb @> tileAtom) (Just tileDest, tBody) - emitLoop (binderNameHint sb) Fwd nEpilogue $ \iEpi -> do + emitLoop (binderNameHint sb) Fwd nEpilogue \iEpi -> do i <- iEpi `iaddI` epilogueOff idx <- intToIndex idxTy i sDest <- fromEmbed $ indexDestDim d dest idx @@ -422,16 +422,16 @@ toImpHof env (maybeDest, hof) = do idxTy <- impSubst env idxTy' (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy let PairTy _ accType = resultTy - (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy $ \LaunchInfo{..} buildBody -> do + (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do let widIdxTy = Fin $ toScalarAtom numWorkgroups let tidIdxTy = Fin $ toScalarAtom workgroupSize wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType - mappingKernelBody <- buildBody $ \ThreadInfo{..} -> do + mappingKernelBody <- buildBody \ThreadInfo{..} -> do let TC (ParIndexRange _ gtid nthr) = threadRange let scope = freeVars mappingDest let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do - buildLam (Bind $ "hwidx":>threadRange) TabArrow $ \hwidx -> do + buildLam (Bind $ "hwidx":>threadRange) TabArrow \hwidx -> do indexDest mappingDest =<< (emitOp $ Inject hwidx) wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid @@ -443,12 +443,12 @@ toImpHof env (maybeDest, hof) = do -- TODO: Skip the reduction kernel if unnecessary? -- TODO: Reduce sequentially in the CPU backend? -- TODO: Actually we only need the previous-power-of-2 many threads - buildKernel widIdxTy $ \LaunchInfo{..} buildBody -> do + buildKernel widIdxTy \LaunchInfo{..} buildBody -> do -- We only do a one-level reduciton in the workgroup, so it is correct -- only if the end up scheduling a single workgroup. moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups guardBlock moreThanOneGroup $ emitStatement IThrowError - redKernelBody <- buildBody $ \ThreadInfo{..} -> + redKernelBody <- buildBody \ThreadInfo{..} -> workgroupReduce tid finalAccDest wgResArr numTileWorkgroups return (redKernelBody, ()) PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest @@ -548,7 +548,7 @@ buildKernel idxTy f = do LLVMCUDA -> (CUDAKernelLaunch, GPU) LLVMMC -> (MCThreadLaunch , CPU) backend -> error $ "Shouldn't be launching kernels from " ++ show backend - ((kernelBody, aux), env) <- scoped $ f LaunchInfo{..} $ \mkBody -> + ((kernelBody, aux), env) <- scoped $ f LaunchInfo{..} \mkBody -> withDevice dev $ withLevel ThreadLevel $ scopedErrBlock $ do gtid <- iaddI tid =<< imulI wid wsz let threadRange = TC $ ParIndexRange idxTy (toScalarAtom gtid) (toScalarAtom nthr) @@ -581,7 +581,7 @@ type DestM = ReaderT DestEnv (CatT (Env (Type, Block)) Embed) makeDest :: AllocInfo -> Type -> Embed ([(Binder, Atom)], Dest) makeDest allocInfo ty = do (dest, ptrs) <- flip runCatT mempty $ flip runReaderT env $ makeDestRec ty - ptrs' <- forM (envPairs ptrs) $ \(v, (ptrTy, numel)) -> do + ptrs' <- forM (envPairs ptrs) \(v, (ptrTy, numel)) -> do numel' <- emitBlock numel return (Bind (v:>ptrTy), numel') return (ptrs', dest) @@ -598,7 +598,7 @@ makeDestRec ty = case ty of makeDestRec ty makeBoxes (envPairs ptrs) dest else do - lam <- buildLam (Bind ("i":> binderAnn b)) TabArrow $ \(Var i) -> do + lam <- buildLam (Bind ("i":> binderAnn b)) TabArrow \(Var i) -> do bodyTy' <- substEmbed (b@>Var i) bodyTy withEnclosingIdxs (Bind i) $ makeDestRec bodyTy' return $ Con $ TabRef lam @@ -614,7 +614,7 @@ makeDestRec ty = case ty of "Dependent data constructors only allowed for single-constructor types" tag <- rec TagRepTy let dcs' = applyDataDefParams def params - contents <- forM dcs' $ \(DataConDef _ bs) -> forM (toList bs) (rec . binderType) + contents <- forM dcs' \(DataConDef _ bs) -> forM (toList bs) (rec . binderType) return $ Con $ ConRef $ SumAsProd ty tag contents RecordTy (NoExt types) -> (Con . RecordRef) <$> forM types rec VariantTy (NoExt types) -> do @@ -720,7 +720,7 @@ loadDest (DataConRef def params bs) = do loadDest (Con dest) = do case dest of BaseTypeRef ptr -> unsafePtrLoad ptr - TabRef (TabVal b body) -> buildLam b TabArrow $ \i -> do + TabRef (TabVal b body) -> buildLam b TabArrow \i -> do body' <- substEmbed (b@>i) body result <- emitBlock body' loadDest result @@ -744,7 +744,7 @@ loadDataConArgs (Nest (DataConRefBinding b ref) rest) = do indexDestDim :: MonadEmbed m => Int->Dest -> Atom -> m Dest indexDestDim 0 dest i = indexDest dest i -indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do +indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j indexDestDim (d-1) dest' i where @@ -757,7 +757,7 @@ indexDest dest _ = error $ pprint dest sliceDestDim :: MonadEmbed m => Int -> Dest -> Atom -> Type -> m Dest sliceDestDim 0 dest i sliceIdxTy = sliceDest dest i sliceIdxTy -sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do +sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j sliceDestDim (d-1) dest' i sliceIdxTy where @@ -766,7 +766,7 @@ sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do sliceDest :: MonadEmbed m => Dest -> Atom -> Type -> m Dest sliceDest ~(Con (TabRef tab@(TabVal b _))) i sliceIdxTy = (Con . TabRef) <$> do - buildFor Fwd (Bind ("j" :> sliceIdxTy)) $ \j -> do + buildFor Fwd (Bind ("j" :> sliceIdxTy)) \j -> do j' <- indexToIntE j ioff <- iadd j' i vidx <- intToIndexE (binderType b) ioff @@ -790,7 +790,7 @@ makeAllocDestWithPtrs allocTy ty = do backend <- asks impBackend curDev <- asks curDevice (ptrsSizes, dest) <- fromEmbed $ makeDest (backend, curDev, allocTy) ty - (env, ptrs) <- flip foldMapM ptrsSizes $ \(Bind (ptr:>PtrTy ptrTy), size) -> do + (env, ptrs) <- flip foldMapM ptrsSizes \(Bind (ptr:>PtrTy ptrTy), size) -> do ptr' <- emitAlloc ptrTy $ fromScalarAtom size case ptrTy of (Heap _, _) | allocTy == Managed -> extendAlloc ptr' @@ -811,7 +811,7 @@ splitDest (maybeDest, (Block decls ans)) = do let closureCopies = fmap (\(n, (d, t)) -> (d, Var $ n :> t)) (envPairs $ varDests `envDiff` foldMap letBoundVars decls) - let destDecls = flip fmap (toList decls) $ \d -> case d of + let destDecls = flip fmap (toList decls) \d -> case d of Let _ b _ -> (fst <$> varDests `envLookup` b, d) (destDecls, (Nothing, ans), gatherCopies ++ closureCopies) _ -> (fmap (Nothing,) $ toList decls, (maybeDest, ans), []) @@ -939,7 +939,7 @@ zipTabDestAtom f ~dest@(Con (TabRef (TabVal b _))) ~src@(TabVal b' _) = do error $ "Mismatched dimensions: " <> pprint b <> " != " <> pprint b' let idxTy = binderType b n <- indexSetSize idxTy - emitLoop "i" Fwd n $ \i -> do + emitLoop "i" Fwd n \i -> do idx <- intToIndex idxTy i destIndexed <- destGet dest idx srcIndexed <- translateExpr mempty (Nothing, App src idx) @@ -1094,7 +1094,7 @@ restructureScalarOrPairTypeRec ty _ = error $ "Not a scalar or pair: " ++ pprint emitMultiReturnInstr :: ImpInstr -> ImpM [IExpr] emitMultiReturnInstr instr = do - vs <- forM (impInstrTypes instr) $ \ty -> freshVar ("v":>ty) + vs <- forM (impInstrTypes instr) \ty -> freshVar ("v":>ty) emitImpDecl $ ImpLet (map Bind vs) instr return (map IVar vs) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 5a74edbfd..73d673554 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -50,7 +50,7 @@ inferModule :: TopEnv -> UModule -> Except Module inferModule scope (UModule decls) = do ((), (bindings, decls')) <- runUInferM mempty scope $ mapM_ (inferUDecl True) decls - let bindings' = envFilter bindings $ \(_, b) -> case b of + let bindings' = envFilter bindings \(_, b) -> case b of DataBoundTypeCon _ -> True DataBoundDataCon _ _ -> True _ -> False @@ -68,7 +68,7 @@ checkSigma expr reqCon sTy = case sTy of WithSrc _ (ULam b arrow' body) | arrow' == void arrow -> checkULam b body piTy _ -> do - buildLam (Bind ("a":> absArgType piTy)) arrow $ \x@(Var v) -> + buildLam (Bind ("a":> absArgType piTy)) arrow \x@(Var v) -> checkLeaks [v] $ checkSigma expr reqCon $ snd $ applyAbs piTy x _ -> checkOrInferRho expr (reqCon sTy) @@ -157,7 +157,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do -- TODO: check leaks kind' <- checkUType kind piTy <- case pat of - Just pat' -> withNameHint ("pat" :: Name) $ buildPi b $ \x -> + Just pat' -> withNameHint ("pat" :: Name) $ buildPi b \x -> withBindPat pat' x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty where b = case pat' of -- Note: The binder name becomes part of the type, so we @@ -182,7 +182,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do case scrutTy' of TypeCon def params -> do let conDefs = applyDataDefParams def params - altsSorted <- forM (enumerate conDefs) $ \(i, DataConDef _ bs) -> do + altsSorted <- forM (enumerate conDefs) \(i, DataConDef _ bs) -> do case lookup (ConAlt i) alts' of Nothing -> return $ Abs (fmap (Ignore . binderType) bs) $ Block Empty $ Op $ ThrowError reqTy' @@ -256,7 +256,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do val' <- checkSigma val reqCon ty' matchRequirement val' UPrimExpr prim -> do - prim' <- forM prim $ \e -> do + prim' <- forM prim \e -> do e' <- inferRho e scope <- getScope return $ typeReduceAtom scope e' @@ -319,7 +319,7 @@ unpackTopPat :: LetAnn -> UPat -> Expr -> UInferM () unpackTopPat letAnn pat expr = do atom <- emit expr bindings <- bindPat pat atom - void $ flip traverseNames bindings $ \name val -> do + void $ flip traverseNames bindings \name val -> do let name' = asGlobal name checkNotInScope name' emitTo name' letAnn $ Atom val @@ -342,15 +342,15 @@ inferUDecl topLevel (ULet letAnn (p, ann) rhs) = do else bindPat p val inferUDecl True (UData tc dcs) = do (tc', paramBs) <- inferUConDef tc - dataDef <- buildDataDef tc' paramBs $ \params -> do - extendR (newEnv paramBs params) $ forM dcs $ \dc -> + dataDef <- buildDataDef tc' paramBs \params -> do + extendR (newEnv paramBs params) $ forM dcs \dc -> uncurry DataConDef <$> inferUConDef dc checkDataDefShadows dataDef emitConstructors dataDef return mempty inferUDecl True (UInterface superclasses tc methods) = do (tc', paramBs) <- inferUConDef tc - dataDef <- buildDataDef tc' paramBs $ \params -> do + dataDef <- buildDataDef tc' paramBs \params -> do extendR (newEnv paramBs params) $ do conName <- freshClassGenName superclasses' <- mkLabeledItems <$> mapM mkSuperclass superclasses @@ -403,16 +403,16 @@ emitConstructors def@(DataDef tyConName _ dataConDefs) = do let tyConTy = getType $ TypeCon def [] checkNotInScope tyConName extendScope $ tyConName @> (tyConTy, DataBoundTypeCon def) - forM_ (zip [0..] dataConDefs) $ \(i, DataConDef dataConName _) -> do + forM_ (zip [0..] dataConDefs) \(i, DataConDef dataConName _) -> do let dataConTy = getType $ DataCon def [] i [] checkNotInScope dataConName extendScope $ dataConName @> (dataConTy, DataBoundDataCon def i) emitMethodGetters :: DataDef -> UInferM () emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do - forM_ (getLabels methodTys) $ \l -> do - f <- buildImplicitNaryLam paramBs $ \params -> do - buildLam (Bind ("d":> TypeCon def params)) ClassArrow $ \dict -> do + forM_ (getLabels methodTys) \l -> do + f <- buildImplicitNaryLam paramBs \params -> do + buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do return $ recGet l $ getProjection [1] dict let methodName = GlobalName $ fromString l checkNotInScope methodName @@ -421,9 +421,9 @@ emitMethodGetters (DataDef _ _ _) = error "Not a class dictionary" emitSuperclassGetters :: MonadEmbed m => DataDef -> m () emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = do - forM_ (getLabels superclassTys) $ \l -> do - f <- buildImplicitNaryLam paramBs $ \params -> do - buildLam (Bind ("d":> TypeCon def params)) PureArrow $ \dict -> do + forM_ (getLabels superclassTys) \l -> do + f <- buildImplicitNaryLam paramBs \params -> do + buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do return $ recGet l $ getProjection [0] dict getterName <- freshClassGenName emitTo getterName SuperclassLet $ Atom f @@ -468,7 +468,7 @@ inferULam (p, ann) arr body = do argTy <- checkAnn ann -- TODO: worry about binder appearing in arrow? buildLam (Bind $ patNameHint p :> argTy) arr - $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ inferSigma body + \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ inferSigma body checkULam :: UPatAnn -> UExpr -> PiType -> UInferM Atom checkULam (p, ann) body piTy = do @@ -476,7 +476,7 @@ checkULam (p, ann) body piTy = do checkAnn ann >>= constrainEq argTy buildDepEffLam (Bind $ patNameHint p :> argTy) ( \x -> return $ fst $ applyAbs piTy x) - $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ + \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x checkInstance :: Type -> [(UVar, UExpr)] -> UInferM Atom @@ -484,7 +484,7 @@ checkInstance ty methods = case ty of TypeCon def@(DataDef className _ _) params -> do case applyDataDefParams def params of ClassDictDef _ superclassTys methodTys -> do - methods' <- liftM mkLabeledItems $ forM methods $ \((v:>()), rhs) -> do + methods' <- liftM mkLabeledItems $ forM methods \((v:>()), rhs) -> do let v' = nameToLabel v case lookupLabel methodTys v' of Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) @@ -492,9 +492,9 @@ checkInstance ty methods = case ty of rhs' <- checkSigma rhs Suggest methodTy return (v', rhs') let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys - forM_ (reflectLabels methods') $ \(l,i) -> + forM_ (reflectLabels methods') \(l,i) -> when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l - forM_ (reflectLabels methodTys) $ \(l,_) -> + forM_ (reflectLabels methodTys) \(l,_) -> case lookupLabel methods' l of Nothing -> throw TypeErr $ "Missing method: " ++ pprint l Just _ -> return () @@ -505,7 +505,7 @@ checkInstance ty methods = case ty of ImplicitArrow -> return () ClassArrow -> return () _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow - buildLam b arrow $ \x@(Var v) -> do + buildLam b arrow \x@(Var v) -> do bodyTy' <- substEmbed (b@>x) bodyTy checkLeaks [v] $ extendR (b@>x) $ checkInstance bodyTy' methods _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty @@ -513,7 +513,7 @@ checkInstance ty methods = case ty of checkUEffRow :: EffectRow -> UInferM EffectRow checkUEffRow (EffectRow effs t) = do effs' <- liftM S.fromList $ mapM checkUEff $ toList effs - t' <- forM t $ \tv -> lookupVarName EffKind tv + t' <- forM t \tv -> lookupVarName EffKind tv return $ EffectRow effs' t' where lookupVarName :: Type -> Name -> UInferM Name @@ -540,7 +540,7 @@ checkCaseAlt reqTy scrutineeTy (UAlt pat body) = do (conIdx, patTys) <- checkCasePat pat scrutineeTy let (subPats, subPatTys) = unzip patTys let bs = zipWith (\p ty -> Bind $ patNameHint p :> ty) subPats subPatTys - alt <- buildNAbs (toNest bs) $ \xs -> + alt <- buildNAbs (toNest bs) \xs -> withBindPats (zip subPats xs) $ checkRho body reqTy return (conIdx, alt) @@ -658,7 +658,7 @@ bindPat' (WithSrc pos pat) val = addSrcContext pos $ case pat of throw TypeErr $ "Incorrect length of table pattern: table index set has " <> pprint (length idxs) <> " elements but there are " <> pprint (length ps) <> " patterns." - flip foldMapM (zip ps idxs) $ \(p, i) -> do + flip foldMapM (zip ps idxs) \(p, i) -> do v <- lift $ emitZonked $ App val i bindPat' p v @@ -883,7 +883,7 @@ runSolverT m = liftM fst $ flip runCatT mempty $ do applyDefaults :: MonadCat SolverEnv m => m () applyDefaults = do vs <- looks unsolved - forM_ (envPairs vs) $ \(v, k) -> case k of + forM_ (envPairs vs) \(v, k) -> case k of EffKind -> addSub v $ Eff Pure _ -> return () where addSub v ty = extend $ SolverEnv mempty (v@>ty) @@ -907,8 +907,8 @@ checkLeaks tvs m = do unless (null $ resultTypeLeaks) $ throw TypeErr $ "Leaked local variable `" ++ pprint (head resultTypeLeaks) ++ "` in result type " ++ pprint (getType ans) - forM_ (solverSub env) $ \ty -> - forM_ tvs $ \tv -> + forM_ (solverSub env) \ty -> + forM_ tvs \tv -> throwIf (tv `occursIn` ty) TypeErr $ "Leaked type variable: " ++ pprint tv extend env return ans diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index a9d374269..d26ad3a4a 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -94,11 +94,11 @@ compileFunction logger fun@(ImpFunction f bs body) = case cc of (argPtrParam , argPtrOperand ) <- freshParamOpPair attrs $ hostPtrTy i64 (resultPtrParam, resultPtrOperand) <- freshParamOpPair attrs $ hostPtrTy i64 initializeOutputStream streamFDOperand - argOperands <- forM (zip [0..] argTys) $ \(i, ty) -> + argOperands <- forM (zip [0..] argTys) \(i, ty) -> gep argPtrOperand (i64Lit i) >>= castLPtr (scalarTy ty) >>= load when (toBool requiresCUDA) ensureHasCUDAContext results <- extendOperands (newEnv bs argOperands) $ compileBlock body - forM_ (zip [0..] results) $ \(i, x) -> + forM_ (zip [0..] results) \(i, x) -> gep resultPtrOperand (i64Lit i) >>= castLPtr (L.typeOf x) >>= flip store x mainFun <- makeFunction (asLLVMName name) [streamFDParam, argPtrParam, resultPtrParam] (Just $ i64Lit 0) @@ -607,7 +607,7 @@ compileExpr expr = case expr of packArgs :: [Operand] -> Compile Operand packArgs elems = do arr <- alloca (length elems) hostVoidp - forM_ (zip [0..] elems) $ \(i, e) -> do + forM_ (zip [0..] elems) \(i, e) -> do eptr <- alloca 1 $ L.typeOf e store eptr e earr <- gep arr $ i32Lit i @@ -616,7 +616,7 @@ packArgs elems = do unpackArgs :: Operand -> [L.Type] -> Compile [Operand] unpackArgs argArrayPtr types = - forM (zip [0..] types) $ \(i, ty) -> do + forM (zip [0..] types) \(i, ty) -> do argVoidPtr <- gep argArrayPtr $ i64Lit i argPtr <- castLPtr (hostPtrTy ty) argVoidPtr load =<< load argPtr @@ -624,7 +624,7 @@ unpackArgs argArrayPtr types = makeMultiResultAlloc :: [L.Type] -> Compile Operand makeMultiResultAlloc tys = do resultsPtr <- alloca (length tys) hostVoidp - forM_ (zip [0..] tys) $ \(i, ty) -> do + forM_ (zip [0..] tys) \(i, ty) -> do ptr <- alloca 1 ty >>= castVoidPtr resultsPtrOffset <- gep resultsPtr $ i32Lit i store resultsPtrOffset ptr @@ -632,7 +632,7 @@ makeMultiResultAlloc tys = do loadMultiResultAlloc :: [L.Type] -> Operand -> Compile [Operand] loadMultiResultAlloc tys ptr = - forM (zip [0..] tys) $ \(i, ty) -> + forM (zip [0..] tys) \(i, ty) -> gep ptr (i32Lit i) >>= load >>= castLPtr ty >>= load runMCKernel :: ExternFunSpec @@ -894,7 +894,7 @@ runCompile dev m = evalState (runReaderT m env) initState initState = CompileState [] [] [] "start_block" mempty mempty mempty extendOperands :: OperandEnv -> Compile a -> Compile a -extendOperands openv = local $ \env -> env { operandEnv = (operandEnv env) <> openv } +extendOperands openv = local \env -> env { operandEnv = (operandEnv env) <> openv } lookupImpVar :: IVar -> Compile Operand lookupImpVar v = asks ((! v) . operandEnv) @@ -912,7 +912,7 @@ freshName :: Name -> Compile L.Name freshName v = do used <- gets usedNames let v' = genFresh v used - modify $ \s -> s { usedNames = used <> v' @> () } + modify \s -> s { usedNames = used <> v' @> () } return $ nameToLName v' where nameToLName :: Name -> L.Name diff --git a/src/lib/LLVM/JIT.hs b/src/lib/LLVM/JIT.hs index e10228a4c..c73b396be 100644 --- a/src/lib/LLVM/JIT.hs +++ b/src/lib/LLVM/JIT.hs @@ -88,14 +88,14 @@ compileModule moduleJIT@JIT{..} ast compilationPipeline = do resolver <- newSymbolResolver execSession (makeResolver compileLayer) modifyIORef resolvers (M.insert moduleKey resolver) OrcJIT.addModule compileLayer moduleKey llvmModule - moduleDtors <- forM dtorNames $ \dtorName -> do + moduleDtors <- forM dtorNames \dtorName -> do dtorSymbol <- OrcJIT.mangleSymbol compileLayer (fromString dtorName) Right (OrcJIT.JITSymbol dtorAddr _) <- OrcJIT.findSymbol compileLayer dtorSymbol False return $ castPtrToFunPtr $ wordPtrToPtr dtorAddr return NativeModule{..} where makeResolver :: OrcJIT.IRCompileLayer OrcJIT.ObjectLinkingLayer -> OrcJIT.SymbolResolver - makeResolver cl = OrcJIT.SymbolResolver $ \sym -> do + makeResolver cl = OrcJIT.SymbolResolver \sym -> do rsym <- OrcJIT.findSymbol cl sym False -- We look up functions like malloc in the current process -- TODO: Use JITDylibs to avoid inlining addresses as constants: @@ -116,7 +116,7 @@ compileModule moduleJIT@JIT{..} ast compilationPipeline = do -- Unfortunately the JIT layers we use here don't handle the destructors properly, -- so we have to find and call them ourselves. dtorNames = do - let dtorStructs = flip foldMap (LLVM.AST.moduleDefinitions ast) $ \case + let dtorStructs = flip foldMap (LLVM.AST.moduleDefinitions ast) \case LLVM.AST.GlobalDefinition LLVM.AST.GlobalVariable{ name="llvm.global_dtors", diff --git a/src/lib/LLVM/Shims.hs b/src/lib/LLVM/Shims.hs index 860b5540a..e509ac12d 100644 --- a/src/lib/LLVM/Shims.hs +++ b/src/lib/LLVM/Shims.hs @@ -35,7 +35,7 @@ data SymbolResolver = SymbolResolver (FunPtr FFIResolver) (Ptr OrcJIT.FFI.Symbol -- | Create a `FFI.SymbolResolver` that can be used with the JIT. newSymbolResolver :: OrcJIT.ExecutionSession -> OrcJIT.SymbolResolver -> IO SymbolResolver newSymbolResolver (OrcJIT.ExecutionSession session) (OrcJIT.SymbolResolver resolverFn) = do - ffiResolverPtr <- wrapFFIResolver $ \sym res -> do + ffiResolverPtr <- wrapFFIResolver \sym res -> do f <- encodeM =<< resolverFn =<< decodeM sym f res lambdaResolver <- OrcJIT.FFI.createLambdaResolver session ffiResolverPtr @@ -60,10 +60,10 @@ newTargetMachine :: Target.Target newTargetMachine (Target.Target targetFFI) triple cpu features (Target.TargetOptions targetOptFFI) relocModel codeModel cgoLevel = do - SBS.useAsCString triple $ \tripleFFI -> do - BS.useAsCString cpu $ \cpuFFI -> do + SBS.useAsCString triple \tripleFFI -> do + BS.useAsCString cpu \cpuFFI -> do let featuresStr = BS.intercalate "," $ fmap encodeFeature $ M.toList features - BS.useAsCString featuresStr $ \featuresFFI -> do + BS.useAsCString featuresStr \featuresFFI -> do relocModelFFI <- encodeM relocModel codeModelFFI <- encodeM codeModel cgoLevelFFI <- encodeM cgoLevel @@ -79,7 +79,7 @@ newHostTargetMachine relocModel codeModel cgoLevel = do (target, _) <- Target.lookupTarget Nothing triple cpu <- Target.getHostCPUName features <- Target.getHostCPUFeatures - Target.withTargetOptions $ \targetOptions -> + Target.withTargetOptions \targetOptions -> newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel disposeTargetMachine :: Target.TargetMachine -> IO () diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index b2957cf5b..f4435dd19 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -72,9 +72,9 @@ type DexExitCode = Int compileAndEval :: Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] compileAndEval logger ast fname args resultTypes = do - withPipeToLogger logger $ \fd -> - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do + withPipeToLogger logger \fd -> + allocaBytes (length args * cellSize) \argsPtr -> + allocaBytes (length resultTypes * cellSize) \resultPtr -> do storeLitVals argsPtr args evalTime <- compileOneOff logger ast fname $ checkedCallFunPtr fd argsPtr resultPtr @@ -84,11 +84,11 @@ compileAndEval logger ast fname args resultTypes = do compileAndBench :: Bool -> Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do - withPipeToLogger logger $ \fd -> - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do + withPipeToLogger logger \fd -> + allocaBytes (length args * cellSize) \argsPtr -> + allocaBytes (length resultTypes * cellSize) \resultPtr -> do storeLitVals argsPtr args - compileOneOff logger ast fname $ \fPtr -> do + compileOneOff logger ast fname \fPtr -> do ((avgTime, benchRuns, results), totalTime) <- measureSeconds $ do -- First warmup iteration, which we also use to get the results void $ checkedCallFunPtr fd argsPtr resultPtr fPtr @@ -112,7 +112,7 @@ compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do withPipeToLogger :: Logger [Output] -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do result <- snd <$> withPipe - (\h -> readStream h $ \s -> logThis logger [TextOut s]) + (\h -> readStream h \s -> logThis logger [TextOut s]) (\h -> handleToFd h >>= writeAction) case result of Left e -> E.throw e @@ -129,9 +129,9 @@ checkedCallFunPtr fd argsPtr resultPtr fPtr = do compileOneOff :: Logger [Output] -> L.Module -> String -> (DexExecutable -> IO a) -> IO a compileOneOff logger ast name f = do - withHostTargetMachine $ \tm -> - withJIT tm $ \jit -> - withNativeModule jit ast (standardCompilationPipeline logger [name] tm) $ \compiled -> + withHostTargetMachine \tm -> + withJIT tm \jit -> + withNativeModule jit ast (standardCompilationPipeline logger [name] tm) \compiled -> f =<< getFunctionPtr compiled name standardCompilationPipeline :: Logger [Output] -> [String] -> T.TargetMachine -> Mod.Module -> IO () @@ -151,12 +151,12 @@ standardCompilationPipeline logger exports tm m = do -- Each module comes with a list of exported functions exportObjectFile :: FilePath -> [(L.Module, [String])] -> IO () exportObjectFile objFile modules = do - withContext $ \c -> do - withHostTargetMachine $ \tm -> - withBrackets (fmap (toLLVM c) modules) $ \mods -> do - Mod.withModuleFromAST c L.defaultModule $ \exportMod -> do + withContext \c -> do + withHostTargetMachine \tm -> + withBrackets (fmap (toLLVM c) modules) \mods -> do + Mod.withModuleFromAST c L.defaultModule \exportMod -> do void $ foldM linkModules exportMod mods - execLogger Nothing $ \logger -> + execLogger Nothing \logger -> standardCompilationPipeline logger allExports tm exportMod Mod.writeObjectToFile tm (Mod.File objFile) exportMod where @@ -164,14 +164,14 @@ exportObjectFile objFile modules = do toLLVM :: Context -> (L.Module, [String]) -> (Mod.Module -> IO a) -> IO a toLLVM c (ast, exports) cont = do - Mod.withModuleFromAST c ast $ \m -> internalize exports m >> cont m + Mod.withModuleFromAST c ast \m -> internalize exports m >> cont m linkModules a b = a <$ Mod.linkModules a b withBrackets :: [(a -> IO b) -> IO b] -> ([a] -> IO b) -> IO b withBrackets brackets f = go brackets [] where - go (h:t) args = h $ \arg -> go t (arg:args) + go (h:t) args = h \arg -> go t (arg:args) go [] args = f args @@ -179,12 +179,12 @@ exportObjectFile objFile modules = do runDefaultPasses :: T.TargetMachine -> Mod.Module -> IO () runDefaultPasses t m = do - P.withPassManager defaultPasses $ \pm -> void $ P.runPassManager pm m + P.withPassManager defaultPasses \pm -> void $ P.runPassManager pm m -- We are highly dependent on LLVM when it comes to some optimizations such as -- turning a sequence of scalar stores into a vector store, so we execute some -- extra passes to make sure they get simplified correctly. runPasses extraPasses (Just t) m - P.withPassManager defaultPasses $ \pm -> void $ P.runPassManager pm m + P.withPassManager defaultPasses \pm -> void $ P.runPassManager pm m where defaultPasses = P.defaultCuratedPassSetSpec {P.optLevel = Just 3} extraPasses = [ P.SuperwordLevelParallelismVectorize @@ -196,7 +196,7 @@ runPasses passes mt m = do Just t -> Just <$> T.getTargetMachineDataLayout t Nothing -> return Nothing let passSpec = P.PassSetSpec passes dl Nothing mt - P.withPassManager passSpec $ \pm -> void $ P.runPassManager pm m + P.withPassManager passSpec \pm -> void $ P.runPassManager pm m internalize :: [String] -> Mod.Module -> IO () internalize names m = runPasses [P.InternalizeFunctions names, P.GlobalDeadCodeElimination] Nothing m @@ -219,7 +219,7 @@ withHostTargetMachine f = withGPUTargetMachine :: B.ByteString -> (T.TargetMachine -> IO a) -> IO a withGPUTargetMachine computeCapability next = do (tripleTarget, _) <- T.lookupTarget Nothing ptxTargetTriple - T.withTargetOptions $ \topt -> + T.withTargetOptions \topt -> T.withTargetMachine tripleTarget ptxTargetTriple @@ -241,8 +241,8 @@ showAsm :: T.TargetMachine -> Mod.Module -> IO String showAsm t m' = do ctx <- Mod.moduleContext m' -- Uncomment this to dump assembly to a file that can be linked to a C benchmark suite: - -- withModuleClone ctx m' $ \m -> Mod.writeObjectToFile t (Mod.File "asm.o") m - withModuleClone ctx m' $ \m -> unpack <$> Mod.moduleTargetAssembly t m + -- withModuleClone ctx m' \m -> Mod.writeObjectToFile t (Mod.File "asm.o") m + withModuleClone ctx m' \m -> unpack <$> Mod.moduleTargetAssembly t m withModuleClone :: Context -> Mod.Module -> (Mod.Module -> IO a) -> IO a withModuleClone ctx m f = do @@ -291,8 +291,8 @@ ptrArray p = map (\i -> p `plusPtr` (i * cellSize)) [0..] {-# NOINLINE dexrtAST #-} dexrtAST :: L.Module dexrtAST = unsafePerformIO $ do - withContext $ \ctx -> do - Mod.withModuleFromBitcode ctx (("dexrt.c" :: String), dexrtBC) $ \m -> + withContext \ctx -> do + Mod.withModuleFromBitcode ctx (("dexrt.c" :: String), dexrtBC) \m -> stripFunctionAnnotations <$> Mod.moduleAST m where -- We strip the function annotations for dexrt functions, because clang @@ -313,7 +313,7 @@ linkDexrt m = do targetTriple <- Mod.getTargetTriple =<< Mod.readModule m let dexrtTargetAST = dexrtAST { L.moduleDataLayout = dataLayout , L.moduleTargetTriple = targetTriple } - Mod.withModuleFromAST ctx dexrtTargetAST $ \dexrtm -> do + Mod.withModuleFromAST ctx dexrtTargetAST \dexrtm -> do Mod.linkModules m dexrtm runPasses [P.AlwaysInline True] Nothing m @@ -325,21 +325,21 @@ data LLVMKernel = LLVMKernel L.Module compileCUDAKernel :: Logger [Output] -> LLVMKernel -> IO CUDAKernel compileCUDAKernel logger (LLVMKernel ast) = do T.initializeAllTargets - withContext $ \ctx -> - Mod.withModuleFromAST ctx ast $ \m -> do - withGPUTargetMachine (pack arch) $ \tm -> do + withContext \ctx -> + Mod.withModuleFromAST ctx ast \m -> do + withGPUTargetMachine (pack arch) \tm -> do linkLibdevice m standardCompilationPipeline logger ["kernel"] tm m ptx <- Mod.moduleTargetAssembly tm m usePTXAS <- maybe False (=="1") <$> lookupEnv "DEX_USE_PTXAS" if usePTXAS then do - withSystemTempFile "kernel.ptx" $ \ptxPath ptxH -> do + withSystemTempFile "kernel.ptx" \ptxPath ptxH -> do B.hPut ptxH ptx hClose ptxH - withSystemTempFile "kernel.sass" $ \sassPath sassH -> do + withSystemTempFile "kernel.sass" \sassPath sassH -> do let cmd = proc ptxasPath [ptxPath, "-o", sassPath, "-arch=" ++ arch, "-O3"] - withCreateProcess cmd $ \_ _ _ ptxas -> do + withCreateProcess cmd \_ _ _ ptxas -> do code <- waitForProcess ptxas case code of ExitSuccess -> return () @@ -354,7 +354,7 @@ compileCUDAKernel logger (LLVMKernel ast) = do {-# NOINLINE libdevice #-} libdevice :: L.Module libdevice = unsafePerformIO $ do - withContext $ \ctx -> do + withContext \ctx -> do let libdeviceDirectory = "/usr/local/cuda/nvvm/libdevice" [libdeviceFileName] <- listDirectory libdeviceDirectory let libdevicePath = libdeviceDirectory ++ "/" ++ libdeviceFileName @@ -367,8 +367,8 @@ libdevice = unsafePerformIO $ do linkLibdevice :: Mod.Module -> IO () linkLibdevice m = do ctx <- Mod.moduleContext m - Mod.withModuleFromAST ctx zeroNVVMReflect $ \reflectm -> - Mod.withModuleFromAST ctx libdevice $ \ldm -> do + Mod.withModuleFromAST ctx zeroNVVMReflect \reflectm -> + Mod.withModuleFromAST ctx libdevice \ldm -> do Mod.linkModules m ldm Mod.linkModules m reflectm runPasses [P.AlwaysInline True] Nothing m diff --git a/src/lib/Logging.hs b/src/lib/Logging.hs index 1ee82ccdc..37d40fd8a 100644 --- a/src/lib/Logging.hs +++ b/src/lib/Logging.hs @@ -20,7 +20,7 @@ data Logger l = Logger (MVar l) (Maybe Handle) runLogger :: (Monoid l, MonadIO m) => Maybe FilePath -> (Logger l -> m a) -> m (a, l) runLogger maybePath m = do log <- liftIO $ newMVar mempty - logFile <- liftIO $ forM maybePath $ \path -> openFile path WriteMode + logFile <- liftIO $ forM maybePath \path -> openFile path WriteMode ans <- m $ Logger log logFile logged <- liftIO $ readMVar log return (ans, logged) @@ -30,10 +30,10 @@ execLogger maybePath m = fst <$> runLogger maybePath m logThis :: (Pretty l, Monoid l, MonadIO m) => Logger l -> l -> m () logThis (Logger log maybeLogHandle) x = liftIO $ do - forM_ maybeLogHandle $ \h -> do + forM_ maybeLogHandle \h -> do hPutStrLn h $ pprint x hFlush h - modifyMVar_ log $ \cur -> return (cur <> x) + modifyMVar_ log \cur -> return (cur <> x) readLog :: MonadIO m => Logger l -> m l readLog (Logger log _) = liftIO $ readMVar log diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index 86aef81d3..e11842020 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -79,9 +79,9 @@ parallelTraverseExpr expr = case expr of False -> nothingSpecial Hof (RunWriter (BinaryFunVal h b _ body)) -> do ~(RefTy _ accTy) <- traverseAtom substTraversalDef $ binderType b - liftM Atom $ emitRunWriter (binderNameHint b) accTy $ \ref@(Var refVar) -> do + liftM Atom $ emitRunWriter (binderNameHint b) accTy \ref@(Var refVar) -> do let RefTy h' _ = varType refVar - modify $ \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } extendR (h @> h' <> b @> ref) $ evalBlockE parallelTrav body -- TODO: Do some alias analysis. This is not fundamentally hard, but it is a little annoying. -- We would have to track not only the base references, but also all the aliases, along @@ -95,7 +95,7 @@ parallelTraverseExpr expr = case expr of where nothingSpecial = traverseExpr parallelTrav expr disallowRef ~(Var refVar) = - modify $ \accEnv -> accEnv { activeAccs = activeAccs accEnv `envDiff` (refVar @> ()) } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv `envDiff` (refVar @> ()) } parallelizableEffect :: Env () -> Effect -> Bool parallelizableEffect allowedRegions effect = case effect of @@ -203,7 +203,7 @@ emitLoops buildPureLoop (ABlock decls result) = do let buildBody pari = do is <- unpackConsList pari extendR (newEnv lbs is) $ do - ctxEnv <- flip traverseNames dapps $ \_ (arr, idx) -> + ctxEnv <- flip traverseNames dapps \_ (arr, idx) -> -- XXX: arr is namespaced in the new program foldM appTryReduce arr =<< substEmbedR idx extendR ctxEnv $ evalBlockE appReduceTraversalDef $ Block decls $ Atom result @@ -211,18 +211,18 @@ emitLoops buildPureLoop (ABlock decls result) = do True -> buildPureLoop (Bind $ "pari" :> iterTy) buildBody False -> do body <- do - buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow $ \gtid -> do - buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow $ \nthr -> do + buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow \gtid -> do + buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow \nthr -> do let threadRange = TC $ ParIndexRange iterTy gtid nthr let accTys = mkConsListTy $ fmap (derefType . varType) newRefs - emitRunWriter "refsList" accTys $ \localRefsList -> do + emitRunWriter "refsList" accTys \localRefsList -> do localRefs <- unpackRefConsList localRefsList - buildFor Fwd (Bind $ "tidx" :> threadRange) $ \tidx -> do + buildFor Fwd (Bind $ "tidx" :> threadRange) \tidx -> do pari <- emitOp $ Inject tidx extendR (newEnv oldRefNames localRefs) $ buildBody pari (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce iterTy body) updates <- unpackConsList updateList - forM_ (zip newRefs updates) $ \(ref, update) -> + forM_ (zip newRefs updates) \(ref, update) -> emitOp $ PrimEffect (Var ref) $ MTell update return ans where diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 62dc532f7..a167af975 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -383,8 +383,8 @@ funDefLet = label "function definition" $ mayBreak $ do let bs = map classAsBinder cs ++ argBinders let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) - let lamBinders = flip map bs $ \(p,_, arr) -> ((p,Nothing), arr) - return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) + let lamBinders = flip map bs \(p,_, arr) -> ((p,Nothing), arr) + return \body -> ULet PlainLet letBinder (buildLam lamBinders body) where classAsBinder :: UType -> (UPat, UType, UArrow) classAsBinder ty = (ns underscorePat, ty, ClassArrow) @@ -892,7 +892,7 @@ prefixNegOp :: Operator Parser UExpr prefixNegOp = Prefix $ label "negation" $ do ((), pos) <- withPos $ sym "-" let f = WithSrc (Just pos) "neg" - return $ \case + return \case -- Special case: negate literals directly WithSrc litpos (IntLitExpr i) -> WithSrc (joinPos (Just pos) litpos) (IntLitExpr (-i)) @@ -914,7 +914,7 @@ infixArrow :: Parser (UType -> UType -> UType) infixArrow = do notFollowedBy (sym "=>") -- table arrows have special fixity (arr, pos) <- withPos $ arrow effects - return $ \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b + return \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b mkArrow :: Arrow -> UExpr -> UExpr -> UExpr mkArrow arr a b = joinSrc a b $ UPi (Nothing, a) arr b @@ -959,7 +959,7 @@ inpostfix' :: Parser a -> Parser (a -> Maybe a -> a) -> Operator Parser a inpostfix' p op = Postfix $ do f <- op rest <- optional p - return $ \x -> f x rest + return \x -> f x rest mkName :: String -> Name mkName s = Name SourceName (fromString s) 0 diff --git a/src/lib/Serialize.hs b/src/lib/Serialize.hs index 11b552be6..602fdeed2 100644 --- a/src/lib/Serialize.hs +++ b/src/lib/Serialize.hs @@ -31,7 +31,7 @@ getDexString :: Val -> IO String getDexString (DataCon _ _ 0 [_, xs]) = do let (TabTy b _) = getType xs idxs <- indices $ getType b - forM idxs $ \i -> do + forM idxs \i -> do ~(Con (Lit (Word8Lit c))) <- evalBlock mempty (Block Empty (App xs i)) return $ toEnum $ fromIntegral c getDexString x = error $ "Not a string: " ++ pprint x @@ -49,7 +49,7 @@ prettyVal val = case val of _ -> "@" <> pretty idxSet -- Otherwise, show explicit index set -- Pretty-print elements. idxs <- indices idxSet - elems <- forM idxs $ \idx -> do + elems <- forM idxs \idx -> do atom <- evalBlock mempty $ snd $ applyAbs abs idx case atom of Con (Lit (Word8Lit c)) -> diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 9f3b256da..d1c03e4fd 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -61,7 +61,7 @@ hoistDepDataCons scope (Module Simp decls bindings) = where (bindings', (_, decls')) = flip runEmbed scope $ do mapM_ emitDecl decls - forM bindings $ \(ty, info) -> case info of + forM bindings \(ty, info) -> case info of LetBound ann x | isData ty -> do x' <- emit x return (ty, LetBound ann $ Atom x') _ -> return (ty, info) @@ -89,7 +89,7 @@ simplifyDecl (Let ann b expr) = do simplifyStandalone :: Expr -> SimplifyM Atom simplifyStandalone (Atom (LamVal b body)) = do b' <- mapM substEmbedR b - buildLam b' PureArrow $ \x -> + buildLam b' PureArrow \x -> extendR (b@>x) $ simplifyBlock body simplifyStandalone block = error $ "@noinline decorator applied to non-function" ++ pprint block @@ -139,9 +139,9 @@ simplifyAtom atom = case atom of case simplifyCase e' alts of Just (env, result) -> extendR env $ simplifyAtom result Nothing -> do - alts' <- forM alts $ \(Abs bs a) -> do + alts' <- forM alts \(Abs bs a) -> do bs' <- mapM (mapM substEmbedR) bs - (Abs bs'' b) <- buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ simplifyAtom a + (Abs bs'' b) <- buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyAtom a case b of Block Empty (Atom r) -> return $ Abs bs'' r _ -> error $ "Nontrivial block in ACase simplification" @@ -192,7 +192,7 @@ simplifyLams numArgs lam = do Left res -> (res, Nothing) Right (dat, (ctx, recon), atomf) -> ( mkConsList $ (toList dat) ++ (toList ctx) - , Just $ \vals -> do + , Just \vals -> do (datEls', ctxEls') <- splitAt (length dat) <$> unpackConsList vals let dat' = restructure datEls' dat let ctx' = restructure ctxEls' ctx @@ -200,7 +200,7 @@ simplifyLams numArgs lam = do ) go n scope ~(Block Empty (Atom (Lam (Abs b (arr, body))))) = do b' <- mapM substEmbedR b - buildLamAux b' (\x -> extendR (b@>x) $ substEmbedR arr) $ \x@(Var v) -> do + buildLamAux b' (\x -> extendR (b@>x) $ substEmbedR arr) \x@(Var v) -> do let scope' = scope <> v @> (varType v, LamBound (void arr)) extendR (b@>x) $ go (n-1) scope' body @@ -278,7 +278,7 @@ separateDataComponent localVars v = do True -> nubCtx t False -> h : (nubCtx t) result = nubCtx $ toList ll - inv ctx' result' = for ll $ \x -> case elemIndex x (toList ctx) of + inv ctx' result' = for ll \x -> case elemIndex x (toList ctx) of Just i -> (toList ctx') !! i Nothing -> result' !! (fromJust $ elemIndex x result) @@ -299,7 +299,7 @@ simplifyExpr expr = case expr of case all isCurriedFun alts of True -> return $ ACase e (fmap appAlt alts) rty' False -> do - let alts' = for alts $ \(Abs bs a) -> Abs bs $ Block Empty (App a x') + let alts' = for alts \(Abs bs a) -> Abs bs $ Block Empty (App a x') dropSub $ simplifyExpr $ Case e alts' rty' where isCurriedFun alt = case alt of @@ -321,16 +321,16 @@ simplifyExpr expr = case expr of Nothing -> do if isData resultTy' then do - alts' <- forM alts $ \(Abs bs body) -> do + alts' <- forM alts \(Abs bs body) -> do bs' <- mapM (mapM substEmbedR) bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ simplifyBlock body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyBlock body emit $ Case e' alts' resultTy' else do -- Construct the blocks of new cases. The results will only get replaced -- later, once we learn the closures of the non-data component of each case. - (alts', facs) <- liftM unzip $ forM alts $ \(Abs bs body) -> do + (alts', facs) <- liftM unzip $ forM alts \(Abs bs body) -> do bs' <- mapM (mapM substEmbedR) bs - buildNAbsAux bs' $ \xs -> do + buildNAbsAux bs' \xs -> do ~(Right fac@(dat, (ctx, _), _)) <- extendR (newEnv bs' xs) $ defunBlock (boundVars bs') body -- NB: The return value here doesn't really matter as we're going to replace it afterwards. return (mkConsList $ toList dat ++ toList ctx, fac) @@ -361,9 +361,9 @@ simplifyExpr expr = case expr of -- a single output. This can probably be made quite a bit faster. -- NB: All the non-data trees have the same structure, so we pick an arbitrary one. nondatTree <- (\(_, (ctx, rec), _) -> rec dat ctx) $ head facs - nondat <- forM (enumerate nondatTree) $ \(i, _) -> do - aalts <- forM facs $ \(_, (ctx, rec), _) -> do - Abs bs' b <- buildNAbs (toNest $ toList $ fmap (Ignore . getType) ctx) $ \ctxVals -> + nondat <- forM (enumerate nondatTree) \(i, _) -> do + aalts <- forM facs \(_, (ctx, rec), _) -> do + Abs bs' b <- buildNAbs (toNest $ toList $ fmap (Ignore . getType) ctx) \ctxVals -> ((!! i) . toList) <$> rec dat (restructure ctxVals ctx) case b of Block Empty (Atom r) -> return $ Abs bs' r @@ -441,7 +441,7 @@ simplifyHof hof = case hof of ans <- emit $ Hof $ For d lam' case recon of Nothing -> return ans - Just f -> buildLam i TabArrow $ \i' -> app ans i' >>= f + Just f -> buildLam i TabArrow \i' -> app ans i' >>= f Tile d fT fS -> do ~(fT', Nothing) <- simplifyLam fT ~(fS', Nothing) <- simplifyLam fS @@ -495,7 +495,7 @@ exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do JustAtom _ x -> extendR (b@>x) $ exceptToMaybeBlock $ Block decls result NothingAtom _ -> return $ NothingAtom a _ -> do - emitMaybeCase maybeResult (return $ NothingAtom a) $ \x -> do + emitMaybeCase maybeResult (return $ NothingAtom a) \x -> do extendR (b@>x) $ exceptToMaybeBlock $ Block decls result exceptToMaybeExpr :: Expr -> SubstEmbed Atom @@ -505,27 +505,27 @@ exceptToMaybeExpr expr = do Case e alts resultTy -> do e' <- substEmbedR e resultTy' <- substEmbedR $ MaybeTy resultTy - alts' <- forM alts $ \(Abs bs body) -> do + alts' <- forM alts \(Abs bs body) -> do bs' <- substEmbedR bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body emit $ Case e' alts' resultTy' Atom x -> substEmbedR $ JustAtom (getType x) x Op (ThrowException _) -> return $ NothingAtom a Hof (For ann ~(Lam (Abs b (_, body)))) -> do b' <- substEmbedR b - maybes <- buildForAnn ann b' $ \i -> extendR (b@>i) $ exceptToMaybeBlock body + maybes <- buildForAnn ann b' \i -> extendR (b@>i) $ exceptToMaybeBlock body catMaybesE maybes Hof (RunState s lam) -> do s' <- substEmbedR s let BinaryFunVal _ b _ body = lam - result <- emitRunState "ref" s' $ \ref -> + result <- emitRunState "ref" s' \ref -> extendR (b@>ref) $ exceptToMaybeBlock body (maybeAns, newState) <- fromPair result - emitMaybeCase maybeAns (return $ NothingAtom a) $ \ans -> + emitMaybeCase maybeAns (return $ NothingAtom a) \ans -> return $ JustAtom a $ PairVal ans newState Hof (While ~(Lam (Abs _ (_, body)))) -> do eff <- getAllowedEffects - lam <- buildLam (Ignore UnitTy) (PlainArrow eff) $ \_ -> + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) \_ -> exceptToMaybeBlock body runMaybeWhile lam _ | not (hasExceptions expr) -> do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 98a303617..0fc83b9a2 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -10,10 +10,8 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE Rank2Types #-} -{-# LANGUAGE LambdaCase #-} module Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), @@ -189,14 +187,14 @@ labeledSingleton label value = LabeledItems $ M.singleton label (value NE.:|[]) reflectLabels :: LabeledItems a -> LabeledItems (Label, Int) reflectLabels (LabeledItems items) = LabeledItems $ - flip M.mapWithKey items $ \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) + flip M.mapWithKey items \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) getLabels :: LabeledItems a -> [Label] getLabels labeledItems = map fst $ toList $ reflectLabels labeledItems withLabels :: LabeledItems a -> LabeledItems (Label, Int, a) withLabels (LabeledItems items) = LabeledItems $ - flip M.mapWithKey items $ \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) + flip M.mapWithKey items \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) lookupLabel :: LabeledItems a -> Label -> Maybe a lookupLabel (LabeledItems items) l = case M.lookup l items of @@ -684,10 +682,10 @@ throwIf True e s = throw e s throwIf False _ _ = return () modifyErr :: MonadError e m => m a -> (e -> e) -> m a -modifyErr m f = catchError m $ \e -> throwError (f e) +modifyErr m f = catchError m \e -> throwError (f e) addContext :: MonadError Err m => String -> m a -> m a -addContext s m = modifyErr m $ \(Err e p s') -> Err e p (s' ++ "\n" ++ s) +addContext s m = modifyErr m \(Err e p s') -> Err e p (s' ++ "\n" ++ s) addSrcContext :: MonadError Err m => SrcCtx -> m a -> m a addSrcContext ctx m = modifyErr m updateErr @@ -698,9 +696,9 @@ addSrcContext ctx m = modifyErr m updateErr catchIOExcept :: (MonadIO m , MonadError Err m) => IO a -> m a catchIOExcept m = (liftIO >=> liftEither) $ (liftM Right m) `catches` - [ Handler $ \(e::Err) -> return $ Left e - , Handler $ \(e::IOError) -> return $ Left $ Err DataIOErr Nothing $ show e - , Handler $ \(e::SomeException) -> return $ Left $ Err CompilerErr Nothing $ show e + [ Handler \(e::Err) -> return $ Left e + , Handler \(e::IOError) -> return $ Left $ Err DataIOErr Nothing $ show e + , Handler \(e::SomeException) -> return $ Left $ Err CompilerErr Nothing $ show e ] liftEitherIO :: (Exception e, MonadIO m) => Either e a -> m a diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index d7269312d..9399ecfdf 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -78,7 +78,7 @@ evalSourceBlock opts env block = do Right env' -> return (env' , Result outs' (Right ())) runTopPassM :: Bool -> EvalConfig -> TopPassM a -> IO (Except a, [Output]) -runTopPassM bench opts m = runLogger (logFile opts) $ \logger -> +runTopPassM bench opts m = runLogger (logFile opts) \logger -> runExceptT $ catchIOExcept $ runReaderT m $ TopPassEnv logger bench opts evalSourceBlockM :: TopEnv -> SourceBlock -> TopPassM TopEnv @@ -97,7 +97,7 @@ evalSourceBlockM env block = case sbContents block of logTop $ HtmlOut s ExportFun name -> do f <- evalUModuleVal env v m - void $ traverseLiterals f $ \val -> case val of + void $ traverseLiterals f \val -> case val of PtrLit _ _ -> liftEitherIO $ throw CompilerErr $ "Can't export functions with captured pointers (not implemented)." _ -> return $ Con $ Lit val @@ -119,7 +119,7 @@ processLogs :: LogLevel -> [Output] -> [Output] processLogs logLevel logs = case logLevel of LogAll -> logs LogNothing -> [] - LogPasses passes -> flip filter logs $ \l -> case l of + LogPasses passes -> flip filter logs \l -> case l of PassInfo pass _ | pass `elem` passes -> True | otherwise -> False _ -> False @@ -249,7 +249,7 @@ logTop x = do abstractPtrLiterals :: Block -> ([IBinder], [LitVal], Block) abstractPtrLiterals block = flip evalState mempty $ do - block' <- traverseLiterals block $ \val -> case val of + block' <- traverseLiterals block \val -> case val of PtrLit ty ptr -> do ptrName <- gets $ M.lookup (ty, ptr) . fst case ptrName of diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 39af39287..0e700202d 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -74,7 +74,7 @@ instance Checkable Module where checkValid m@(Module ir decls bindings) = addContext ("Checking module:\n" ++ pprint m) $ asCompilerErr $ do let env = freeVars m - forM_ (envNames env) $ \v -> when (not $ isGlobal $ v:>()) $ + forM_ (envNames env) \v -> when (not $ isGlobal $ v:>()) $ throw CompilerErr $ "Non-global free variable in module: " ++ pprint v addContext "Checking IR variant" $ checkModuleVariant m addContext "Checking body types" $ do @@ -152,7 +152,7 @@ instance HasType Atom where ACase e alts resultTy -> checkCase e alts resultTy DataConRef ~def@(DataDef _ paramBs [DataConDef _ argBs]) params args -> do checkEq (length paramBs) (length params) - forM_ (zip (toList paramBs) (toList params)) $ \(b, param) -> + forM_ (zip (toList paramBs) (toList params)) \(b, param) -> param |: binderAnn b let argBs' = applyNaryAbs (Abs paramBs argBs) params checkDataConRefBindings argBs' args @@ -203,7 +203,7 @@ typeCheckVar v@(name:>annTy) = do annTy |: TyKind when (annTy == EffKind) $ throw CompilerErr "Effect variables should only occur in effect rows" - checkWithEnv $ \(env, _) -> case envLookup env v of + checkWithEnv \(env, _) -> case envLookup env v of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq annTy ty $ "Annotation on var: " ++ pprint name return annTy @@ -227,19 +227,19 @@ instance HasType Expr where checkCase :: HasType b => Atom -> [AltP b] -> Type -> TypeM Type checkCase e alts resultTy = do - checkWithEnv $ \_ -> do + checkWithEnv \_ -> do ety <- typeCheck e case ety of TypeCon def params -> do let cons = applyDataDefParams def params checkEq (length cons) (length alts) - forM_ (zip cons alts) $ \((DataConDef _ bs'), (Abs bs body)) -> do + forM_ (zip cons alts) \((DataConDef _ bs'), (Abs bs body)) -> do checkEq bs' bs resultTy' <- flip (foldr withBinder) bs $ typeCheck body checkEq resultTy resultTy' VariantTy (NoExt types) -> do checkEq (length types) (length alts) - forM_ (zip (toList types) alts) $ \(ty, (Abs bs body)) -> do + forM_ (zip (toList types) alts) \(ty, (Abs bs body)) -> do [b] <- pure $ toList bs checkEq (getType b) ty resultTy' <- flip (foldr withBinder) bs $ typeCheck body @@ -319,7 +319,7 @@ instance HasType Block where instance HasType Binder where typeCheck b = do - checkWithEnv $ \(env, _) -> checkNoShadow env b + checkWithEnv \(env, _) -> checkNoShadow env b let ty = binderType b ty |: TyKind return ty @@ -344,7 +344,7 @@ infixr 7 |: checkEq reqTy ty checkEq :: (Show a, Pretty a, Eq a) => a -> a -> TypeM () -checkEq reqTy ty = checkWithEnv $ \_ -> assertEq reqTy ty "" +checkEq reqTy ty = checkWithEnv \_ -> assertEq reqTy ty "" withBinder :: Binder -> TypeM a -> TypeM a withBinder b m = typeCheck b >> extendTypeEnv (boundVars b) m @@ -407,7 +407,7 @@ instance CoreVariant Expr where Hof e -> checkVariant e >> forM_ e checkVariant Case e alts _ -> do checkVariant e - forM_ alts $ \(Abs _ body) -> checkVariant body + forM_ alts \(Abs _ body) -> checkVariant body instance CoreVariant Decl where -- let annotation restrictions? @@ -470,7 +470,7 @@ goneBy ir = do when (curIR >= ir) $ throw IRVariantErr $ "shouldn't appear after " ++ show ir addExpr :: (Pretty e, MonadError Err m) => e -> m a -> m a -addExpr x m = modifyErr m $ \e -> case e of +addExpr x m = modifyErr m \e -> case e of Err IRVariantErr ctx s -> Err CompilerErr ctx (s ++ ": " ++ pprint x) _ -> e @@ -478,11 +478,11 @@ addExpr x m = modifyErr m $ \e -> case e of checkEffRow :: EffectRow -> TypeM () checkEffRow (EffectRow effs effTail) = do - forM_ effs $ \eff -> case eff of + forM_ effs \eff -> case eff of RWSEffect _ v -> Var (v:>TyKind) |: TyKind ExceptionEffect -> return () - forM_ effTail $ \v -> do - checkWithEnv $ \(env, _) -> case envLookup env (v:>()) of + forM_ effTail \v -> do + checkWithEnv \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq EffKind ty "Effect var" @@ -490,7 +490,7 @@ declareEff :: Effect -> TypeM () declareEff eff = declareEffs $ oneEffect eff declareEffs :: EffectRow -> TypeM () -declareEffs effs = checkWithEnv $ \(_, allowedEffects) -> +declareEffs effs = checkWithEnv \(_, allowedEffects) -> checkExtends allowedEffects effs checkExtends :: MonadError Err m => EffectRow -> EffectRow -> m () @@ -499,7 +499,7 @@ checkExtends allowed (EffectRow effs effTail) = do case effTail of Just _ -> assertEq allowedEffTail effTail "" Nothing -> return () - forM_ effs $ \eff -> unless (eff `elem` allowedEffs) $ + forM_ effs \eff -> unless (eff `elem` allowedEffs) $ throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ "\nAllowed: " ++ pprint allowed @@ -517,8 +517,8 @@ ioEffect = RWSEffect State theWorld checkLabeledRow :: ExtLabeledItems Type Name -> TypeM () checkLabeledRow (Ext items rest) = do mapM_ (|: TyKind) items - forM_ rest $ \v -> do - checkWithEnv $ \(env, _) -> case envLookup env (v:>()) of + forM_ rest \v -> do + checkWithEnv \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq LabeledRowKind ty "Labeled row var" @@ -528,7 +528,7 @@ labeledRowDifference :: ExtLabeledItems Type Name labeledRowDifference (Ext (LabeledItems items) rest) (Ext (LabeledItems subitems) subrest) = do -- Check types in the right. - _ <- flip M.traverseWithKey subitems $ \label subtypes -> + _ <- flip M.traverseWithKey subitems \label subtypes -> case M.lookup label items of Just types -> assertEq subtypes (NE.fromList $ NE.take (length subtypes) types) $ @@ -556,7 +556,7 @@ checkWithEnv check = do CheckWith env -> check env updateTypeEnv :: (TypeEnv -> TypeEnv) -> TypeM a -> TypeM a -updateTypeEnv f m = flip local m $ fmap $ \(env, eff) -> (f env, eff) +updateTypeEnv f m = flip local m $ fmap \(env, eff) -> (f env, eff) extendTypeEnv :: TypeEnv -> TypeM a -> TypeM a extendTypeEnv new m = updateTypeEnv (<> new) m @@ -568,7 +568,7 @@ extendAllowedEffect :: Effect -> TypeM () -> TypeM () extendAllowedEffect eff m = updateAllowedEff (extendEffect eff) m updateAllowedEff :: (EffectRow -> EffectRow) -> TypeM a -> TypeM a -updateAllowedEff f m = flip local m $ fmap $ \(env, eff) -> (env, f eff) +updateAllowedEff f m = flip local m $ fmap \(env, eff) -> (env, f eff) withAllowedEff :: EffectRow -> TypeM a -> TypeM a withAllowedEff eff m = updateAllowedEff (const eff) m @@ -687,7 +687,7 @@ typeCheckOp op = case op of ToOrdinal i -> typeCheck i $> IdxRepTy IdxSetSize i -> typeCheck i $> IdxRepTy FFICall _ ansTy args -> do - forM_ args $ \arg -> do + forM_ args \arg -> do argTy <- typeCheck arg case argTy of BaseTy _ -> return () @@ -815,7 +815,7 @@ typeCheckOp op = case op of t |: TyKind x |: Word8Ty (TypeCon (DataDef _ _ dataConDefs) _) <- return t - forM_ dataConDefs $ \(DataConDef _ binders) -> + forM_ dataConDefs \(DataConDef _ binders) -> assertEq binders Empty "Not an enum" return t From e6855c711e6e7a4d3953a6c8c1b4a71d75206379 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 4 Jan 2021 15:06:12 -0500 Subject: [PATCH 4/6] Infer types of implicit implicit arguments. I guess that makes them implicitly typed implicit implicit arguments. --- examples/mcmc.dx | 3 +- examples/particle-swarm-optimizer.dx | 1 - lib/prelude.dx | 19 ++++------- src/lib/Inference.hs | 21 ++++++------ src/lib/PPrint.hs | 8 +---- src/lib/Parser.hs | 50 ++++++++++++---------------- src/lib/Syntax.hs | 10 +++--- tests/type-tests.dx | 6 ++-- 8 files changed, 49 insertions(+), 69 deletions(-) diff --git a/examples/mcmc.dx b/examples/mcmc.dx index a3bcbd314..d205cdbd7 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -28,8 +28,7 @@ def propose accept = logDensity proposal > (logDensity cur + log (rand k)) select accept proposal cur -def meanAndCovariance (n:Type) ?-> (d:Type) ?-> - (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) = +def meanAndCovariance (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) = xsMean : d=>Float = (for i. sum for j. xs.j.i) / IToF (size n) xsCov : d=>d=>Float = (for i i'. sum for j. (xs.j.i' - xsMean.i') * diff --git a/examples/particle-swarm-optimizer.dx b/examples/particle-swarm-optimizer.dx index 21b0cab5d..58227779e 100644 --- a/examples/particle-swarm-optimizer.dx +++ b/examples/particle-swarm-optimizer.dx @@ -57,7 +57,6 @@ We have **arguments**: ' **Returns**: the optimal point found with-in the bounds on the input domain of `f`. def optimize - (d:Type) ?-> (np':Int) -- number of particles (niter:Int) -- number of iterations (key:Key) -- random seed diff --git a/lib/prelude.dx b/lib/prelude.dx index 1831ef83e..4918585c0 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -233,7 +233,6 @@ def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref def runReader - (eff:Effects) ?-> (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = @@ -241,27 +240,23 @@ def runReader %runReader init explicitAction def withReader - (eff:Effects) ?-> (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = runReader init action def runAccum - (eff:Effects) ?-> (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref %runWriter explicitAction def yieldAccum - (eff:Effects) ?-> (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) : {|eff} w = snd $ runAccum action def runState - (eff:Effects) ?-> (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} (a & s) = @@ -269,13 +264,11 @@ def runState %runState init explicitAction def withState - (eff:Effects) ?-> (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} a = fst $ runState init action def yieldState - (eff:Effects) ?-> (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} s = snd $ runState init action @@ -449,10 +442,10 @@ def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i def iota (n:Type) : n=>Int = view i. ordinal i -- TODO: we want Eq and Ord for all index sets, not just `Fin n` -instance (n:Int) ?-> Eq (Fin n) +instance Eq (Fin n) (==) = \x y. ordinal x == ordinal y -instance (n:Int) ?-> Ord (Fin n) +instance Ord (Fin n) (>) = \x y. ordinal x > ordinal y (<) = \x y. ordinal x < ordinal y @@ -625,7 +618,7 @@ def newKey (x:Int) : Key = hash (IToI64 0) x def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) -def splitKey (n:Int) ?-> (k:Key) : Fin n => Key = for i. ixkey k i +def splitKey (k:Key) : Fin n => Key = for i. ixkey k i def rand (k:Key) : Float = unsafeIO do F64ToF $ %ffi randunif Float64 k def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i) @@ -1036,7 +1029,7 @@ def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = withCString modeStr \(MkCString modePtr). MkStream $ %ffi fopen RawPtr pathPtr modePtr -def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = +def fclose (stream:Stream mode) : {State World} Unit = (MkStream stream') = stream %ffi fclose Int64 stream' () @@ -1049,7 +1042,7 @@ def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = %ffi fflush Int64 stream' () -def while (eff:Effects) ?-> (body: Unit -> {|eff} Bool) : {|eff} Unit = +def while (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () %while body' @@ -1237,7 +1230,7 @@ instance Arbitrary Int32 instance [Arbitrary a] Arbitrary (n=>a) arb = \key. for i. arb $ ixkey key i -instance (n:Int) ?-> Arbitrary (Fin n) +instance Arbitrary (Fin n) arb = randIdx 'Control flow diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 73d673554..263955d12 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -152,20 +152,20 @@ checkOrInferRho (WithSrc pos expr) reqTy = do addEffects $ arrowEff arr' appVal <- emitZonked $ App fVal xVal' instantiateSigma appVal >>= matchRequirement - UPi (pat, kind) arr ty -> do + UPi (pat, ann) arr ty -> do -- TODO: make sure there's no effect if it's an implicit or table arrow -- TODO: check leaks - kind' <- checkUType kind + ann' <- checkAnn ann piTy <- case pat of - Just pat' -> withNameHint ("pat" :: Name) $ buildPi b \x -> - withBindPat pat' x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty - where b = case pat' of + UnderscoreUPat -> buildPi (Ignore ann') $ const $ + (,) <$> mapM checkUEffRow arr <*> checkUType ty + _ -> withNameHint ("pat" :: Name) $ buildPi b \x -> + withBindPat pat x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty + where b = case pat of -- Note: The binder name becomes part of the type, so we -- need to keep the same name used in the pattern. - WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>kind') - _ -> Ignore kind' - Nothing -> buildPi (Ignore kind') $ const $ - (,) <$> mapM checkUEffRow arr <*> checkUType ty + WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>ann') + _ -> Ignore ann' matchRequirement piTy UDecl decl body -> do env <- inferUDecl False decl @@ -526,7 +526,8 @@ checkUEffRow (EffectRow effs t) = do checkUEff :: Effect -> UInferM Effect checkUEff eff = case eff of RWSEffect rws region -> do - (Var (v:>TyKind)) <- lookupSourceVar (region:>()) + (Var (v:>ty)) <- lookupSourceVar (region:>()) + constrainEq TyKind ty return $ RWSEffect rws v ExceptionEffect -> return ExceptionEffect diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 9757b9fea..aa4883ffb 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -549,7 +549,7 @@ instance PrettyPrec UExpr' where where kw = case dir of Fwd -> "for" Rev -> "rof" UPi binder arr ty -> atPrec LowestPrec $ - prettyUPiBinder binder <+> pretty arr <+> pLowest ty + prettyUBinder binder <+> pretty arr <+> pLowest ty UDecl decl body -> atPrec LowestPrec $ align $ p decl <> hardline <> pLowest body UHole -> atPrec ArgPrec "_" @@ -614,12 +614,6 @@ prettyUBinder (pat, ann) = p pat <> annDoc where Just ty -> ":" <> pApp ty Nothing -> mempty -prettyUPiBinder :: UPiPatAnn -> Doc ann -prettyUPiBinder (pat, ann) = patDoc <> p ann where - patDoc = case pat of - Just pat' -> pApp pat' <> ":" - Nothing -> mempty - spaced :: (Foldable f, Pretty a) => f a -> Doc ann spaced xs = hsep $ map p $ toList xs diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index a167af975..8472d5d59 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -251,7 +251,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ -- recursive steps UVar _ -> mempty UPi (p, ann) _ ty -> - findVarsInAppLHS ann <> (findVarsInAppLHS ty `envDiff` boundUVars p) + foldMap findVarsInAppLHS ann <> (findVarsInAppLHS ty `envDiff` boundUVars p) UApp _ f x -> findVarsInAppLHS f <> findVarsInAppLHS x ULam (p, ann) _ x -> foldMap findVarsInAppLHS ann <> (findVarsInAppLHS x `envDiff` boundUVars p) @@ -284,12 +284,9 @@ addImplicitImplicitArgs (Just typ) ex = addImplicitArg :: Name -> (UType, UExpr) -> (UType, UExpr) addImplicitArg v (ty, e) = - ( ns $ UPi (Just uPat, uTyKind) ImplicitArrow ty - , ns $ ULam (uPat, Just uTyKind) ImplicitArrow e) - where - uPat = ns $ nameToPat v - k = if v == mkName "eff" then EffectRowKind else TypeKind - uTyKind = ns $ UPrimExpr $ TCExpr k + ( ns $ UPi (uPat, Nothing) ImplicitArrow ty + , ns $ ULam (uPat, Nothing) ImplicitArrow e) + where uPat = ns $ nameToPat v superclassConstraints :: Parser [UType] superclassConstraints = optionalMonoid $ brackets $ uType `sepBy` sym "," @@ -349,12 +346,11 @@ instanceDef = do return $ UInstance ty' methods where addClassConstraint :: UType -> UType -> UType - addClassConstraint c ty = ns $ UPi (Nothing, c) ClassArrow ty + addClassConstraint c ty = ns $ UPi (UnderscoreUPat, Just c) ClassArrow ty addImplicitArg :: Name -> UType -> UType addImplicitArg v ty = - ns $ UPi (Just (ns $ nameToPat v), uTyKind) ImplicitArrow ty - where uTyKind = ns $ UPrimExpr $ TCExpr TypeKind + ns $ UPi (ns $ nameToPat v, Nothing) ImplicitArrow ty instanceMethod :: Parser (UVar, UExpr) instanceMethod = do @@ -386,26 +382,25 @@ funDefLet = label "function definition" $ mayBreak $ do let lamBinders = flip map bs \(p,_, arr) -> ((p,Nothing), arr) return \body -> ULet PlainLet letBinder (buildLam lamBinders body) where - classAsBinder :: UType -> (UPat, UType, UArrow) - classAsBinder ty = (ns underscorePat, ty, ClassArrow) + classAsBinder :: UType -> (UPat, Maybe UType, UArrow) + classAsBinder ty = (UnderscoreUPat, Just ty, ClassArrow) -defArg :: Parser (UPat, UType, UArrow) +defArg :: Parser (UPat, Maybe UType, UArrow) defArg = label "def arg" $ do (p, ty) <-parens ((,) <$> pat <*> annot uType) arr <- arrow (return ()) <|> return (PlainArrow ()) - return (p, ty, arr) + return (p, Just ty, arr) classConstraints :: Parser [UType] classConstraints = label "class constraints" $ optionalMonoid $ brackets $ mayNotPair $ uType `sepBy` sym "," -buildPiType :: [(UPat, UType, UArrow)] -> EffectRow -> UType -> UType +buildPiType :: [(UPat, Maybe UType, UArrow)] -> EffectRow -> UType -> UType buildPiType [] Pure ty = ty buildPiType [] _ _ = error "shouldn't be possible" -buildPiType ((p, patTy, arr):bs) eff resTy = WithSrc pos $ case bs of - [] -> UPi (Just p, patTy) (fmap (const eff ) arr) resTy - _ -> UPi (Just p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy - where WithSrc pos _ = patTy +buildPiType ((p, patTy, arr):bs) eff resTy = ns case bs of + [] -> UPi (p, patTy) (fmap (const eff ) arr) resTy + _ -> UPi (p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy effectiveType :: Parser (EffectRow, UType) effectiveType = (,) <$> effects <*> uType @@ -472,13 +467,10 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ ns $ UDecl (ULet PlainLet (ns underscorePat, Nothing) e) $ + then return $ ns $ UDecl (ULet PlainLet (UnderscoreUPat, Nothing) e) $ ns unitExpr else return e -underscorePat :: UPat' -underscorePat = UPatBinder $ Ignore () - nameToPat :: Name -> UPat' nameToPat v = UPatBinder (Bind (v:>())) @@ -514,7 +506,7 @@ wrapUStatements statements = case statements of (s, pos):rest -> WithSrc (Just pos) $ case s of Left d -> UDecl d $ wrapUStatements rest Right e -> UDecl d $ wrapUStatements rest - where d = ULet PlainLet (ns underscorePat, Nothing) e + where d = ULet PlainLet (UnderscoreUPat, Nothing) e [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement @@ -528,8 +520,8 @@ uPiType = withSrc $ UPi <$> piBinderPat <*> arrow effects <*> uType b <- annBinder return $ case b of Bind (n:>a@(WithSrc pos _)) -> - (Just $ WithSrc pos $ nameToPat n, a) - Ignore a -> (Nothing, a) + (WithSrc pos $ nameToPat n, Just a) + Ignore a -> (UnderscoreUPat, Just a) annBinder :: Parser UAnnBinder annBinder = try $ namedBinder <|> anonBinder @@ -613,7 +605,7 @@ leafPat = <|> brackets (UPatTable <$> leafPat `sepBy` sym ",") ) where pun pos l = WithSrc (Just pos) $ nameToPat $ mkName l - def pos = WithSrc (Just pos) $ underscorePat + def pos = WithSrc (Just pos) $ UPatBinder (Ignore ()) variantPat = parseVariant leafPat UPatVariant UPatVariantLift recordPat = UPatRecord <$> parseLabeledItems "," "=" leafPat (Just pun) (Just def) @@ -914,10 +906,10 @@ infixArrow :: Parser (UType -> UType -> UType) infixArrow = do notFollowedBy (sym "=>") -- table arrows have special fixity (arr, pos) <- withPos $ arrow effects - return \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b + return \a b -> WithSrc (Just pos) $ UPi (UnderscoreUPat, Just a) arr b mkArrow :: Arrow -> UExpr -> UExpr -> UExpr -mkArrow arr a b = joinSrc a b $ UPi (Nothing, a) arr b +mkArrow arr a b = joinSrc a b $ UPi (UnderscoreUPat, Just a) arr b withSrc :: Parser a -> Parser (WithSrc a) withSrc p = do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 0fc83b9a2..ee0978d15 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -39,7 +39,7 @@ module Syntax ( freeVars, freeUVars, Subst, HasVars, BindsVars, Ptr, PtrType, AddressSpace (..), showPrimName, strToPrimName, primNameToStr, monMapSingle, monMapLookup, Direction (..), Limit (..), - UExpr, UExpr' (..), UType, UPatAnn, UPiPatAnn, UAnnBinder, UVar, + UExpr, UExpr' (..), UType, UPatAnn, UAnnBinder, UVar, UPat, UPat' (..), UModule (..), UDecl (..), UArrow, arrowEff, DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, @@ -63,7 +63,7 @@ module Syntax ( pattern Unlabeled, pattern NoExt, pattern LabeledRowKind, pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind, pattern NestOne, pattern NewTypeCon, pattern BinderAnn, - pattern ClassDictDef, pattern ClassDictCon) + pattern ClassDictDef, pattern ClassDictCon, pattern UnderscoreUPat) where import qualified Data.Map.Strict as M @@ -225,7 +225,7 @@ prefixExtLabeledItems items (Ext items' rest) = Ext (items <> items') rest type UExpr = WithSrc UExpr' data UExpr' = UVar UVar | ULam UPatAnn UArrow UExpr - | UPi UPiPatAnn Arrow UType + | UPi UPatAnn Arrow UType | UApp UArrow UExpr UExpr | UDecl UDecl UExpr | UFor Direction UPatAnn UExpr @@ -257,7 +257,6 @@ type UVar = VarP () type UBinder = BinderP () type UPatAnn = (UPat, Maybe UType) -type UPiPatAnn = (Maybe UPat, UType) type UAnnBinder = BinderP UType data UAlt = UAlt UPat UExpr deriving (Show, Generic) @@ -285,6 +284,9 @@ srcPos (WithSrc pos _) = pos instance IsString UExpr' where fromString s = UVar $ Name SourceName (fromString s) 0 :> () +pattern UnderscoreUPat :: UPat +pattern UnderscoreUPat = WithSrc Nothing (UPatBinder (Ignore ())) + -- === primitive constructors and operators === data PrimExpr e = diff --git a/tests/type-tests.dx b/tests/type-tests.dx index b9f1b1f0f..2261e6878 100644 --- a/tests/type-tests.dx +++ b/tests/type-tests.dx @@ -158,11 +158,11 @@ MyPair : Type -> Type = -- TODO: put source annotation on effect for a better message here fEff : Unit -> {| a} a = todo > Type error: -> Expected: EffKind -> Actual: Type +> Expected: Type +> Actual: EffKind > > fEff : Unit -> {| a} a = todo -> ^^^^^^^^^ +> ^^ :p for i:(Fin 7). sum for j:(Fin unboundName). 1.0 From beabafbb4db8df5c0ded2354bc2dd182496c3d57 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Jan 2021 14:47:27 -0500 Subject: [PATCH 5/6] Make fixes suggested in review. --- src/lib/Embed.hs | 6 +++--- src/lib/Inference.hs | 10 +++++----- src/lib/PPrint.hs | 43 ++++++++++++++++++++++++++++++++++++++++++- src/lib/Syntax.hs | 20 ++++++++++---------- tests/adt-tests.dx | 14 +++++++------- 5 files changed, 67 insertions(+), 26 deletions(-) diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index c46397d64..330daa991 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -17,7 +17,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, recGet, buildImplicitNaryLam, + fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, select, substEmbed, substEmbedR, emitUnpack, getUnpacked, fromPair, getFst, getSnd, getFstRef, getSndRef, naryApp, appReduce, appTryReduce, buildAbs, @@ -206,8 +206,8 @@ buildImplicitNaryLam (Nest b bs) body = bs' <- substEmbed (b@>x) bs buildImplicitNaryLam bs' \xs -> body $ x:xs -recGet :: Label -> Atom -> Atom -recGet l x = do +recGetHead :: Label -> Atom -> Atom +recGetHead l x = do let (RecordTy (Ext r _)) = getType x let i = fromJust $ elemIndex l $ map fst $ toList $ reflectLabels r getProjection [i] x diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 263955d12..06cb409af 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -413,7 +413,7 @@ emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do forM_ (getLabels methodTys) \l -> do f <- buildImplicitNaryLam paramBs \params -> do buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do - return $ recGet l $ getProjection [1] dict + return $ recGetHead l $ getProjection [1] dict let methodName = GlobalName $ fromString l checkNotInScope methodName emitTo methodName PlainLet $ Atom f @@ -424,10 +424,10 @@ emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = forM_ (getLabels superclassTys) \l -> do f <- buildImplicitNaryLam paramBs \params -> do buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do - return $ recGet l $ getProjection [0] dict + return $ recGetHead l $ getProjection [0] dict getterName <- freshClassGenName emitTo getterName SuperclassLet $ Atom f -emitSuperclassGetter (DataDef _ _ _) = error "Not a class dictionary" +emitSuperclassGetters (DataDef _ _ _) = error "Not a class dictionary" checkNotInScope :: Name -> UInferM () checkNotInScope v = do @@ -486,7 +486,7 @@ checkInstance ty methods = case ty of ClassDictDef _ superclassTys methodTys -> do methods' <- liftM mkLabeledItems $ forM methods \((v:>()), rhs) -> do let v' = nameToLabel v - case lookupLabel methodTys v' of + case lookupLabelHead methodTys v' of Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) Just methodTy -> do rhs' <- checkSigma rhs Suggest methodTy @@ -495,7 +495,7 @@ checkInstance ty methods = case ty of forM_ (reflectLabels methods') \(l,i) -> when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l forM_ (reflectLabels methodTys) \(l,_) -> - case lookupLabel methods' l of + case lookupLabelHead methods' l of Nothing -> throw TypeErr $ "Missing method: " ++ pprint l Just _ -> return () return $ ClassDictCon def params superclassHoles methods' diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index aa4883ffb..600d42b38 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -21,6 +21,7 @@ import Data.Foldable (toList) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import qualified Data.ByteString.Lazy.Char8 as B +import Data.Maybe (fromMaybe) import Data.String (fromString) import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc @@ -32,6 +33,7 @@ import Numeric import Env import Syntax +import Util (enumerate) -- Specifies what kinds of operations are allowed to be printed at this point. -- Printing at AppPrec level means that applications can be printed @@ -362,7 +364,7 @@ instance PrettyPrec Atom where "DataConRef" <+> p params <+> p args BoxedRef b ptr size body -> atPrec AppPrec $ "Box" <+> p b <+> "<-" <+> p ptr <+> "[" <> p size <> "]" <+> hardline <> "in" <+> p body - ProjectElt idxs x -> atPrec LowestPrec $ "project" <+> p idxs <+> p x + ProjectElt idxs x -> prettyProjection idxs x instance Pretty DataConRefBinding where pretty = prettyFromPrettyPrec instance PrettyPrec DataConRefBinding where @@ -374,6 +376,45 @@ fromInfix t = do (t'', ')') <- unsnoc t' return t'' +prettyProjection :: NE.NonEmpty Int -> Var -> DocPrec ann +prettyProjection idxs (name :> ty) = prettyPrec uproj where + -- Builds a source expression that performs the given projection. + uproj = UApp (PlainArrow ()) (nosrc ulam) (nosrc uvar) + ulam = ULam (upat, Nothing) (PlainArrow ()) (nosrc $ UVar $ target :> ()) + uvar = UVar $ name :> () + (_, upat, target) = buildProj idxs + + buildProj :: NE.NonEmpty Int -> (Type, UPat, Name) + buildProj (i NE.:| is) = let + -- Lazy Haskell trick: refer to `target` even though this function is + -- responsible for setting it! + (ty', pat', eltName) = case NE.nonEmpty is of + Just is' -> let (x, y, z) = buildProj is' in (x, y, Just z) + Nothing -> (ty, nosrc $ UPatBinder $ Bind $ target :> (), Nothing) + in case ty' of + TypeCon def params -> let + [DataConDef conName bs] = applyDataDefParams def params + b = toList bs !! i + pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate bs + hint = case b of + Bind (n :> _) -> n + Ignore _ -> Name SourceName "elt" 0 + in ( binderAnn b, nosrc $ UPatCon conName pats, fromMaybe hint eltName) + RecordTy (NoExt types) -> let + ty'' = toList types !! i + pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate types + (fieldName, _) = toList (reflectLabels types) !! i + hint = Name SourceName (fromString fieldName) 0 + in (ty'', nosrc $ UPatRecord $ NoExt pats, fromMaybe hint eltName) + PairTy x _ | i == 0 -> + (x, nosrc $ UPatPair pat' uignore, fromMaybe "a" eltName) + PairTy _ y | i == 1 -> + (y, nosrc $ UPatPair uignore pat', fromMaybe "b" eltName) + _ -> error "Bad projection" + + nosrc = WithSrc Nothing + uignore = nosrc $ UPatBinder $ Ignore () + prettyExtLabeledItems :: (PrettyPrec a, PrettyPrec b) => ExtLabeledItems a b -> Doc ann -> Doc ann -> DocPrec ann prettyExtLabeledItems (Ext (LabeledItems row) rest) separator bindwith = diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index ee0978d15..1f99e550f 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -28,8 +28,8 @@ module Syntax ( IExpr (..), IVal, ImpInstr (..), Backend (..), Device (..), IPrimOp, IVar, IBinder, IType, SetVal (..), MonMap (..), LitProg, IFunType (..), IFunVar, CallingConvention (..), IsCUDARequired (..), - UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, lookupLabel, - reflectLabels, withLabels, ExtLabeledItems (..), + UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, + lookupLabelHead, reflectLabels, withLabels, ExtLabeledItems (..), prefixExtLabeledItems, getLabels, IScope, BinderInfo (..), Bindings, CUDAKernel (..), BenchStats, SrcCtx, Result (..), Output (..), OutFormat (..), @@ -196,8 +196,8 @@ withLabels :: LabeledItems a -> LabeledItems (Label, Int, a) withLabels (LabeledItems items) = LabeledItems $ flip M.mapWithKey items \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) -lookupLabel :: LabeledItems a -> Label -> Maybe a -lookupLabel (LabeledItems items) l = case M.lookup l items of +lookupLabelHead :: LabeledItems a -> Label -> Maybe a +lookupLabelHead (LabeledItems items) l = case M.lookup l items of Nothing -> Nothing Just (x NE.:| _) -> Just x @@ -798,8 +798,9 @@ instance BindsUVars UPat' where instance HasUVars UDecl where freeUVars (ULet _ p expr) = freeUVars p <> freeUVars expr freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons - freeUVars (UInterface _ _ _) = mempty -- TODO - freeUVars (UInstance _ _) = mempty -- TODO + freeUVars (UInterface superclasses tc methods) = + freeUVars $ Abs tc (superclasses, methods) + freeUVars (UInstance ty methods) = mempty -- TODO instance BindsUVars UDecl where boundUVars decl = case decl of @@ -1538,15 +1539,14 @@ pattern BinderAnn x <- ((\case Ignore ann -> ann where BinderAnn x = Ignore x pattern NewTypeCon :: Name -> Type -> [DataConDef] -pattern NewTypeCon con ty <- [DataConDef con (NestOne (BinderAnn ty))] - where NewTypeCon con ty = [DataConDef con (NestOne (Ignore ty))] +pattern NewTypeCon con ty = [DataConDef con (NestOne (BinderAnn ty))] pattern ClassDictDef :: Name -> LabeledItems Type -> LabeledItems Type -> [DataConDef] pattern ClassDictDef conName superclasses methods = [DataConDef conName - (Nest (Ignore (RecordTy (NoExt superclasses))) - (Nest (Ignore (RecordTy (NoExt methods))) Empty))] + (Nest (BinderAnn (RecordTy (NoExt superclasses))) + (Nest (BinderAnn (RecordTy (NoExt methods))) Empty))] pattern ClassDictCon :: DataDef -> [Type] -> LabeledItems Atom -> LabeledItems Atom -> Atom diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 97ad29dc7..1d2d2306e 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -216,7 +216,7 @@ def catLists (xs:List a) (ys:List a) : List a = def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs :t listToTable -> ((a:Type) ?-> (pat:(List a)) -> (Fin (project [0] pat:(List a))) => a) +> ((a:Type) ?-> (pat:(List a)) -> (Fin ((\((AsList n _)). n) pat)) => a) :p l = AsList _ [1, 2, 3] @@ -228,7 +228,7 @@ def listToTable2 (l: List a) : (Fin (listLength l))=>a = xs :t listToTable2 -> ((a:Type) ?-> (l:(List a)) -> (Fin (project [0] l:(List a))) => a) +> ((a:Type) ?-> (l:(List a)) -> (Fin ((\((AsList n _)). n) l)) => a) :p l = AsList _ [1, 2, 3] @@ -258,7 +258,7 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = :t graphToAdjacencyMatrix > ((a:Type) > ?-> (pat:(Graph a)) -> -> (project [0] pat:(Graph a)) => (project [0] pat:(Graph a)) => Bool) +> -> ((\((MkGraph n _ _ _)). n) pat) => ((\((MkGraph n _ _ _)). n) pat) => Bool) :p g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)] @@ -269,15 +269,15 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = def pairUnpack ((v, _):(Int & Float)) : Int = v :p pairUnpack -> \pat:(Int32 & Float32). project [0] pat:(Int32 & Float32) +> \pat:(Int32 & Float32). (\(a, _). a) pat def adtUnpack ((MkMyPair v _):MyPair Int Float) : Int = v :p adtUnpack -> \pat:(MyPair Int32 Float32). project [0] pat:(MyPair Int32 Float32) +> \pat:(MyPair Int32 Float32). (\((MkMyPair elt _)). elt) pat def recordUnpack ({a=v, b=_}:{a:Int & b:Float}) : Int = v :p recordUnpack -> \pat:{a: Int32 & b: Float32}. project [0] pat:{a: Int32 & b: Float32} +> \pat:{a: Int32 & b: Float32}. (\{a = a, b = _}. a) pat def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = (MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x @@ -285,7 +285,7 @@ def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = :p nestedUnpack > \x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)). -> project [0, 0, 0, 1] x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)) +> (\((MkIntish (((MkMyPair ((MkMyPair _ elt)) _)), _))). elt) x :p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) > 4 From ff91d2dd25ba082cf2277e40729dc7f429786000 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Jan 2021 16:50:08 -0500 Subject: [PATCH 6/6] Make scoping of binders in instance declarations more explicit. And fix type inference to handle them properly. --- src/lib/Inference.hs | 82 +++++++++++++++++++++++++------------------- src/lib/PPrint.hs | 7 ++-- src/lib/Parser.hs | 36 +++++++++---------- src/lib/Syntax.hs | 27 ++++++++++----- 4 files changed, 89 insertions(+), 63 deletions(-) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 06cb409af..84c11c9d2 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -361,15 +361,14 @@ inferUDecl True (UInterface superclasses tc methods) = do emitSuperclassGetters dataDef emitMethodGetters dataDef return mempty -inferUDecl True (UInstance instanceTy methods) = do - ty <- checkUType instanceTy - instanceDict <- checkInstance ty methods +inferUDecl True (UInstance argBinders instanceTy methods) = do + instanceDict <- checkInstance argBinders instanceTy methods let instanceName = Name TypeClassGenName "instance" 0 void $ emitTo instanceName InstanceLet $ Atom instanceDict return mempty inferUDecl False (UData _ _ ) = error "data definitions should be top-level" inferUDecl False (UInterface _ _ _) = error "interface definitions should be top-level" -inferUDecl False (UInstance _ _ ) = error "instance definitions should be top-level" +inferUDecl False (UInstance _ _ _) = error "instance definitions should be top-level" freshClassGenName :: MonadEmbed m => m Name freshClassGenName = do @@ -479,36 +478,46 @@ checkULam (p, ann) body piTy = do \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x -checkInstance :: Type -> [(UVar, UExpr)] -> UInferM Atom -checkInstance ty methods = case ty of - TypeCon def@(DataDef className _ _) params -> do - case applyDataDefParams def params of - ClassDictDef _ superclassTys methodTys -> do - methods' <- liftM mkLabeledItems $ forM methods \((v:>()), rhs) -> do - let v' = nameToLabel v - case lookupLabelHead methodTys v' of - Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) - Just methodTy -> do - rhs' <- checkSigma rhs Suggest methodTy - return (v', rhs') - let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys - forM_ (reflectLabels methods') \(l,i) -> - when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l - forM_ (reflectLabels methodTys) \(l,_) -> - case lookupLabelHead methods' l of - Nothing -> throw TypeErr $ "Missing method: " ++ pprint l - Just _ -> return () - return $ ClassDictCon def params superclassHoles methods' - _ -> throw TypeErr $ "Not a valid instance: " ++ pprint ty - Pi (Abs b (arrow, bodyTy)) -> do - case arrow of - ImplicitArrow -> return () - ClassArrow -> return () - _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow - buildLam b arrow \x@(Var v) -> do - bodyTy' <- substEmbed (b@>x) bodyTy - checkLeaks [v] $ extendR (b@>x) $ checkInstance bodyTy' methods - _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty +checkInstance :: Nest UPatAnnArrow -> UType -> [UMethodDef] -> UInferM Atom +checkInstance Empty ty methods = do + ty' <- checkUType ty + case ty' of + TypeCon def@(DataDef className _ _) params -> + case applyDataDefParams def params of + ClassDictDef _ superclassTys methodTys -> do + let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys + methods' <- checkMethodDefs className methodTys methods + return $ ClassDictCon def params superclassHoles methods' + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty +checkInstance (Nest ((p, ann), arrow) rest) ty methods = do + case arrow of + ImplicitArrow -> return () + ClassArrow -> return () + _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow + argTy <- checkAnn ann + buildLam (Bind $ patNameHint p :> argTy) (fromUArrow arrow) \x@(Var v) -> + checkLeaks [v] $ withBindPat p x $ checkInstance rest ty methods + + +checkMethodDefs :: Name -> LabeledItems Type -> [UMethodDef] + -> UInferM (LabeledItems Atom) +checkMethodDefs className methodTys methods = do + methods' <- liftM mkLabeledItems $ forM methods \(UMethodDef (v:>()) rhs) -> do + let v' = nameToLabel v + case lookupLabelHead methodTys v' of + Nothing -> throw TypeErr $ + pprint v ++ " is not a method of " ++ pprint className + Just methodTy -> do + rhs' <- checkSigma rhs Suggest methodTy + return (v', rhs') + forM_ (reflectLabels methods') \(l,i) -> + when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l + forM_ (reflectLabels methodTys) \(l,_) -> + case lookupLabelHead methods' l of + Nothing -> throw TypeErr $ "Missing method: " ++ pprint l + Just _ -> return () + return methods' checkUEffRow :: EffectRow -> UInferM EffectRow checkUEffRow (EffectRow effs t) = do @@ -732,13 +741,16 @@ inferTabCon xs reqTy = do return (tabTy, xs') emitZonked $ Op $ TabCon tabTy xs' +fromUArrow :: UArrow -> Arrow +fromUArrow arr = fmap (const Pure) arr + -- Bool flag is just to tweak the reported error message fromPiType :: Bool -> UArrow -> Type -> UInferM PiType fromPiType _ _ (Pi piTy) = return piTy -- TODO: check arrow fromPiType expectPi arr ty = do a <- freshType TyKind b <- freshType TyKind - let piTy = Abs (Ignore a) (fmap (const Pure) arr, b) + let piTy = Abs (Ignore a) (fromUArrow arr, b) if expectPi then constrainEq (Pi piTy) ty else constrainEq ty (Pi piTy) return piTy diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 600d42b38..eecfe641e 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -631,8 +631,11 @@ instance Pretty UDecl where "data" <+> p tyCon <+> "where" <> nest 2 (hardline <> prettyLines dataCons) pretty (UInterface cs def methods) = "interface" <+> p cs <+> p def <> hardline <> prettyLines methods - pretty (UInstance ty methods) = - "instance" <+> p ty <> hardline <> prettyLines methods + pretty (UInstance bs ty methods) = + "instance" <+> p bs <+> p ty <> hardline <> prettyLines methods + +instance Pretty UMethodDef where + pretty (UMethodDef b rhs) = p b <+> "=" <+> p rhs instance Pretty UConDef where pretty (UConDef con bs) = p con <+> spaced bs diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 8472d5d59..b6f9d3fd0 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -259,7 +259,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UFor _ _ _ -> error "Unexpected for in type annotation" UHole -> mempty UTypeAnn v ty -> findVarsInAppLHS v <> findVarsInAppLHS ty - UTabCon _ -> mempty + UTabCon _ -> error "Unexpected table constructor in type annotation" UIndexRange low high -> foldMap findVarsInAppLHS low <> foldMap findVarsInAppLHS high UPrimExpr prim -> foldMap findVarsInAppLHS prim @@ -339,25 +339,25 @@ instanceDef = do explicitArgs <- many defArg constraints <- classConstraints classTy <- uType - let ty = buildPiType explicitArgs Pure $ - foldr addClassConstraint classTy constraints - let ty' = foldr addImplicitArg ty $ findImplicitImplicitArgNames ty + let implicitArgs = findImplicitImplicitArgNames $ + buildPiType explicitArgs Pure $ + foldr addClassConstraint classTy constraints + let argBinders = + [((ns (nameToPat v), Nothing), ImplicitArrow) | v <- implicitArgs] ++ + explicitArgs ++ + [((UnderscoreUPat, Just c) , ClassArrow ) | c <- constraints] methods <- onePerLine instanceMethod - return $ UInstance ty' methods + return $ UInstance (toNest argBinders) classTy methods where addClassConstraint :: UType -> UType -> UType addClassConstraint c ty = ns $ UPi (UnderscoreUPat, Just c) ClassArrow ty - addImplicitArg :: Name -> UType -> UType - addImplicitArg v ty = - ns $ UPi (ns $ nameToPat v, Nothing) ImplicitArrow ty - -instanceMethod :: Parser (UVar, UExpr) +instanceMethod :: Parser UMethodDef instanceMethod = do v <- anyName sym "=" rhs <- blockOrExpr - return (v:>(), rhs) + return $ UMethodDef (v:>()) rhs simpleLet :: Parser (UExpr -> UDecl) simpleLet = label "let binding" $ do @@ -379,26 +379,26 @@ funDefLet = label "function definition" $ mayBreak $ do let bs = map classAsBinder cs ++ argBinders let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) - let lamBinders = flip map bs \(p,_, arr) -> ((p,Nothing), arr) + let lamBinders = flip map bs \((p,_), arr) -> ((p,Nothing), arr) return \body -> ULet PlainLet letBinder (buildLam lamBinders body) where - classAsBinder :: UType -> (UPat, Maybe UType, UArrow) - classAsBinder ty = (UnderscoreUPat, Just ty, ClassArrow) + classAsBinder :: UType -> UPatAnnArrow + classAsBinder ty = ((UnderscoreUPat, Just ty), ClassArrow) -defArg :: Parser (UPat, Maybe UType, UArrow) +defArg :: Parser UPatAnnArrow defArg = label "def arg" $ do (p, ty) <-parens ((,) <$> pat <*> annot uType) arr <- arrow (return ()) <|> return (PlainArrow ()) - return (p, Just ty, arr) + return ((p, Just ty), arr) classConstraints :: Parser [UType] classConstraints = label "class constraints" $ optionalMonoid $ brackets $ mayNotPair $ uType `sepBy` sym "," -buildPiType :: [(UPat, Maybe UType, UArrow)] -> EffectRow -> UType -> UType +buildPiType :: [UPatAnnArrow] -> EffectRow -> UType -> UType buildPiType [] Pure ty = ty buildPiType [] _ _ = error "shouldn't be possible" -buildPiType ((p, patTy, arr):bs) eff resTy = ns case bs of +buildPiType (((p, patTy), arr):bs) eff resTy = ns case bs of [] -> UPi (p, patTy) (fmap (const eff ) arr) resTy _ -> UPi (p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 1f99e550f..1d26b02c2 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -40,6 +40,7 @@ module Syntax ( AddressSpace (..), showPrimName, strToPrimName, primNameToStr, monMapSingle, monMapLookup, Direction (..), Limit (..), UExpr, UExpr' (..), UType, UPatAnn, UAnnBinder, UVar, + UMethodDef (..), UPatAnnArrow, UPat, UPat' (..), UModule (..), UDecl (..), UArrow, arrowEff, DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, @@ -245,18 +246,21 @@ data UExpr' = UVar UVar deriving (Show, Generic) data UConDef = UConDef Name (Nest UAnnBinder) deriving (Show, Generic) -data UDecl = ULet LetAnn UPatAnn UExpr - | UData UConDef [UConDef] - | UInterface [UType] UConDef [UAnnBinder] - | UInstance UType [(UVar, UExpr)] - deriving (Show, Generic) +data UDecl = + ULet LetAnn UPatAnn UExpr + | UData UConDef [UConDef] + | UInterface [UType] UConDef [UAnnBinder] -- superclasses, constructor, methods + | UInstance (Nest UPatAnnArrow) UType [UMethodDef] -- args, type, methods + deriving (Show, Generic) type UType = UExpr type UArrow = ArrowP () type UVar = VarP () type UBinder = BinderP () +data UMethodDef = UMethodDef UVar UExpr deriving (Show, Generic) -type UPatAnn = (UPat, Maybe UType) +type UPatAnn = (UPat, Maybe UType) +type UPatAnnArrow = (UPatAnn, UArrow) type UAnnBinder = BinderP UType data UAlt = UAlt UPat UExpr deriving (Show, Generic) @@ -800,14 +804,21 @@ instance HasUVars UDecl where freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons freeUVars (UInterface superclasses tc methods) = freeUVars $ Abs tc (superclasses, methods) - freeUVars (UInstance ty methods) = mempty -- TODO + freeUVars (UInstance bsArrows ty methods) = freeUVars $ Abs bs (ty, methods) + where bs = fmap fst bsArrows + +instance HasUVars UMethodDef where + freeUVars (UMethodDef _ def) = freeUVars def + +instance BindsUVars UPatAnn where + boundUVars (p, _) = boundUVars p instance BindsUVars UDecl where boundUVars decl = case decl of ULet _ (p,_) _ -> boundUVars p UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons UInterface _ _ _ -> mempty - UInstance _ _ -> mempty + UInstance _ _ _ -> mempty instance HasUVars UModule where freeUVars (UModule decls) = freeUVars decls