From fa2433d44b64d4c2969fdb3edec2f5eb2f1ba189 Mon Sep 17 00:00:00 2001 From: Ailing Date: Thu, 1 Dec 2022 22:28:59 -0800 Subject: [PATCH] [vulkan] Support rw_texture in aot add_kernel (#6789) --- python/taichi/aot/utils.py | 6 ++++- python/taichi/types/texture_type.py | 15 +++++++++-- tests/python/test_aot.py | 15 +++++++++++ tests/python/test_deprecation.py | 42 +++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/python/taichi/aot/utils.py b/python/taichi/aot/utils.py index 6fe433bc1c5bd..1920b65ad5f4d 100644 --- a/python/taichi/aot/utils.py +++ b/python/taichi/aot/utils.py @@ -89,7 +89,11 @@ def produce_injected_args(kernel, symbolic_args=None): shape=(2, ) * ndim)) else: raise RuntimeError('') - elif isinstance(anno, (TextureType, RWTextureType)): + elif isinstance(anno, RWTextureType): + texture_shape = (2, ) * anno.num_dimensions + fmt = anno.fmt + injected_args.append(Texture(fmt, texture_shape)) + elif isinstance(anno, TextureType): if symbolic_args is None: raise RuntimeError( 'Texture type annotation doesn\'t have enough information for aot. Please either specify the channel_format, shape and num_channels in the graph arg declaration.' diff --git a/python/taichi/types/texture_type.py b/python/taichi/types/texture_type.py index 77d3cea5768bf..107f7c3ded0df 100644 --- a/python/taichi/types/texture_type.py +++ b/python/taichi/types/texture_type.py @@ -1,4 +1,7 @@ +import warnings + from taichi.lang.enums import Format +from taichi.lang.exception import TaichiCompilationError from taichi.types.primitive_types import f16, f32, i8, i16, i32, u8, u16, u32 FORMAT2TY_CH = { @@ -40,7 +43,9 @@ Format.rgba32i: (i32, 4), Format.rgba32f: (f32, 4), } -import warnings + +# Reverse lookup by (channel_format, num_channels) +TY_CH2FORMAT = {v: k for k, v in FORMAT2TY_CH.items()} class TextureType: @@ -74,10 +79,16 @@ def __init__(self, warnings.warn( "Specifying num_channels and channel_format is deprecated and will be removed in v1.5.0, please specify fmt instead.", DeprecationWarning) + if num_channels is None or channel_format is None: + raise TaichiCompilationError( + "Incomplete type info for rw_texture, please specify its fmt (ti.Format)" + ) self.num_channels = num_channels self.channel_format = channel_format + self.fmt = TY_CH2FORMAT[(self.channel_format, self.num_channels)] else: - self.num_channels, self.channel_format = FORMAT2TY_CH[fmt] + self.channel_format, self.num_channels = FORMAT2TY_CH[fmt] + self.fmt = fmt self.lod = lod diff --git a/tests/python/test_aot.py b/tests/python/test_aot.py index 8fcea000dda21..f25914e8ca4dd 100644 --- a/tests/python/test_aot.py +++ b/tests/python/test_aot.py @@ -650,3 +650,18 @@ def test_module_arch_fallback(): r'AOT compilation to a different arch than the current one is not yet supported, switching' ): m = ti.aot.Module(ti.cpu) + + +@test_utils.test(arch=[ti.vulkan]) +def test_save_kernel_with_rwtexture(): + @ti.kernel + def write(tex: ti.types.rw_texture(num_dimensions=2, + fmt=ti.Format.r32f, + lod=0)): + for i, j in tex: + tex.store(ti.Vector([i, j]), ti.Vector([1.0, 0.0, 0.0, 0.0])) + + m = ti.aot.Module() + m.add_kernel(write) + with tempfile.TemporaryDirectory() as tmpdir: + m.save(tmpdir) diff --git a/tests/python/test_deprecation.py b/tests/python/test_deprecation.py index 2fa79fc32c8da..5071441264116 100644 --- a/tests/python/test_deprecation.py +++ b/tests/python/test_deprecation.py @@ -130,3 +130,45 @@ def ker(tex: ti.types.rw_texture(num_dimensions=2, for i, j in ti.ndrange(n, n): ret = ti.cast(1, ti.f32) tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0])) + + +@test_utils.test(arch=ti.vulkan) +def test_incomplete_info_rwtexture(): + n = 128 + + with pytest.raises( + ti.TaichiCompilationError, + match=r"Incomplete type info for rw_texture, please specify its fmt" + ): + + @ti.kernel + def ker(tex: ti.types.rw_texture(num_dimensions=2, + channel_format=ti.f32, + lod=0)): + for i, j in ti.ndrange(n, n): + ret = ti.cast(1, ti.f32) + tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0])) + + with pytest.raises( + ti.TaichiCompilationError, + match=r"Incomplete type info for rw_texture, please specify its fmt" + ): + + @ti.kernel + def ker(tex: ti.types.rw_texture(num_dimensions=2, + num_channels=2, + lod=0)): + for i, j in ti.ndrange(n, n): + ret = ti.cast(1, ti.f32) + tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0])) + + with pytest.raises( + ti.TaichiCompilationError, + match=r"Incomplete type info for rw_texture, please specify its fmt" + ): + + @ti.kernel + def ker(tex: ti.types.rw_texture(num_dimensions=2, lod=0)): + for i, j in ti.ndrange(n, n): + ret = ti.cast(1, ti.f32) + tex.store(ti.Vector([i, j]), ti.Vector([ret, 0.0, 0.0, 0.0]))