diff --git a/Makefile b/Makefile index 7ac7b9a..d84f997 100644 --- a/Makefile +++ b/Makefile @@ -15,12 +15,12 @@ test: coverage xml lint: - flake8 src tests - isort --check-only src tests - pydocstyle src tests - black --check src tests - mypy src tests - bandit -r src + flake8 src tests || exit 1 + isort --check-only src tests || exit 1 + pydocstyle src tests || exit 1 + black --check src tests || exit 1 + mypy src tests || exit 1 + bandit -r src --skip B101 || exit 11 format: isort src tests diff --git a/pyproject.toml b/pyproject.toml index 677ccf7..fbda10e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pytest-alembic" -version = "0.10.1" +version = "0.10.2" description = "A pytest plugin for verifying alembic migrations." authors = [ "Dan Cardin ", diff --git a/src/pytest_alembic/plugin/error.py b/src/pytest_alembic/plugin/error.py index 29c84ac..c1727d2 100644 --- a/src/pytest_alembic/plugin/error.py +++ b/src/pytest_alembic/plugin/error.py @@ -1,33 +1,25 @@ import textwrap - -from _pytest._code.code import FormattedExcinfo +from typing import List class AlembicTestFailure(AssertionError): def __init__(self, message, context=None): super().__init__(message) self.context = context + self.exce = self + self.item = None - -class AlembicReprError: - def __init__(self, exce, item): - self.exce = exce - self.item = item - - def toterminal(self, tw): + def format_context(self) -> List[str]: """Print out a custom error message to the terminal.""" - exc = self.exce.value - context = exc.context - - if context: - for title, item in context: - tw.line(title + ":", white=True, bold=True) - tw.line(textwrap.indent(item, " "), red=True) - tw.line("") - - e = FormattedExcinfo() - lines = e.get_exconly(self.exce) - - tw.line("Errors:", white=True, bold=True) - for line in lines: - tw.line(line, red=True, bold=True) + result = [] + if not self.context: + return [] + + for title, item in self.context: + result.extend(["", f"{title}:", textwrap.indent(item, " ")]) + return result + + def __str__(self): + content = self.format_context() + segments = [super().__str__(), *content] + return "\n".join(segments) diff --git a/src/pytest_alembic/plugin/hooks.py b/src/pytest_alembic/plugin/hooks.py index e5612ab..6bf2871 100644 --- a/src/pytest_alembic/plugin/hooks.py +++ b/src/pytest_alembic/plugin/hooks.py @@ -10,6 +10,11 @@ def pytest_addoption(parser): ) experimental_tests = ", ".join(t.name for t in experimental_collector.available_tests.values()) + parser.addini( + "pytest_alembic_enabled", + "Whether to enable/disable the plugin's behavior entirely. Defaults to true.", + default=True, + ) parser.addini( "pytest_alembic_include", "List of built-in tests to include. If specified, 'pytest_alembic_exclude' is ignored. " @@ -40,7 +45,7 @@ def pytest_addoption(parser): action="store_true", default=False, help="Enable pytest-alembic built-in tests", - dest="pytest_alembic_enabled", + dest="pytest_alembic_registration_enabled", ) group.addoption( "--alembic-exclude", @@ -61,6 +66,6 @@ def pytest_configure(config): def pytest_sessionstart(session): - if session.config.option.pytest_alembic_enabled: + if session.config.getini("pytest_alembic_enabled"): plugin = PytestAlembicPlugin(session.config) session.config.pluginmanager.register(plugin, "pytest-alembic") diff --git a/src/pytest_alembic/plugin/plugin.py b/src/pytest_alembic/plugin/plugin.py index 791038a..4c86aa2 100644 --- a/src/pytest_alembic/plugin/plugin.py +++ b/src/pytest_alembic/plugin/plugin.py @@ -6,8 +6,6 @@ import pytest from _pytest import config -from pytest_alembic.plugin.error import AlembicReprError, AlembicTestFailure - pytest_version_tuple = getattr(pytest, "version_tuple", None) @@ -60,11 +58,12 @@ def __init__(self, **kwargs): self.add_marker("alembic") def collect(self): + assert self.parent config = self.parent.config - cli_enabled = config.option.pytest_alembic_enabled + cli_enabled = config.option.pytest_alembic_registration_enabled if not cli_enabled: - return None + return [] option = config.option @@ -101,11 +100,6 @@ class PytestAlembicItem(pytest.Function): def reportinfo(self): return (self.fspath, 0, f"[pytest-alembic] {self.name}") - def repr_failure(self, excinfo): - if isinstance(excinfo.value, AlembicTestFailure): - return AlembicReprError(excinfo, self) - return super().repr_failure(excinfo) - @dataclass(frozen=True) class PytestAlembicTest: