Skip to content

Commit

Permalink
Better operator overloading inference
Browse files Browse the repository at this point in the history
This commit introduces a weak form a bi-directional typing, and
does a two-pass typing of overloading operators arguments.
  • Loading branch information
strub committed Jan 6, 2025
1 parent 13acf05 commit 848bf68
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/ecHiInductive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ let trans_matchfix
let filter = fun _ op -> EcDecl.is_ctor op in
let PPApp ((cname, tvi), cargs) = pb.pop_pattern in
let tvi = tvi |> omap (TT.transtvi env ue) in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in

match cts with
| [] ->
Expand Down
6 changes: 3 additions & 3 deletions src/ecPrinting.ml
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ let pp_opapp
(es : 'a list))
=
let (nm, opname) =
PPEnv.op_symb ppe op (Some (pred, tvi, List.map t_ty es)) in
PPEnv.op_symb ppe op (Some (pred, tvi, (List.map t_ty es, None))) in

let inm = if nm = [] then fst outer else nm in

Expand Down Expand Up @@ -1250,7 +1250,7 @@ let pp_chained_orderings (ppe : PPEnv.t) t_ty pp_sub outer fmt (f, fs) =
ignore (List.fold_left
(fun fe (op, tvi, f) ->
let (nm, opname) =
PPEnv.op_symb ppe op (Some (`Form, tvi, [t_ty fe; t_ty f]))
PPEnv.op_symb ppe op (Some (`Form, tvi, ([t_ty fe; t_ty f], None)))
in
Format.fprintf fmt " %t@ %a"
(fun fmt ->
Expand Down Expand Up @@ -1343,7 +1343,7 @@ let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form)
else l_l f2 onm e_bin_prio_rop4
| Fapp ({f_node = Fop (op, tys)}, [f1; f2]) ->
(let (inm, opname) =
PPEnv.op_symb ppe op (Some (`Form, tys, List.map t_ty [f1; f2])) in
PPEnv.op_symb ppe op (Some (`Form, tys, (List.map t_ty [f1; f2], None))) in
if inm <> [] && inm <> onm
then None
else match priority_of_binop opname with
Expand Down
2 changes: 1 addition & 1 deletion src/ecScope.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ module Ty = struct
let tvi = List.map (TT.transty tp_tydecl env ue) tvi in
let selected =
EcUnify.select_op ~filter:(fun _ -> EcDecl.is_oper)
(Some (EcUnify.TVIunamed tvi)) env (unloc op) ue []
(Some (EcUnify.TVIunamed tvi)) env (unloc op) ue ([], None)
in
let op =
match selected with
Expand Down
115 changes: 84 additions & 31 deletions src/ecTyping.ml
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,15 @@ let select_local env (qs,s) =
else None

(* -------------------------------------------------------------------- *)
let select_pv env side name ue tvi psig =
let select_pv env side name ue tvi (psig, retty) =
if tvi <> None
then []
else
try
let pvs = EcEnv.Var.lookup_progvar ?side name env in
let select (pv,ty) =
let subue = UE.copy ue in
let texpected = EcUnify.tfun_expected subue psig in
let texpected = EcUnify.tfun_expected subue ?retty psig in
try
EcUnify.unify env subue ty texpected;
[(pv, ty, subue)]
Expand Down Expand Up @@ -346,7 +346,7 @@ let gen_select_op
(env : EcEnv.env)
(name : EcSymbols.qsymbol)
(ue : EcUnify.unienv)
(psig : EcTypes.dom)
(psig : EcTypes.dom * EcTypes.ty option)

: OpSelect.gopsel list
=
Expand Down Expand Up @@ -432,7 +432,7 @@ let select_form_op env mode ~forcepv opsc name ue tvi psig =
(* -------------------------------------------------------------------- *)
let select_proj env opsc name ue tvi recty =
let filter = (fun _ op -> EcDecl.is_proj op) in
let ops = EcUnify.select_op ~filter tvi env name ue [recty] in
let ops = EcUnify.select_op ~filter tvi env name ue ([recty], None) in
let ops = List.map (fun (p, ty, ue, _) -> (p, ty, ue)) ops in

match ops, opsc with
Expand Down Expand Up @@ -1060,7 +1060,7 @@ let transpattern1 env ue (p : EcParsetree.plpattern) =
let fields =
let for1 (name, v) =
let filter = fun _ op -> EcDecl.is_proj op in
let fds = EcUnify.select_op ~filter None env (unloc name) ue [] in
let fds = EcUnify.select_op ~filter None env (unloc name) ue ([], None) in
match List.ohead fds with
| None ->
let exn = UnknownRecFieldName (unloc name) in
Expand Down Expand Up @@ -1200,7 +1200,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) =
let for1 rf =
let filter = fun _ op -> EcDecl.is_proj op in
let tvi = rf.rf_tvi |> omap (transtvi env ue) in
let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue [] in
let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue ([], None) in
match List.ohead fds with
| None ->
let exn = UnknownRecFieldName (unloc rf.rf_name) in
Expand Down Expand Up @@ -1289,7 +1289,7 @@ let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) =
let filter = fun _ op -> EcDecl.is_ctor op in
let PPApp ((cname, tvi), cargs) = pb in
let tvi = tvi |> omap (transtvi env ue) in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in

match cts with
| [] ->
Expand Down Expand Up @@ -2512,7 +2512,7 @@ and translvalue ue (env : EcEnv.env) lvalue =
let e, ety = e_tuple e, ttuple ety in
let name = ([], EcCoreLib.s_set) in
let esig = [xty; ety; codomty] in
let ops = select_exp_op env `InProc None name ue tvi esig in
let ops = select_exp_op env `InProc None name ue tvi (esig, None) in

match ops with
| [] ->
Expand Down Expand Up @@ -2581,8 +2581,9 @@ and trans_gbinding env ue decl =
and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
let state = PFS.create () in

let rec transf_r opsc env f =
let transf = transf_r opsc in
let rec transf_r_tyinfo opsc env ?tt f =
let transf env ?tt f =
transf_r opsc env ?tt f in

match f.pl_desc with
| PFhole -> begin
Expand Down Expand Up @@ -2814,20 +2815,18 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
| PFdecimal (n, f) ->
f_decimal (n, f)

| PFtuple args -> begin
let args = List.map (transf env) args in
match args with
| [] -> f_tt
| [f] -> f
| fs -> f_tuple fs
end
| PFtuple pes ->
let esig = List.map (fun _ -> EcUnify.UniEnv.fresh ue) pes in
tt |> oiter (fun tt -> unify_or_fail env ue f.pl_loc ~expct:tt (ttuple esig));
let es = List.map2 (fun tt pe -> transf env ~tt pe) esig pes in
f_tuple es

| PFident ({ pl_desc = name; pl_loc = loc }, tvi) ->
let tvi = tvi |> omap (transtvi env ue) in
let ops =
select_form_op
~forcepv:(PFS.isforced state)
env mode opsc name ue tvi [] in
env mode opsc name ue tvi ([], tt) in
begin match ops with
| [] ->
tyerror loc env (UnknownVarOrOp (name, []))
Expand Down Expand Up @@ -2962,13 +2961,43 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
check_mem f.pl_loc EcFol.mright;
EcFol.f_ands (List.map (do1 (EcFol.mleft, EcFol.mright)) fs)

| PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) ->
| PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) -> begin
let try_trans ?tt pe =
let ue' = EcUnify.UniEnv.copy ue in
let ps' = Option.map (fun ps -> ref !ps) ps in
match transf env ?tt pe with
| e -> Some e
| exception TyError (_, _, MultipleOpMatch _) ->
Option.iter (fun ps -> ps := !(Option.get ps')) ps;
EcUnify.UniEnv.restore ~dst:ue ~src:ue';
None
in

match
let ue' = EcUnify.UniEnv.copy ue in
let ps' = Option.map (fun ps -> ref !ps) ps in
let es = List.map (fun pe -> try_trans pe) pes in
let tvi = tvi |> omap (transtvi env ue) in
let esig = List.map (fun e ->
match e with Some e -> e.f_ty | None -> EcUnify.UniEnv.fresh ue
) es in
match
select_form_op ~forcepv:(PFS.isforced state)
env mode opsc name ue tvi (esig, tt)
with
| [sel] -> Some (sel, (es, esig, tvi))
| _ ->
Option.iter (fun ps -> ps := !(Option.get ps')) ps;
EcUnify.UniEnv.restore ~dst:ue ~src:ue';
None
with
| None -> begin
let tvi = tvi |> omap (transtvi env ue) in
let es = List.map (transf env) pes in
let esig = List.map EcFol.f_ty es in
let ops =
select_form_op ~forcepv:(PFS.isforced state)
env mode opsc name ue tvi esig in
env mode opsc name ue tvi (esig, tt) in

begin match ops with
| [] ->
Expand All @@ -2986,6 +3015,24 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
let matches = List.map (fun (_, _, subue, m) -> (m, subue)) ops in
tyerror loc env (MultipleOpMatch (name, esig, matches))
end
end

| Some ((_, _, subue, _) as sel, (es, esig, _tvi)) ->
EcUnify.UniEnv.restore ~dst:ue ~src:subue;
let es =
List.map2 (
fun (e, ty) pe ->
match e with None -> try_trans ~tt:ty pe | Some e -> Some e
) (List.combine es esig) pes in
let es =
List.map2 (
fun (e, ty) pe ->
match e with None -> transf env ~tt:ty pe | Some e -> e
) (List.combine es esig) pes in
let es = List.map2 (fun e l -> mk_loc l.pl_loc e) es pes in
EcUnify.UniEnv.restore ~src:ue ~dst:subue;
form_of_opselect (env, ue) loc sel es
end

| PFapp (e, pes) ->
let es = List.map (transf env) pes in
Expand Down Expand Up @@ -3041,25 +3088,30 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
let f1 = transf env pf1 in
unify_or_fail env ue pf1.pl_loc ~expct:pty f1.f_ty;
aty |> oiter (fun aty-> unify_or_fail env ue pf1.pl_loc ~expct:pty aty);
let f2 = transf penv f2 in
let f2 = transf penv ?tt f2 in
f_let p f1 f2

| PFforall (xs, pf) ->
let env, xs = trans_gbinding env ue xs in
let f = transf env pf in
unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty;
f_forall xs f
unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty;
f_forall xs f

| PFexists (xs, f1) ->
let env, xs = trans_gbinding env ue xs in
let f = transf env f1 in
unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty;
f_exists xs f
unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty;
f_exists xs f

| PFlambda (xs, f1) ->
let env, xs = trans_binding env ue xs in
let f = transf env f1 in
f_lambda (List.map (fun (x,ty) -> (x,GTty ty)) xs) f
let subtt = tt |> Option.map (fun tt ->
let codom = EcUnify.UniEnv.fresh ue in
unify_or_fail env ue (loc f) ~expct:(toarrow (List.snd xs) codom) tt;
codom
) in
let f = transf env ?tt:subtt f1 in
f_lambda (List.map (fun (x, ty) -> (x, GTty ty)) xs) f

| PFrecord (b, fields) ->
let (ctor, fields, (rtvi, reccty)) =
Expand Down Expand Up @@ -3190,11 +3242,12 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
unify_or_fail qenv ue post.pl_loc ~expct:tbool post'.f_ty;
f_eagerF pre' s1 fpath1 fpath2 s2 post'

in
and transf_r opsc env ?tt pf =
let f = transf_r_tyinfo opsc env ?tt pf in
let () = oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty) tt in
f

let f = transf_r None env pf in
tt |> oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty);
f
in transf_r None env ?tt pf

(* Type-check a memtype. *)
and trans_memtype env ue (pmemtype : pmemtype) : memtype =
Expand Down
10 changes: 5 additions & 5 deletions src/ecUnify.ml
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,15 @@ let hastc env ue ty tc =
ue := { !ue with ue_uf = uf; }

(* -------------------------------------------------------------------- *)
let tfun_expected ue psig =
let tres = UniEnv.fresh ue in
EcTypes.toarrow psig tres
let tfun_expected ue ?retty psig =
let retty = ofdfl (fun () -> UniEnv.fresh ue) retty in
EcTypes.toarrow psig retty

(* -------------------------------------------------------------------- *)
type sbody = ((EcIdent.t * ty) list * expr) Lazy.t

(* -------------------------------------------------------------------- *)
let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig =
let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue (psig, retty) =
ignore hidden; (* FIXME *)
let module D = EcDecl in
Expand Down Expand Up @@ -457,7 +457,7 @@ let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig
let (tip, tvs) = UniEnv.openty_r subue op.D.op_tparams tvi in
let top = ty_subst tip op.D.op_ty in
let texpected = tfun_expected subue psig in
let texpected = tfun_expected subue ?retty psig in
(try unify env subue top texpected
with UnificationFailure _ -> raise E.Failure);
Expand Down
4 changes: 2 additions & 2 deletions src/ecUnify.mli
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
val unify : EcEnv.env -> unienv -> ty -> ty -> unit
val hastc : EcEnv.env -> unienv -> ty -> Sp.t -> unit

val tfun_expected : unienv -> EcTypes.ty list -> EcTypes.ty
val tfun_expected : unienv -> ?retty:ty -> EcTypes.ty list -> EcTypes.ty

type sbody = ((EcIdent.t * ty) list * expr) Lazy.t

Expand All @@ -48,5 +48,5 @@ val select_op :
-> EcEnv.env
-> qsymbol
-> unienv
-> dom
-> dom * ty option
-> ((EcPath.path * ty list) * ty * unienv * sbody option) list
24 changes: 24 additions & 0 deletions tests/overloading.ec
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
require import AllCore List.

theory T.
op o : int.
op a : int -> int -> int.
end T.

theory U.
op o : bool.
op a : bool -> bool -> bool.
end U.

import T U.

op foo : int -> unit.

op bar = foo o.

op plop1 = foldr a false [].

op plop2 = foldr (fun x => a x) false [].

op plop3 = foldr (fun x y => a x y) false [].

0 comments on commit 848bf68

Please sign in to comment.