Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GOT.func imports in mo-ld #1811

Merged
merged 2 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 214 additions & 17 deletions src/linking/linkModule.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,78 @@ through further refactoring before we are happy with it. Things to do:
of functions for each syntactic category.
*)

(*
Resolving GOT.func and GOT.mem imports
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

GOT.func and GOT.mem imports arise from function and data pointers,
respectively, in languages with pointers (e.g. C and Rust). The idea is that if
a shared library exposes a function/data pointer the entire process should use
the same pointer for the function/data so that the pointer arithmetic and
comparisons will work. For example, this C code:

__attribute__ ((visibility("default")))
int f0(int x, int y)
{
return x + y;
}

__attribute__ ((visibility("default")))
int (\*f1(void)) (int x, int y)
{
return &f0;
}

generates this GOT.func import:

(import "GOT.func" "f0" (global (;N;) (mut i32)))

The host is responsible of allocating a table index for this function and
resolving the import to the table index for `f0` so that this code in the
importing module would work:

assert(f1() == f0);

Note that the definition of `f1` is in the *imported* module and this assertion
is in the *importing* module.

Similarly exposing a data pointer generates a GOT.mem import. All GOT.mem
imports to a symbol should resolve to the same constant to support equality as
above, and additionally pointer arithmetic.

(Pointer arithmetic on function pointers are undefined behavior is C and is not
supported by clang's wasm backend)

Normally this stuff is for dynamic linking, but we want to link the RTS
statically, so we resolve these imports during linking. Currently we only
support GOT.func imports, but implementing GOT.mem imports would be similar.
Secondly, we only support GOT.func imports in the module that defines the
function that we take the address of. This currently works as moc-generated code
doesn't import function addresses from the RTS.

We resolve GOT.func imports in two steps:

- After loading the RTS module we generate a list of (global index, function
index) pairs of GOT.func imports. In the example above, global index is N and
function index is the index of f0 in the defining module (the RTS).

This is implemented in `collect_got_func_imports`.

- After merging the sections we add the functions to the table and replace
`GOT.func` imports with globals to the functions' table indices.

Note that we don't reuse table entries when a function is already in the
table, to avoid breakage when [ref-types] proposal is implemented, which will
allow mutating table entries.

[ref-types]: https://github.com/WebAssembly/reference-types

This is implemented in `replace_got_func_imports`.

See also the test `test/ld/fun-ptr` for a concrete exaple of GOT.func generation
and resolving.
*)

(* Linking *)

type imports = (int32 * name) list
Expand Down Expand Up @@ -64,7 +136,8 @@ let remove_imports is_thing resolved : module_' -> module_' = fun m ->
if List.mem_assoc i resolved
then go (Int32.add i 1l) is
else imp :: go (Int32.add i 1l) is
else imp :: go i is in
else imp :: go i is
in
{ m with imports = go 0l m.imports }

let count_imports is_thing m =
Expand Down Expand Up @@ -191,9 +264,9 @@ let remove_non_ic_exports (em : extended_module) : extended_module =
let keep_export exp =
is_ic_export exp ||
match exp.it.edesc.it with
| FuncExport var -> false
| FuncExport _
| GlobalExport _ -> false
| MemoryExport _ -> true
| MemoryExport _
| TableExport _ -> true in

map_module (fun m -> { m with exports = List.filter keep_export m.exports }) em
Expand All @@ -211,7 +284,7 @@ let resolve imports exports : (int32 * int32) list =
| None -> []
) imports)

let calculate_renaming n_imports1 n_things1 n_imports2 n_things2 resolved12 resolved21 : (renumbering * renumbering) =
let calculate_renaming n_imports1 n_things1 n_imports2 resolved12 resolved21 : (renumbering * renumbering) =
let open Int32 in

let n_imports1' = sub n_imports1 (Lib.List32.length resolved12) in
Expand Down Expand Up @@ -575,6 +648,114 @@ let align p n =
let p = to_int p in
shift_left (shift_right_logical (add n (sub (shift_left 1l p) 1l)) p) p

let find_fun_export (name : name) (exports : export list) : var option =
Lib.List.first_opt (fun (export : export) ->
if export.it.name = name then
match export.it.edesc.it with
| FuncExport var -> Some var
| _ -> raise (LinkError (Format.sprintf "Export %s is not a function" (Wasm.Utf8.encode name)))
else
None
) exports

let remove_got_func_imports (imports : import list) : import list =
let got_func_str = Wasm.Utf8.decode "GOT.func" in
List.filter (fun import -> import.it.module_name <> got_func_str) imports

(* Merge global list of a module with a sorted (on global index) list of (global
index, global) pairs, overriding globals at those indices, and appending
left-overs at the end. *)
let add_globals (globals0 : global list) (insert0 : (int32 * global') list) : global list =
let rec go (current_idx : int32) globals insert =
match insert with
| [] -> globals
| (insert_idx, global) :: rest ->
if current_idx = insert_idx then
(global @@ no_region) :: go (Int32.add current_idx 1l) globals rest
else
match globals with
| [] -> List.map (fun (_, global) -> global @@ no_region) insert
| global :: globals -> global :: go (Int32.add current_idx 1l) globals rest
in
go 0l globals0 insert0

let mk_i32_const (i : int32) =
Const (Wasm.Values.I32 i @@ no_region) @@ no_region

let mk_i32_global (i : int32) =
{ gtype = Wasm.Types.GlobalType (Wasm.Types.I32Type, Wasm.Types.Immutable);
value = [mk_i32_const i] @@ no_region }

(* Generate (global index, function index) pairs for GOT.func imports of a
module. Uses import and export lists of the module so those should be valid. *)
let collect_got_func_imports (m : module_') : (int32 * int32) list =
let got_func_name = Wasm.Utf8.decode "GOT.func" in

let get_got_func_import (global_idx, imports) import : (int32 * (int32 * int32) list) =
if import.it.module_name = got_func_name then
(* Found a GOT.func import, find the exported function for it *)
let name = import.it.item_name in
let fun_idx =
match find_fun_export name m.exports with
| None -> raise (LinkError (Format.sprintf "Can't find export for GOT.func import %s" (Wasm.Utf8.encode name)))
| Some export_idx -> export_idx.it
in
let global_idx =
if is_global_import import.it.idesc.it then
global_idx
else
raise (LinkError "GOT.func import is not global")
in
( Int32.add global_idx (Int32.of_int 1), (global_idx, fun_idx) :: imports )
else
let global_idx =
if is_global_import import.it.idesc.it then
Int32.add global_idx (Int32.of_int 1)
else
global_idx
in
( global_idx, imports )
Comment on lines +711 to +717
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let global_idx =
if is_global_import import.it.idesc.it then
Int32.add global_idx (Int32.of_int 1)
else
global_idx
in
( global_idx, imports )
if is_global_import import.it.idesc.it
then (Int32.add global_idx (Int32.of_int 1), imports)
else (global_idx, imports)

seems simpler, at the expense of symmetry with the previous. Up to you.

in

(* (global index, function index) list *)
let (_, got_func_imports) =
List.fold_left get_got_func_import (0l, []) m.imports
in

got_func_imports

(* Add functions imported from GOT.func to the table, replace GOT.func imports
with globals to the table indices.

`tbe_size` is the size of the table in the merged module before adding
GOT.func functions. *)
let replace_got_func_imports (tbl_size : int32) (imports : (int32 * int32) list) (m : module_') : module_' =
(* null check to avoid adding empty elem section *)
if imports = [] then
m
else
let imports =
List.sort (fun (gbl_idx_1, _) (gbl_idx_2, _) -> compare gbl_idx_1 gbl_idx_2) imports
in

let elems : var list =
List.map (fun (_, fun_idx) -> fun_idx @@ no_region) imports
in

let elem_section =
{ index = 0l @@ no_region; offset = [ mk_i32_const tbl_size ] @@ no_region; init = elems }
in

let globals =
List.mapi (fun idx (global_idx, _) -> (global_idx, mk_i32_global (Int32.add tbl_size (Int32.of_int idx)))) imports
in

{ m with
elems = List.append m.elems [elem_section @@ no_region];
imports = remove_got_func_imports m.imports;
globals = add_globals m.globals globals
}

(* The first argument specifies the global of the first module indicating the
start of free memory *)
let link (em1 : extended_module) libname (em2 : extended_module) =
Expand All @@ -595,14 +776,15 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
let lib_heap_start = align dylink.memory_alignment old_heap_start in
let new_heap_start = align 4l (Int32.add lib_heap_start dylink.memory_size) in

let old_elem_size = read_table_size em1.module_ in
let lib_elem_start = align dylink.table_alignment old_elem_size in
let new_elem_size = Int32.add lib_elem_start dylink.table_size in
let old_table_size = read_table_size em1.module_ in
let lib_table_start = align dylink.table_alignment old_table_size in

(* Fill in memory base pointer *)
(* Fill in memory and table base pointers *)
let dm2 = em2.module_
|> fill_memory_base_import lib_heap_start
|> fill_table_base_import lib_elem_start in
|> fill_table_base_import lib_table_start in

let got_func_imports = collect_got_func_imports dm2 in

(* Link functions *)
let fun_required1 = find_imports is_fun_import libname em1.module_ in
Expand All @@ -617,7 +799,6 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
(count_imports is_fun_import em1.module_)
(Lib.List32.length em1.module_.funcs)
(count_imports is_fun_import dm2)
(Lib.List32.length dm2.funcs)
fun_resolved12
fun_resolved21 in

Expand All @@ -626,7 +807,7 @@ let link (em1 : extended_module) libname (em2 : extended_module) =

(* Link globals *)
let global_required1 = find_imports is_global_import libname em1.module_ in
let global_required2 = find_imports is_global_import "env" dm2 in
let global_required2 = find_imports is_global_import "env" dm2 in
let global_exports2 = find_exports is_global_export dm2 in
(* Resolve imports, to produce a renumbering globalction: *)
let global_resolved12 = resolve global_required1 global_exports2 in
Expand All @@ -636,7 +817,6 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
(count_imports is_global_import em1.module_)
(Lib.List32.length em1.module_.globals)
(count_imports is_global_import dm2)
(Lib.List32.length dm2.globals)
global_resolved12
global_resolved21 in
assert (global_required1 = []); (* so far, we do not import globals *)
Expand Down Expand Up @@ -673,7 +853,7 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
in

(* Rename types in second module *)
let dm2_tys =
let dm2 =
{ (rename_types (ty_renamer dm2.types) dm2) with types = [] }
in

Expand All @@ -692,9 +872,13 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
| Some fi -> prepend_to_start (funs2 fi) (add_or_get_ty (Wasm.Types.FuncType ([], [])))
in

assert (dm2_tys.globals = []);
assert (dm2.globals = []);

let new_table_size =
Int32.add (Int32.add lib_table_start dylink.table_size) (Int32.of_int (List.length got_func_imports))
in

join_modules
let merged = join_modules
( em1_tys
|> map_module (fun m -> { m with types = type_indices_sorted })
|> map_module (remove_imports is_fun_import fun_resolved12)
Expand All @@ -704,9 +888,9 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
|> rename_globals_extended globals1
|> map_module (set_global heap_global new_heap_start)
|> map_module (set_memory_size new_heap_start)
|> map_module (set_table_size new_elem_size)
|> map_module (set_table_size new_table_size)
)
( dm2_tys
( dm2
|> remove_imports is_fun_import fun_resolved21
|> remove_imports is_global_import global_resolved21
|> remove_imports is_memory_import [0l, 0l]
Expand All @@ -721,3 +905,16 @@ let link (em1 : extended_module) libname (em2 : extended_module) =
)
|> add_call_ctors
|> remove_non_ic_exports (* only sane if no additional files get linked in *)
in

(* Rename global and function indices in GOT.func stuff *)
let got_func_imports =
List.map (fun (global_idx, func_idx) -> (globals2 global_idx, funs2 func_idx)) got_func_imports
in

(* Replace GOT.func imports with globals to function table indices *)
let final =
replace_got_func_imports (Int32.add lib_table_start dylink.table_size) got_func_imports merged.module_
in

{ merged with module_ = final }
4 changes: 2 additions & 2 deletions test/ld/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ WASM_LD?=wasm-ld-10
MO_LD?=../../src/mo-ld

_out/%.lib.o: %.c | _out
$(WASM_CLANG) --compile -fpic --std=c11 --target=wasm32-unknown-unknown-wasm -fvisibility=hidden --optimize=3 \
$(WASM_CLANG) --compile -fPIC --target=wasm32-emscripten --optimize=3 \
-fno-builtin -Wall \
$< --output $@

Expand All @@ -30,7 +30,7 @@ _out/%.linked.wasm: _out/%.base.wasm _out/%.lib.wasm
$(MO_LD) -b _out/$*.base.wasm -l _out/$*.lib.wasm -o _out/$*.linked.wasm

_out/%.wat: _out/%.wasm
wasm2wat --no-check $< -o $@
wasm2wat $< -o $@

_out/%.valid: _out/%.wasm
wasm-validate $< > $@ 2>&1 || true
Expand Down
17 changes: 17 additions & 0 deletions test/ld/fun-ptr.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
__attribute__ ((visibility("default")))
int f0(int x, int y)
{
return x + y;
}

__attribute__ ((visibility("default")))
int (*f1(void)) (int x, int y)
{
return &f0;
}

__attribute__ ((visibility("default")))
void *f2()
{
return &f1;
}
14 changes: 14 additions & 0 deletions test/ld/fun-ptr.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
(module
(type (;0;) (func (result i32)))
(type (;1;) (func (param i32) (param i32) (result i32)))
(import "rts" "f2" (func $f2 (result i32)))
(table (;0;) 1 1 anyfunc)
(memory (;0;) 2)
(global $heap_base i32 (i32.const 65536))
(export "__heap_base" (global $heap_base))
(func $call_imported (type 0)
call $f2
call_indirect (type 0)
i32.const 3
i32.const 5
call_indirect (type 1)))
Loading