Skip to content

Commit

Permalink
consolidating the two validators and removing assertion check from ev…
Browse files Browse the repository at this point in the history
…aluator
  • Loading branch information
apbose committed Apr 4, 2024
1 parent 95fb7fb commit 123f8b6
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def rand_validator(rand_node: Node) -> bool:
if layout is not None:
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
return False
return True


@dynamo_tensorrt_converter(
Expand All @@ -76,21 +77,8 @@ def aten_ops_rand(
return np.random.rand(*args)


def randn_validator(randn_node: Node) -> bool:
dtype = randn_node.kwargs.get("dtype", None)
layout = randn_node.kwargs.get("layout", None)
if dtype is not None:
_LOGGER.debug(
f"Currently we don't support specifying output dtype, got {dtype}."
)
return False
if layout is not None:
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
return False


@dynamo_tensorrt_converter(
torch.ops.aten.randn.default, capability_validator=randn_validator
torch.ops.aten.randn.default, capability_validator=rand_validator
)
def aten_ops_randn(
ctx: ConversionContext,
Expand Down Expand Up @@ -118,6 +106,7 @@ def randperm_validator(randperm_node: Node) -> bool:
if layout is not None:
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
return False
return True


@dynamo_tensorrt_converter(
Expand All @@ -131,7 +120,4 @@ def aten_ops_randperm(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
device = kwargs.get("device", None)
input = args[0]
if not isinstance(input, int):
raise RuntimeError(f"The input must be an integer")
return np.random.permutation(*args)

0 comments on commit 123f8b6

Please sign in to comment.