diff --git a/zakat/zakat_tracker.py b/zakat/zakat_tracker.py index e26f134..263b0d1 100644 --- a/zakat/zakat_tracker.py +++ b/zakat/zakat_tracker.py @@ -2493,8 +2493,8 @@ class Account(db.Entity): name = pony.Required(pony.LongStr, unique=True) balance = pony.Optional(int, default=0) count = pony.Optional(int, default=0) - hide = pony.Optional(bool, default='false') - zakatable = pony.Optional(bool, default='true') + hide = pony.Optional(bool, default=False) + zakatable = pony.Optional(bool, default=True) created_at = pony.Required(datetime.datetime, default=lambda: datetime.datetime.now()) updated_at = pony.Optional(datetime.datetime) box = pony.Set('Box') @@ -2698,15 +2698,15 @@ def exchanges(self, account: int) -> dict | None: pass def account(self, name: str, ref: int = None) -> tuple[int, str]: - account = self.db.select( - table_name='account', - columns=['id', 'name'], - where=f'name = "{name}"', - ) - if account: - return account[0]['id'], account[0]['name'] - account = self.db.insert(table_name='account', data={'name': name}) - return account['id'], account['name'] + with pony.db_session: + if name: + account = Account.get(name=name) + if ref: + account = Account[ref] + if account: + return account.id, account.name + account = Account(name=name) + return account.id, account.name def transfer(self, unscaled_amount: float | int | Decimal, from_account: int, to_account: int, desc: str = '', created: int = None, debug: bool = False) -> list[int]: @@ -2781,17 +2781,22 @@ def daily_logs(self, weekday: WeekDay = WeekDay.Friday, debug: bool = False): def export_json(self, path: str = "data.json") -> bool: pass + @pony.db_session() def vault(self, section: Vault = Vault.ALL) -> dict: match section: case Vault.ACCOUNT: - return {} + return Account.select() case Vault.NAME: - return {} + return Account.select() case Vault.HISTORY: - return {} + return History.select() case Vault.REPORT: - return {} - return {} + return Report.select() + return { + 'account': Account.select(), + 'history': History.select(), + 'report': Report.select(), + } def snapshot(self) -> bool: pass @@ -3308,6 +3313,8 @@ def _test_core(self, restore=False, debug=False): if debug: print(f'index = {index + 1}, ref = {ref}') assert index + 1 == ref + print(self.db.vault(Vault.ACCOUNT)) + #raise 123 assert index + 1 in self.db.vault(Vault.ACCOUNT) assert name == self.db.vault(Vault.ACCOUNT)[index + 1]['name'] account_z_ref, account_z_name = self.db.account(name='z')