From 4004db402f420c2c5b7b2e9dca72735b85e46441 Mon Sep 17 00:00:00 2001 From: arthur-adjedj Date: Wed, 14 Aug 2024 15:07:07 +0200 Subject: [PATCH] fix: everything --- src/Lean/Elab/Deriving/BEq.lean | 15 +++-- src/Lean/Elab/Deriving/DecEq.lean | 9 ++- src/Lean/Elab/Deriving/FromToJson.lean | 9 +-- src/Lean/Elab/Deriving/Hashable.lean | 3 +- src/Lean/Elab/Deriving/Repr.lean | 4 +- src/Lean/Elab/Deriving/Util.lean | 90 +++++++++++++++----------- tests/lean/3057.lean.expected.out | 20 +++--- 7 files changed, 86 insertions(+), 64 deletions(-) diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean index 5ec72fc55fe7..f064ab8f2713 100644 --- a/src/Lean/Elab/Deriving/BEq.lean +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -102,13 +102,18 @@ where def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let auxFunName := ctx.auxFunNames[i]! let nestedOcc := ctx.typeInfos[i]! - let argNames := ctx.typeArgNames[i]! + let argNames := ctx.typeArgNames[i]! let header ← mkBEqHeader argNames nestedOcc let binders := header.binders Term.elabBinders binders fun xs => do let type ← Term.elabTerm header.targetType none - let body ← mkMatch ctx header type xs - `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term) + let mut body ← mkMatch ctx header type xs + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames + body ← mkLet letDecls body + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term) def mkMutualBlock (ctx : Context) : TermElabM Syntax := do let mut auxDefs := #[] @@ -120,7 +125,7 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do end) private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "beq" declName + let ctx ← mkContext "beq" declName false let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq) trace[Elab.Deriving.beq] "\n{cmds}" return cmds @@ -130,7 +135,7 @@ private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do `(private def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx) private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do - let ctx ← mkContext "beq" name + let ctx ← mkContext "beq" name false let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq) trace[Elab.Deriving.beq] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/DecEq.lean b/src/Lean/Elab/Deriving/DecEq.lean index 6cfd25d3b389..955fe2837cd9 100644 --- a/src/Lean/Elab/Deriving/DecEq.lean +++ b/src/Lean/Elab/Deriving/DecEq.lean @@ -98,15 +98,14 @@ def mkAuxFunction (ctx : Context) (auxFunName : Name) (argNames : Array Name) (n let binders := header.binders let target₁ := mkIdent header.targetNames[0]! let target₂ := mkIdent header.targetNames[1]! - let termSuffix ← if ctx.auxFunNames.size > 1 || nestedOcc.getIndVal.isRec - then `(Parser.Termination.suffix|termination_by structural $target₁) - else `(Parser.Termination.suffix|) + -- let termSuffix ← if ctx.auxFunNames.size > 1 || nestedOcc.getIndVal.isRec + -- then `(Parser.Termination.suffix|termination_by structural $target₁) + -- else `(Parser.Termination.suffix|) Term.elabBinders binders fun xs => do let type ← Term.elabTerm header.targetType none let body ← mkMatch ctx header type xs let type ← `(Decidable ($target₁ = $target₂)) - `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term - $termSuffix:suffix) + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term) def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do let mut res : Array (TSyntax `command) := #[] diff --git a/src/Lean/Elab/Deriving/FromToJson.lean b/src/Lean/Elab/Deriving/FromToJson.lean index 37dc220ff0f2..6e517a4ce47c 100644 --- a/src/Lean/Elab/Deriving/FromToJson.lean +++ b/src/Lean/Elab/Deriving/FromToJson.lean @@ -187,7 +187,6 @@ def mkToJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames body ← mkLet letDecls body - if ctx.usePartial then `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) @@ -206,11 +205,13 @@ def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let binders := header.binders Term.elabBinders binders fun xs => do let type ← Term.elabTerm header.targetType none - let mut body ← mkFromJsonBody ctx header type xs - if ctx.usePartial then + let mut body ← mkFromJsonBody ctx header type xs + if ctx.usePartial || nestedOcc.getIndVal.isRec then let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames body ← mkLet letDecls body - `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term) + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← ctx.typeInfos[i]!.mkAppTerm header.argNames) := $body:term) def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do diff --git a/src/Lean/Elab/Deriving/Hashable.lean b/src/Lean/Elab/Deriving/Hashable.lean index 0d014226424a..f2b4f25f1841 100644 --- a/src/Lean/Elab/Deriving/Hashable.lean +++ b/src/Lean/Elab/Deriving/Hashable.lean @@ -69,10 +69,9 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let type ← Term.elabTerm header.targetType none let mut body ← mkMatch ctx header type xs if ctx.usePartial then + -- TODO(Dany): Get rid of this code branch altogether once we have well-founded recursion let letDecls ← mkLocalInstanceLetDecls ctx `Hashable header.argNames body ← mkLet letDecls body - if ctx.usePartial then - -- TODO(Dany): Get rid of this code branch altogether once we have well-founded recursion `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term) else `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : UInt64 := $body:term) diff --git a/src/Lean/Elab/Deriving/Repr.lean b/src/Lean/Elab/Deriving/Repr.lean index 126aacbdd74b..269603efa572 100644 --- a/src/Lean/Elab/Deriving/Repr.lean +++ b/src/Lean/Elab/Deriving/Repr.lean @@ -111,7 +111,9 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames body ← mkLet letDecls body - `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term) + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Format := $body:term) def mkMutualBlock (ctx : Context) : TermElabM Syntax := do let mut auxDefs := #[] diff --git a/src/Lean/Elab/Deriving/Util.lean b/src/Lean/Elab/Deriving/Util.lean index d71b9ec0b1ff..bfce5b201907 100644 --- a/src/Lean/Elab/Deriving/Util.lean +++ b/src/Lean/Elab/Deriving/Util.lean @@ -45,7 +45,6 @@ def mkImplicitBinders (argNames : Array Name) : TermElabM (Array (TSyntax ``Pars argNames.mapM fun argName => `(implicitBinderF| { $(mkIdent argName) }) -inductive NestedOccurence : Type := /-- Represents `ind params1 ... paramsn`, where `paramsi` is either a nested occurence, a constant name, or a free variable @@ -56,7 +55,7 @@ inductive NestedOccurence : Type := inductive Bar | foo : A -> B -> Foo A (Option Bar) B Nat → Bar - ```, + ``` this nested occurence `Foo A (Option Bar) B Nat` is encoded as ```lean node Foo #[ @@ -72,15 +71,16 @@ inductive NestedOccurence : Type := Remark 2 : Free variables are abstracted away in nested occurences. This is useful when trying to delete duplicate occurences, since the check is now purely syntactical. -/ +inductive NestedOccurence : Type := | node (ind : InductiveVal) (params : Array (NestedOccurence ⊕ Expr)) - | leaf (ind : InductiveVal) + | leaf (ind : InductiveVal) (fvars : Array Expr) namespace NestedOccurence -instance : Inhabited NestedOccurence := ⟨leaf default⟩ +instance : Inhabited NestedOccurence := ⟨leaf default #[]⟩ partial instance : BEq NestedOccurence := ⟨go⟩ where go - | leaf ind₁,.leaf ind₂ => ind₁.name == ind₂.name + | leaf ind₁ _,.leaf ind₂ _ => ind₁.name == ind₂.name | node ind₁ arr₁,.node ind₂ arr₂ => Id.run do unless ind₁.name == ind₂.name && arr₁.size == arr₂.size do return false @@ -92,34 +92,35 @@ where go partial instance : ToString NestedOccurence := ⟨go⟩ where go - | leaf ind => s!"leaf {ind.name}" - | node ind arr => - let s := arr.map (@instToStringSum _ _ ⟨go⟩ ⟨reprStr⟩ |>.toString) - s!"node {ind.name} {s}" + | leaf ind e=> s!"leaf {ind.name} {e}" + | node ind arr => + let s := arr.map (@instToStringSum _ _ ⟨go⟩ inferInstance |>.toString) + s!"node {ind.name} {s}" @[inline] def getIndVal : NestedOccurence → InductiveVal - | leaf indVal | node indVal _ => indVal + | leaf indVal _| node indVal _ => indVal +@[inline] def getArr : NestedOccurence → Array (NestedOccurence ⊕ Expr) - | leaf _ => #[] + | leaf .. => #[] | node _ arr => arr @[inline] def isLeaf : NestedOccurence → Bool - | leaf _ => true + | leaf ..=> true | node .. => false @[inline] def isNode : NestedOccurence → Bool := not ∘ isLeaf partial def containsFVar (fvarId : FVarId) : NestedOccurence → Bool - | leaf _ => false + | leaf _ e => e.any (Expr.containsFVar · fvarId) | node _ arr => arr.any (Sum.lift (containsFVar fvarId) (Expr.containsFVar · fvarId)) partial def toListofNests (e : NestedOccurence) : List NestedOccurence := match e with - | .leaf _ => [] + | .leaf _ _ => [] | .node _ arr => let l := flip arr.foldr [] fun occ l => if let .inl occ := occ then @@ -133,11 +134,11 @@ partial def mkAppTerm (nestedOcc : NestedOccurence) (argNames : Array Name) : Te where go (nestedOcc : NestedOccurence) (argNames : Array Name) : TermElabM Term := do match nestedOcc with - | leaf indVal => do + | leaf indVal _ => do let f := mkCIdent indVal.name let numArgs := indVal.numParams + indVal.numIndices - -- unless argNames.size >= numArgs do - -- throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" + unless argNames.size >= numArgs do + throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" let mut args := Array.mkArray numArgs default for i in [:numArgs] do let arg := mkIdent argNames[i]! @@ -162,16 +163,16 @@ where `(@$f $args*) /-- Return the inductive declaration's type applied to the arguments in `argNames`. -/ -partial def mkAppExpr (nestedOcc : NestedOccurence) (argNames : Subarray Expr) : TermElabM Expr := do +partial def mkAppExpr (nestedOcc : NestedOccurence) (argNames : Array Expr) : TermElabM Expr := do let res ← go nestedOcc argNames return res where go (nestedOcc : NestedOccurence) (argNames : Array Expr): TermElabM Expr := do match nestedOcc with - | leaf indVal => do + | leaf indVal _ => do let numArgs := indVal.numParams + indVal.numIndices - -- unless argNames.size >= numArgs do - -- throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" + unless argNames.size >= numArgs do + throwError s!"Expected {numArgs} arguments for {indVal.name}, got {argNames}" let mut args := Array.mkArray numArgs default for i in [:numArgs] do let arg := argNames[i]! @@ -197,9 +198,12 @@ where structure Result where occ : NestedOccurence - args : Array Expr + args : Subarray Expr argNames : Array Name +instance : ToString Result where + toString res := s!"⟨{res.occ},{res.args},{res.argNames}⟩" + instance : BEq Result := ⟨(·.occ == ·.occ)⟩ structure Context where @@ -222,22 +226,27 @@ def add_name (n : Name) : NestedOccM Unit := do let ⟨names,res⟩ ← get set (⟨n::names,res⟩ : NestedOccurence.Context) -partial def getNestedOccurencesOf (inds : List Name) (e: Expr) (fvars : Subarray Expr): MetaM (Option NestedOccurence) := do +partial def getNestedOccurencesOf (inds : List Name) (e: Expr) (fvars : Array Expr): MetaM (Option NestedOccurence) := do let .inl occs ← go e | return none + trace[Elab.Deriving] s!"getNestedOccurencesOf {inds} {e} {fvars} =\n{occs}" return occs where go (e : Expr) : MetaM (NestedOccurence ⊕ Expr) := do + trace[Elab.Deriving] s!"go {inds} {e} {fvars}" let hd := e.getAppFn + let args := e.getAppArgs + trace[Elab.Deriving] s!"args : {args}" let fallback _ := return .inr <| e.abstract fvars let .const name _ := hd | fallback () if let some indName := inds.find? (· = name) then let indVal ← getConstInfoInduct indName - return .inl <| .leaf indVal + let args := args.map (Expr.instantiateRev · fvars) + return .inl <| .leaf indVal args else try let indVal ← getConstInfoInduct name - let args := e.getAppArgs - let args := args.map (·.abstract fvars) + let args := args.map (Expr.abstract · fvars) + trace[Elab.Deriving] s!"abstracted args : {args}" let nestedOccsArgs ← args.mapM go if nestedOccsArgs.any Sum.isLeft then return .inl <| .node indVal nestedOccsArgs @@ -248,9 +257,12 @@ partial def getNestedOccurences (indNames : List Name) : TermElabM (List NestedO let ⟨_,l⟩ ← withIndNames indNames do for name in indNames do go name #[] #[] + let l := l.eraseDups + trace[Elab.Deriving] s!"getNestedOccurences {indNames} =\n{l}" return l.eraseDups where go (indName : Name) (args : Array Expr) (fvars : Array Expr): NestedOccM Unit := do + trace[Elab.Deriving] s!"go2 {indNames} {indName} {args} {fvars}" let indVal ← getConstInfoInduct indName if !indVal.isNested && args.size == 0 then return @@ -277,18 +289,22 @@ where let occs ← getNestedOccurencesOf indNames ty xs[:i] let l' := if let .some x := occs then x.toListofNests else [] for occ in l' do - let relevantLocalArgs := localArgs.filter (occ.containsFVar ⟨·⟩) - let fvars := fvars ++ xs[:i] - let new_args := paramArgs ++ relevantLocalArgs + trace[Elab.Deriving] s!"paramArgs : {paramArgs}" + let new_args := paramArgs ++ localArgs.filter (occ.containsFVar ⟨·⟩) + trace[Elab.Deriving] s!"localArgs : {localArgs}" + trace[Elab.Deriving] s!"occ : {occ}" + -- let new_args := new_args.filter (occ.containsFVar ⟨·⟩) + trace[Elab.Deriving] s!"filtered vars : {new_args}" if (← get).res.all (occ != ·.occ) then - add_res ⟨occ,xs[:i].toArray,new_args⟩ - let app ← occ.mkAppExpr fvars.toSubarray - let hd := app.getAppFn.constName! - let args := app.getAppArgs + add_res ⟨occ,xs[:i],new_args⟩ + let fvars := fvars ++ xs[:i] + let app ← occ.mkAppExpr fvars + let hd := app.getAppFn.constName! + let args := app.getAppArgs add_name hd go hd args fvars else - add_res ⟨occ,xs[:i].toArray,new_args⟩ + add_res ⟨occ,xs[:i],new_args⟩ l := l ++ l' localArgs := localArgs.push paramName @@ -298,7 +314,7 @@ def indNameToFunName (indName : Name) : String := | _ => "instFn" partial def mkInstName: NestedOccurence → String - | .leaf ind => indNameToFunName ind.name + | .leaf ind _ => indNameToFunName ind.name | .node ind arr => Id.run do let mut res ← indNameToFunName ind.name for nestedOcc in arr do @@ -344,14 +360,14 @@ structure Context : Type where auxFunNames : Array Name usePartial : Bool -def mkContext (fnPrefix : String) (typeName : Name) (withNested : Bool := true): TermElabM Context := do +def mkContext (fnPrefix : String) (typeName : Name) (withNested : Bool := true): TermElabM Context := do let indVal ← getConstInfoInduct typeName let indNames := indVal.all let mut typeInfos' : List NestedOccurence.Result := [] for indName in indNames do let indVal ← getConstInfoInduct indName let args ← mkInductArgNames indVal - typeInfos' := ⟨.leaf indVal,#[],args⟩::typeInfos' + typeInfos' := ⟨.leaf indVal #[],#[].toSubarray,args⟩::typeInfos' if withNested then typeInfos' := (← getNestedOccurences indVal.all) ++ typeInfos' let typeArgNames := typeInfos'.map (·.argNames) |>.toArray diff --git a/tests/lean/3057.lean.expected.out b/tests/lean/3057.lean.expected.out index 34efeb5f83b6..dcb3e5ddfc90 100644 --- a/tests/lean/3057.lean.expected.out +++ b/tests/lean/3057.lean.expected.out @@ -1,10 +1,10 @@ -instReprTree -instReprListTree -instDecidableEqTree -instDecidableEqListTree -instBEqTree -instBEqListTree -instHashableTree -instHashableListTree -instOrdTree -instOrdListTree +instReprTree_1 +instReprListTree_1 +instDecidableEqTree_1 +instDecidableEqListTree_1 +instBEqTree_1 +instBEqListTree_1 +instHashableTree_1 +instHashableListTree_1 +instOrdTree_1 +instOrdListTree_1