diff --git a/db_revisions/versions/ada681e1877f_create_fi_to_type_mapping_table.py b/db_revisions/versions/ada681e1877f_create_fi_to_type_mapping_table.py new file mode 100644 index 0000000..f5b7687 --- /dev/null +++ b/db_revisions/versions/ada681e1877f_create_fi_to_type_mapping_table.py @@ -0,0 +1,49 @@ +"""create fi_to_type_mapping table + +Revision ID: ada681e1877f +Revises: 383ab402c8c2 +Create Date: 2023-12-29 12:33:11.031470 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from db_revisions.utils import table_exists + + +# revision identifiers, used by Alembic. +revision: str = "ada681e1877f" +down_revision: Union[str, None] = "383ab402c8c2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + if not table_exists("fi_to_type_mapping"): + op.create_table( + "fi_to_type_mapping", + sa.Column("fi_id", sa.String(), sa.ForeignKey("financial_institutions.lei"), primary_key=True), + sa.Column("type_id", sa.String(), sa.ForeignKey("sbl_institution_type.id"), primary_key=True), + ) + with op.batch_alter_table("financial_institutions") as batch_op: + batch_op.drop_constraint("fk_sbl_institution_type_financial_institutions", type_="foreignkey") + batch_op.drop_index(op.f("ix_financial_institutions_sbl_institution_type_id")) + batch_op.drop_column("sbl_institution_type_id") + + +def downgrade() -> None: + op.drop_table("fi_to_type_mapping") + with op.batch_alter_table("financial_institutions") as batch_op: + batch_op.add_column(sa.Column("sbl_institution_type_id", sa.String(), nullable=True)) + batch_op.create_foreign_key( + "fk_sbl_institution_type_financial_institutions", + "sbl_institution_type", + ["sbl_institution_type_id"], + ["id"], + ) + batch_op.create_index( + op.f("ix_financial_institutions_sbl_institution_type_id"), + ["sbl_institution_type_id"], + unique=False, + ) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index f8c1eba..3cf3bb7 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -1,9 +1,10 @@ from datetime import datetime from typing import List -from sqlalchemy import ForeignKey, func, String +from sqlalchemy import ForeignKey, func, String, Table, Column from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy class Base(AsyncAttrs, DeclarativeBase): @@ -14,6 +15,14 @@ class AuditMixin(object): event_time: Mapped[datetime] = mapped_column(server_default=func.now()) +fi_to_type_mapping = Table( + "fi_to_type_mapping", + Base.metadata, + Column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True), + Column("type_id", ForeignKey("sbl_institution_type.id"), primary_key=True), +) + + class FinancialInstitutionDao(AuditMixin, Base): __tablename__ = "financial_institutions" lei: Mapped[str] = mapped_column(unique=True, index=True, primary_key=True) @@ -28,8 +37,10 @@ class FinancialInstitutionDao(AuditMixin, Base): primary_federal_regulator: Mapped["FederalRegulatorDao"] = relationship(lazy="selectin") hmda_institution_type_id: Mapped[str] = mapped_column(ForeignKey("hmda_institution_type.id"), nullable=True) hmda_institution_type: Mapped["HMDAInstitutionTypeDao"] = relationship(lazy="selectin") - sbl_institution_type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), nullable=True) - sbl_institution_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin") + sbl_institution_types: Mapped[List["SBLInstitutionTypeDao"]] = relationship( + lazy="selectin", secondary=fi_to_type_mapping + ) + sbl_institution_type_ids: AssociationProxy[List[str]] = association_proxy("sbl_institution_types", "id") hq_address_street_1: Mapped[str] hq_address_street_2: Mapped[str] = mapped_column(nullable=True) hq_address_city: Mapped[str] diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index 192fc9f..275cbef 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -29,7 +29,7 @@ class FinancialInstitutionDto(FinancialInstitutionBase): rssd_id: int | None = None primary_federal_regulator_id: str | None = None hmda_institution_type_id: str | None = None - sbl_institution_type_id: str | None = None + sbl_institution_type_ids: List[str] = [] hq_address_street_1: str hq_address_street_2: str | None = None hq_address_city: str @@ -95,7 +95,7 @@ class Config: class FinancialInstitutionWithRelationsDto(FinancialInstitutionDto): primary_federal_regulator: FederalRegulatorDto | None = None hmda_institution_type: InstitutionTypeDto | None = None - sbl_institution_type: InstitutionTypeDto | None = None + sbl_institution_types: List[InstitutionTypeDto] = [] hq_address_state: AddressStateDto domains: List[FinancialInsitutionDomainDto] = [] diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 3bb1340..679c131 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -71,6 +71,17 @@ async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) async with session.begin(): fi_data = fi.__dict__.copy() fi_data.pop("_sa_instance_state", None) + + # Populate with model objects from SBLInstitutionTypeDao and clear out + # the id field since it's just a view + if "sbl_institution_type_ids" in fi_data: + sbl_type_stmt = select(SBLInstitutionTypeDao).filter( + SBLInstitutionTypeDao.id.in_(fi_data["sbl_institution_type_ids"]) + ) + sbl_types = await session.scalars(sbl_type_stmt) + fi_data["sbl_institution_types"] = sbl_types.all() + del fi_data["sbl_institution_type_ids"] + db_fi = await session.merge(FinancialInstitutionDao(**fi_data)) await session.flush([db_fi]) await session.refresh(db_fi) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 1c95982..80a525d 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -70,8 +70,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock: primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_type_id="SIT1", - sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"), + sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index d5f0abd..9af99dd 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -47,8 +47,7 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas primary_federal_regulator=FederalRegulatorDao(id="FRI2", name="FRI2"), hmda_institution_type_id="HIT2", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT2", name="HIT2"), - sbl_institution_type_id="SIT2", - sbl_institution_type=SBLInstitutionTypeDao(id="SIT2", name="SIT2"), + sbl_institution_types=[SBLInstitutionTypeDao(id="SIT2", name="SIT2")], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -75,7 +74,7 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas "rssd_id": 12344, "primary_federal_regulator_id": "FRI2", "hmda_institution_type_id": "HIT2", - "sbl_institution_type_id": "SIT2", + "sbl_institution_type_ids": ["SIT2"], "hq_address_street_1": "Test Address Street 1", "hq_address_street_2": "", "hq_address_city": "Test City 1", @@ -163,7 +162,7 @@ def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, aut "rssd_id": 12344, "primary_federal_regulator_id": "FIR2", "hmda_institution_type_id": "HIT2", - "sbl_institution_type_id": "SIT2", + "sbl_institution_type_ids": ["SIT2"], "hq_address_street_1": "Test Address Street 1", "hq_address_street_2": "", "hq_address_city": "Test City 1", @@ -198,8 +197,7 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_type_id="SIT1", - sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"), + sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -296,8 +294,7 @@ def test_get_associated_institutions( primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_type_id="SIT1", - sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"), + sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -322,8 +319,7 @@ def test_get_associated_institutions( primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_type_id="SIT1", - sbl_institution_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"), + sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], hq_address_street_1="Test Address Street 2", hq_address_street_2="", hq_address_city="Test City 2", diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 74718ed..29cc3f0 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -3,6 +3,7 @@ from entities.models import ( FinancialInstitutionDao, + FinancialInstitutionDto, FinancialInstitutionDomainDao, FinancialInsitutionDomainCreate, ) @@ -52,7 +53,7 @@ async def setup( rssd_id=1234, primary_federal_regulator_id="FRI1", hmda_institution_type_id="HIT1", - sbl_institution_type_id="SIT1", + sbl_institution_types=[sbl_it_dao_sit1], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -74,7 +75,7 @@ async def setup( rssd_id=4321, primary_federal_regulator_id="FRI2", hmda_institution_type_id="HIT2", - sbl_institution_type_id="SIT2", + sbl_institution_types=[sbl_it_dao_sit2], hq_address_street_1="Test Address Street 2", hq_address_street_2="", hq_address_city="Test City 2", @@ -96,7 +97,7 @@ async def setup( rssd_id=2134, primary_federal_regulator_id="FRI3", hmda_institution_type_id="HIT3", - sbl_institution_type_id="SIT3", + sbl_institution_types=[sbl_it_dao_sit3], hq_address_street_1="Test Address Street 3", hq_address_street_2="", hq_address_city="Test City 3", @@ -190,7 +191,7 @@ async def test_get_institutions_by_lei_list_item_not_existing(self, query_sessio async def test_add_institution(self, transaction_session: AsyncSession): db_fi = await repo.upsert_institution( transaction_session, - FinancialInstitutionDao( + FinancialInstitutionDto( name="New Bank 123", lei="NEWBANK123", is_active=True, @@ -198,7 +199,7 @@ async def test_add_institution(self, transaction_session: AsyncSession): rssd_id=6543, primary_federal_regulator_id="FRI3", hmda_institution_type_id="HIT3", - sbl_institution_type_id="SIT3", + sbl_institution_type_ids=["SIT3"], hq_address_street_1="Test Address Street 3", hq_address_street_2="", hq_address_city="Test City 3", @@ -215,6 +216,9 @@ async def test_add_institution(self, transaction_session: AsyncSession): assert db_fi.domains == [] res = await repo.get_institutions(transaction_session) assert len(res) == 4 + new_sbl_types = next(iter([fi for fi in res if fi.lei == "NEWBANK123"])).sbl_institution_types + assert len(new_sbl_types) == 1 + assert next(iter(new_sbl_types)).name == "Test SBL Instituion ID 3" async def test_add_institution_only_required_fields( self, transaction_session: AsyncSession, query_session: AsyncSession @@ -309,8 +313,8 @@ async def test_institution_mapped_to_hmda_it_invalid(self, query_session: AsyncS async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, leis=["TESTBANK123"]) - assert res[0].sbl_institution_type.name == "Test SBL Instituion ID 1" + assert res[0].sbl_institution_types[0].name == "Test SBL Instituion ID 1" async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, leis=["TESTBANK456"]) - assert res[0].sbl_institution_type.name != "Test SBL Instituion ID 1" + assert res[0].sbl_institution_types[0].name != "Test SBL Instituion ID 1"