Skip to content

Commit

Permalink
test: SQLAlchemy 2.0 Style
Browse files Browse the repository at this point in the history
  • Loading branch information
purplesmoke05 committed Aug 7, 2023
1 parent 2aabcc0 commit 9857429
Show file tree
Hide file tree
Showing 66 changed files with 2,479 additions and 2,001 deletions.
10 changes: 5 additions & 5 deletions app/routers/share.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,20 +500,20 @@ def list_share_history(
total = db.scalar(select(func.count()).select_from(stmt.subquery()))

if request_query.trigger:
stmt = stmt.filter(UpdateToken.trigger == request_query.trigger)
stmt = stmt.where(UpdateToken.trigger == request_query.trigger)
if request_query.modified_contents:
stmt = stmt.filter(
stmt = stmt.where(
cast(UpdateToken.arguments, String).like(
"%" + request_query.modified_contents + "%"
)
)
if request_query.created_from:
stmt = stmt.filter(
stmt = stmt.where(
UpdateToken.created
>= local_tz.localize(request_query.created_from).astimezone(utc_tz)
)
if request_query.created_to:
stmt = stmt.filter(
stmt = stmt.where(
UpdateToken.created
<= local_tz.localize(request_query.created_to).astimezone(utc_tz)
)
Expand Down Expand Up @@ -1618,7 +1618,7 @@ def list_all_holders(
IDXLockedPosition.account_address == IDXPosition.account_address,
),
)
.filter(IDXPosition.token_address == token_address)
.where(IDXPosition.token_address == token_address)
.group_by(
IDXPosition.id,
IDXLockedPosition.token_address,
Expand Down
15 changes: 9 additions & 6 deletions tests/model/blockchain/test_token_IbetShare.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
from eth_keyfile import decode_keyfile_json
from pydantic.error_wrappers import ValidationError
from sqlalchemy import select
from web3.exceptions import (
ContractLogicError,
InvalidAddress,
Expand Down Expand Up @@ -603,7 +604,7 @@ def test_normal_1(self, db):
assert share_contract.is_canceled is False
assert share_contract.memo == ""

_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update is None

# <Normal_2>
Expand Down Expand Up @@ -680,7 +681,7 @@ def test_normal_2(self, db):
assert share_contract.principal_value == 9000
assert share_contract.is_canceled is True
assert share_contract.memo == "memo_test"
_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update.id == 1
assert _token_attr_update.token_address == contract_address
assert _token_attr_update.updated_datetime > pre_datetime
Expand Down Expand Up @@ -1321,7 +1322,7 @@ def test_normal_1(self, db):
balance = share_contract.get_account_balance(issuer_address)
assert balance == arguments[3] + 10

_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update.id == 1
assert _token_attr_update.token_address == contract_address
assert _token_attr_update.updated_datetime > pre_datetime
Expand Down Expand Up @@ -1637,7 +1638,7 @@ def test_normal_1(self, db):
balance = share_contract.get_account_balance(issuer_address)
assert balance == arguments[3] - 10

_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update.id == 1
assert _token_attr_update.token_address == contract_address
assert _token_attr_update.updated_datetime > pre_datetime
Expand Down Expand Up @@ -2070,7 +2071,7 @@ def test_normal_1(self, db):
share_contract.record_attr_update(db)

# assertion
_update = db.query(TokenAttrUpdate).first()
_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _update.id == 1
assert _update.token_address == self.token_address
assert _update.updated_datetime == datetime(2021, 4, 27, 12, 34, 56)
Expand All @@ -2093,7 +2094,9 @@ def test_normal_2(self, db, freezer):
share_contract.record_attr_update(db)

# assertion
_update = db.query(TokenAttrUpdate).filter(TokenAttrUpdate.id == 2).first()
_update = db.scalars(
select(TokenAttrUpdate).where(TokenAttrUpdate.id == 2).limit(1)
).first()
assert _update.id == 2
assert _update.token_address == self.token_address
assert _update.updated_datetime == datetime(2021, 4, 27, 12, 34, 56)
Expand Down
15 changes: 9 additions & 6 deletions tests/model/blockchain/test_token_IbetStraightBond.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
from eth_keyfile import decode_keyfile_json
from pydantic.error_wrappers import ValidationError
from sqlalchemy import select
from web3.exceptions import (
ContractLogicError,
InvalidAddress,
Expand Down Expand Up @@ -681,7 +682,7 @@ def test_normal_1(self, db):
assert bond_contract.transfer_approval_required is False
assert bond_contract.memo == ""

_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update is None

# <Normal_2>
Expand Down Expand Up @@ -770,7 +771,7 @@ def test_normal_2(self, db):
assert bond_contract.transfer_approval_required is True
assert bond_contract.memo == "memo test"

_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update.id == 1
assert _token_attr_update.token_address == contract_address
assert _token_attr_update.updated_datetime > pre_datetime
Expand Down Expand Up @@ -1431,7 +1432,7 @@ def test_normal_1(self, db):
balance = bond_contract.get_account_balance(issuer_address)
assert balance == arguments[2] + 10

_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update.id == 1
assert _token_attr_update.token_address == contract_address
assert _token_attr_update.updated_datetime > pre_datetime
Expand Down Expand Up @@ -1747,7 +1748,7 @@ def test_normal_1(self, db):
balance = bond_contract.get_account_balance(issuer_address)
assert balance == arguments[2] - 10

_token_attr_update = db.query(TokenAttrUpdate).first()
_token_attr_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _token_attr_update.id == 1
assert _token_attr_update.token_address == contract_address
assert _token_attr_update.updated_datetime > pre_datetime
Expand Down Expand Up @@ -2180,7 +2181,7 @@ def test_normal_1(self, db):
bond_contract.record_attr_update(db)

# assertion
_update = db.query(TokenAttrUpdate).first()
_update = db.scalars(select(TokenAttrUpdate).limit(1)).first()
assert _update.id == 1
assert _update.token_address == self.token_address
assert _update.updated_datetime == datetime(2021, 4, 27, 12, 34, 56)
Expand All @@ -2203,7 +2204,9 @@ def test_normal_2(self, db, freezer):
bond_contract.record_attr_update(db)

# assertion
_update = db.query(TokenAttrUpdate).filter(TokenAttrUpdate.id == 2).first()
_update = db.scalars(
select(TokenAttrUpdate).where(TokenAttrUpdate.id == 2).limit(1)
).first()
assert _update.id == 2
assert _update.token_address == self.token_address
assert _update.updated_datetime == datetime(2021, 4, 27, 12, 34, 56)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_app_routers_accounts_POST.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import base64
from unittest import mock

from sqlalchemy import select

from app.model.db import Account, AccountRsaStatus
from app.utils.e2ee_utils import E2EEUtils
from config import EOA_PASSWORD_PATTERN_MSG
Expand All @@ -37,7 +39,7 @@ class TestAppRoutersAccountsPOST:

# <Normal_1>
def test_normal_1(self, client, db):
accounts_before = db.query(Account).all()
accounts_before = db.scalars(select(Account)).all()

password = self.valid_password
req_param = {"eoa_password": E2EEUtils.encrypt(password)}
Expand All @@ -51,7 +53,7 @@ def test_normal_1(self, client, db):
assert resp.json()["rsa_status"] == AccountRsaStatus.UNSET.value
assert resp.json()["is_deleted"] is False

accounts_after = db.query(Account).all()
accounts_after = db.scalars(select(Account)).all()

assert 0 == len(accounts_before)
assert 1 == len(accounts_after)
Expand All @@ -70,7 +72,7 @@ def test_normal_1(self, client, db):
@mock.patch("app.routers.account.AWS_KMS_GENERATE_RANDOM_ENABLED", True)
@mock.patch("boto3.client")
def test_normal_2(self, boto3_mock, client, db):
accounts_before = db.query(Account).all()
accounts_before = db.scalars(select(Account)).all()

password = self.valid_password
req_param = {"eoa_password": E2EEUtils.encrypt(password)}
Expand All @@ -92,7 +94,7 @@ def generate_random(self, NumberOfBytes):
assert resp.json()["rsa_status"] == AccountRsaStatus.UNSET.value
assert resp.json()["is_deleted"] is False

accounts_after = db.query(Account).all()
accounts_after = db.scalars(select(Account)).all()

assert 0 == len(accounts_before)
assert 1 == len(accounts_after)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_app_routers_accounts_{issuer_address}_DELETE.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
SPDX-License-Identifier: Apache-2.0
"""
from sqlalchemy import select

from app.model.db import Account, AccountRsaStatus
from tests.account_config import config_eth_account

Expand Down Expand Up @@ -53,7 +55,7 @@ def test_normal_1(self, client, db):
"rsa_status": AccountRsaStatus.UNSET.value,
"is_deleted": True,
}
_account_after = db.query(Account).first()
_account_after = db.scalars(select(Account).limit(1)).first()
assert _account_after.issuer_address == _admin_account["address"]
assert _account_after.is_deleted == True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import hashlib
from datetime import datetime

from sqlalchemy import select

from app.model.db import Account, AuthToken
from app.utils.e2ee_utils import E2EEUtils
from tests.account_config import config_eth_account
Expand Down Expand Up @@ -63,11 +65,11 @@ def test_normal_1(self, client, db):
# assertion
assert resp.status_code == 200

auth_token: AuthToken = (
db.query(AuthToken)
.filter(AuthToken.issuer_address == test_account["address"])
.first()
)
auth_token: AuthToken = db.scalars(
select(AuthToken)
.where(AuthToken.issuer_address == test_account["address"])
.limit(1)
).first()
assert auth_token is None

# Normal_2
Expand Down Expand Up @@ -97,11 +99,11 @@ def test_normal_2(self, client, db):
# assertion
assert resp.status_code == 200

auth_token: AuthToken = (
db.query(AuthToken)
.filter(AuthToken.issuer_address == test_account["address"])
.first()
)
auth_token: AuthToken = db.scalars(
select(AuthToken)
.where(AuthToken.issuer_address == test_account["address"])
.limit(1)
).first()
assert auth_token is None

###########################################################################
Expand Down Expand Up @@ -147,11 +149,11 @@ def test_error_1(self, client, db):
],
}

auth_token: AuthToken = (
db.query(AuthToken)
.filter(AuthToken.issuer_address == test_account["address"])
.first()
)
auth_token: AuthToken = db.scalars(
select(AuthToken)
.where(AuthToken.issuer_address == test_account["address"])
.limit(1)
).first()
assert auth_token is not None

# Error_2
Expand Down Expand Up @@ -193,11 +195,11 @@ def test_error_2(self, client, db):
],
}

auth_token: AuthToken = (
db.query(AuthToken)
.filter(AuthToken.issuer_address == test_account["address"])
.first()
)
auth_token: AuthToken = db.scalars(
select(AuthToken)
.where(AuthToken.issuer_address == test_account["address"])
.limit(1)
).first()
assert auth_token is not None

# Error_3_1
Expand Down Expand Up @@ -232,11 +234,11 @@ def test_error_3_1(self, client, db):
"detail": "issuer does not exist, or password mismatch",
}

auth_token: AuthToken = (
db.query(AuthToken)
.filter(AuthToken.issuer_address == test_account["address"])
.first()
)
auth_token: AuthToken = db.scalars(
select(AuthToken)
.where(AuthToken.issuer_address == test_account["address"])
.limit(1)
).first()
assert auth_token is not None

# Error_3_2
Expand Down Expand Up @@ -274,11 +276,11 @@ def test_error_3_2(self, client, db):
"detail": "issuer does not exist, or password mismatch",
}

auth_token: AuthToken = (
db.query(AuthToken)
.filter(AuthToken.issuer_address == test_account["address"])
.first()
)
auth_token: AuthToken = db.scalars(
select(AuthToken)
.where(AuthToken.issuer_address == test_account["address"])
.limit(1)
).first()
assert auth_token is not None

# Error_4
Expand Down
Loading

0 comments on commit 9857429

Please sign in to comment.