diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1bfad81..3510c5d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.14.1 hooks: - id: mypy args: [--install-types, --non-interactive, --ignore-missing-imports] diff --git a/pyproject.toml b/pyproject.toml index c3aae10..ddbcd31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,14 @@ build-backend = "setuptools.build_meta" [tool.black] line-length = 79 +[[tool.mypy.overrides]] +module = ["beancount.*"] +follow_untyped_imports = true + +[[tool.mypy.overrides]] +module = ["beangulp.*"] +follow_untyped_imports = true + [tool.ruff] target-version = "py38" line-length = 79 diff --git a/smart_importer/hooks.py b/smart_importer/hooks.py index f2ca2e8..1e3517a 100644 --- a/smart_importer/hooks.py +++ b/smart_importer/hooks.py @@ -1,7 +1,13 @@ """Importer decorators.""" +from __future__ import annotations + import logging from functools import wraps +from typing import Callable, Sequence + +from beancount.core import data +from beangulp import Adapter, Importer, ImporterProtocol logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -9,7 +15,13 @@ class ImporterHook: """Interface for an importer hook.""" - def __call__(self, importer, file, imported_entries, existing): + def __call__( + self, + importer: Importer, + file: str, + imported_entries: data.Directives, + existing: data.Directives, + ) -> data.Directives: """Apply the hook and modify the imported entries. Args: @@ -25,7 +37,14 @@ def __call__(self, importer, file, imported_entries, existing): raise NotImplementedError -def apply_hooks(importer, hooks): +def apply_hooks( + importer: Importer | ImporterProtocol, + hooks: Sequence[ + Callable[ + [Importer, str, data.Directives, data.Directives], data.Directives + ] + ], +) -> Importer: """Apply a list of importer hooks to an importer. Args: @@ -33,12 +52,16 @@ def apply_hooks(importer, hooks): hooks: A list of hooks, each a callable object. """ + if not isinstance(importer, Importer): + importer = Adapter(importer) unpatched_extract = importer.extract @wraps(unpatched_extract) - def patched_extract_method(filepath, existing=None): + def patched_extract_method( + filepath: str, existing: data.Directives + ) -> data.Directives: logger.debug("Calling the importer's extract method.") - imported_entries = unpatched_extract(filepath, existing=existing) + imported_entries = unpatched_extract(filepath, existing) for hook in hooks: imported_entries = hook( diff --git a/smart_importer/predictor.py b/smart_importer/predictor.py index eeec34b..c3999d1 100644 --- a/smart_importer/predictor.py +++ b/smart_importer/predictor.py @@ -6,10 +6,10 @@ import logging import threading -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable +from beancount.core import data from beancount.core.data import ( - ALL_DIRECTIVES, Close, Open, Transaction, @@ -27,6 +27,7 @@ from smart_importer.pipelines import get_pipeline if TYPE_CHECKING: + from beangulp import Importer from sklearn import Pipeline logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -53,11 +54,11 @@ class EntryPredictor(ImporterHook): def __init__( self, - predict=True, - overwrite=False, + predict: bool = True, + overwrite: bool = False, string_tokenizer: Callable[[str], list] | None = None, denylist_accounts: list[str] | None = None, - ): + ) -> None: super().__init__() self.training_data = None self.open_accounts: dict[str, str] = {} @@ -65,13 +66,19 @@ def __init__( self.pipeline: Pipeline | None = None self.is_fitted = False self.lock = threading.Lock() - self.account = None + self.account: str | None = None self.predict = predict self.overwrite = overwrite self.string_tokenizer = string_tokenizer - def __call__(self, importer, file, imported_entries, existing_entries): + def __call__( + self, + importer: Importer, + file: str, + imported_entries: data.Directives, + existing_entries: data.Directives, + ) -> data.Directives: """Predict attributes for imported transactions. Args: @@ -157,7 +164,7 @@ def targets(self): for entry in self.training_data ] - def define_pipeline(self): + def define_pipeline(self) -> None: """Defines the machine learning pipeline based on given weights.""" transformers = [ @@ -172,7 +179,7 @@ def define_pipeline(self): SVC(kernel="linear"), ) - def train_pipeline(self): + def train_pipeline(self) -> None: """Train the machine learning pipeline.""" self.is_fitted = False @@ -187,11 +194,14 @@ def train_pipeline(self): self.is_fitted = True logger.debug("Only one target possible.") else: + assert self.pipeline is not None self.pipeline.fit(self.training_data, self.targets) self.is_fitted = True logger.debug("Trained the machine learning model.") - def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]: + def process_entries( + self, imported_entries: data.Directives + ) -> data.Directives: """Process imported entries. Transactions might be modified, all other entries are left as is. @@ -206,7 +216,9 @@ def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]: imported_entries, enhanced_transactions ) - def apply_prediction(self, entry, prediction): + def apply_prediction( + self, entry: data.Transaction, prediction: Any + ) -> data.Transaction: """Apply a single prediction to an entry. Args: diff --git a/tests/data_test.py b/tests/data_test.py index f219f04..f21705b 100644 --- a/tests/data_test.py +++ b/tests/data_test.py @@ -1,11 +1,15 @@ """Tests for the `PredictPostings` decorator""" +from __future__ import annotations + # pylint: disable=missing-docstring import os import pprint import re +from typing import Callable import pytest +from beancount.core import data from beancount.core.compare import stable_hash_namedtuple from beancount.parser import parser from beangulp import Importer @@ -13,17 +17,19 @@ from smart_importer import PredictPostings, apply_hooks -def chinese_string_tokenizer(pre_tokenizer_string): +def chinese_string_tokenizer(pre_tokenizer_string: str) -> list[str]: jieba = pytest.importorskip("jieba") jieba.initialize() return list(jieba.cut(pre_tokenizer_string)) -def _hash(entry): +def _hash(entry: data.Directive) -> str: return stable_hash_namedtuple(entry, ignore={"meta", "units"}) -def _load_testset(testset): +def _load_testset( + testset: str, +) -> tuple[data.Directives, data.Directives, data.Directives]: path = os.path.join( os.path.dirname(__file__), "data", testset + ".beancount" ) @@ -35,7 +41,7 @@ def _load_testset(testset): assert not errors parsed_sections.append(entries) assert len(parsed_sections) == 3 - return parsed_sections + return tuple(parsed_sections) @pytest.mark.parametrize( @@ -47,25 +53,27 @@ def _load_testset(testset): ("chinese", chinese_string_tokenizer), ], ) -def test_testset(testset, string_tokenizer): +def test_testset( + testset: str, string_tokenizer: Callable[[str], list[str]] +) -> None: # pylint: disable=unbalanced-tuple-unpacking imported, training_data, expected = _load_testset(testset) class DummyImporter(Importer): - def extract(self, filepath, existing=None): + def extract( + self, filepath: str, existing: data.Directives + ) -> data.Directives: return imported - def account(self, filepath): + def account(self, filepath: str) -> str: return "" - def identify(self, filepath): + def identify(self, filepath: str) -> bool: return True importer = DummyImporter() apply_hooks(importer, [PredictPostings(string_tokenizer=string_tokenizer)]) - imported_transactions = importer.extract( - "dummy-data", existing=training_data - ) + imported_transactions = importer.extract("dummy-data", training_data) for txn1, txn2 in zip(imported_transactions, expected): if _hash(txn1) != _hash(txn2): diff --git a/tests/predictors_test.py b/tests/predictors_test.py index dd3de51..4162d0a 100644 --- a/tests/predictors_test.py +++ b/tests/predictors_test.py @@ -1,6 +1,7 @@ """Tests for the `PredictPayees` and the `PredictPostings` decorator""" # pylint: disable=missing-docstring +from beancount.core import data from beancount.parser import parser from beangulp import Importer @@ -133,7 +134,9 @@ class BasicTestImporter(Importer): - def extract(self, filepath, existing=None): + def extract( + self, filepath: str, existing: data.Directives + ) -> data.Directives: if filepath == "dummy-data": return TEST_DATA if filepath == "empty": @@ -141,10 +144,10 @@ def extract(self, filepath, existing=None): assert False return [] - def account(self, filepath): + def account(self, filepath: str) -> str: return "Assets:US:BofA:Checking" - def identify(self, filepath): + def identify(self, filepath: str) -> bool: return True @@ -155,39 +158,38 @@ def identify(self, filepath): ) -def test_empty_training_data(): +def test_empty_training_data() -> None: """ Verifies that the decorator leaves the narration intact. """ - assert POSTING_IMPORTER.extract("dummy-data") == TEST_DATA - assert PAYEE_IMPORTER.extract("dummy-data") == TEST_DATA + assert POSTING_IMPORTER.extract("dummy-data", []) == TEST_DATA + assert PAYEE_IMPORTER.extract("dummy-data", []) == TEST_DATA -def test_no_transactions(): +def test_no_transactions() -> None: """ Should not crash when passed empty list of transactions. """ - POSTING_IMPORTER.extract("empty") - PAYEE_IMPORTER.extract("empty") - POSTING_IMPORTER.extract("empty", existing=TRAINING_DATA) - PAYEE_IMPORTER.extract("empty", existing=TRAINING_DATA) + POSTING_IMPORTER.extract("empty", []) + PAYEE_IMPORTER.extract("empty", []) + POSTING_IMPORTER.extract("empty", TRAINING_DATA) + PAYEE_IMPORTER.extract("empty", TRAINING_DATA) -def test_unchanged_narrations(): +def test_unchanged_narrations() -> None: """ Verifies that the decorator leaves the narration intact """ correct_narrations = [transaction.narration for transaction in TEST_DATA] extracted_narrations = [ transaction.narration - for transaction in PAYEE_IMPORTER.extract( - "dummy-data", existing=TRAINING_DATA - ) + for transaction in PAYEE_IMPORTER.extract("dummy-data", TRAINING_DATA) + if isinstance(transaction, data.Transaction) ] assert extracted_narrations == correct_narrations -def test_unchanged_first_posting(): +def test_unchanged_first_posting() -> None: """ Verifies that the decorator leaves the first posting intact """ @@ -196,30 +198,32 @@ def test_unchanged_first_posting(): ] extracted_first_postings = [ transaction.postings[0] - for transaction in PAYEE_IMPORTER.extract( - "dummy-data", existing=TRAINING_DATA - ) + for transaction in PAYEE_IMPORTER.extract("dummy-data", TRAINING_DATA) + if isinstance(transaction, data.Transaction) ] assert extracted_first_postings == correct_first_postings -def test_payee_predictions(): +def test_payee_predictions() -> None: """ Verifies that the decorator adds predicted postings. """ - transactions = PAYEE_IMPORTER.extract("dummy-data", existing=TRAINING_DATA) - predicted_payees = [transaction.payee for transaction in transactions] + transactions = PAYEE_IMPORTER.extract("dummy-data", TRAINING_DATA) + predicted_payees = [ + transaction.payee + for transaction in transactions + if isinstance(transaction, data.Transaction) + ] assert predicted_payees == PAYEE_PREDICTIONS -def test_account_predictions(): +def test_account_predictions() -> None: """ Verifies that the decorator adds predicted postings. """ predicted_accounts = [ entry.postings[-1].account - for entry in POSTING_IMPORTER.extract( - "dummy-data", existing=TRAINING_DATA - ) + for entry in POSTING_IMPORTER.extract("dummy-data", TRAINING_DATA) + if isinstance(entry, data.Transaction) ] assert predicted_accounts == ACCOUNT_PREDICTIONS