Skip to content

Commit

Permalink
fix: everything
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-adjedj committed Aug 14, 2024
1 parent 29b96a2 commit 4004db4
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 64 deletions.
15 changes: 10 additions & 5 deletions src/Lean/Elab/Deriving/BEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,18 @@ where
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let auxFunName := ctx.auxFunNames[i]!
let nestedOcc := ctx.typeInfos[i]!
let argNames := ctx.typeArgNames[i]!
let argNames := ctx.typeArgNames[i]!
let header ← mkBEqHeader argNames nestedOcc
let binders := header.binders
Term.elabBinders binders fun xs => do
let type ← Term.elabTerm header.targetType none
let body ← mkMatch ctx header type xs
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
let mut body ← mkMatch ctx header type xs
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
body ← mkLet letDecls body
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)

def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
let mut auxDefs := #[]
Expand All @@ -120,7 +125,7 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
end)

private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" declName
let ctx ← mkContext "beq" declName false
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq)
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds
Expand All @@ -130,7 +135,7 @@ private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
`(private def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx)

private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" name
let ctx ← mkContext "beq" name false
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq)
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds
Expand Down
9 changes: 4 additions & 5 deletions src/Lean/Elab/Deriving/DecEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,14 @@ def mkAuxFunction (ctx : Context) (auxFunName : Name) (argNames : Array Name) (n
let binders := header.binders
let target₁ := mkIdent header.targetNames[0]!
let target₂ := mkIdent header.targetNames[1]!
let termSuffix ← if ctx.auxFunNames.size > 1 || nestedOcc.getIndVal.isRec
then `(Parser.Termination.suffix|termination_by structural $target₁)
else `(Parser.Termination.suffix|)
-- let termSuffix ← if ctx.auxFunNames.size > 1 || nestedOcc.getIndVal.isRec
-- then `(Parser.Termination.suffix|termination_by structural $target₁)
-- else `(Parser.Termination.suffix|)
Term.elabBinders binders fun xs => do
let type ← Term.elabTerm header.targetType none
let body ← mkMatch ctx header type xs
let type ← `(Decidable ($target₁ = $target₂))
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term
$termSuffix:suffix)
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term)

def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do
let mut res : Array (TSyntax `command) := #[]
Expand Down
9 changes: 5 additions & 4 deletions src/Lean/Elab/Deriving/FromToJson.lean
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def mkToJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames
body ← mkLet letDecls body
if ctx.usePartial then
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term)
Expand All @@ -206,11 +205,13 @@ def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let binders := header.binders
Term.elabBinders binders fun xs => do
let type ← Term.elabTerm header.targetType none
let mut body mkFromJsonBody ctx header type xs
if ctx.usePartial then
let mut body mkFromJsonBody ctx header type xs
if ctx.usePartial || nestedOcc.getIndVal.isRec then
let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames
body ← mkLet letDecls body
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term)
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term)


def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do
Expand Down
3 changes: 1 addition & 2 deletions src/Lean/Elab/Deriving/Hashable.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let type ← Term.elabTerm header.targetType none
let mut body ← mkMatch ctx header type xs
if ctx.usePartial then
-- TODO(Dany): Get rid of this code branch altogether once we have well-founded recursion
let letDecls ← mkLocalInstanceLetDecls ctx `Hashable header.argNames
body ← mkLet letDecls body
if ctx.usePartial then
-- TODO(Dany): Get rid of this code branch altogether once we have well-founded recursion
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term)
Expand Down
4 changes: 3 additions & 1 deletion src/Lean/Elab/Deriving/Repr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames
body ← mkLet letDecls body
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term)
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term)

def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
let mut auxDefs := #[]
Expand Down
90 changes: 53 additions & 37 deletions src/Lean/Elab/Deriving/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def mkImplicitBinders (argNames : Array Name) : TermElabM (Array (TSyntax ``Pars
argNames.mapM fun argName =>
`(implicitBinderF| { $(mkIdent argName) })

inductive NestedOccurence : Type :=
/--
Represents `ind params1 ... paramsn`, where `paramsi` is either a nested occurence, a constant name, or a free variable
Expand All @@ -56,7 +55,7 @@ inductive NestedOccurence : Type :=
inductive Bar
| foo : A -> B -> Foo A (Option Bar) B Nat → Bar
```,
```
this nested occurence `Foo A (Option Bar) B Nat` is encoded as
```lean
node Foo #[
Expand All @@ -72,15 +71,16 @@ inductive NestedOccurence : Type :=
Remark 2 : Free variables are abstracted away in nested occurences.
This is useful when trying to delete duplicate occurences, since the check is now purely syntactical.
-/
inductive NestedOccurence : Type :=
| node (ind : InductiveVal) (params : Array (NestedOccurence ⊕ Expr))
| leaf (ind : InductiveVal)
| leaf (ind : InductiveVal) (fvars : Array Expr)

namespace NestedOccurence

instance : Inhabited NestedOccurence := ⟨leaf default⟩
instance : Inhabited NestedOccurence := ⟨leaf default #[]
partial instance : BEq NestedOccurence := ⟨go⟩
where go
| leaf ind₁,.leaf ind₂ => ind₁.name == ind₂.name
| leaf ind₁ _,.leaf ind₂ _ => ind₁.name == ind₂.name
| node ind₁ arr₁,.node ind₂ arr₂ => Id.run do
unless ind₁.name == ind₂.name && arr₁.size == arr₂.size do
return false
Expand All @@ -92,34 +92,35 @@ where go

partial instance : ToString NestedOccurence := ⟨go⟩
where go
| leaf ind => s!"leaf {ind.name}"
| node ind arr =>
let s := arr.map (@instToStringSum _ _ ⟨go⟩ ⟨reprStr⟩ |>.toString)
s!"node {ind.name} {s}"
| leaf ind e=> s!"leaf {ind.name} {e}"
| node ind arr =>
let s := arr.map (@instToStringSum _ _ ⟨go⟩ inferInstance |>.toString)
s!"node {ind.name} {s}"

@[inline]
def getIndVal : NestedOccurence → InductiveVal
| leaf indVal | node indVal _ => indVal
| leaf indVal _| node indVal _ => indVal

@[inline]
def getArr : NestedOccurence → Array (NestedOccurence ⊕ Expr)
| leaf _ => #[]
| leaf .. => #[]
| node _ arr => arr

@[inline]
def isLeaf : NestedOccurence → Bool
| leaf _ => true
| leaf ..=> true
| node .. => false

@[inline]
def isNode : NestedOccurence → Bool := not ∘ isLeaf

partial def containsFVar (fvarId : FVarId) : NestedOccurence → Bool
| leaf _ => false
| leaf _ e => e.any (Expr.containsFVar · fvarId)
| node _ arr => arr.any (Sum.lift (containsFVar fvarId) (Expr.containsFVar · fvarId))

partial def toListofNests (e : NestedOccurence) : List NestedOccurence :=
match e with
| .leaf _ => []
| .leaf _ _ => []
| .node _ arr =>
let l := flip arr.foldr [] fun occ l =>
if let .inl occ := occ then
Expand All @@ -133,11 +134,11 @@ partial def mkAppTerm (nestedOcc : NestedOccurence) (argNames : Array Name) : Te
where
go (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM Term := do
match nestedOcc with
| leaf indVal => do
| leaf indVal _ => do
let f := mkCIdent indVal.name
let numArgs := indVal.numParams + indVal.numIndices
-- unless argNames.size >= numArgs do
-- throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}"
unless argNames.size >= numArgs do
throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}"
let mut args := Array.mkArray numArgs default
for i in [:numArgs] do
let arg := mkIdent argNames[i]!
Expand All @@ -162,16 +163,16 @@ where
`(@$f $args*)

/-- Return the inductive declaration's type applied to the arguments in `argNames`. -/
partial def mkAppExpr (nestedOcc : NestedOccurence) (argNames : Subarray Expr) : TermElabM Expr := do
partial def mkAppExpr (nestedOcc : NestedOccurence) (argNames : Array Expr) : TermElabM Expr := do
let res ← go nestedOcc argNames
return res
where
go (nestedOcc : NestedOccurence) (argNames : Array Expr): TermElabM Expr := do
match nestedOcc with
| leaf indVal => do
| leaf indVal _ => do
let numArgs := indVal.numParams + indVal.numIndices
-- unless argNames.size >= numArgs do
-- throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}"
unless argNames.size >= numArgs do
throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}"
let mut args := Array.mkArray numArgs default
for i in [:numArgs] do
let arg := argNames[i]!
Expand All @@ -197,9 +198,12 @@ where

structure Result where
occ : NestedOccurence
args : Array Expr
args : Subarray Expr
argNames : Array Name

instance : ToString Result where
toString res := s!"⟨{res.occ},{res.args},{res.argNames}⟩"

instance : BEq Result := ⟨(·.occ == ·.occ)⟩

structure Context where
Expand All @@ -222,22 +226,27 @@ def add_name (n : Name) : NestedOccM Unit := do
let ⟨names,res⟩ ← get
set (⟨n::names,res⟩ : NestedOccurence.Context)

partial def getNestedOccurencesOf (inds : List Name) (e: Expr) (fvars : Subarray Expr): MetaM (Option NestedOccurence) := do
partial def getNestedOccurencesOf (inds : List Name) (e: Expr) (fvars : Array Expr): MetaM (Option NestedOccurence) := do
let .inl occs ← go e | return none
trace[Elab.Deriving] s!"getNestedOccurencesOf {inds} {e} {fvars} =\n{occs}"
return occs
where
go (e : Expr) : MetaM (NestedOccurence ⊕ Expr) := do
trace[Elab.Deriving] s!"go {inds} {e} {fvars}"
let hd := e.getAppFn
let args := e.getAppArgs
trace[Elab.Deriving] s!"args : {args}"
let fallback _ := return .inr <| e.abstract fvars
let .const name _ := hd | fallback ()
if let some indName := inds.find? (· = name) then
let indVal ← getConstInfoInduct indName
return .inl <| .leaf indVal
let args := args.map (Expr.instantiateRev · fvars)
return .inl <| .leaf indVal args
else
try
let indVal ← getConstInfoInduct name
let args := e.getAppArgs
let args := args.map (·.abstract fvars)
let args := args.map (Expr.abstract · fvars)
trace[Elab.Deriving] s!"abstracted args : {args}"
let nestedOccsArgs ← args.mapM go
if nestedOccsArgs.any Sum.isLeft then
return .inl <| .node indVal nestedOccsArgs
Expand All @@ -248,9 +257,12 @@ partial def getNestedOccurences (indNames : List Name) : TermElabM (List NestedO
let ⟨_,l⟩ ← withIndNames indNames do
for name in indNames do
go name #[] #[]
let l := l.eraseDups
trace[Elab.Deriving] s!"getNestedOccurences {indNames} =\n{l}"
return l.eraseDups
where
go (indName : Name) (args : Array Expr) (fvars : Array Expr): NestedOccM Unit := do
trace[Elab.Deriving] s!"go2 {indNames} {indName} {args} {fvars}"
let indVal ← getConstInfoInduct indName
if !indVal.isNested && args.size == 0 then
return
Expand All @@ -277,18 +289,22 @@ where
let occs ← getNestedOccurencesOf indNames ty xs[:i]
let l' := if let .some x := occs then x.toListofNests else []
for occ in l' do
let relevantLocalArgs := localArgs.filter (occ.containsFVar ⟨·⟩)
let fvars := fvars ++ xs[:i]
let new_args := paramArgs ++ relevantLocalArgs
trace[Elab.Deriving] s!"paramArgs : {paramArgs}"
let new_args := paramArgs ++ localArgs.filter (occ.containsFVar ⟨·⟩)
trace[Elab.Deriving] s!"localArgs : {localArgs}"
trace[Elab.Deriving] s!"occ : {occ}"
-- let new_args := new_args.filter (occ.containsFVar ⟨·⟩)
trace[Elab.Deriving] s!"filtered vars : {new_args}"
if (← get).res.all (occ != ·.occ) then
add_res ⟨occ,xs[:i].toArray,new_args⟩
let app ← occ.mkAppExpr fvars.toSubarray
let hd := app.getAppFn.constName!
let args := app.getAppArgs
add_res ⟨occ,xs[:i],new_args⟩
let fvars := fvars ++ xs[:i]
let app ← occ.mkAppExpr fvars
let hd := app.getAppFn.constName!
let args := app.getAppArgs
add_name hd
go hd args fvars
else
add_res ⟨occ,xs[:i].toArray,new_args⟩
add_res ⟨occ,xs[:i],new_args⟩
l := l ++ l'
localArgs := localArgs.push paramName

Expand All @@ -298,7 +314,7 @@ def indNameToFunName (indName : Name) : String :=
| _ => "instFn"

partial def mkInstName: NestedOccurence → String
| .leaf ind => indNameToFunName ind.name
| .leaf ind _ => indNameToFunName ind.name
| .node ind arr => Id.run do
let mut res ← indNameToFunName ind.name
for nestedOcc in arr do
Expand Down Expand Up @@ -344,14 +360,14 @@ structure Context : Type where
auxFunNames : Array Name
usePartial : Bool

def mkContext (fnPrefix : String) (typeName : Name) (withNested : Bool := true): TermElabM Context := do
def mkContext (fnPrefix : String) (typeName : Name) (withNested : Bool := true): TermElabM Context := do
let indVal ← getConstInfoInduct typeName
let indNames := indVal.all
let mut typeInfos' : List NestedOccurence.Result := []
for indName in indNames do
let indVal ← getConstInfoInduct indName
let args ← mkInductArgNames indVal
typeInfos' := ⟨.leaf indVal,#[],args⟩::typeInfos'
typeInfos' := ⟨.leaf indVal #[],#[].toSubarray,args⟩::typeInfos'
if withNested then
typeInfos' := (← getNestedOccurences indVal.all) ++ typeInfos'
let typeArgNames := typeInfos'.map (·.argNames) |>.toArray
Expand Down
20 changes: 10 additions & 10 deletions tests/lean/3057.lean.expected.out
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
instReprTree
instReprListTree
instDecidableEqTree
instDecidableEqListTree
instBEqTree
instBEqListTree
instHashableTree
instHashableListTree
instOrdTree
instOrdListTree
instReprTree_1
instReprListTree_1
instDecidableEqTree_1
instDecidableEqListTree_1
instBEqTree_1
instBEqListTree_1
instHashableTree_1
instHashableListTree_1
instOrdTree_1
instOrdListTree_1

0 comments on commit 4004db4

Please sign in to comment.