Skip to content

Commit

Permalink
Don't use Pyro jit by default (#1474)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/botorch#1474

Pull Request resolved: #1241

We have observed that using jit with Pyro can result in increased memory usage and even memory leaks, so we are disabling it for now. This will make model fitting with NUTS about ~2X slower.

Reviewed By: esantorella

Differential Revision: D40949763

fbshipit-source-id: a5c6dc772c3000d4ee788a541fe2f7009fe3aafa
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Nov 3, 2022
1 parent bc9ec01 commit f512de0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ax/models/tests/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def test_FullyBayesianBotorchModelPyro(self, dtype=torch.double, cuda=False):
# check NUTS.__init__ arguments
_mock_nuts.assert_called_with(
single_task_pyro_model,
jit_compile=True,
jit_compile=False,
full_mass=True,
ignore_jit_warnings=True,
max_tree_depth=1,
Expand Down
3 changes: 2 additions & 1 deletion ax/models/torch/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def run_inference(
verbose: bool = False,
task_feature: Optional[int] = None,
rank: Optional[int] = None,
jit_compile: bool = False,
) -> Dict[str, Tensor]:
start = time.time()
try:
Expand All @@ -400,7 +401,7 @@ def run_inference(
raise RuntimeError("Cannot call run_inference without pyro installed!")
kernel = NUTS(
pyro_model,
jit_compile=True,
jit_compile=jit_compile,
full_mass=True,
ignore_jit_warnings=True,
max_tree_depth=max_tree_depth,
Expand Down

0 comments on commit f512de0

Please sign in to comment.