Skip to content

Commit

Permalink
Reland 132308, 132314, 132318, 132334 - Make builtin nn modules attri…
Browse files Browse the repository at this point in the history
…butes static (#132539)

Summary:
Relanding 4 PRs ending at pytorch/pytorch#132334

X-link: pytorch/pytorch#132539
Approved by: https://github.com/Skylion007, https://github.com/yanboliang, https://github.com/mlazos

Reviewed By: yanboliang, PaliC

Differential Revision: D60684569

Pulled By: anijain2305

fbshipit-source-id: 9daa6aca81a0baeace3e6da22630cf17b34c5cf1
  • Loading branch information
anijain2305 authored and facebook-github-bot committed Aug 6, 2024
1 parent f597aa4 commit a9b8a9a
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import torch.utils._pytree as pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._guards import Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
Expand Down Expand Up @@ -2127,7 +2127,7 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):
def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
guard_source: "torch._guards.GuardSource",
tensor_source: Source,
) -> Tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
Expand All @@ -2140,12 +2140,18 @@ def tensor_always_has_static_shape(
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
"""
from .source import is_from_unspecialized_param_buffer_source

if (
guard_source.is_specialized_nn_module()
and config.force_nn_module_property_static_shapes
):
tensor_source.guard_source().is_specialized_nn_module()
or tensor_source.guard_source().is_unspecialized_builtin_nn_module()
) and config.force_nn_module_property_static_shapes:
return True, TensorStaticReason.NN_MODULE_PROPERTY
if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes:

if (
type(tensor) is torch.nn.Parameter
or is_from_unspecialized_param_buffer_source(tensor_source)
) and config.force_parameter_static_shapes:
return True, TensorStaticReason.PARAMETER
if not is_tensor:
return True, TensorStaticReason.NOT_TENSOR
Expand Down

0 comments on commit a9b8a9a

Please sign in to comment.