From 7b4c4260de22e903e2ef4f8cb467c98ffdd6c303 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Fri, 17 Jan 2025 13:47:29 -0800 Subject: [PATCH] Add compiler flags to help hint more to torch --- pytensor/link/pytorch/linker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..0d2e190ddc 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -37,6 +37,10 @@ def conversion_func_register(*args, **kwargs): def jit_compile(self, fn): import torch + # two flags that tend to help our graphs + torch._dynamo.config.capture_func_transforms = True + torch._dynamo.config.capture_scalar_outputs = True + from pytensor.link.pytorch.dispatch import pytorch_typify class wrapper: