Skip to content

Commit

Permalink
Allow hole filling to deal with recursion (#472)
Browse files Browse the repository at this point in the history
This PR enhances the "attempt to fill hole" code action, allowing it to implement self-recursive functions. The generated code ensures recursion occurs only on structurally-smaller values, and preserves the positional ordering of homomorphically destructed arguments.

It's clever enough to implement foldr and nontrivial functor instances.

Co-authored-by: TOTBWF <reed.mullanix@calabrio.com>
  • Loading branch information
isovector and TOTBWF authored Oct 19, 2020
1 parent cdf50a6 commit 2533574
Show file tree
Hide file tree
Showing 35 changed files with 785 additions and 244 deletions.
2 changes: 1 addition & 1 deletion cabal.project
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ package ghcide

write-ghc-environment-files: never

index-state: 2020-10-08T12:51:21Z
index-state: 2020-10-16T04:00:00Z

allow-newer: data-tree-print:base
5 changes: 4 additions & 1 deletion haskell-language-server.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ executable haskell-language-server
Ide.Plugin.Retrie
Ide.Plugin.StylishHaskell
Ide.Plugin.Tactic
Ide.Plugin.Tactic.Auto
Ide.Plugin.Tactic.CodeGen
Ide.Plugin.Tactic.Context
Ide.Plugin.Tactic.Debug
Ide.Plugin.Tactic.GHC
Ide.Plugin.Tactic.Judgements
Ide.Plugin.Tactic.KnownStrategies
Ide.Plugin.Tactic.Machinery
Ide.Plugin.Tactic.Naming
Ide.Plugin.Tactic.Range
Expand Down Expand Up @@ -156,9 +158,10 @@ executable haskell-language-server
, transformers
, unordered-containers
, ghc-source-gen
, refinery ^>=0.2
, refinery ^>=0.3
, ghc-exactprint
, fingertree
, generic-lens

if flag(agpl)
build-depends: brittany
Expand Down
88 changes: 73 additions & 15 deletions plugins/tactics/src/Ide/Plugin/Tactic.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumDecimals #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
Expand All @@ -20,8 +22,12 @@ import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.Aeson
import Data.Coerce
import Data.Generics.Aliases (mkQ)
import Data.Generics.Schemes (everything)
import Data.List
import qualified Data.Map as M
import Data.Maybe
import Data.Monoid
import qualified Data.Set as S
import qualified Data.Text as T
import Data.Traversable
Expand All @@ -38,6 +44,7 @@ import qualified FastString
import GHC.Generics (Generic)
import GHC.LanguageExtensions.Type (Extension (LambdaCase))
import Ide.Plugin (mkLspCommand)
import Ide.Plugin.Tactic.Auto
import Ide.Plugin.Tactic.Context
import Ide.Plugin.Tactic.GHC
import Ide.Plugin.Tactic.Judgements
Expand All @@ -50,6 +57,8 @@ import Ide.Types
import Language.Haskell.LSP.Core (clientCapabilities)
import Language.Haskell.LSP.Types
import OccName
import SrcLoc (containsSpan)
import System.Timeout


descriptor :: PluginId -> PluginDescriptor
Expand Down Expand Up @@ -250,12 +259,24 @@ judgementForHole state nfp range = do
resulting_range <- liftMaybe $ toCurrentRange amapping $ realSrcSpanToRange rss
(tcmod, _) <- MaybeT $ runIde state $ useWithStale TypeCheck nfp
let tcg = fst $ tm_internals_ $ tmrModule tcmod
tcs = tm_typechecked_source $ tmrModule tcmod
ctx = mkContext
(mapMaybe (sequenceA . (occName *** coerce))
$ getDefiningBindings binds rss)
tcg
hyps = hypothesisFromBindings rss binds
pure (resulting_range, mkFirstJudgement hyps goal, ctx, dflags)
pure ( resulting_range
, mkFirstJudgement
hyps
(isRhsHole rss tcs)
(maybe
mempty
(uncurry M.singleton . fmap pure)
$ getRhsPosVals rss tcs)
goal
, ctx
, dflags
)



Expand All @@ -266,20 +287,26 @@ tacticCmd tac lf state (TacticParams uri range var_name)
(range', jdg, ctx, dflags) <- judgementForHole state nfp range
let span = rangeToRealSrcSpan (fromNormalizedFilePath nfp) range'
pm <- MaybeT $ useAnnotatedSource "tacticsCmd" state nfp
case runTactic ctx jdg
$ tac
$ mkVarOcc
$ T.unpack var_name of
Left err ->
pure $ (, Nothing)
$ Left
$ ResponseError InvalidRequest (T.pack $ show err) Nothing
Right res -> do
let g = graft (RealSrcSpan span) res
response = transform dflags (clientCapabilities lf) uri g pm
pure $ case response of
Right res -> (Right Null , Just (WorkspaceApplyEdit, ApplyWorkspaceEditParams res))
Left err -> (Left $ ResponseError InternalError (T.pack err) Nothing, Nothing)
x <- lift $ timeout 2e8 $
case runTactic ctx jdg
$ tac
$ mkVarOcc
$ T.unpack var_name of
Left err ->
pure $ (, Nothing)
$ Left
$ ResponseError InvalidRequest (T.pack $ show err) Nothing
Right (_, ext) -> do
let g = graft (RealSrcSpan span) ext
response = transform dflags (clientCapabilities lf) uri g pm
pure $ case response of
Right res -> (Right Null , Just (WorkspaceApplyEdit, ApplyWorkspaceEditParams res))
Left err -> (Left $ ResponseError InternalError (T.pack err) Nothing, Nothing)
pure $ case x of
Just y -> y
Nothing -> (, Nothing)
$ Left
$ ResponseError InvalidRequest "timed out" Nothing
tacticCmd _ _ _ _ =
pure ( Left $ ResponseError InvalidRequest (T.pack "Bad URI") Nothing
, Nothing
Expand All @@ -292,3 +319,34 @@ fromMaybeT def = fmap (fromMaybe def) . runMaybeT
liftMaybe :: Monad m => Maybe a -> MaybeT m a
liftMaybe a = MaybeT $ pure a


------------------------------------------------------------------------------
-- | Is this hole immediately to the right of an equals sign?
isRhsHole :: RealSrcSpan -> TypecheckedSource -> Bool
isRhsHole rss tcs = everything (||) (mkQ False $ \case
TopLevelRHS _ _ (L (RealSrcSpan span) _) -> containsSpan rss span
_ -> False
) tcs


------------------------------------------------------------------------------
-- | Compute top-level position vals of a function
getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Maybe (OccName, [OccName])
getRhsPosVals rss tcs = getFirst $ everything (<>) (mkQ mempty $ \case
TopLevelRHS name ps
(L (RealSrcSpan span) -- body with no guards and a single defn
(HsVar _ (L _ hole)))
| containsSpan rss span -- which contains our span
, isHole $ occName hole -- and the span is a hole
-> First $ do
patnames <- traverse getPatName ps
pure (occName name, patnames)
_ -> mempty
) tcs



-- TODO(sandy): Make this more robust
isHole :: OccName -> Bool
isHole = isPrefixOf "_" . occNameString

25 changes: 25 additions & 0 deletions plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Ide.Plugin.Tactic.Auto where

import Ide.Plugin.Tactic.Context
import Ide.Plugin.Tactic.Judgements
import Ide.Plugin.Tactic.KnownStrategies
import Ide.Plugin.Tactic.Tactics
import Ide.Plugin.Tactic.Types
import Refinery.Tactic
import Ide.Plugin.Tactic.Machinery (tracing)


------------------------------------------------------------------------------
-- | Automatically solve a goal.
auto :: TacticsM ()
auto = do
jdg <- goal
current <- getCurrentDefinitions
traceMX "goal" jdg
traceMX "ctx" current
commit knownStrategies
. tracing "auto"
. localTactic (auto' 4)
. disallowing
$ fmap fst current

110 changes: 79 additions & 31 deletions plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
module Ide.Plugin.Tactic.CodeGen where

import Control.Monad.Except
import Data.List
import Data.Traversable
import DataCon
import Development.IDE.GHC.Compat
import GHC.Exts
import GHC.SourceGen.Binds
import GHC.SourceGen.Expr
import GHC.SourceGen.Overloaded
import GHC.SourceGen.Pat
import Ide.Plugin.Tactic.Judgements
import Ide.Plugin.Tactic.Machinery
import Ide.Plugin.Tactic.Naming
import Ide.Plugin.Tactic.Types
import Name
import Type hiding (Var)
import Control.Monad.Except
import Control.Monad.State (MonadState)
import Control.Monad.State.Class (modify)
import Data.List
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Traversable
import DataCon
import Development.IDE.GHC.Compat
import GHC.Exts
import GHC.SourceGen.Binds
import GHC.SourceGen.Expr
import GHC.SourceGen.Overloaded
import GHC.SourceGen.Pat
import Ide.Plugin.Tactic.Judgements
import Ide.Plugin.Tactic.Machinery
import Ide.Plugin.Tactic.Naming
import Ide.Plugin.Tactic.Types
import Name
import Type hiding (Var)


useOccName :: MonadState TacticState m => Judgement -> OccName -> m ()
useOccName jdg name =
case M.lookup name $ jHypothesis jdg of
Just{} -> modify $ withUsedVals $ S.insert name
Nothing -> pure ()


destructMatches
:: (DataCon -> Judgement -> Rule)
-- ^ How to construct each match
-> (Judgement -> Judgement)
-> ([(OccName, CType)] -> Judgement -> Judgement)
-- ^ How to derive each match judgement
-> CType
-- ^ Type being destructed
-> Judgement
-> RuleM [RawMatch]
-> RuleM (Trace, [RawMatch])
destructMatches f f2 t jdg = do
let hy = jHypothesis jdg
g = jGoal jdg
Expand All @@ -37,18 +49,32 @@ destructMatches f f2 t jdg = do
let dcs = tyConDataCons tc
case dcs of
[] -> throwError $ GoalMismatch "destruct" g
_ -> for dcs $ \dc -> do
_ -> fmap unzipTrace $ for dcs $ \dc -> do
let args = dataConInstOrigArgTys' dc apps
names <- mkManyGoodNames hy args
let hy' = zip names $ coerce args
dcon_name = nameOccName $ dataConName dc

let pat :: Pat GhcPs
pat = conP (fromString $ occNameString $ nameOccName $ dataConName dc)
pat = conP (fromString $ occNameString dcon_name)
$ fmap bvar' names
j = f2
$ introducingPat (zip names $ coerce args)
j = f2 hy'
$ withPositionMapping dcon_name names
$ introducingPat hy'
$ withNewGoal g jdg
sg <- f dc j
pure $ match [pat] $ unLoc sg
(tr, sg) <- f dc j
modify $ withIntroducedVals $ mappend $ S.fromList names
pure ( rose ("match " <> show dc <> " {" <>
intercalate ", " (fmap show names) <> "}")
$ pure tr
, match [pat] $ unLoc sg
)


unzipTrace :: [(Trace, a)] -> (Trace, [a])
unzipTrace l =
let (trs, as) = unzip l
in (rose mempty trs, as)


-- | Essentially same as 'dataConInstOrigArgTys' in GHC,
Expand All @@ -66,24 +92,34 @@ dataConInstOrigArgTys' con ty =

destruct' :: (DataCon -> Judgement -> Rule) -> OccName -> Judgement -> Rule
destruct' f term jdg = do
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
let hy = jHypothesis jdg
case find ((== term) . fst) $ toList hy of
Nothing -> throwError $ UndefinedHypothesis term
Just (_, t) ->
fmap noLoc $ case' (var' term) <$>
destructMatches f (destructing term) t jdg
Just (_, t) -> do
useOccName jdg term
(tr, ms)
<- destructMatches
f
(\cs -> setParents term (fmap fst cs) . destructing term)
t
jdg
pure ( rose ("destruct " <> show term) $ pure tr
, noLoc $ case' (var' term) ms
)


------------------------------------------------------------------------------
-- | Combinator for performign case splitting, and running sub-rules on the
-- resulting matches.
destructLambdaCase' :: (DataCon -> Judgement -> Rule) -> Judgement -> Rule
destructLambdaCase' f jdg = do
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
let g = jGoal jdg
case splitFunTy_maybe (unCType g) of
Just (arg, _) | isAlgType arg ->
fmap noLoc $ lambdaCase <$>
destructMatches f id (CType arg) jdg
fmap (fmap noLoc $ lambdaCase) <$>
destructMatches f (const id) (CType arg) jdg
_ -> throwError $ GoalMismatch "destructLambdaCase'" g


Expand All @@ -93,11 +129,21 @@ buildDataCon
:: Judgement
-> DataCon -- ^ The data con to build
-> [Type] -- ^ Type arguments for the data con
-> RuleM (LHsExpr GhcPs)
-> RuleM (Trace, LHsExpr GhcPs)
buildDataCon jdg dc apps = do
let args = dataConInstOrigArgTys' dc apps
sgs <- traverse (newSubgoal . flip withNewGoal jdg . CType) args
dcon_name = nameOccName $ dataConName dc
(tr, sgs)
<- fmap unzipTrace
$ traverse ( \(arg, n) ->
newSubgoal
. filterSameTypeFromOtherPositions dcon_name n
. blacklistingDestruct
. flip withNewGoal jdg
$ CType arg
) $ zip args [0..]
pure
. (rose (show dc) $ pure tr,)
. noLoc
. foldl' (@@)
(HsVar noExtField $ noLoc $ Unqual $ nameOccName $ dataConName dc)
Expand All @@ -109,7 +155,9 @@ buildDataCon jdg dc apps = do
var' :: Var a => OccName -> a
var' = var . fromString . occNameString


------------------------------------------------------------------------------
-- | Like 'bvar', but works over standard GHC 'OccName's.
bvar' :: BVar a => OccName -> a
bvar' = bvar . fromString . occNameString

4 changes: 2 additions & 2 deletions plugins/tactics/src/Ide/Plugin/Tactic/Context.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ getFunBindId (AbsBinds _ _ _ abes _ _ _)
getFunBindId _ = []


getCurrentDefinitions :: MonadReader Context m => m [OccName]
getCurrentDefinitions = asks $ fmap fst . ctxDefiningFuncs
getCurrentDefinitions :: MonadReader Context m => m [(OccName, CType)]
getCurrentDefinitions = asks $ ctxDefiningFuncs

getModuleHypothesis :: MonadReader Context m => m [(OccName, CType)]
getModuleHypothesis = asks ctxModuleFuncs
Expand Down
Loading

0 comments on commit 2533574

Please sign in to comment.