Skip to content

Commit

Permalink
Mock out downloading data from internet in torchvision unit test (#2725)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2725

Reviewed By: Balandat

Differential Revision: D61970253

fbshipit-source-id: 44a38f705f379f56da9ab2a3b12fe7c7ecc666ac
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 29, 2024
1 parent 373fe81 commit 1a322b0
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import tempfile
from functools import partial
from unittest.mock import patch

import numpy as np
import torch
Expand Down Expand Up @@ -51,6 +52,7 @@
get_multi_objective_benchmark_problem,
get_single_objective_benchmark_problem,
get_sobol_gpei_benchmark_method,
TestDataset,
)
from ax.utils.testing.core_stubs import (
get_abandoned_arm,
Expand Down Expand Up @@ -404,12 +406,20 @@ def __post_init__(self, doesnt_serialize: None) -> None:
self.assertEqual(obj, recovered)

def test_EncodeDecode_torchvision_problem(self) -> None:
test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST")
registry_path = "ax.benchmark.problems.hpo.torchvision._REGISTRY"
mock_registry = {"MNIST": TestDataset}
with patch.dict(registry_path, mock_registry):
test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST")

self.assertIsNotNone(test_problem.train_loader)
self.assertIsNotNone(test_problem.test_loader)

as_json = object_to_json(obj=test_problem)
self.assertNotIn("train_loader", as_json)
recovered = object_from_json(as_json)

with patch.dict(registry_path, mock_registry):
recovered = object_from_json(as_json)

self.assertIsNotNone(recovered.train_loader)
self.assertEqual(test_problem, recovered)

Expand Down

0 comments on commit 1a322b0

Please sign in to comment.