Skip to content

Commit

Permalink
Add re-import fixture to dataset unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralph Liu committed Jul 20, 2023
1 parent b2102e8 commit 121c888
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions python/cugraph/cugraph/tests/utils/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from pathlib import Path
from tempfile import TemporaryDirectory
import gc
import sys
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory

import pytest

Expand Down Expand Up @@ -73,6 +73,22 @@ def setup(tmpdir):
gc.collect()


@pytest.fixture()
def setup_deprecation_warning_tests():
"""
Fixture used to set warning filters to 'default' and reload
experimental.datasets module if it has been previously
imported. Tests that import this fixture are expected to
import cugraph.experimental.datasets
"""
warnings.filterwarnings("default")

if "cugraph.experimental.datasets" in sys.modules:
del sys.modules["cugraph.experimental.datasets"]

yield


###############################################################################
# Tests

Expand Down Expand Up @@ -270,19 +286,14 @@ def test_is_directed(dataset):
#
# Test experimental for DeprecationWarnings
#
def test_experimental_dataset_import():
warnings.filterwarnings("default")

with pytest.deprecated_call() as record:
def test_experimental_dataset_import(setup_deprecation_warning_tests):
with pytest.deprecated_call():
from cugraph.experimental.datasets import karate

if not record:
pytest.fail("Expected experimental.datasets to raise DeprecationWarning")

karate.unload()


def test_experimental_method_warnings():
def test_experimental_method_warnings(setup_deprecation_warning_tests):
from cugraph.experimental.datasets import (
load_all,
set_download_dir,
Expand All @@ -292,22 +303,9 @@ def test_experimental_method_warnings():
warnings.filterwarnings("default")
tmpd = TemporaryDirectory()

with pytest.deprecated_call() as record:
with pytest.deprecated_call():
set_download_dir(tmpd.name)

if not record:
pytest.fail("Expected set_download_dir to raise DeprecationWarning")

with pytest.deprecated_call() as record:
get_download_dir()

if not record:
pytest.fail("Expected get_download_dir to raise DeprecationWarning")

with pytest.deprecated_call() as record:
load_all()

if not record:
pytest.fail("Expected load_all to raise DeprecationWarning")

tmpd.cleanup()

0 comments on commit 121c888

Please sign in to comment.