Skip to content

Commit

Permalink
use more types; use Adapter for compat with old importers in hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Jan 3, 2025
1 parent 8828038 commit 100bcae
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.14.1
hooks:
- id: mypy
args: [--install-types, --non-interactive, --ignore-missing-imports]
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79

[[tool.mypy.overrides]]
module = ["beancount.*"]
follow_untyped_imports = true

[[tool.mypy.overrides]]
module = ["beangulp.*"]
follow_untyped_imports = true

[tool.ruff]
target-version = "py38"
line-length = 79
Expand Down
31 changes: 27 additions & 4 deletions smart_importer/hooks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
"""Importer decorators."""

from __future__ import annotations

import logging
from functools import wraps
from typing import Callable, Sequence

from beancount.core import data
from beangulp import Adapter, Importer, ImporterProtocol

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class ImporterHook:
"""Interface for an importer hook."""

def __call__(self, importer, file, imported_entries, existing):
def __call__(
self,
importer: Importer,
file: str,
imported_entries: data.Directives,
existing: data.Directives,
) -> data.Directives:
"""Apply the hook and modify the imported entries.
Args:
Expand All @@ -25,20 +37,31 @@ def __call__(self, importer, file, imported_entries, existing):
raise NotImplementedError


def apply_hooks(importer, hooks):
def apply_hooks(
importer: Importer | ImporterProtocol,
hooks: Sequence[
Callable[
[Importer, str, data.Directives, data.Directives], data.Directives
]
],
) -> Importer:
"""Apply a list of importer hooks to an importer.
Args:
importer: An importer instance.
hooks: A list of hooks, each a callable object.
"""

if not isinstance(importer, Importer):
importer = Adapter(importer)
unpatched_extract = importer.extract

@wraps(unpatched_extract)
def patched_extract_method(filepath, existing=None):
def patched_extract_method(
filepath: str, existing: data.Directives
) -> data.Directives:
logger.debug("Calling the importer's extract method.")
imported_entries = unpatched_extract(filepath, existing=existing)
imported_entries = unpatched_extract(filepath, existing)

for hook in hooks:
imported_entries = hook(
Expand Down
34 changes: 23 additions & 11 deletions smart_importer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import logging
import threading
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from beancount.core import data
from beancount.core.data import (
ALL_DIRECTIVES,
Close,
Open,
Transaction,
Expand All @@ -27,6 +27,7 @@
from smart_importer.pipelines import get_pipeline

if TYPE_CHECKING:
from beangulp import Importer
from sklearn import Pipeline

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand All @@ -53,25 +54,31 @@ class EntryPredictor(ImporterHook):

def __init__(
self,
predict=True,
overwrite=False,
predict: bool = True,
overwrite: bool = False,
string_tokenizer: Callable[[str], list] | None = None,
denylist_accounts: list[str] | None = None,
):
) -> None:
super().__init__()
self.training_data = None
self.open_accounts: dict[str, str] = {}
self.denylist_accounts = set(denylist_accounts or [])
self.pipeline: Pipeline | None = None
self.is_fitted = False
self.lock = threading.Lock()
self.account = None
self.account: str | None = None

self.predict = predict
self.overwrite = overwrite
self.string_tokenizer = string_tokenizer

def __call__(self, importer, file, imported_entries, existing_entries):
def __call__(
self,
importer: Importer,
file: str,
imported_entries: data.Directives,
existing_entries: data.Directives,
) -> data.Directives:
"""Predict attributes for imported transactions.
Args:
Expand Down Expand Up @@ -157,7 +164,7 @@ def targets(self):
for entry in self.training_data
]

def define_pipeline(self):
def define_pipeline(self) -> None:
"""Defines the machine learning pipeline based on given weights."""

transformers = [
Expand All @@ -172,7 +179,7 @@ def define_pipeline(self):
SVC(kernel="linear"),
)

def train_pipeline(self):
def train_pipeline(self) -> None:
"""Train the machine learning pipeline."""

self.is_fitted = False
Expand All @@ -187,11 +194,14 @@ def train_pipeline(self):
self.is_fitted = True
logger.debug("Only one target possible.")
else:
assert self.pipeline is not None
self.pipeline.fit(self.training_data, self.targets)
self.is_fitted = True
logger.debug("Trained the machine learning model.")

def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]:
def process_entries(
self, imported_entries: data.Directives
) -> data.Directives:
"""Process imported entries.
Transactions might be modified, all other entries are left as is.
Expand All @@ -206,7 +216,9 @@ def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]:
imported_entries, enhanced_transactions
)

def apply_prediction(self, entry, prediction):
def apply_prediction(
self, entry: data.Transaction, prediction: Any
) -> data.Transaction:
"""Apply a single prediction to an entry.
Args:
Expand Down
30 changes: 19 additions & 11 deletions tests/data_test.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
"""Tests for the `PredictPostings` decorator"""

from __future__ import annotations

# pylint: disable=missing-docstring
import os
import pprint
import re
from typing import Callable

import pytest
from beancount.core import data
from beancount.core.compare import stable_hash_namedtuple
from beancount.parser import parser
from beangulp import Importer

from smart_importer import PredictPostings, apply_hooks


def chinese_string_tokenizer(pre_tokenizer_string):
def chinese_string_tokenizer(pre_tokenizer_string: str) -> list[str]:
jieba = pytest.importorskip("jieba")
jieba.initialize()
return list(jieba.cut(pre_tokenizer_string))


def _hash(entry):
def _hash(entry: data.Directive) -> str:
return stable_hash_namedtuple(entry, ignore={"meta", "units"})


def _load_testset(testset):
def _load_testset(
testset: str,
) -> tuple[data.Directives, data.Directives, data.Directives]:
path = os.path.join(
os.path.dirname(__file__), "data", testset + ".beancount"
)
Expand All @@ -35,7 +41,7 @@ def _load_testset(testset):
assert not errors
parsed_sections.append(entries)
assert len(parsed_sections) == 3
return parsed_sections
return tuple(parsed_sections)


@pytest.mark.parametrize(
Expand All @@ -47,25 +53,27 @@ def _load_testset(testset):
("chinese", chinese_string_tokenizer),
],
)
def test_testset(testset, string_tokenizer):
def test_testset(
testset: str, string_tokenizer: Callable[[str], list[str]]
) -> None:
# pylint: disable=unbalanced-tuple-unpacking
imported, training_data, expected = _load_testset(testset)

class DummyImporter(Importer):
def extract(self, filepath, existing=None):
def extract(
self, filepath: str, existing: data.Directives
) -> data.Directives:
return imported

def account(self, filepath):
def account(self, filepath: str) -> str:
return ""

def identify(self, filepath):
def identify(self, filepath: str) -> bool:
return True

importer = DummyImporter()
apply_hooks(importer, [PredictPostings(string_tokenizer=string_tokenizer)])
imported_transactions = importer.extract(
"dummy-data", existing=training_data
)
imported_transactions = importer.extract("dummy-data", training_data)

for txn1, txn2 in zip(imported_transactions, expected):
if _hash(txn1) != _hash(txn2):
Expand Down
Loading

0 comments on commit 100bcae

Please sign in to comment.