Skip to content

Commit

Permalink
more strictness for strict tanon_identification.unify
Browse files Browse the repository at this point in the history
  • Loading branch information
kLabz committed Dec 29, 2023
1 parent 21d1a99 commit 0c2b618
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 66 deletions.
111 changes: 56 additions & 55 deletions src/core/tUnification.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ type eq_kind =
| EqDoNotFollowNull (* like EqStrict, but does not follow Null<T> *)

type unification_context = {
allow_transitive_cast : bool;
allow_abstract_cast : bool; (* allows a non-transitive abstract cast (from,to,@:from,@:to) *)
allow_dynamic_to_cast : bool; (* allows a cast from dynamic to non-dynamic *)
equality_kind : eq_kind;
equality_underlying : bool;
allow_transitive_cast : bool;
allow_abstract_cast : bool; (* allows a non-transitive abstract cast (from,to,@:from,@:to) *)
allow_dynamic_to_cast : bool; (* allows a cast from dynamic to non-dynamic *)
allow_arg_name_mismatch : bool;
equality_kind : eq_kind;
equality_underlying : bool;
strict_field_kind : bool;
}

type unify_min_result =
Expand All @@ -53,12 +55,50 @@ let check_constraint name f =
let unify_ref : (unification_context -> t -> t -> unit) ref = ref (fun _ _ _ -> ())
let unify_min_ref : (unification_context -> t -> t list -> unify_min_result) ref = ref (fun _ _ _ -> assert false)

(*
we can restrict access as soon as both are runtime-compatible
*)
let unify_access a1 a2 =
a1 = a2 || match a1, a2 with
| _, AccNo | _, AccNever -> true
| AccInline, AccNormal -> true
| _ -> false

let direct_access = function
| AccNo | AccNever | AccNormal | AccInline | AccRequire _ | AccCtor -> true
| AccCall -> false

let unify_kind ?(strict:bool = false) k1 k2 =
k1 = k2 || match k1, k2 with
| Var v1, Var v2 -> unify_access v1.v_read v2.v_read && unify_access v1.v_write v2.v_write
| Method m1, Method m2 ->
(match m1,m2 with
| MethInline, MethNormal
| MethDynamic, MethNormal -> true
| _ -> false)
| Var v, Method m when not strict ->
(match v.v_read, v.v_write, m with
| AccNormal, _, MethNormal -> true
| AccNormal, AccNormal, MethDynamic -> true
| _ -> false)
| Method m, Var v when not strict ->
(match m with
| MethDynamic -> direct_access v.v_read && direct_access v.v_write
| MethMacro -> false
| MethNormal | MethInline ->
match v.v_read,v.v_write with
| AccNormal,(AccNo | AccNever) -> true
| _ -> false)
| _ -> false

let default_unification_context = {
allow_transitive_cast = true;
allow_abstract_cast = true;
allow_dynamic_to_cast = true;
equality_kind = EqStrict;
equality_underlying = false;
allow_transitive_cast = true;
allow_abstract_cast = true;
allow_dynamic_to_cast = true;
allow_arg_name_mismatch = true;
equality_kind = EqStrict;
equality_underlying = false;
strict_field_kind = false;
}

module Monomorph = struct
Expand Down Expand Up @@ -410,45 +450,6 @@ let invalid_visibility n = Invalid_visibility n
let has_no_field t n = Has_no_field (t,n)
let has_extra_field t n = Has_extra_field (t,n)

(*
we can restrict access as soon as both are runtime-compatible
*)
let unify_access a1 a2 =
a1 = a2 || match a1, a2 with
| _, AccNo | _, AccNever -> true
| AccInline, AccNormal -> true
| _ -> false

let direct_access = function
| AccNo | AccNever | AccNormal | AccInline | AccRequire _ | AccCtor -> true
| AccCall -> false

let unify_kind k1 k2 =
k1 = k2 || match k1, k2 with
| Var v1, Var v2 -> unify_access v1.v_read v2.v_read && unify_access v1.v_write v2.v_write
| Var v, Method m ->
(match v.v_read, v.v_write, m with
| AccNormal, _, MethNormal -> true
| AccNormal, AccNormal, MethDynamic -> true
| _ -> false)
| Method m, Var v ->
(match m with
| MethDynamic -> direct_access v.v_read && direct_access v.v_write
| MethMacro -> false
| MethNormal | MethInline ->
match v.v_read,v.v_write with
| AccNormal,(AccNo | AccNever) -> true
| _ -> false)
| Method m1, Method m2 ->
match m1,m2 with
| MethInline, MethNormal
| MethDynamic, MethNormal -> true
| _ -> false

let unify_kind_strict cfk1 cfk2 = cfk1 = cfk2 || match cfk1, cfk2 with
| Var _, Var _ | Method _, Method _ -> unify_kind cfk1 cfk2
| _ -> false

type 'a rec_stack = {
mutable rec_stack : 'a list;
}
Expand Down Expand Up @@ -545,9 +546,10 @@ let rec type_eq uctx a b =
let i = ref 0 in
(try
type_eq uctx r1 r2;
List.iter2 (fun (n,o1,t1) (_,o2,t2) ->
List.iter2 (fun (n1,o1,t1) (n2,o2,t2) ->
incr i;
if o1 <> o2 then error [Not_matching_optional n];
if not uctx.allow_arg_name_mismatch && n1 <> n2 then error [Unify_custom (Printf.sprintf "Arg name mismatch: %s should be %s" n2 n1)];
if o1 <> o2 then error [Not_matching_optional n1];
type_eq uctx t1 t2
) l1 l2
with
Expand All @@ -572,8 +574,7 @@ let rec type_eq uctx a b =
PMap.iter (fun n f1 ->
try
let f2 = PMap.find n a2.a_fields in
(* if f1.cf_kind <> f2.cf_kind && (param = EqStrict || param = EqCoreType || not (unify_kind f1.cf_kind f2.cf_kind)) then error [invalid_kind n f1.cf_kind f2.cf_kind]; *)
if f1.cf_kind <> f2.cf_kind && (param = EqStrict || param = EqCoreType || param = EqDoNotFollowNull || not (unify_kind f1.cf_kind f2.cf_kind)) then error [invalid_kind n f1.cf_kind f2.cf_kind];
if f1.cf_kind <> f2.cf_kind && (param = EqStrict || param = EqCoreType || param = EqDoNotFollowNull || not (unify_kind ~strict:uctx.strict_field_kind f1.cf_kind f2.cf_kind)) then error [invalid_kind n f1.cf_kind f2.cf_kind];
let a = f1.cf_type and b = f2.cf_type in
(try type_eq uctx a b with Unify_error l -> error (invalid_field n :: l));
if (has_class_field_flag f1 CfPublic) != (has_class_field_flag f2 CfPublic) then error [invalid_visibility n];
Expand Down Expand Up @@ -750,7 +751,7 @@ let rec unify (uctx : unification_context) a b =
in
let _, ft, f1 = (try raw_class_field make_type c tl n with Not_found -> error [has_no_field a n]) in
let ft = apply_params c.cl_params tl ft in
if not (unify_kind f1.cf_kind f2.cf_kind) then error [invalid_kind n f1.cf_kind f2.cf_kind];
if not (unify_kind ~strict:uctx.strict_field_kind f1.cf_kind f2.cf_kind) then error [invalid_kind n f1.cf_kind f2.cf_kind];
if (has_class_field_flag f2 CfPublic) && not (has_class_field_flag f1 CfPublic) then error [invalid_visibility n];

(match f2.cf_kind with
Expand Down Expand Up @@ -906,7 +907,7 @@ and unify_anons uctx a b a1 a2 =
PMap.iter (fun n f2 ->
try
let f1 = PMap.find n a1.a_fields in
if not (unify_kind f1.cf_kind f2.cf_kind) then
if not (unify_kind ~strict:uctx.strict_field_kind f1.cf_kind f2.cf_kind) then
error [invalid_kind n f1.cf_kind f2.cf_kind];
if (has_class_field_flag f2 CfPublic) && not (has_class_field_flag f1 CfPublic) then error [invalid_visibility n];
try
Expand Down
26 changes: 15 additions & 11 deletions src/typing/tanon_identification.ml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,17 @@ object(self)
DynArray.add (DynArray.get pfm_by_arity pfm.pfm_arity) pfm;
Hashtbl.replace pfms path pfm

method unify ?(unify_kind = TUnification.unify_kind) ?(strict:bool = false) (tc : Type.t) (pfm : 'a path_field_mapping) =
method unify ?(strict:bool = false) (tc : Type.t) (pfm : 'a path_field_mapping) =
let uctx = if strict then {
allow_transitive_cast = false;
allow_abstract_cast = false;
allow_dynamic_to_cast = false;
allow_arg_name_mismatch = false;
equality_kind = EqDoNotFollowNull;
equality_underlying = true;
strict_field_kind = true;
} else {default_unification_context with equality_kind = EqDoNotFollowNull} in

let check () =
let pair_up fields =
PMap.fold (fun cf acc ->
Expand All @@ -73,7 +83,7 @@ object(self)
let monos = List.map (fun _ -> mk_mono()) pfm.pfm_params in
let map = apply_params pfm.pfm_params monos in
List.iter (fun (cf,cf') ->
if not (unify_kind cf'.cf_kind cf.cf_kind) then raise (Unify_error [Unify_custom "kind mismatch"]);
if not (unify_kind ~strict:uctx.strict_field_kind cf'.cf_kind cf.cf_kind) then raise (Unify_error [Unify_custom "kind mismatch"]);
Type.unify (apply_params c.cl_params tl (monomorphs cf'.cf_params cf'.cf_type)) (map (monomorphs cf.cf_params cf.cf_type))
) pairs;
monos
Expand All @@ -83,15 +93,9 @@ object(self)
let monos = List.map (fun _ -> mk_mono()) pfm.pfm_params in
let map = apply_params pfm.pfm_params monos in
List.iter (fun (cf,cf') ->
if not (unify_kind cf'.cf_kind cf.cf_kind) then raise (Unify_error [Unify_custom "kind mismatch"]);
if strict && (Meta.has Meta.Optional cf.cf_meta) != (Meta.has Meta.Optional cf'.cf_meta) then raise (Unify_error [Unify_custom "optional mismatch"]);
if not (unify_kind ~strict:uctx.strict_field_kind cf'.cf_kind cf.cf_kind) then raise (Unify_error [Unify_custom "kind mismatch"]);
fields := PMap.remove cf.cf_name !fields;
let uctx = if strict then {
allow_transitive_cast = false;
allow_abstract_cast = false;
allow_dynamic_to_cast = false;
equality_kind = EqDoNotFollowNull;
equality_underlying = true;
} else {default_unification_context with equality_kind = EqDoNotFollowNull} in
type_eq_custom uctx cf'.cf_type (map (monomorphs cf.cf_params cf.cf_type))
) pairs;
if not (PMap.is_empty !fields) then raise (Unify_error [Unify_custom "not enough fields"]);
Expand Down Expand Up @@ -123,7 +127,7 @@ object(self)
raise Not_found;
let pfm = DynArray.unsafe_get d i in
try
if strict then self#unify ~unify_kind:unify_kind_strict ~strict tc pfm else self#unify tc pfm;
if strict then self#unify ~strict tc pfm else self#unify tc pfm;
pfm
with Unify_error _ ->
loop (i + 1)
Expand Down

0 comments on commit 0c2b618

Please sign in to comment.