Skip to content

Commit

Permalink
Allow expression of scalar tensor buffers, non string values in varia…
Browse files Browse the repository at this point in the history
…nts (#4292)

Summary:
Pull Request resolved: #4292

Some simple improvements to the SPIR-V compilation script:

1. Allow `layout_declare_tensor` to create a scalar buffer instead of always creating a vectorized buffer
2. Allow handling of non-string (i.e. int) values in shader codegen YAML configurations.

Differential Revision: D59877805
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Jul 18, 2024
1 parent e5687a4 commit 1f9cce4
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def layout_declare_tensor(
var_name: str,
dtype: str,
storage_type: str,
is_scalar_array: bool = False,
precision: str = "PRECISION",
) -> str:
assert storage_type.lower() in ["buffer", "texture3d", "texture2d"]
Expand All @@ -242,7 +243,12 @@ def layout_declare_tensor(
# Create buffer binding
if storage_type.lower() == "buffer":
return layout_declare_buffer(
slot, access_type, var_name, dtype, precision, is_scalar_array=False
slot,
access_type,
var_name,
dtype,
precision,
is_scalar_array=is_scalar_array,
)

# Create image/sampler binding
Expand Down Expand Up @@ -533,7 +539,7 @@ def generateVariantCombinations(
curr_suffix = (
suffix + "_" + str(i) if suffix else str(i)
)
param_values.append((param_name, curr_suffix, str(i)))
param_values.append((param_name, curr_suffix, i))
else:
raise ValueError(
f"{value['RANGE']} is not a valid range. Must be in format [start, end] (inclusive)."
Expand Down Expand Up @@ -595,7 +601,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
variant_name = variant["NAME"]
for param_value in combination:
default_params_copy[param_value[0]] = param_value[2]
if len(param_value[1]) > 0:
if len(str(param_value[1])) > 0:
variant_name = f"{variant_name}_{param_value[1]}"

default_params_copy["NAME"] = variant_name
Expand Down

0 comments on commit 1f9cce4

Please sign in to comment.