Skip to content

Commit 2c01e75

Browse files
committed
mo-ld: Implement GOT.func imports
Fixes #1810
1 parent 7d9b099 commit 2c01e75

File tree

8 files changed

+285
-20
lines changed

8 files changed

+285
-20
lines changed

src/lib/lib.ml

+4
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ struct
264264
then (take n xs, drop n xs)
265265
else (xs, [])
266266

267+
let null = function
268+
| [] -> true
269+
| _ :: _ -> false
270+
267271
let hd_opt = function
268272
| x :: _ -> Some x
269273
| _ -> None

src/lib/lib.mli

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ sig
1818
val take : int -> 'a list -> 'a list (* raises Failure *)
1919
val drop : int -> 'a list -> 'a list (* raises Failure *)
2020
val split_at : int -> 'a list -> ('a list * 'a list)
21+
val null : 'a list -> bool
2122

2223
val hd_opt : 'a list -> 'a option
2324
val last : 'a list -> 'a (* raises Failure *)

src/linking/linkModule.ml

+213-16
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,78 @@ through further refactoring before we are happy with it. Things to do:
1818
of functions for each syntactic category.
1919
*)
2020

21+
(*
22+
Resolving GOT.func and GOT.mem imports
23+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24+
25+
GOT.func and GOT.mem imports arise from function and data pointers,
26+
respectively, in languages with pointers (e.g. C and Rust). The idea is that if
27+
a shared library exposes a function/data pointer the entire process should use
28+
the same pointer for the function/data so that the pointer arithmetic and
29+
comparisons will work. For example, this C code:
30+
31+
__attribute__ ((visibility("default")))
32+
int f0(int x, int y)
33+
{
34+
return x + y;
35+
}
36+
37+
__attribute__ ((visibility("default")))
38+
int (\*f1(void)) (int x, int y)
39+
{
40+
return &f0;
41+
}
42+
43+
generates this GOT.func import:
44+
45+
(import "GOT.func" "f0" (global (;N;) (mut i32)))
46+
47+
The host is responsible of allocating a table index for this function and
48+
resolving the import to the table index for `f0` so that this code in the
49+
importing module would work:
50+
51+
assert(f1() == f0);
52+
53+
Note that the definition of `f1` is in the *imported* module and this assertion
54+
is in the *importing* module.
55+
56+
Similarly exposing a data pointer generates a GOT.mem import. All GOT.mem
57+
imports to a symbol should resolve to the same constant to support equality as
58+
above, and additionally pointer arithmetic.
59+
60+
(Pointer arithmetic on function pointers are undefined behavior is C and is not
61+
supported by clang's wasm backend)
62+
63+
Normally this stuff is for dynamic linking, but we want to link the RTS
64+
statically, so we resolve these imports during linking. Currently we only
65+
support GOT.func imports, but implementing GOT.mem imports would be similar.
66+
Secondly, we only support GOT.func imports in the module that defines the
67+
function that we take the address of. This currently works as moc-generated code
68+
doesn't import function addresses from the RTS.
69+
70+
We resolve GOT.func imports in two steps:
71+
72+
- After loading the RTS module we generate a list of (global index, function
73+
index) pairs of GOT.func imports. In the example above, global index is N and
74+
function index is the index of f0 in the defining module (the RTS).
75+
76+
This is implemented in `collect_got_func_imports`.
77+
78+
- After merging the sections we add the functions to the table and replace
79+
`GOT.func` imports with globals to the functions' table indices.
80+
81+
Note that we don't reuse table entries when a function is already in the
82+
table, to avoid breakage when [ref-types] proposal is implemented, which will
83+
allow mutating table entries.
84+
85+
[ref-types]: https://github.com/WebAssembly/reference-types
86+
87+
This is implemented in `replace_got_func_imports`.
88+
89+
See also the test `test/ld/fun-ptr` for a concrete exaple of GOT.func generation
90+
and resolving.
91+
*)
92+
2193
(* Linking *)
2294

2395
type imports = (int32 * name) list
@@ -64,7 +136,8 @@ let remove_imports is_thing resolved : module_' -> module_' = fun m ->
64136
if List.mem_assoc i resolved
65137
then go (Int32.add i 1l) is
66138
else imp :: go (Int32.add i 1l) is
67-
else imp :: go i is in
139+
else imp :: go i is
140+
in
68141
{ m with imports = go 0l m.imports }
69142

70143
let count_imports is_thing m =
@@ -191,7 +264,7 @@ let remove_non_ic_exports (em : extended_module) : extended_module =
191264
let keep_export exp =
192265
is_ic_export exp ||
193266
match exp.it.edesc.it with
194-
| FuncExport var -> false
267+
| FuncExport _ -> false
195268
| GlobalExport _ -> false
196269
| MemoryExport _ -> true
197270
| TableExport _ -> true in
@@ -211,7 +284,7 @@ let resolve imports exports : (int32 * int32) list =
211284
| None -> []
212285
) imports)
213286

214-
let calculate_renaming n_imports1 n_things1 n_imports2 n_things2 resolved12 resolved21 : (renumbering * renumbering) =
287+
let calculate_renaming n_imports1 n_things1 n_imports2 resolved12 resolved21 : (renumbering * renumbering) =
215288
let open Int32 in
216289

217290
let n_imports1' = sub n_imports1 (Lib.List32.length resolved12) in
@@ -575,6 +648,114 @@ let align p n =
575648
let p = to_int p in
576649
shift_left (shift_right_logical (add n (sub (shift_left 1l p) 1l)) p) p
577650

651+
let find_fun_export (name : name) (exports : export list) : var option =
652+
Lib.List.first_opt (fun (export : export) ->
653+
if export.it.name = name then
654+
match export.it.edesc.it with
655+
| FuncExport var -> Some var
656+
| _ -> raise (LinkError (Format.sprintf "Export %s is not a function" (Wasm.Utf8.encode name)))
657+
else
658+
None
659+
) exports
660+
661+
let remove_got_func_imports (imports : import list) : import list =
662+
let got_func_str = Wasm.Utf8.decode "GOT.func" in
663+
List.filter (fun import -> import.it.module_name <> got_func_str) imports
664+
665+
(* Merge global list of a module with a sorted (on global index) list of (global
666+
index, global) pairs, overriding globals at those indices, and appending
667+
left-overs at the end. *)
668+
let add_globals (globals0 : global list) (insert0 : (int32 * global') list) : global list =
669+
let rec go (current_idx : int32) globals insert =
670+
match insert with
671+
| [] -> globals
672+
| (insert_idx, global) :: rest ->
673+
if current_idx = insert_idx then
674+
(global @@ no_region) :: go (Int32.add current_idx 1l) globals rest
675+
else
676+
match globals with
677+
| [] -> List.map (fun (_, global) -> global @@ no_region) insert
678+
| global :: globals -> global :: go (Int32.add current_idx 1l) globals rest
679+
in
680+
go 0l globals0 insert0
681+
682+
let mk_i32_const (i : int32) =
683+
Const (Wasm.Values.I32 i @@ no_region) @@ no_region
684+
685+
let mk_i32_global (i : int32) =
686+
{ gtype = Wasm.Types.GlobalType (Wasm.Types.I32Type, Wasm.Types.Immutable);
687+
value = [mk_i32_const i] @@ no_region }
688+
689+
(* Generate (global index, function index) pairs for GOT.func imports of a
690+
module. Uses import and export lists of the module so those should be valid. *)
691+
let collect_got_func_imports (m : module_') : (int32 * int32) list =
692+
let got_func_name = Wasm.Utf8.decode "GOT.func" in
693+
694+
let get_got_func_import (global_idx, imports) import : (int32 * (int32 * int32) list) =
695+
if import.it.module_name = got_func_name then
696+
(* Found a GOT.func import, find the exported function for it *)
697+
let name = import.it.item_name in
698+
let fun_idx =
699+
match find_fun_export name m.exports with
700+
| None -> raise (LinkError (Format.sprintf "Can't find export for GOT.func import %s" (Wasm.Utf8.encode name)))
701+
| Some export_idx -> export_idx.it
702+
in
703+
let global_idx =
704+
if is_global_import import.it.idesc.it then
705+
global_idx
706+
else
707+
raise (LinkError "GOT.func import is not global")
708+
in
709+
( Int32.add global_idx (Int32.of_int 1), (global_idx, fun_idx) :: imports )
710+
else
711+
let global_idx =
712+
if is_global_import import.it.idesc.it then
713+
Int32.add global_idx (Int32.of_int 1)
714+
else
715+
global_idx
716+
in
717+
( global_idx, imports )
718+
in
719+
720+
(* (global index, function index) list *)
721+
let (_, got_func_imports) =
722+
List.fold_left get_got_func_import (0l, []) m.imports
723+
in
724+
725+
got_func_imports
726+
727+
(* Add functions imported from GOT.func to the table, replace GOT.func imports
728+
with globals to the table indices.
729+
730+
`tbe_size` is the size of the table in the merged module before adding
731+
GOT.func functions. *)
732+
let replace_got_func_imports (tbl_size : int32) (imports : (int32 * int32) list) (m : module_') : module_' =
733+
(* null check to avoid adding empty elem section *)
734+
if Lib.List.null imports then
735+
m
736+
else
737+
let imports =
738+
List.sort (fun (gbl_idx_1, _) (gbl_idx_2, _) -> compare gbl_idx_1 gbl_idx_2) imports
739+
in
740+
741+
let elems : var list =
742+
List.map (fun (_, fun_idx) -> fun_idx @@ no_region) imports
743+
in
744+
745+
let elem_section =
746+
{ index = 0l @@ no_region; offset = [ mk_i32_const tbl_size ] @@ no_region; init = elems }
747+
in
748+
749+
let globals =
750+
List.mapi (fun idx (global_idx, _) -> (global_idx, mk_i32_global (Int32.add tbl_size (Int32.of_int idx)))) imports
751+
in
752+
753+
{ m with
754+
elems = List.append m.elems [elem_section @@ no_region];
755+
imports = remove_got_func_imports m.imports;
756+
globals = add_globals m.globals globals
757+
}
758+
578759
(* The first argument specifies the global of the first module indicating the
579760
start of free memory *)
580761
let link (em1 : extended_module) libname (em2 : extended_module) =
@@ -595,14 +776,15 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
595776
let lib_heap_start = align dylink.memory_alignment old_heap_start in
596777
let new_heap_start = align 4l (Int32.add lib_heap_start dylink.memory_size) in
597778

598-
let old_elem_size = read_table_size em1.module_ in
599-
let lib_elem_start = align dylink.table_alignment old_elem_size in
600-
let new_elem_size = Int32.add lib_elem_start dylink.table_size in
779+
let old_table_size = read_table_size em1.module_ in
780+
let lib_table_start = align dylink.table_alignment old_table_size in
601781

602-
(* Fill in memory base pointer *)
782+
(* Fill in memory and table base pointers *)
603783
let dm2 = em2.module_
604784
|> fill_memory_base_import lib_heap_start
605-
|> fill_table_base_import lib_elem_start in
785+
|> fill_table_base_import lib_table_start in
786+
787+
let got_func_imports = collect_got_func_imports dm2 in
606788

607789
(* Link functions *)
608790
let fun_required1 = find_imports is_fun_import libname em1.module_ in
@@ -617,7 +799,6 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
617799
(count_imports is_fun_import em1.module_)
618800
(Lib.List32.length em1.module_.funcs)
619801
(count_imports is_fun_import dm2)
620-
(Lib.List32.length dm2.funcs)
621802
fun_resolved12
622803
fun_resolved21 in
623804

@@ -626,7 +807,7 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
626807

627808
(* Link globals *)
628809
let global_required1 = find_imports is_global_import libname em1.module_ in
629-
let global_required2 = find_imports is_global_import "env" dm2 in
810+
let global_required2 = find_imports is_global_import "env" dm2 in
630811
let global_exports2 = find_exports is_global_export dm2 in
631812
(* Resolve imports, to produce a renumbering globalction: *)
632813
let global_resolved12 = resolve global_required1 global_exports2 in
@@ -636,7 +817,6 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
636817
(count_imports is_global_import em1.module_)
637818
(Lib.List32.length em1.module_.globals)
638819
(count_imports is_global_import dm2)
639-
(Lib.List32.length dm2.globals)
640820
global_resolved12
641821
global_resolved21 in
642822
assert (global_required1 = []); (* so far, we do not import globals *)
@@ -673,7 +853,7 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
673853
in
674854

675855
(* Rename types in second module *)
676-
let dm2_tys =
856+
let dm2 =
677857
{ (rename_types (ty_renamer dm2.types) dm2) with types = [] }
678858
in
679859

@@ -692,9 +872,13 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
692872
| Some fi -> prepend_to_start (funs2 fi) (add_or_get_ty (Wasm.Types.FuncType ([], [])))
693873
in
694874

695-
assert (dm2_tys.globals = []);
875+
assert (dm2.globals = []);
876+
877+
let new_table_size =
878+
Int32.add (Int32.add lib_table_start dylink.table_size) (Int32.of_int (List.length got_func_imports))
879+
in
696880

697-
join_modules
881+
let merged = join_modules
698882
( em1_tys
699883
|> map_module (fun m -> { m with types = type_indices_sorted })
700884
|> map_module (remove_imports is_fun_import fun_resolved12)
@@ -704,9 +888,9 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
704888
|> rename_globals_extended globals1
705889
|> map_module (set_global heap_global new_heap_start)
706890
|> map_module (set_memory_size new_heap_start)
707-
|> map_module (set_table_size new_elem_size)
891+
|> map_module (set_table_size new_table_size)
708892
)
709-
( dm2_tys
893+
( dm2
710894
|> remove_imports is_fun_import fun_resolved21
711895
|> remove_imports is_global_import global_resolved21
712896
|> remove_imports is_memory_import [0l, 0l]
@@ -721,3 +905,16 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
721905
)
722906
|> add_call_ctors
723907
|> remove_non_ic_exports (* only sane if no additional files get linked in *)
908+
in
909+
910+
(* Rename global and function indices in GOT.func stuff *)
911+
let got_func_imports =
912+
List.map (fun (global_idx, func_idx) -> (globals2 global_idx, funs2 func_idx)) got_func_imports
913+
in
914+
915+
(* Replace GOT.func imports with globals to function table indices *)
916+
let final =
917+
replace_got_func_imports (Int32.add lib_table_start dylink.table_size) got_func_imports merged.module_
918+
in
919+
920+
{ merged with module_ = final }

test/ld/Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ WASM_LD?=wasm-ld-10
1414
MO_LD?=../../src/mo-ld
1515

1616
_out/%.lib.o: %.c | _out
17-
$(WASM_CLANG) --compile -fpic --std=c11 --target=wasm32-unknown-unknown-wasm -fvisibility=hidden --optimize=3 \
17+
$(WASM_CLANG) --compile -fPIC --target=wasm32-unknown-emscripten --optimize=3 \
1818
-fno-builtin -Wall \
1919
$< --output $@
2020

@@ -30,7 +30,7 @@ _out/%.linked.wasm: _out/%.base.wasm _out/%.lib.wasm
3030
$(MO_LD) -b _out/$*.base.wasm -l _out/$*.lib.wasm -o _out/$*.linked.wasm
3131

3232
_out/%.wat: _out/%.wasm
33-
wasm2wat --no-check $< -o $@
33+
wasm2wat $< -o $@
3434

3535
_out/%.valid: _out/%.wasm
3636
wasm-validate $< > $@ 2>&1 || true

test/ld/fun-ptr.c

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
__attribute__ ((visibility("default")))
2+
int f0(int x, int y)
3+
{
4+
return x + y;
5+
}
6+
7+
__attribute__ ((visibility("default")))
8+
int (*f1(void)) (int x, int y)
9+
{
10+
return &f0;
11+
}
12+
13+
__attribute__ ((visibility("default")))
14+
void *f2()
15+
{
16+
return &f1;
17+
}

test/ld/fun-ptr.wat

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
(module
2+
(type (;0;) (func (result i32)))
3+
(type (;1;) (func (param i32) (param i32) (result i32)))
4+
(import "rts" "f2" (func $f2 (result i32)))
5+
(table (;0;) 1 1 anyfunc)
6+
(memory (;0;) 2)
7+
(global $heap_base i32 (i32.const 65536))
8+
(export "__heap_base" (global $heap_base))
9+
(func $call_imported (type 0)
10+
call $f2
11+
call_indirect (type 0)
12+
i32.const 3
13+
i32.const 5
14+
call_indirect (type 1)))

0 commit comments

Comments
 (0)