Skip to content

Commit

Permalink
Make dtype to/from str function public
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Oct 31, 2024
1 parent 6ff055e commit db17e7b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _get_json_tensor(
if qp.get("input_zp_dtype") is not None
else "torch.int8"
)
quantization_dtype = tensors._serialized_name_to_dtype(
quantization_dtype = tensors.serialized_name_to_dtype(
quantization_type.split(".")[-1]
)
if output_scale is not None:
Expand Down
8 changes: 4 additions & 4 deletions sharktank/sharktank/types/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
register_quantized_layout,
MetaDataValueType,
QuantizedLayout,
_dtype_to_serialized_name,
_serialized_name_to_dtype,
dtype_to_serialized_name,
serialized_name_to_dtype,
)

from .layout_utils import (
Expand Down Expand Up @@ -96,7 +96,7 @@ def create(
m = planes.get("m")
dtype_str = metadata.get("dtype")
if dtype_str is not None:
dtype = _serialized_name_to_dtype(dtype_str)
dtype = serialized_name_to_dtype(dtype_str)
else:
# Backwards compat with old serialized. Emulate original behavior
# before mixed precision.
Expand All @@ -106,7 +106,7 @@ def create(
@property
def metadata(self) -> Optional[dict[str, MetaDataValueType]]:
"""Additional metadata needed to reconstruct a layout."""
return {"dtype": _dtype_to_serialized_name(self._dtype)}
return {"dtype": dtype_to_serialized_name(self._dtype)}

@property
def planes(self) -> dict[str, torch.Tensor]:
Expand Down
12 changes: 6 additions & 6 deletions sharktank/sharktank/types/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
QuantizedTensor,
UnnamedTensorName,
register_inference_tensor,
_serialized_name_to_dtype,
_dtype_to_serialized_name,
serialized_name_to_dtype,
dtype_to_serialized_name,
)

__all__ = [
Expand Down Expand Up @@ -246,7 +246,7 @@ def create(
raise IOError("Missing property") from e
axis = int(extra_properties["axis"]) if "axis" in extra_properties else None
disable_saturate = bool(extra_properties.get("disable_saturate"))
dtype = _serialized_name_to_dtype(dtype_name)
dtype = serialized_name_to_dtype(dtype_name)
return cls(
name=name,
scale=scale,
Expand All @@ -272,7 +272,7 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad
scale_name = f"{self.name}:scale"
rscale_name = f"{self.name}:rscale"
offset_name = f"{self.name}:offset"
extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)}
extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)}
if self._axis is not None:
extra_properties["axis"] = self._axis
if self._disable_saturate:
Expand Down Expand Up @@ -388,7 +388,7 @@ def create(
dtype_name = extra_properties["dtype"]
except KeyError as e:
raise IOError("Missing property") from e
dtype = _serialized_name_to_dtype(dtype_name)
dtype = serialized_name_to_dtype(dtype_name)
return cls(
name=name,
dtype=dtype,
Expand All @@ -400,7 +400,7 @@ def globals(self) -> dict[str, torch.Tensor]:

def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata:
"""Adds this tensor to the global archive."""
extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)}
extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)}
raw_tensors = {}
return InferenceTensorMetadata(
self.serialized_name(),
Expand Down
6 changes: 4 additions & 2 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
__all__ = [
"AnyTensor",
"DefaultPrimitiveTensor",
"dtype_to_serialized_name",
"flatten_tensor_tree",
"InferenceTensor",
"MetaDataValueType",
Expand All @@ -48,6 +49,7 @@
"QuantizedTensor",
"register_quantized_layout",
"ReplicatedTensor",
"serialized_name_to_dtype",
"ShardedTensor",
"SplitPrimitiveTensor",
"torch_tree_flatten",
Expand Down Expand Up @@ -1235,7 +1237,7 @@ def unbox_tensor(t: Any) -> Tensor:
########################################################################################


def _dtype_to_serialized_name(dtype: torch.dtype) -> str:
def dtype_to_serialized_name(dtype: torch.dtype) -> str:
try:
return _DTYPE_TO_NAME[dtype]
except KeyError as e:
Expand All @@ -1244,7 +1246,7 @@ def _dtype_to_serialized_name(dtype: torch.dtype) -> str:
) from e


def _serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
def serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
try:
return _NAME_TO_DTYPE[dtype_name]
except KeyError as e:
Expand Down

0 comments on commit db17e7b

Please sign in to comment.