Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

68 update institutions table and daos institution type fields #75

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""create fi_to_type_mapping table

Revision ID: ada681e1877f
Revises: 383ab402c8c2
Create Date: 2023-12-29 12:33:11.031470

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from db_revisions.utils import table_exists


# revision identifiers, used by Alembic.
revision: str = "ada681e1877f"
down_revision: Union[str, None] = "383ab402c8c2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
if not table_exists("fi_to_type_mapping"):
op.create_table(
"fi_to_type_mapping",
sa.Column("fi_id", sa.String(), sa.ForeignKey("financial_institutions.lei"), primary_key=True),
sa.Column("type_id", sa.String(), sa.ForeignKey("sbl_institution_type.id"), primary_key=True),
)
with op.batch_alter_table("financial_institutions") as batch_op:
batch_op.drop_constraint("fk_sbl_institution_type_financial_institutions", type_="foreignkey")
batch_op.drop_index(op.f("ix_financial_institutions_sbl_institution_type_id"))
batch_op.drop_column("sbl_institution_type_id")


def downgrade() -> None:
op.drop_table("fi_to_type_mapping")
with op.batch_alter_table("financial_institutions") as batch_op:
batch_op.add_column(sa.Column("sbl_institution_type_id", sa.String(), nullable=True))
batch_op.create_foreign_key(
"fk_sbl_institution_type_financial_institutions",
"sbl_institution_type",
["sbl_institution_type_id"],
["id"],
)
batch_op.create_index(
op.f("ix_financial_institutions_sbl_institution_type_id"),
["sbl_institution_type_id"],
unique=False,
)
17 changes: 14 additions & 3 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import datetime
from typing import List
from sqlalchemy import ForeignKey, func, String
from sqlalchemy import ForeignKey, func, String, Table, Column
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy


class Base(AsyncAttrs, DeclarativeBase):
Expand All @@ -14,6 +15,14 @@ class AuditMixin(object):
event_time: Mapped[datetime] = mapped_column(server_default=func.now())


fi_to_type_mapping = Table(
"fi_to_type_mapping",
Base.metadata,
Column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True),
Column("type_id", ForeignKey("sbl_institution_type.id"), primary_key=True),
)


class FinancialInstitutionDao(AuditMixin, Base):
__tablename__ = "financial_institutions"
lei: Mapped[str] = mapped_column(unique=True, index=True, primary_key=True)
Expand All @@ -28,8 +37,10 @@ class FinancialInstitutionDao(AuditMixin, Base):
primary_federal_regulator: Mapped["FederalRegulatorDao"] = relationship(lazy="selectin")
hmda_institution_type_id: Mapped[str] = mapped_column(ForeignKey("hmda_institution_type.id"), nullable=True)
hmda_institution_type: Mapped["HMDAInstitutionTypeDao"] = relationship(lazy="selectin")
sbl_institution_type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), nullable=True)
sbl_institution_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin")
sbl_institution_types: Mapped[List["SBLInstitutionTypeDao"]] = relationship(
lazy="selectin", secondary=fi_to_type_mapping
)
sbl_institution_type_ids: AssociationProxy[List[str]] = association_proxy("sbl_institution_types", "id")
hq_address_street_1: Mapped[str]
hq_address_street_2: Mapped[str] = mapped_column(nullable=True)
hq_address_city: Mapped[str]
Expand Down
4 changes: 2 additions & 2 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class FinancialInstitutionDto(FinancialInstitutionBase):
rssd_id: int | None = None
primary_federal_regulator_id: str | None = None
hmda_institution_type_id: str | None = None
sbl_institution_type_id: str | None = None
sbl_institution_type_ids: List[str] = []
hq_address_street_1: str
hq_address_street_2: str | None = None
hq_address_city: str
Expand Down Expand Up @@ -95,7 +95,7 @@ class Config:
class FinancialInstitutionWithRelationsDto(FinancialInstitutionDto):
primary_federal_regulator: FederalRegulatorDto | None = None
hmda_institution_type: InstitutionTypeDto | None = None
sbl_institution_type: InstitutionTypeDto | None = None
sbl_institution_types: List[InstitutionTypeDto] = []
hq_address_state: AddressStateDto
domains: List[FinancialInsitutionDomainDto] = []

Expand Down
11 changes: 11 additions & 0 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto)
async with session.begin():
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the update to the upsert using merge, this is causing the merge conflicts.
I have mine like:

    async with session.begin():
        fi_data = fi.__dict__.copy()
        fi_data.pop("_sa_instance_state", None)
        sbl_types = []
        if len(fi.sbl_institution_type_ids):
            sbl_type_stmt = select(SBLInstitutionTypeDao).filter(
                SBLInstitutionTypeDao.id.in_(fi.sbl_institution_type_ids)
            )
            sbl_types = (await session.scalars(sbl_type_stmt)).all()
            del fi_data["sbl_institution_type_ids"]
        new_fi = FinancialInstitutionDao(**fi_data)
        new_fi.sbl_institution_types.extend(sbl_types)
        db_fi = await session.merge(new_fi)
        await session.flush([db_fi])
        await session.refresh(db_fi)
        return db_fi

# Populate with model objects from SBLInstitutionTypeDao and clear out
# the id field since it's just a view
if "sbl_institution_type_ids" in fi_data:
sbl_type_stmt = select(SBLInstitutionTypeDao).filter(
SBLInstitutionTypeDao.id.in_(fi_data["sbl_institution_type_ids"])
)
sbl_types = await session.scalars(sbl_type_stmt)
fi_data["sbl_institution_types"] = sbl_types.all()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to keep in mind with not initializing this field out of the if check is any previous relationships would remain the same if the field if not specified. Since this function will only be used by internal processes, I'm ok with this behavior, but it is worth noting this is different from the rest of the fields.
I think we're good to merge now once the commented out code are cleaned up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested wrong, not an issue; we're good to go after clean up.

del fi_data["sbl_institution_type_ids"]

db_fi = await session.merge(FinancialInstitutionDao(**fi_data))
await session.flush([db_fi])
await session.refresh(db_fi)
Expand Down
3 changes: 1 addition & 2 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock:
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_type_id="SIT1",
sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down
16 changes: 6 additions & 10 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas
primary_federal_regulator=FederalRegulatorDao(id="FRI2", name="FRI2"),
hmda_institution_type_id="HIT2",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT2", name="HIT2"),
sbl_institution_type_id="SIT2",
sbl_institution_type=SBLInstitutionTypeDao(id="SIT2", name="SIT2"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT2", name="SIT2")],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand All @@ -75,7 +74,7 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas
"rssd_id": 12344,
"primary_federal_regulator_id": "FRI2",
"hmda_institution_type_id": "HIT2",
"sbl_institution_type_id": "SIT2",
"sbl_institution_type_ids": ["SIT2"],
"hq_address_street_1": "Test Address Street 1",
"hq_address_street_2": "",
"hq_address_city": "Test City 1",
Expand Down Expand Up @@ -163,7 +162,7 @@ def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, aut
"rssd_id": 12344,
"primary_federal_regulator_id": "FIR2",
"hmda_institution_type_id": "HIT2",
"sbl_institution_type_id": "SIT2",
"sbl_institution_type_ids": ["SIT2"],
"hq_address_street_1": "Test Address Street 1",
"hq_address_street_2": "",
"hq_address_city": "Test City 1",
Expand Down Expand Up @@ -198,8 +197,7 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_type_id="SIT1",
sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down Expand Up @@ -296,8 +294,7 @@ def test_get_associated_institutions(
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_type_id="SIT1",
sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand All @@ -322,8 +319,7 @@ def test_get_associated_institutions(
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_type_id="SIT1",
sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
hq_address_street_1="Test Address Street 2",
hq_address_street_2="",
hq_address_city="Test City 2",
Expand Down
18 changes: 11 additions & 7 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from entities.models import (
FinancialInstitutionDao,
FinancialInstitutionDto,
FinancialInstitutionDomainDao,
FinancialInsitutionDomainCreate,
)
Expand Down Expand Up @@ -52,7 +53,7 @@ async def setup(
rssd_id=1234,
primary_federal_regulator_id="FRI1",
hmda_institution_type_id="HIT1",
sbl_institution_type_id="SIT1",
sbl_institution_types=[sbl_it_dao_sit1],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand All @@ -74,7 +75,7 @@ async def setup(
rssd_id=4321,
primary_federal_regulator_id="FRI2",
hmda_institution_type_id="HIT2",
sbl_institution_type_id="SIT2",
sbl_institution_types=[sbl_it_dao_sit2],
hq_address_street_1="Test Address Street 2",
hq_address_street_2="",
hq_address_city="Test City 2",
Expand All @@ -96,7 +97,7 @@ async def setup(
rssd_id=2134,
primary_federal_regulator_id="FRI3",
hmda_institution_type_id="HIT3",
sbl_institution_type_id="SIT3",
sbl_institution_types=[sbl_it_dao_sit3],
hq_address_street_1="Test Address Street 3",
hq_address_street_2="",
hq_address_city="Test City 3",
Expand Down Expand Up @@ -190,15 +191,15 @@ async def test_get_institutions_by_lei_list_item_not_existing(self, query_sessio
async def test_add_institution(self, transaction_session: AsyncSession):
db_fi = await repo.upsert_institution(
transaction_session,
FinancialInstitutionDao(
FinancialInstitutionDto(
name="New Bank 123",
lei="NEWBANK123",
is_active=True,
tax_id="654321987",
rssd_id=6543,
primary_federal_regulator_id="FRI3",
hmda_institution_type_id="HIT3",
sbl_institution_type_id="SIT3",
sbl_institution_type_ids=["SIT3"],
hq_address_street_1="Test Address Street 3",
hq_address_street_2="",
hq_address_city="Test City 3",
Expand All @@ -215,6 +216,9 @@ async def test_add_institution(self, transaction_session: AsyncSession):
assert db_fi.domains == []
res = await repo.get_institutions(transaction_session)
assert len(res) == 4
new_sbl_types = next(iter([fi for fi in res if fi.lei == "NEWBANK123"])).sbl_institution_types
assert len(new_sbl_types) == 1
assert next(iter(new_sbl_types)).name == "Test SBL Instituion ID 3"

async def test_add_institution_only_required_fields(
self, transaction_session: AsyncSession, query_session: AsyncSession
Expand Down Expand Up @@ -309,8 +313,8 @@ async def test_institution_mapped_to_hmda_it_invalid(self, query_session: AsyncS

async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK123"])
assert res[0].sbl_institution_type.name == "Test SBL Instituion ID 1"
assert res[0].sbl_institution_types[0].name == "Test SBL Instituion ID 1"

async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK456"])
assert res[0].sbl_institution_type.name != "Test SBL Instituion ID 1"
assert res[0].sbl_institution_types[0].name != "Test SBL Instituion ID 1"
Loading