Skip to content

Commit

Permalink
First concretize "deep" hashes, then shallow ones.
Browse files Browse the repository at this point in the history
  • Loading branch information
palkeo committed Jul 13, 2019
1 parent 6504fae commit 579de85
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
27 changes: 25 additions & 2 deletions pakala/claripy_sha3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import itertools
import operator

from claripy.ast import bv
import claripy
Expand Down Expand Up @@ -49,10 +50,29 @@ def _no_sha3_symbol(ast):
return all(_no_sha3_symbol(child) for child in ast.args)


def _this_sha3_symbol(ast, symbol):
if not isinstance(ast, claripy.ast.base.Base):
return False
if ast is symbol:
return True
return any(_this_sha3_symbol(child, symbol) for child in ast.args)



def _no_sha3_symbols(constraints):
return all(_no_sha3_symbol(ast) for ast in constraints)


def _hash_depth(hashes, hash_symbol):
"""Returns how "deep" this hash symbol is, if it's inside another hash."""
depth = 0
for in1, s1 in hashes.items():
if _this_sha3_symbol(in1, hash_symbol):
assert s1 is not hash_symbol # A hash cannot contain itself.
depth = max(depth, 1 + _hash_depth(hashes, s1))
return depth


def get_claripy_solver():
# TODO: What about SolverComposite? Tried, and seems slower.
return claripy.Solver()
Expand Down Expand Up @@ -168,7 +188,10 @@ def _hash_constraints(self, extra_constraints, hashes, pairs_done=None):

assert self.solver.satisfiable(extra_constraints=extra_constraints)

for in1, s1 in hashes.items():
# We need to first concretize the hashes that are the "deepest", i.e. that
# are serving as input for other hashes.
hash_depth = {symbol: _hash_depth(hashes, symbol) for symbol in hashes.values()}
for in1, s1 in sorted(hashes.items(), key=lambda i: hash_depth[i[1]], reverse=True):
# Next line can raise UnsatError. Handled in the caller if needed.
sol1, = self.solver.eval(in1, 1, extra_constraints=extra_constraints)
extra_constraints.append(in1 == sol1)
Expand All @@ -180,7 +203,7 @@ def _hash_constraints(self, extra_constraints, hashes, pairs_done=None):
)
assert len(sol1_bytes) * 8 == in1.length
extra_constraints.append(s1 == eth_utils.crypto.keccak(sol1_bytes))
logger.debug("Added concrete constraint: %s", extra_constraints[-1])
logger.debug("Added concrete constraint on hash: %s and on input: %s", extra_constraints[-1], extra_constraints[-2])

return tuple(extra_constraints)

Expand Down
40 changes: 37 additions & 3 deletions pakala/test_claripy_sha3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import functools
import logging

import claripy

Expand Down Expand Up @@ -97,9 +98,6 @@ def test_solver_recursive(self):
in2 = claripy.BVS("in2", 256)

self.assertFalse(s.satisfiable(extra_constraints=[Sha3(Sha3(in1)) == 0]))
self.assertFalse(
s.satisfiable(extra_constraints=[Sha3(Sha3(in1)) == Sha3(bvv(0))])
)
self.assertTrue(
s.satisfiable(extra_constraints=[Sha3(Sha3(in1)) == Sha3(Sha3(bvv(0)))])
)
Expand Down Expand Up @@ -133,6 +131,42 @@ def test_solver_recursive(self):
s_copy = s.branch()
self.assertFalse(s_copy.satisfiable())

def test_solver_recursive_unbalanced(self):
s = get_solver()
in1 = claripy.BVS("in1", 256)
in2 = claripy.BVS("in2", 256)

self.assertFalse(
s.satisfiable(extra_constraints=[Sha3(Sha3(in1)) == Sha3(bvv(0))])
)
self.assertTrue(
s.satisfiable(extra_constraints=[Sha3(Sha3(in1)) == Sha3(in2)])
)
logging.debug('here')
self.assertTrue(
s.satisfiable(extra_constraints=[Sha3(in1) == Sha3(Sha3(in2))])
)

self.assertTrue(
s.satisfiable(extra_constraints=[Sha3(Sha3(Sha3(in1))) == Sha3(in2)])
)
self.assertTrue(
s.satisfiable(extra_constraints=[Sha3(in1) == Sha3(Sha3(Sha3(in2)))])
)

def test_solver_three_symbols(self):
s = get_solver()
in1 = claripy.BVS("in1", 256)
in2 = claripy.BVS("in2", 256)
in3 = claripy.BVS("in2", 256)

self.assertFalse(
s.satisfiable(extra_constraints=[Sha3(in1) == Sha3(Sha3(in3)) + Sha3(Sha3(Sha3(in2)))])
)
self.assertTrue(
s.satisfiable(extra_constraints=[in1 == Sha3(Sha3(in3)) + Sha3(Sha3(Sha3(in2)))])
)

def test_solver_copy(self):
s = get_solver()
in1 = claripy.BVS("in1", 256)
Expand Down

0 comments on commit 579de85

Please sign in to comment.