Skip to content

Commit

Permalink
SCP-2638: simplify datatypes which are used only at the type-level (#…
Browse files Browse the repository at this point in the history
…4289)

* SCP-2638: simplify datatypes which are used only at the type-level

1. Change dependency analysis to account for the fact that the
term-level parts can be removed (see note).
2. Simplify datatype bindings into trivial type bindings if all their
term-level parts are dead.

Had to do a bit of test rearrangement since a lot of the
`plutus-tx-plugin` tests for a type T just used a lambda with an unused
argument of type T... which gets simplified with this PR!

Fixes #4147.
Fixes #3702.

* Comments
  • Loading branch information
michaelpj authored Dec 17, 2021
1 parent 6cc82fd commit bd8def3
Show file tree
Hide file tree
Showing 20 changed files with 212 additions and 71 deletions.
51 changes: 42 additions & 9 deletions plutus-core/plutus-ir/src/PlutusIR/Analysis/Dependencies.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
-- | Functions for computing the dependency graph of variables within a term or type. A "dependency" between
-- two nodes "A depends on B" means that B cannot be removed from the program without also removing A.
module PlutusIR.Analysis.Dependencies (Node (..), DepGraph, StrictnessMap, runTermDeps, runTypeDeps) where
Expand Down Expand Up @@ -115,6 +114,34 @@ From the point of view of our algorithm, we handle the dependency by treating it
reference to the newly bound variable alongside the binding, but only in the cases where it matters.
-}

{- Note [Dependencies for datatype bindings, and pruning them]
At face value, all the names introduced by datatype bindings should depend on each other.
Given our meaning of "A depends on B", since we cannot remove any part of the datatype binding without
removing the whole thing, they all depend on each other
However, there are some circumstances in which we *can* prune datatype bindings.
In particular, if the datatype is only used at the type-level (i.e. all the term-level parts
(constructors and destructor) are dead), then we are free to completely replace the binding
with one for a trivial type with the same kind.
This is because there are *no* term-level effects, and types are erased in the end, so
in this case rest of the datatype binding really is superfluous.
But how do we represent this in the dependency graph? We still need to have proper dependencies
so that we don't make the wrong decisions wrt transitively used values, e.g.
let U :: * = ...
let datatype T = T1 | T2 U
in T1
Here we need to not delete U, even though T2 is "dead"!
The solution is to focus on the meaning of "dependency": with the pruning that we can do, we *can*
remove all the term level bits en masse, but only en-mass. So we need to make *them* into a clique,
so that this is visible to the dependency analysis.
-}

bindingDeps
:: (DepGraph g, MonadReader (DepCtx term) m, MonadState DepState m, PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique,
PLC.ToBuiltinMeaning uni fun)
Expand All @@ -137,16 +164,22 @@ bindingDeps b = case b of
tDeps <- withCurrent n $ typeDeps rhs
pure $ G.overlay vDeps tDeps
DatatypeBind _ (Datatype _ d tvs destr constrs) -> do
-- See Note [Dependencies for datatype bindings, and pruning them]
vDeps <- tyVarDeclDeps d
tvDeps <- traverse tyVarDeclDeps tvs
cstrDeps <- traverse varDeclDeps constrs
-- All the datatype bindings depend on each other since they can't be used separately. Consider
-- the identity function on a datatype type - it only uses the type variable, but the whole definition
-- will therefore be kept, and so we must consider any uses in e.g. the constructors as live.
let tyus = fmap (view PLC.theUnique) $ _tyVarDeclName d : fmap _tyVarDeclName tvs
let tus = fmap (view PLC.theUnique) $ destr : fmap _varDeclName constrs
let localDeps = G.clique (fmap Variable $ tyus ++ tus)
pure $ G.overlays $ [vDeps] ++ tvDeps ++ cstrDeps ++ [localDeps]
-- Destructors depend on the datatype and the argument types of all the constructors, because e.g. a destructor for Maybe looks like:
-- forall a . Maybe a -> (a -> r) -> r -> r
-- i.e. the argument type of the Just constructor appears as the argument to the branch.
--
-- We can get the effect of that by having it depend on all the constructor types (which also include the datatype).
-- This is more diligent than currently necessary since we're going to make all the term-level
-- parts depend on each other later, but it's good practice and will be useful if we ever stop doing that.
destrDeps <- G.overlays <$> (withCurrent destr $ traverse (typeDeps . _varDeclType) constrs)
let tus = fmap (view PLC.theUnique) (destr : fmap _varDeclName constrs)
-- See Note [Dependencies for datatype bindings, and pruning them]
let nonDatatypeClique = G.clique (fmap Variable tus)
pure $ G.overlays $ [vDeps] ++ tvDeps ++ cstrDeps ++ [destrDeps] ++ [nonDatatypeClique]

bindingStrictness
:: (MonadState DepState m, PLC.HasUnique name PLC.TermUnique)
Expand Down Expand Up @@ -194,7 +227,7 @@ termDeps = \case
modify (Map.insert (n ^. PLC.theUnique) Strict)
tds <- termDeps t
tyds <- typeDeps ty
pure $ G.overlays $ [tds, tyds]
pure $ G.overlays [tds, tyds]
x -> do
tds <- traverse termDeps (x ^.. termSubterms)
tyds <- traverse typeDeps (x ^.. termSubtypes)
Expand Down
67 changes: 48 additions & 19 deletions plutus-core/plutus-ir/src/PlutusIR/Transform/DeadCode.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
-- | Optimization passes for removing dead code, mainly dead let bindings.
module PlutusIR.Transform.DeadCode (removeDeadBindings) where

Expand All @@ -15,7 +15,6 @@ import PlutusCore.Constant qualified as PLC
import PlutusCore.Name qualified as PLC

import Control.Lens
import Control.Monad
import Control.Monad.Reader

import Data.Coerce
Expand All @@ -24,16 +23,19 @@ import Data.Set qualified as Set
import Algebra.Graph qualified as G
import Algebra.Graph.ToGraph qualified as T
import Data.List.NonEmpty qualified as NE
import PlutusCore.Quote (MonadQuote, freshTyName, liftQuote)
import PlutusCore.StdLib.Data.ScottUnit qualified as Unit
import Witherable (Witherable (wither))

-- | Remove all the dead let bindings in a term.
removeDeadBindings
:: (PLC.HasUnique name PLC.TermUnique, PLC.HasUnique tyname PLC.TypeUnique,
:: (PLC.HasUnique name PLC.TermUnique,
PLC.ToBuiltinMeaning uni fun, PLC.MonadQuote m)
=> Term tyname name uni fun a
-> m (Term tyname name uni fun a)
=> Term TyName name uni fun a
-> m (Term TyName name uni fun a)
removeDeadBindings t = do
tRen <- PLC.rename t
runReaderT (transformMOf termSubterms processTerm tRen) (calculateLiveness tRen)
liftQuote $ runReaderT (transformMOf termSubterms processTerm tRen) (calculateLiveness tRen)

type Liveness = Set.Set Deps.Node

Expand All @@ -55,24 +57,51 @@ live n =
in asks $ Set.member (Deps.Variable u)

liveBinding
:: (MonadReader Liveness m, PLC.HasUnique name PLC.TermUnique, PLC.HasUnique tyname PLC.TypeUnique)
=> Binding tyname name uni fun a
-> m Bool
:: (MonadReader Liveness m, PLC.HasUnique name PLC.TermUnique, MonadQuote m)
=> Binding TyName name uni fun a
-> m (Maybe (Binding TyName name uni fun a))
liveBinding =
let
-- TODO: HasUnique instances for VarDecl and TyVarDecl?
liveVarDecl (VarDecl _ n _) = live n
liveTyVarDecl (TyVarDecl _ n _) = live n
in \case
TermBind _ _ d _ -> liveVarDecl d
TypeBind _ d _ -> liveTyVarDecl d
DatatypeBind _ (Datatype _ d _ destr constrs) -> or <$> (sequence $ [liveTyVarDecl d, live destr] ++ fmap liveVarDecl constrs)
b@(TermBind _ _ d _) -> do
l <- liveVarDecl d
pure $ if l then Just b else Nothing
b@(TypeBind _ d _) -> do
l <- liveTyVarDecl d
pure $ if l then Just b else Nothing
b@(DatatypeBind x (Datatype _ d _ destr constrs)) -> do
dtypeLive <- liveTyVarDecl d
destrLive <- live destr
constrsLive <- traverse liveVarDecl constrs
let termLive = or (destrLive : constrsLive)
case (dtypeLive, termLive) of
-- At least one term-level part is live, keep the whole thing
(_, True) -> pure $ Just b
-- Nothing is live, remove the whole thing
(False, False) -> pure Nothing
-- See Note [Dependencies for datatype bindings, and pruning them]
-- Datatype is live but no term-level parts are, replace with a trivial type binding
(True, False) -> Just . TypeBind x d <$> mkTypeOfKind (_tyVarDeclKind d)

-- | Given a kind, make a type (any type!) of that kind.
-- Generates things of the form 'unit -> unit -> ... -> unit'
mkTypeOfKind :: MonadQuote m => Kind a -> m (Type TyName uni a)
mkTypeOfKind = \case
-- The scott-encoded unit here is a little bulky but it continues to be the easiest
-- way to get a type of kind Type without relying on builtins.
Type a -> pure $ a <$ Unit.unit
KindArrow a ki ki' -> do
n <- freshTyName "a"
TyLam a n ki <$> mkTypeOfKind ki'

processTerm
:: (MonadReader Liveness m, PLC.HasUnique name PLC.TermUnique, PLC.HasUnique tyname PLC.TypeUnique)
=> Term tyname name uni fun a
-> m (Term tyname name uni fun a)
:: (MonadReader Liveness m, PLC.HasUnique name PLC.TermUnique, MonadQuote m)
=> Term TyName name uni fun a
-> m (Term TyName name uni fun a)
processTerm = \case
-- throw away dead bindings
Let x r bs t -> mkLet x r <$> filterM liveBinding (NE.toList bs) <*> pure t
Let x r bs t -> mkLet x r <$> wither liveBinding (NE.toList bs) <*> pure t
x -> pure x
1 change: 1 addition & 0 deletions plutus-core/plutus-ir/test/TransformSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ deadCode =
, "nestedBindingsIndirect"
, "recBindingSimple"
, "recBindingComplex"
, "pruneDatatype"
]

retainedSize :: TestNested
Expand Down
6 changes: 5 additions & 1 deletion plutus-core/plutus-ir/test/errors/mutuallyRecursiveTypes
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
(let
(nonrec)
(typebind (tyvardecl unit (type)) (all a (type) (fun a a)))
(let
(rec)
(datatypebind
Expand All @@ -17,5 +20,6 @@
(vardecl Cons (fun [Tree a] (fun [Forest a] [Forest a])))
)
)
{ Nil (all a (type) (fun a a)) }
[ { Node unit } (error unit) { Nil unit } ]
)
)
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
E003: Unsupported construct: Mutually recursive datatypes ((recursive) let binding; from mutuallyRecursiveTypes:1:2)
E003: Unsupported construct: Mutually recursive datatypes ((recursive) let binding; from [ mutuallyRecursiveTypes:1:2
, mutuallyRecursiveTypes:4:2 ])
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
(let
(nonrec)
(datatypebind
(datatype
(tyvardecl Maybe (fun (type) (type)))
(tyvardecl a (type))
match_Maybe
(vardecl Nothing [ Maybe a ]) (vardecl Just (fun a [ Maybe a ]))
)
(typebind
(tyvardecl Maybe (fun (type) (type)))
(lam a (type) (all a (type) (fun a a)))
)
(error Maybe)
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
(vardecl Constr (fun unit SomeType))
)
)
(lam arg SomeType arg)
[Constr (error unit)]
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
(vardecl Constr (fun unit SomeType))
)
)
(lam arg SomeType arg)
[ Constr (error unit) ]
)
)
16 changes: 16 additions & 0 deletions plutus-core/plutus-ir/test/transform/deadCode/pruneDatatype
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
(let
(nonrec)
(typebind (tyvardecl unit (type)) (all a (type) (fun a a)))
(let
(nonrec)
(datatypebind
(datatype
(tyvardecl SomeType (type))

match_SomeType
(vardecl Constr (fun unit SomeType))
)
)
(lam arg SomeType (error unit))
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
(let
(nonrec)
(typebind (tyvardecl unit (type)) (all a (type) (fun a a)))
(let
(nonrec)
(typebind (tyvardecl SomeType (type)) (all a (type) (fun a a)))
(lam arg SomeType (error unit))
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
(nonrec)
(datatypebind
(datatype
$1$
$0$
(tyvardecl $4$ Maybe (fun (type) (type)))
(tyvardecl $2$ a (type))
match_Maybe
(vardecl $17$ Nothing [ Maybe a ]) (vardecl $6$ Just (fun a [ Maybe a ]))
(vardecl $16$ Nothing [ Maybe a ]) (vardecl $6$ Just (fun a [ Maybe a ]))
)
)
Nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
(nonrec)
(datatypebind
(datatype
$1$
(tyvardecl $17$ Maybe (fun (type) (type)))
(tyvardecl $2$ a (type))
$0$
(tyvardecl $4$ Maybe (fun (type) (type)))
(tyvardecl $0$ a (type))
match_Maybe
(vardecl $4$ Nothing [ Maybe a ]) (vardecl $6$ Just (fun a [ Maybe a ]))
(vardecl $0$ Nothing [ Maybe a ]) (vardecl $0$ Just (fun a [ Maybe a ]))
)
)
(error Maybe)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
(let
(nonrec)
(typebind (tyvardecl $7$ unit (type)) (all a (type) (fun a a)))
(typebind (tyvardecl $0$ unit (type)) (all a (type) (fun a a)))
(let
(nonrec)
(datatypebind
(datatype
$1$
(tyvardecl $14$ SomeType (type))
$0$
(tyvardecl $2$ SomeType (type))

match_SomeType
(vardecl $11$ Constr (fun unit SomeType))
(vardecl $0$ Constr (fun unit SomeType))
)
)
(lam arg SomeType arg)
Expand Down
25 changes: 15 additions & 10 deletions plutus-tx-plugin/test/Plugin/Data/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ instance P.Eq MyMonoData where
(Mono3 i1) == (Mono3 i2) = i1 P.== i2
_ == _ = False

monoDataType :: CompiledCode (MyMonoData -> MyMonoData)
monoDataType = plc (Proxy @"monoDataType") (\(x :: MyMonoData) -> x)
-- pattern match to avoid type getting simplified away
monoDataType :: CompiledCode (MyMonoData -> Integer)
monoDataType = plc (Proxy @"monoDataType") (\(x :: MyMonoData) -> case x of { Mono2 i -> i; _ -> 1; })

monoConstructor :: CompiledCode (Integer -> Integer -> MyMonoData)
monoConstructor = plc (Proxy @"monConstructor") Mono1
Expand Down Expand Up @@ -101,13 +102,15 @@ instance P.Eq MyMonoRecord where
{-# INLINABLE (==) #-}
(MyMonoRecord i1 j1) == (MyMonoRecord i2 j2) = i1 P.== i2 && j1 P.== j2

monoRecord :: CompiledCode (MyMonoRecord -> MyMonoRecord)
monoRecord = plc (Proxy @"monoRecord") (\(x :: MyMonoRecord) -> x)
-- pattern match to avoid type getting simplified away
monoRecord :: CompiledCode (MyMonoRecord -> Integer)
monoRecord = plc (Proxy @"monoRecord") (\(x :: MyMonoRecord) -> case x of { MyMonoRecord i _ -> i; })

data RecordNewtype = RecordNewtype { newtypeField :: MyNewtype }

recordNewtype :: CompiledCode (RecordNewtype -> RecordNewtype)
recordNewtype = plc (Proxy @"recordNewtype") (\(x :: RecordNewtype) -> x)
-- pattern match to avoid type getting simplified away
recordNewtype :: CompiledCode (RecordNewtype -> Integer)
recordNewtype = plc (Proxy @"recordNewtype") (\(x :: RecordNewtype) -> case x of { RecordNewtype (MyNewtype i) -> i; })

-- must be compiled with a lazy case
nonValueCase :: CompiledCode (MyEnum -> Integer)
Expand Down Expand Up @@ -139,8 +142,9 @@ instance (P.Eq a, P.Eq b) => P.Eq (MyPolyData a b) where
(Poly2 a1) == (Poly2 a2) = a1 P.== a2
_ == _ = False

polyDataType :: CompiledCode (MyPolyData Integer Integer -> MyPolyData Integer Integer)
polyDataType = plc (Proxy @"polyDataType") (\(x:: MyPolyData Integer Integer) -> x)
-- pattern match to avoid type getting simplified away
polyDataType :: CompiledCode (MyPolyData Integer Integer -> Integer)
polyDataType = plc (Proxy @"polyDataType") (\(x:: MyPolyData Integer Integer) -> case x of { Poly2 i -> i; _ -> 1; })

polyConstructed :: CompiledCode (MyPolyData Integer Integer)
polyConstructed = plc (Proxy @"polyConstructed") (Poly1 (1::Integer) (2::Integer))
Expand Down Expand Up @@ -185,8 +189,9 @@ nestedNewtypeMatch = plc (Proxy @"nestedNewtypeMatch") (\(MyNewtype2 (MyNewtype

newtype ParamNewtype a = ParamNewtype (Maybe a)

paramNewtype :: CompiledCode (ParamNewtype Integer -> ParamNewtype Integer)
paramNewtype = plc (Proxy @"paramNewtype") (\(x ::ParamNewtype Integer) -> x)
-- pattern match to avoid type getting simplified away
paramNewtype :: CompiledCode (ParamNewtype Integer -> Integer)
paramNewtype = plc (Proxy @"paramNewtype") (\(x ::ParamNewtype Integer) -> case x of { ParamNewtype (Just i) -> i; _ -> 1 })

recursiveTypes :: TestNested
recursiveTypes = testNested "recursive" [
Expand Down
Loading

0 comments on commit bd8def3

Please sign in to comment.