Skip to content

Commit

Permalink
compiler: Generate fminf/fmaxf where necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Dec 18, 2024
1 parent 4b2b94c commit 3da0080
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
13 changes: 9 additions & 4 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,25 @@ def _print_Rational(self, expr):
def _print_math_func(self, expr, nest=False, known=None):
cls = type(expr)
name = cls.__name__
if name not in self._prec_funcs:
return super()._print_math_func(expr, nest=nest, known=known)

try:
cname = self.known_functions[name]
except KeyError:
return super()._print_math_func(expr, nest=nest, known=known)

if cname not in self._prec_funcs:
return super()._print_math_func(expr, nest=nest, known=known)

if self.single_prec(expr):
cname = '%sf' % cname

args = ', '.join((self._print(arg) for arg in expr.args))
if nest and len(expr.args) > 2:
args = ', '.join([self._print(expr.args[0]),
self._print_math_func(cls(*expr.args[1:]))])
else:
args = ', '.join([self._print(arg) for arg in expr.args])

return '%s(%s)' % (cname, args)
return f'{cname}({args})'

def _print_Pow(self, expr):
# Need to override because of issue #1627
Expand Down
25 changes: 25 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,31 @@ def test_minmax():
assert np.all(f.data == 4)


@pytest.mark.parametrize('dtype,expected', [
(np.float32, ("fmaxf", "fminf")),
(np.float64, ("fmax", "fmin")),
])
def test_minmax_precision(dtype, expected):
grid = Grid(shape=(5, 5), dtype=dtype)

f = Function(name="f", grid=grid)
g = Function(name="g", grid=grid)

eqn = Eq(f, Min(g, 4.0) + Max(g, 2.0))

op = Operator(eqn)

g.data[:] = 3.0

op.apply()

# Check generated code -- ensure it's using the fp64 versions of min/max,
# that is fminf/fmaxf
assert all(i in str(op) for i in expected)

assert np.all(f.data == 6.0)


class TestRelationsWithAssumptions:

def test_multibounds_op(self):
Expand Down

0 comments on commit 3da0080

Please sign in to comment.