Skip to content

Commit

Permalink
[TOPI][SPIRV] Cast to float32 not float64 before log2 in sort/scan (#…
Browse files Browse the repository at this point in the history
…7669)

* [TOPI] Cast to float32 before log2 in sort/scan

* revert sort change since this seems unnecessary

* only does cast to float32 on vk + dynamic input case

* check against IntImm instead of Var

* revert change

* use clz for ceil_log2 when compiling for vk

* add doc on ceil_log2

* fix pylint

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
  • Loading branch information
masahi and Masahiro Masuda authored Apr 17, 2021
1 parent 899bc06 commit e082ef5
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
7 changes: 3 additions & 4 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust

from .. import tag
from ..math import cast
from ..math import cast, ceil_log2
from ..transform import expand_dims, reshape, squeeze, transpose
from ..utils import ceil_div, get_const_int, prod, swap
from .injective import schedule_injective_from_existing
Expand Down Expand Up @@ -103,9 +103,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i

# The following algorithm performs parallel exclusive scan
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
)
lim = ceil_log2(scan_axis_size)

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

Expand Down
10 changes: 3 additions & 7 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..transform import strided_slice, transpose
from .. import tag
from ..utils import ceil_div, swap
from ..math import cast
from ..math import cast, ceil_log2


def _schedule_sort(outs):
Expand Down Expand Up @@ -238,9 +238,7 @@ def compare(a, b):
return out

# Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
lower_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
)
lower_lim = ceil_log2(block_size)

_odd_even_sort(
ib,
Expand All @@ -254,9 +252,7 @@ def compare(a, b):
values_swap,
)

upper_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
)
upper_lim = ceil_log2(size)

def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
first = ib.allocate("int64", (1,), name="first", scope="local")
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,37 @@ def fast_erf(x):
The result.
"""
return cpp.fast_erf(x, x.dtype, tag.ELEMWISE)


def ceil_log2(x):
"""Compute integer ceil log2 with a special code path for vulkan
SPIR-V does not support log2 on fp64. Instead, we compute integer ceil_log2 via clz
intrinsic when the target is vulkan.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
if not isinstance(x, tvm.tir.PrimExpr):
x = tvm.tir.const(x)

if "float" in x.dtype:
return tvm.tir.ceil(tvm.tir.log2(x))

if "vulkan" in tvm.target.Target.current().kind.name:
clz = tvm.tir.clz(x)
bits = int(x.dtype[-2:])
res = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz)

if res.dtype != x.dtype:
return cast(res, x.dtype)

return res

return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype)
10 changes: 10 additions & 0 deletions tests/python/unittest/test_target_codegen_spirv.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ def test_pushconstants():

check_mod(mod, x_np, res_np)

# One 64 bit and one 32 bit constants
dtype = "int32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.cumsum(x))
x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
res_np = np.cumsum(x_np)

check_mod(mod, x_np, res_np)


def test_unique():
if not tvm.testing.device_enabled("vulkan"):
Expand Down

0 comments on commit e082ef5

Please sign in to comment.