diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index aa32b9ab70..c9e3aaa31e 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -231,6 +231,7 @@ def layout_declare_tensor( var_name: str, dtype: str, storage_type: str, + is_scalar_array: bool = False, precision: str = "PRECISION", ) -> str: assert storage_type.lower() in ["buffer", "texture3d", "texture2d"] @@ -242,7 +243,12 @@ def layout_declare_tensor( # Create buffer binding if storage_type.lower() == "buffer": return layout_declare_buffer( - slot, access_type, var_name, dtype, precision, is_scalar_array=False + slot, + access_type, + var_name, + dtype, + precision, + is_scalar_array=is_scalar_array, ) # Create image/sampler binding @@ -533,7 +539,7 @@ def generateVariantCombinations( curr_suffix = ( suffix + "_" + str(i) if suffix else str(i) ) - param_values.append((param_name, curr_suffix, str(i))) + param_values.append((param_name, curr_suffix, i)) else: raise ValueError( f"{value['RANGE']} is not a valid range. Must be in format [start, end] (inclusive)." @@ -595,7 +601,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: variant_name = variant["NAME"] for param_value in combination: default_params_copy[param_value[0]] = param_value[2] - if len(param_value[1]) > 0: + if len(str(param_value[1])) > 0: variant_name = f"{variant_name}_{param_value[1]}" default_params_copy["NAME"] = variant_name