Skip to content

Commit

Permalink
Merge pull request #4 from ivg/tests-1458
Browse files Browse the repository at this point in the history
 updates to changes in BAP, improving `clz`
  • Loading branch information
DukMastaaa authored Apr 20, 2022
2 parents e5be8cb + 6011aa7 commit 73cee18
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 70 deletions.
2 changes: 1 addition & 1 deletion plugins/arm/semantics/aarch64-helper.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
called with true.
Modified from ARMv8 ISA pseudocode."
(let ((memory-width 64) ; change to 32 if 32-bit system
(len (- 64 (clz64 (concat immN (lnot imms))) 1))
(len (- 64 (clz (cast-unsigned 64 (concat immN (lnot imms)))) 1))
(levels (cast-unsigned 6 (ones len)))
(S (logand imms levels))
(R (logand immr levels))
Expand Down
2 changes: 1 addition & 1 deletion plugins/arm/semantics/arm.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

(defun CLZ (rd rn pre _)
(when (condition-holds pre)
(set$ rd (clz32 rn))))
(set$ rd (clz rn))))
93 changes: 50 additions & 43 deletions plugins/primus_lisp/primus_lisp_semantic_primitives.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ open Bap_core_theory
open Bap_primus.Std
open KB.Syntax
open KB.Let
module Z = Bitvec

let export = Primus.Lisp.Type.Spec.[
"+", all any @-> any,
Expand Down Expand Up @@ -377,18 +378,22 @@ module Primitives(CT : Theory.Core)(T : Target) = struct

let nbitv = KB.List.map ~f:bitv

let join_types s xs =
let join s xs =
List.max_elt xs ~compare:(fun x y ->
let xs = sort x and ys = sort y in
Theory.Bitv.(compare_int (size xs) (size ys))) |> function
| None -> s
| Some v -> sort v

let with_nbitv s xs f = match xs with
let first s = function
| [] -> s
| x::_ -> sort x

let with_nbitv s cast xs f = match xs with
| [] -> f s []
| xs ->
nbitv xs >>= fun xs ->
f (join_types s xs) xs
f (cast s xs) xs

type 'a bitv = 'a Theory.Bitv.t Theory.Value.sort

Expand All @@ -398,15 +403,17 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
| Some x -> const_int s x
| None -> CT.signed s !!x

let monoid s sf df init xs =
with_nbitv s xs @@ fun s xs -> match xs with
| [] -> forget@@const_int s init
let monoid s cast sf df init xs =
with_nbitv s cast xs @@ fun s xs ->
let m = Z.modulus (size s) in
match xs with
| [] -> forget@@const_int s Z.(init mod m)
| x :: xs ->
let* init = coerce s x in
KB.List.fold ~init xs ~f:(fun res x ->
match const res, const x with
| Some res, Some x ->
const_int s@@sf res x
const_int s Z.(sf res x mod m)
| _ ->
let* x = coerce s x in
df !!res !!x) |>
Expand Down Expand Up @@ -438,12 +445,13 @@ module Primitives(CT : Theory.Core)(T : Target) = struct

let order sf df xs = forget@@is_ordered sf df xs

let all sf df xs =
let all s cast sf df xs =
true_ >>= fun init ->
with_nbitv s cast xs @@ fun s xs ->
let m = Z.modulus (size s) in
KB.List.fold ~init xs ~f:(fun r x ->
bitv x >>= fun x ->
let r' = match const x with
| Some x -> const_bool (sf x)
| Some x -> const_bool Z.(sf x m)
| None -> df !!x in
r' >>= fun r' ->
r &&& r') |>
Expand Down Expand Up @@ -473,9 +481,16 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
let stores = memory Theory.Effect.Sort.wmem
let loads = pure

let is_negative x = CT.msb x
let is_positive x =
CT.(and_ (non_zero x) (inv (is_negative x)))

let d_is_negative x = CT.msb x
let d_is_positive x =
CT.(and_ (non_zero x) (inv (d_is_negative x)))

let s_is_negative x m = Z.(msb x mod m)
let s_is_positive x m =
not (Z.(s_is_negative x m) && Z.equal x Z.zero)
let s_is_zero x _ =
Z.equal x Z.zero

let word_width s xs =
nbitv xs >>= fun xs ->
Expand Down Expand Up @@ -588,13 +603,14 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
| Some v -> forget@@const_int (sort x) v

let apply_static s x =
let m = Bitvec.modulus (Theory.Bitv.size s) in
let m = Bitvec.modulus (size s) in
forget@@const_int s Bitvec.(x mod m)

let lnot x =
bitv x >>= fun x -> match const x with
| None -> forget@@CT.not !!x
| Some v -> apply_static (sort x) (Bitvec.lnot v)
| Some v ->
apply_static (sort x) (Bitvec.lnot v)

let one_op_x sop dop x =
bitv x >>= fun x -> match const x with
Expand Down Expand Up @@ -824,8 +840,8 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
bitv x >>= fun x -> match const x with
| None -> forget@@cast s !!x
| Some v ->
let r = Theory.Bitv.size s in
let w = Theory.Bitv.size @@ Theory.Value.sort x in
let r = size s in
let w = size @@ Theory.Value.sort x in
forget@@const_int s@@match t with
| `hi -> Bitvec.extract ~hi:(w-1) ~lo:(w-r) v
| `lo -> Bitvec.extract ~hi:r ~lo:0 v
Expand All @@ -835,7 +851,7 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
let open Bitvec.Make(struct
let modulus = Bitvec.modulus r
end) in
(ones lsl int Int.(r - w)) lor v
(ones lsl int w) lor v
else Bitvec.extract ~hi:r ~lo:0 v

let signed = mk_cast `se CT.signed
Expand Down Expand Up @@ -895,15 +911,6 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
(CT.extract b1 (int s b) (int s b) !!x))

let bits = Theory.Target.bits target
module Z = struct
include Bitvec.Make(struct
let modulus = Bitvec.modulus bits
end)
let is_zero = Bitvec.equal zero
let is_negative = msb
let is_positive x =
not (is_negative x) && not (is_zero x)
end

let s = Theory.Bitv.define bits

Expand All @@ -928,23 +935,23 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
let dispatch lbl name args =
let t = target in
match name,args with
| "+",_-> pure@@monoid s Z.add CT.add Z.zero args
| "+",_-> pure@@monoid s join Z.add CT.add (Z.int 0) args
| "-",[x]|"neg",[x] -> pure@@neg x
| "-",_-> pure@@monoid s Z.sub CT.sub Z.zero args
| "*",_-> pure@@monoid s Z.mul CT.mul Z.one args
| "-",_-> pure@@monoid s join Z.sub CT.sub (Z.int 0) args
| "*",_-> pure@@monoid s join Z.mul CT.mul (Z.int 1) args
| "/",[x]-> pure@@reciprocal x
| "/",_-> pure@@monoid s Z.div CT.div Z.one args
| "/",_-> pure@@monoid s join Z.div CT.div (Z.int 1) args
| "s/",[x]-> pure@@sreciprocal x
| "s/",_-> pure@@monoid s Z.sdiv CT.sdiv Z.one args
| "mod",_-> pure@@monoid s Z.rem CT.modulo Z.one args
| "s/",_-> pure@@monoid s join Z.sdiv CT.sdiv (Z.int 1) args
| "mod",_-> pure@@monoid s join Z.rem CT.modulo (Z.int 1) args
| "lnot",[x] -> pure@@lnot x
| "signed-mod",_-> pure@@monoid s Z.srem CT.smodulo Z.one args
| "lshift",_-> pure@@monoid s Z.lshift CT.lshift Z.one args
| "rshift",_-> pure@@monoid s Z.rshift CT.rshift Z.one args
| "arshift",_-> pure@@monoid s Z.arshift CT.arshift Z.one args
| "logand",_-> pure@@monoid s Z.logand CT.logand Z.ones args
| "logor",_-> pure@@monoid s Z.logor CT.logor Z.zero args
| "logxor",_-> pure@@monoid s Z.logxor CT.logxor Z.zero args
| "signed-mod",_-> pure@@monoid s join Z.srem CT.smodulo (Z.int 1) args
| "lshift",_-> pure@@monoid s first Z.lshift CT.lshift (Z.int 1) args
| "rshift",_-> pure@@monoid s first Z.rshift CT.rshift (Z.int 1) args
| "arshift",_-> pure@@monoid s first Z.arshift CT.arshift (Z.int 1) args
| "logand",_-> pure@@monoid s join Z.logand CT.logand (Z.int 1) args
| "logor",_-> pure@@monoid s join Z.logor CT.logor (Z.int 0) args
| "logxor",_-> pure@@monoid s join Z.logxor CT.logxor (Z.int 0) args
| "=",_-> pure@@order Bitvec.(=) CT.eq args
| "<",_-> pure@@order Bitvec.(<) CT.ult args
| "s<",_ -> pure@@order SBitvec.(<) CT.slt args
Expand All @@ -955,9 +962,9 @@ module Primitives(CT : Theory.Core)(T : Target) = struct
| "s<=",_-> pure@@order SBitvec.(<=) CT.ule args
| "s>=",_-> pure@@order SBitvec.(>=) CT.uge args
| "/=",_| "distinct",_-> pure@@forget@@distinct args
| "is-zero",_| "not",_-> pure@@all Bitvec.(equal zero) CT.is_zero args
| "is-positive",_-> pure@@all Z.is_positive is_positive args
| "is-negative",_-> pure@@all Z.is_negative is_negative args
| "is-zero",_| "not",_-> pure@@all s join s_is_zero CT.is_zero args
| "is-positive",_-> pure@@all s join s_is_positive d_is_positive args
| "is-negative",_-> pure@@all s join s_is_negative d_is_negative args
| "word-width",_-> pure@@word_width s args
| "exec-addr",_-> ctrl@@exec_addr args
| "goto-subinstruction",_ -> ctrl@@goto_subinstruction lbl args
Expand Down
119 changes: 94 additions & 25 deletions plugins/primus_lisp/semantics/bits.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
(defun ones (n)
"(ones n) returns a bitvector of length n with all bits set."
(lnot (zeros n)))

(defun rotate-right (bitv n)
"(rotate-right bitv n) rotates bitv to the right by n positions.
Carry-out is ignored."
Expand All @@ -50,7 +50,7 @@
(m (mod n bitv-length)))
; need to trim the result of logor.
(extract (- bitv-length 1) 0
(logor
(logor
(rshift bitv m)
(lshift bitv (- bitv-length m)))))))

Expand All @@ -62,39 +62,89 @@
(let ((bitv-length (word-width bitv))
(m (mod n bitv-length)))
(extract (- bitv-length 1) 0
(logor
(logor
(lshift bitv m)
(rshift bitv (- bitv-length m)))))))

(defmacro popcount/helper (x sh m1 m2 m4 h01)
(prog
(set x (- x (logand (rshift x 1) m1)))
(set x (+ (logand x m2) (logand (rshift x 2) m2)))
(set x (logand (+ x (rshift x 4)) m4))
(rshift (* x h01) sh)))

(defmacro popcount16 (x)
(defun clz (x)
"(clz X) counts leading zeros in X.
The returned value is the number of consecutive zeros starting
from the most significant bit. Returns 0 for 0 and works for
inputs of any size, including inputs that are not statically
known. In the latter case, the computation is unfolded into
the loopless code with the size proportional to the size of word
divided by 64."
(case (word-width x)
8 (clz8 x)
16 (clz16 x)
32 (clz32 x)
64 (clz64 x)
(if (> (word-width x) 64)
(clz/rec x)
(clz/small x))))

(defun popcount (x)
"(popcount X) computes the total number of 1 bits in X."
(if (> (word-width x) 64)
(+ (popcount64 (cast-high 64 x))
(popcount (cast-low (- (word-width x) 64) x)))
(if (= (word-width x) 64)
(popcount64 x)
(popcount64 (cast-unsigned 64 x)))))

;; private helpers

(defun popcount/helper (x sh m1 m2 m4 h01)
(declare (visibility :private))
(let ((x x))
(set x (- x (logand (rshift x 1) m1)))
(set x (+ (logand x m2) (logand (rshift x 2) m2)))
(set x (logand (+ x (rshift x 4)) m4))
(rshift (* x h01) sh)))

(defun popcount8 (x)
(declare (visibility :private))
(popcount/helper x 0
0x55:8
0x33:8
0x0f:8
0x01:8))

(defun popcount16 (x)
(declare (visibility :private))
(popcount/helper x 8
0x5555
0x3333
0x0f0f
0x0101))
0x5555:16
0x3333:16
0x0f0f:16
0x0101:16))

(defmacro popcount32 (x)
(defun popcount32 (x)
(declare (visibility :private))
(popcount/helper x 24
0x55555555
0x33333333
0x0f0f0f0f
0x01010101))
0x55555555:32
0x33333333:32
0x0f0f0f0f:32
0x01010101:32))

(defmacro popcount64 (x)
(defun popcount64 (x)
(declare (visibility :private))
(popcount/helper x 56
0x5555555555555555
0x3333333333333333
0x0f0f0f0f0f0f0f0f
0x0101010101010101))
0x5555555555555555:64
0x3333333333333333:64
0x0f0f0f0f0f0f0f0f:64
0x0101010101010101:64))

(defun clz8 (r)
(declare (visibility :private))
(let ((x r))
(set x (logor x (rshift x 1)))
(set x (logor x (rshift x 2)))
(set x (logor x (rshift x 4)))
(set x (lnot x))
(popcount8 x)))

(defun clz16 (r)
(declare (visibility :private))
(let ((x r))
(set x (logor x (rshift x 1)))
(set x (logor x (rshift x 2)))
Expand All @@ -104,6 +154,7 @@
(popcount16 x)))

(defun clz32 (x)
(declare (visibility :private))
(let ((x x))
(set x (logor x (rshift x 1)))
(set x (logor x (rshift x 2)))
Expand All @@ -114,6 +165,7 @@
(popcount32 x)))

(defun clz64 (x)
(declare (visibility :private))
(let ((x x))
(set x (logor x (rshift x 1)))
(set x (logor x (rshift x 2)))
Expand All @@ -123,3 +175,20 @@
(set x (logor x (rshift x 32)))
(set x (lnot x))
(popcount64 x)))

(defun clz/rec (x)
(declare (visibility :private))
(if (> (word-width x) 64)
(if (is-zero (cast-high 64 x))
(+ 64 (clz (cast-low (- (word-width x) 64) x)))
(clz64 (cast-high 64 x)))
(clz x)))

(defun clz/small (x)
(declare (visibility :private))
(let ((w (word-width x)))
(if (< w 8) (- (clz8 (cast-unsigned 8 x)) (- 8 w))
(if (< w 16) (- (clz16 (cast-unsigned 16 x)) (- 16 w))
(if (< w 32) (- (clz32 (cast-unsigned 32 x)) (- 32 w))
(if (< w 64) (- (clz64 (cast-unsigned 64 x)) (- 64 w))
(clz x)))))))

0 comments on commit 73cee18

Please sign in to comment.