Skip to content

Commit

Permalink
[vulkan] Support rw_texture in aot add_kernel (#6789)
Browse files Browse the repository at this point in the history
  • Loading branch information
ailzhang authored Dec 2, 2022
1 parent a202762 commit fa2433d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 3 deletions.
6 changes: 5 additions & 1 deletion python/taichi/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def produce_injected_args(kernel, symbolic_args=None):
shape=(2, ) * ndim))
else:
raise RuntimeError('')
elif isinstance(anno, (TextureType, RWTextureType)):
elif isinstance(anno, RWTextureType):
texture_shape = (2, ) * anno.num_dimensions
fmt = anno.fmt
injected_args.append(Texture(fmt, texture_shape))
elif isinstance(anno, TextureType):
if symbolic_args is None:
raise RuntimeError(
'Texture type annotation doesn\'t have enough information for aot. Please either specify the channel_format, shape and num_channels in the graph arg declaration.'
Expand Down
15 changes: 13 additions & 2 deletions python/taichi/types/texture_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import warnings

from taichi.lang.enums import Format
from taichi.lang.exception import TaichiCompilationError
from taichi.types.primitive_types import f16, f32, i8, i16, i32, u8, u16, u32

FORMAT2TY_CH = {
Expand Down Expand Up @@ -40,7 +43,9 @@
Format.rgba32i: (i32, 4),
Format.rgba32f: (f32, 4),
}
import warnings

# Reverse lookup by (channel_format, num_channels)
TY_CH2FORMAT = {v: k for k, v in FORMAT2TY_CH.items()}


class TextureType:
Expand Down Expand Up @@ -74,10 +79,16 @@ def __init__(self,
warnings.warn(
"Specifying num_channels and channel_format is deprecated and will be removed in v1.5.0, please specify fmt instead.",
DeprecationWarning)
if num_channels is None or channel_format is None:
raise TaichiCompilationError(
"Incomplete type info for rw_texture, please specify its fmt (ti.Format)"
)
self.num_channels = num_channels
self.channel_format = channel_format
self.fmt = TY_CH2FORMAT[(self.channel_format, self.num_channels)]
else:
self.num_channels, self.channel_format = FORMAT2TY_CH[fmt]
self.channel_format, self.num_channels = FORMAT2TY_CH[fmt]
self.fmt = fmt
self.lod = lod


Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,18 @@ def test_module_arch_fallback():
r'AOT compilation to a different arch than the current one is not yet supported, switching'
):
m = ti.aot.Module(ti.cpu)


@test_utils.test(arch=[ti.vulkan])
def test_save_kernel_with_rwtexture():
@ti.kernel
def write(tex: ti.types.rw_texture(num_dimensions=2,
fmt=ti.Format.r32f,
lod=0)):
for i, j in tex:
tex.store(ti.Vector([i, j]), ti.Vector([1.0, 0.0, 0.0, 0.0]))

m = ti.aot.Module()
m.add_kernel(write)
with tempfile.TemporaryDirectory() as tmpdir:
m.save(tmpdir)
42 changes: 42 additions & 0 deletions tests/python/test_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,45 @@ def ker(tex: ti.types.rw_texture(num_dimensions=2,
for i, j in ti.ndrange(n, n):
ret = ti.cast(1, ti.f32)
tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0]))


@test_utils.test(arch=ti.vulkan)
def test_incomplete_info_rwtexture():
n = 128

with pytest.raises(
ti.TaichiCompilationError,
match=r"Incomplete type info for rw_texture, please specify its fmt"
):

@ti.kernel
def ker(tex: ti.types.rw_texture(num_dimensions=2,
channel_format=ti.f32,
lod=0)):
for i, j in ti.ndrange(n, n):
ret = ti.cast(1, ti.f32)
tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0]))

with pytest.raises(
ti.TaichiCompilationError,
match=r"Incomplete type info for rw_texture, please specify its fmt"
):

@ti.kernel
def ker(tex: ti.types.rw_texture(num_dimensions=2,
num_channels=2,
lod=0)):
for i, j in ti.ndrange(n, n):
ret = ti.cast(1, ti.f32)
tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0]))

with pytest.raises(
ti.TaichiCompilationError,
match=r"Incomplete type info for rw_texture, please specify its fmt"
):

@ti.kernel
def ker(tex: ti.types.rw_texture(num_dimensions=2, lod=0)):
for i, j in ti.ndrange(n, n):
ret = ti.cast(1, ti.f32)
tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0]))

0 comments on commit fa2433d

Please sign in to comment.