Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got rid of Obj.magic use in memory module #84

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 41 additions & 38 deletions ml-proto/src/spec/memory.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,6 @@ type memory' = (int, int8_unsigned_elt, c_layout) Array1.t
type memory = memory' ref
type t = memory

type char_view = (char, int8_unsigned_elt, c_layout) Array1.t
type sint8_view = (int, int8_signed_elt, c_layout) Array1.t
type sint16_view = (int, int16_signed_elt, c_layout) Array1.t
type sint32_view = (int32, int32_elt, c_layout) Array1.t
type sint64_view = (int64, int64_elt, c_layout) Array1.t
type uint8_view = (int, int8_unsigned_elt, c_layout) Array1.t
type uint16_view = (int, int16_unsigned_elt, c_layout) Array1.t
type uint32_view = (int32, int32_elt, c_layout) Array1.t
type uint64_view = (int64, int64_elt, c_layout) Array1.t
type float32_view = (int32, int32_elt, c_layout) Array1.t
type float64_view = (int64, int64_elt, c_layout) Array1.t

let view : memory' -> ('c, 'd, c_layout) Array1.t = Obj.magic


(* Queries *)

let mem_size = function
Expand All @@ -65,7 +50,7 @@ let create n =
let init_seg mem seg =
(* There currently is no way to blit from a string. *)
for i = 0 to String.length seg.data - 1 do
(view !mem : char_view).{seg.addr + i} <- seg.data.[i]
!mem.{seg.addr + i} <- Char.code seg.data.[i]
done

let init mem segs =
Expand All @@ -91,38 +76,56 @@ let address_of_value = function

(* Load and store *)

let int32_mask = Int64.shift_right_logical (Int64.of_int (-1)) 32
let int64_of_int32_u i = Int64.logand (Int64.of_int32 i) int32_mask
let load8 mem a ext =
(match ext with
| SX -> Int32.shift_right (Int32.shift_left (Int32.of_int !mem.{a}) 24) 24
| _ -> Int32.of_int !mem.{a})

let load16 mem a ext =
Int32.logor (load8 mem a NX) (Int32.shift_left (load8 mem (a+1) ext) 8)

let load32 mem a =
Int32.logor (load16 mem a NX) (Int32.shift_left (load16 mem (a+2) NX) 16)

let load64 mem a =
Int64.logor (Int64.of_int32 (load32 mem a)) (Int64.shift_left (Int64.of_int32 (load32 mem (a+4))) 32)

let store8 mem a bits =
!mem.{a} <- Int32.to_int (Int32.logand bits (Int32.of_int 255))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this &255 necessary? I guess it doesn't hurt, but I was initially confused about what this code is doing and wasn't clear whether it's doing something that needs the high bits to be cleared.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK it's not necessary. I think there's some value in making it clear there's a truncation happening there, but I'm happy to remove it if you want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The syntax was indeed unfamiliar to me at first, so I do appreciate clarifying it. I'd suggest a comment, something like "store the least significant byte of bits at memory index a". With the &255, I was wondering whether it was perhaps doing a bitwise or of bits with memory such that it actually needed the high bits to be zero or something.


let store16 mem a bits =
store8 mem (a+0) bits;
store8 mem (a+1) (Int32.shift_right_logical bits 8)

let store32 mem a bits =
store16 mem (a+0) bits;
store16 mem (a+2) (Int32.shift_right_logical bits 16)

let buf = create' 8
let store64 mem a bits =
store32 mem (a+0) (Int64.to_int32 bits);
store32 mem (a+4) (Int64.to_int32 (Int64.shift_right_logical bits 32))

let load mem a memty ext =
let sz = mem_size memty in
let open Types in
try
Array1.blit (Array1.sub !mem a sz) (Array1.sub buf 0 sz);
match memty, ext with
| Int8Mem, SX -> Int32 (Int32.of_int (view buf : sint8_view).{0})
| Int8Mem, ZX -> Int32 (Int32.of_int (view buf : uint8_view).{0})
| Int16Mem, SX -> Int32 (Int32.of_int (view buf : sint16_view).{0})
| Int16Mem, ZX -> Int32 (Int32.of_int (view buf : uint16_view).{0})
| Int32Mem, NX -> Int32 (view buf : sint32_view).{0}
| Int64Mem, NX -> Int64 (view buf : sint64_view).{0}
| Float32Mem, NX -> Float32 (Float32.of_bits (view buf : float32_view).{0})
| Float64Mem, NX -> Float64 (Float64.of_bits (view buf : float64_view).{0})
| Int8Mem, _ -> Int32 (load8 mem a ext)
| Int16Mem, _ -> Int32 (load16 mem a ext)
| Int32Mem, NX -> Int32 (load32 mem a)
| Int64Mem, NX -> Int64 (load64 mem a)
| Float32Mem, NX -> Float32 (Float32.of_bits (load32 mem a))
| Float64Mem, NX -> Float64 (Float64.of_bits (load64 mem a))
| _ -> raise Type
with Invalid_argument _ -> raise Bounds

let store mem a memty v =
let sz = mem_size memty in
try
(match memty, v with
| Int8Mem, Int32 x -> (view buf : sint8_view).{0} <- Int32.to_int x
| Int16Mem, Int32 x -> (view buf : sint16_view).{0} <- Int32.to_int x
| Int32Mem, Int32 x -> (view buf : sint32_view).{0} <- x
| Int64Mem, Int64 x -> (view buf : sint64_view).{0} <- x
| Float32Mem, Float32 x -> (view buf : float32_view).{0} <- Float32.to_bits x
| Float64Mem, Float64 x -> (view buf : float64_view).{0} <- Float64.to_bits x
| _ -> raise Type);
Array1.blit (Array1.sub buf 0 sz) (Array1.sub !mem a sz)
| Int8Mem, Int32 x -> store8 mem a x
| Int16Mem, Int32 x -> store16 mem a x
| Int32Mem, Int32 x -> store32 mem a x
| Int64Mem, Int64 x -> store64 mem a x
| Float32Mem, Float32 x -> store32 mem a (Float32.to_bits x)
| Float64Mem, Float64 x -> store64 mem a (Float64.to_bits x)
| _ -> raise Type)
with Invalid_argument _ -> raise Bounds
32 changes: 32 additions & 0 deletions ml-proto/test/memory.wasm
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,45 @@
(return (f64.load/1 (i32.const 9)))
)

;; Sign and zero extending memory loads
(func $load8_s (param $i i32) (result i32)
(i32.store8 (i32.const 8) (get_local $i))
(return (i32.load8_s (i32.const 8)))
)
(func $load8_u (param $i i32) (result i32)
(i32.store8 (i32.const 8) (get_local $i))
(return (i32.load8_u (i32.const 8)))
)
(func $load16_s (param $i i32) (result i32)
(i32.store16 (i32.const 8) (get_local $i))
(return (i32.load16_s (i32.const 8)))
)
(func $load16_u (param $i i32) (result i32)
(i32.store16 (i32.const 8) (get_local $i))
(return (i32.load16_u (i32.const 8)))
)

(export "data" $data)
(export "aligned" $aligned)
(export "unaligned" $unaligned)
(export "cast" $cast)
(export "load8_s" $load8_s)
(export "load8_u" $load8_u)
(export "load16_s" $load16_s)
(export "load16_u" $load16_u)
)

(assert_eq (invoke "data") (i32.const 1))
(assert_eq (invoke "aligned") (i32.const 1))
(assert_eq (invoke "unaligned") (i32.const 1))
(assert_eq (invoke "cast") (f64.const 42.0))

(assert_eq (invoke "load8_s" (i32.const -1)) (i32.const -1))
(assert_eq (invoke "load8_u" (i32.const -1)) (i32.const 255))
(assert_eq (invoke "load16_s" (i32.const -1)) (i32.const -1))
(assert_eq (invoke "load16_u" (i32.const -1)) (i32.const 65535))

(assert_eq (invoke "load8_s" (i32.const 100)) (i32.const 100))
(assert_eq (invoke "load8_u" (i32.const 200)) (i32.const 200))
(assert_eq (invoke "load16_s" (i32.const 20000)) (i32.const 20000))
(assert_eq (invoke "load16_u" (i32.const 40000)) (i32.const 40000))