Skip to content

Commit

Permalink
enable extern lib offload for nvptx
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 29, 2021
1 parent ef032b3 commit ee2363b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add cudnn implementation
if target.kind.name == "cuda" and "cudnn" in target.libs:
if target.kind.name in ["cuda", "nvptx"] and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and padding[1] == padding[3]:
strategy.add_implementation(
wrap_compute_conv2d(
Expand Down Expand Up @@ -705,7 +705,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
name="dense_tensorcore.cuda",
plevel=20,
)
if target.kind.name == "cuda" and "cublas" in target.libs:
if target.kind.name in ["cuda", "nvptx"] and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas),
wrap_topi_schedule(topi.cuda.schedule_dense_cublas),
Expand Down Expand Up @@ -858,7 +858,7 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda",
)
if target.kind.name == "cuda" and get_global_func(
if target.kind.name in ["cuda", "nvptx"] and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
strategy.add_implementation(
Expand All @@ -879,7 +879,7 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda",
)
if target.kind.name == "cuda" and get_global_func(
if target.kind.name in ["cuda", "nvptx"] and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
strategy.add_implementation(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
)

target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
if target and target.kind.name in ["cuda", "nvptx"] and is_thrust_available():
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def exclusive_scan(

def do_scan(data, output_dtype):
target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
if target and target.kind.name in ["cuda", "nvptx"] and is_thrust_available():
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
Expand Down

0 comments on commit ee2363b

Please sign in to comment.