From 1a322b02a83823e5ab7ef03aeff3869c779aecc3 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Thu, 29 Aug 2024 07:23:18 -0700 Subject: [PATCH] Mock out downloading data from internet in torchvision unit test (#2725) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2725 Reviewed By: Balandat Differential Revision: D61970253 fbshipit-source-id: 44a38f705f379f56da9ab2a3b12fe7c7ecc666ac --- ax/storage/json_store/tests/test_json_store.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index f5886141483..d6538ced2d0 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -10,6 +10,7 @@ import os import tempfile from functools import partial +from unittest.mock import patch import numpy as np import torch @@ -51,6 +52,7 @@ get_multi_objective_benchmark_problem, get_single_objective_benchmark_problem, get_sobol_gpei_benchmark_method, + TestDataset, ) from ax.utils.testing.core_stubs import ( get_abandoned_arm, @@ -404,12 +406,20 @@ def __post_init__(self, doesnt_serialize: None) -> None: self.assertEqual(obj, recovered) def test_EncodeDecode_torchvision_problem(self) -> None: - test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST") + registry_path = "ax.benchmark.problems.hpo.torchvision._REGISTRY" + mock_registry = {"MNIST": TestDataset} + with patch.dict(registry_path, mock_registry): + test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST") + self.assertIsNotNone(test_problem.train_loader) self.assertIsNotNone(test_problem.test_loader) + as_json = object_to_json(obj=test_problem) self.assertNotIn("train_loader", as_json) - recovered = object_from_json(as_json) + + with patch.dict(registry_path, mock_registry): + recovered = object_from_json(as_json) + self.assertIsNotNone(recovered.train_loader) self.assertEqual(test_problem, recovered)