Skip to content

Commit

Permalink
Make typing strict on tests (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt authored Jan 15, 2025
1 parent 55abd65 commit d529a96
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 95 deletions.
45 changes: 25 additions & 20 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import inspect
import unittest
from typing import Callable, TypeVar, cast

import pandas as pd
import rdflib
Expand All @@ -14,37 +15,41 @@

SKIP = {"__init__"}

X = TypeVar("X")

def _df_equal(a: pd.DataFrame, b: pd.DataFrame, msg=None) -> bool:
return a.values.tolist() == b.values.tolist()

def _df_equal(a: pd.DataFrame, b: pd.DataFrame, msg: str | None = None) -> bool:
return bool(a.values.tolist() == b.values.tolist())

def _rdf_equal(a: rdflib.Graph, b: rdflib.Graph, msg=None) -> bool:

def _rdf_equal(a: rdflib.Graph, b: rdflib.Graph, msg: str | None = None) -> bool:
return {tuple(t) for t in a} == {tuple(t) for t in b}


def _etree_equal(a: etree.ElementTree, b: etree.ElementTree, msg=None) -> bool:
return etree.tostring(a) == etree.tostring(b)
def _etree_equal(a: etree.ElementTree, b: etree.ElementTree, msg: str | None = None) -> bool:
return cast(str, etree.tostring(a)) == cast(str, etree.tostring(b))


class TestExposed(unittest.TestCase):
"""Test API exposure."""

def setUp(self) -> None:
"""Set up the test case."""
self.addTypeEqualityFunc(pd.DataFrame, _df_equal)
self.addTypeEqualityFunc(rdflib.Graph, _rdf_equal)
self.addTypeEqualityFunc(type(etree.ElementTree()), _etree_equal)
self.addTypeEqualityFunc(pd.DataFrame, _df_equal) # type:ignore[arg-type]
self.addTypeEqualityFunc(rdflib.Graph, _rdf_equal) # type:ignore[arg-type]
self.addTypeEqualityFunc(type(etree.ElementTree()), _etree_equal) # type:ignore[arg-type]

def assert_io(self, obj, ext: str, dump, load):
def assert_io(
self, obj: X, extension: str, dump: Callable[..., None], load: Callable[..., X]
) -> None:
"""Test an object can be dumped and loaded.
:param obj: The object to dump
:param ext: The extension to use
:param extension: The extension to use
:param dump: The dump function
:param load: The load function
"""
name = f"test.{ext}"
name = f"test.{extension}"
path = pystow.join("test", name=name)
if path.is_file():
path.unlink()
Expand All @@ -54,7 +59,7 @@ def assert_io(self, obj, ext: str, dump, load):
self.assertTrue(path.is_file())
self.assertEqual(obj, load("test", name=name))

def test_exposed(self):
def test_exposed(self) -> None:
"""Test that all module-level functions also have a counterpart in the top-level API."""
for name, func in Module.__dict__.items():
if not inspect.isfunction(func) or name in SKIP:
Expand All @@ -75,24 +80,24 @@ def test_exposed(self):
msg=f"`pystow.api.{name}` should be imported in `pystow.__init__`.",
)

def test_io(self):
def test_io(self) -> None:
"""Test IO functions."""
obj = ["a", "b", "c"]
for ext, dump, load in [
("json", pystow.dump_json, pystow.load_json),
("pkl", pystow.dump_pickle, pystow.load_pickle),
]:
with self.subTest(ext=ext):
self.assert_io(obj, ext=ext, dump=dump, load=load)
self.assert_io(obj, extension=ext, dump=dump, load=load) # type:ignore

def test_pd_io(self):
def test_pd_io(self) -> None:
"""Test pandas IO."""
columns = list("abc")
data = [(1, 2, 3), (4, 5, 6)]
df = pd.DataFrame(data, columns=columns)
self.assert_io(df, ext="tsv", load=pystow.load_df, dump=pystow.dump_df)
self.assert_io(df, extension="tsv", load=pystow.load_df, dump=pystow.dump_df)

def test_rdf_io(self):
def test_rdf_io(self) -> None:
"""Test RDFlib IO."""
graph = rdflib.Graph()
graph.add(
Expand All @@ -103,12 +108,12 @@ def test_rdf_io(self):
)
)
self.assertEqual(1, len(graph))
self.assert_io(graph, ext="ttl", dump=pystow.dump_rdf, load=pystow.load_rdf)
self.assert_io(graph, extension="ttl", dump=pystow.dump_rdf, load=pystow.load_rdf)

def test_xml_io(self):
def test_xml_io(self) -> None:
"""Test XML I/O."""
root = etree.Element("root")
root.set("interesting", "somewhat")
etree.SubElement(root, "test")
my_tree = etree.ElementTree(root)
self.assert_io(my_tree, ext="xml", dump=pystow.dump_xml, load=pystow.load_xml)
self.assert_io(my_tree, extension="xml", dump=pystow.dump_xml, load=pystow.load_xml)
10 changes: 5 additions & 5 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def tearDown(self) -> None:
"""Tear down the test case's temporary directory."""
self.tmpdir.cleanup()

def test_cache_exception(self):
def test_cache_exception(self) -> None:
"""Test that exceptions aren't swallowed."""
path = self.directory.joinpath("test.pkl")

self.assertFalse(path.is_file())

@CachedPickle(path=path)
def _f1():
def _f1() -> None:
raise NotImplementedError

self.assertFalse(path.is_file(), msg="function has not been called")
Expand All @@ -45,7 +45,7 @@ def _f1():
msg="file should not have been created if an exception was thrown by the function",
)

def test_cache_pickle(self):
def test_cache_pickle(self) -> None:
"""Test caching a pickle."""
path = self.directory.joinpath("test.pkl")
self.assertFalse(
Expand All @@ -56,7 +56,7 @@ def test_cache_pickle(self):
raise_flag = True

@CachedPickle(path=path)
def _f1():
def _f1() -> int:
if raise_flag:
raise ValueError
return EXPECTED
Expand Down Expand Up @@ -87,7 +87,7 @@ def _f1():
_f1()

@CachedPickle(path=path, force=True)
def _f2():
def _f2() -> int:
return EXPECTED_2

self.assertEqual(EXPECTED_2, _f2()) # overwrites the file
Expand Down
19 changes: 13 additions & 6 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import tempfile
import unittest
from configparser import ConfigParser
from pathlib import Path
from typing import ClassVar

import pystow
from pystow.config_api import CONFIG_HOME_ENVVAR, _get_cfp
Expand All @@ -14,6 +16,11 @@
class TestConfig(unittest.TestCase):
"""Test configuration."""

test_section: ClassVar[str]
test_option: ClassVar[str]
test_value: ClassVar[str]
cfp: ClassVar[ConfigParser]

@classmethod
def setUpClass(cls) -> None:
"""Set up the class for testing."""
Expand All @@ -28,7 +35,7 @@ def setUpClass(cls) -> None:
value=cls.test_value,
)

def test_env_cast(self):
def test_env_cast(self) -> None:
"""Test casting works properly when getting from the environment."""
with mock_envvar("TEST_VAR", "1234"):
self.assertEqual("1234", pystow.get_config("test", "var"))
Expand All @@ -39,7 +46,7 @@ def test_env_cast(self):
with self.assertRaises(TypeError):
pystow.get_config("test", "var", dtype=object)

def test_get_config(self):
def test_get_config(self) -> None:
"""Test lookup not existing."""
self.assertIsNone(pystow.get_config(self.test_section, "key"))
self.assertEqual("1234", pystow.get_config(self.test_section, "key", default="1234"))
Expand Down Expand Up @@ -94,17 +101,17 @@ def test_get_config(self):
True, pystow.get_config(self.test_section, self.test_option, passthrough=1, dtype=bool)
)

def test_subsection(self):
def test_subsection(self) -> None:
"""Test subsections."""
with tempfile.TemporaryDirectory() as directory, mock_envvar(CONFIG_HOME_ENVVAR, directory):
directory = Path(directory)
path = directory.joinpath("test.ini")
directory_ = Path(directory)
path = directory_.joinpath("test.ini")
self.assertFalse(path.is_file(), msg="file should not already exist")

self.assertIsNone(pystow.get_config("test:subtest", "key"))
self.assertFalse(path.is_file(), msg="getting config should not create a file")

pystow.write_config("test:subtest", "key", "value")
self.assertTrue(path.is_file(), msg=f"{list(directory.iterdir())}")
self.assertTrue(path.is_file(), msg=f"{list(directory_.iterdir())}")

self.assertEqual("value", pystow.get_config("test:subtest", "key"))
Loading

0 comments on commit d529a96

Please sign in to comment.