Skip to content

Commit

Permalink
Support serializing scalar tensors as SymInt values (#6070)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #6070

## Context

* Add `SymInt` to serialization schema
* Make the serializer serialize scalar tensors as `SymInt` instead of `VkTensor`
* Add support for `SymInt` in `VulkanBackend.cpp`
ghstack-source-id: 247163958
exported-using-ghexport

Reviewed By: jorgep31415

Differential Revision: D64139868

fbshipit-source-id: 44d225ca6c63b311e4839783787713a38b8b6017
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Oct 10, 2024
1 parent 1a0c2c7 commit 4b6a033
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 0 deletions.
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ class GraphBuilder {
ref_mapping_[fb_id] = ref;
}

void add_symint_to_graph(const uint32_t fb_id, VkValuePtr value) {
const int32_t fb_symint = value->value_as_SymInt()->value();
ValueRef ref = compute_graph_->add_symint(fb_symint);
ref_mapping_[fb_id] = ref;
}

void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) {
ET_CHECK_MSG(
!fb_id_exists(fb_id),
Expand Down Expand Up @@ -300,6 +306,9 @@ class GraphBuilder {
case vkgraph::GraphTypes::String:
add_string_to_graph(fb_id, value);
break;
case vkgraph::GraphTypes::SymInt:
add_symint_to_graph(fb_id, value);
break;
default:
ET_CHECK_MSG(false, "Unsupported value type.");
}
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ table ValueList {
items:[int];
}

table SymInt {
value:int;
}

union GraphTypes {
Null,
Int,
Expand All @@ -100,6 +104,7 @@ union GraphTypes {
BoolList,
ValueList,
String,
SymInt,
}

table VkValue {
Expand Down
11 changes: 11 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def maybe_add_constant_tensor(self, node: Node) -> int:
return constant_id

def create_node_value(self, node: Node) -> int:
# If the node has been marked as a scalar tensor, create a SymInt instead of a tensor
if node.meta.get("vkdg_is_scalar_tensor", False):
new_id = self.create_symint_value()
self.node_to_value_ids[node] = new_id
return new_id

spec = node.meta.get("spec")
if isinstance(spec, TensorSpec):
constant_id = self.maybe_add_constant_tensor(node)
Expand Down Expand Up @@ -169,6 +175,11 @@ def create_scalar_value(self, scalar: _ScalarType) -> int:
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
return new_id

def create_symint_value(self) -> int:
new_id = len(self.values)
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0)))
return new_id

def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
# Negative id indicates that this tensor will have its own dedicated memory.
mem_obj_id = -1
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class String:
string_val: str


@dataclass
class SymInt:
value: int


GraphTypes = Union[
Null,
Int,
Expand All @@ -111,6 +116,7 @@ class String:
DoubleList,
ValueList,
String,
SymInt,
]


Expand Down

0 comments on commit 4b6a033

Please sign in to comment.