diff --git a/api/src/search/backend/load_opportunities_to_index.py b/api/src/search/backend/load_opportunities_to_index.py index a1e7d0c9c..178cc6f32 100644 --- a/api/src/search/backend/load_opportunities_to_index.py +++ b/api/src/search/backend/load_opportunities_to_index.py @@ -3,7 +3,6 @@ from enum import StrEnum from typing import Iterator, Sequence -import smart_open from opensearchpy.exceptions import ConnectionTimeout, TransportError from pydantic import Field from pydantic_settings import SettingsConfigDict @@ -22,11 +21,16 @@ OpportunitySearchIndexQueue, ) from src.task.task import Task +from src.util import file_util from src.util.datetime_util import get_now_us_eastern_datetime from src.util.env_config import PydanticBaseEnvConfig logger = logging.getLogger(__name__) +ALLOWED_ATTACHMENT_SUFFIXES = set( + ["txt", "pdf", "docx", "doc", "xlsx", "xlsm", "html", "htm", "pptx", "ppt", "rtf"] +) + class LoadOpportunitiesToIndexConfig(PydanticBaseEnvConfig): model_config = SettingsConfigDict(env_prefix="LOAD_OPP_SEARCH_") @@ -275,10 +279,9 @@ def fetch_existing_opportunity_ids_in_index(self) -> set[int]: return opportunity_ids - def filter_attachments( - self, attachments: list[OpportunityAttachment] - ) -> list[OpportunityAttachment]: - return [attachment for attachment in attachments] + def filter_attachment(self, attachment: OpportunityAttachment) -> bool: + file_suffix = attachment.file_name.lower().split(".")[-1] + return file_suffix in ALLOWED_ATTACHMENT_SUFFIXES def get_attachment_json_for_opportunity( self, opp_attachments: list[OpportunityAttachment] @@ -286,17 +289,18 @@ def get_attachment_json_for_opportunity( attachments = [] for att in opp_attachments: - with smart_open.open( - att.file_location, - "rb", - ) as file: - file_content = file.read() - attachments.append( - { - "filename": att.file_name, - "data": base64.b64encode(file_content).decode("utf-8"), - } - ) + if self.filter_attachment(att): + with file_util.open_stream( + att.file_location, + "rb", + ) as file: + file_content = file.read() + attachments.append( + { + "filename": att.file_name, + "data": base64.b64encode(file_content).decode("utf-8"), + } + ) return attachments diff --git a/api/tests/src/search/backend/test_load_opportunities_to_index.py b/api/tests/src/search/backend/test_load_opportunities_to_index.py index 64a5190f1..36296baaa 100644 --- a/api/tests/src/search/backend/test_load_opportunities_to_index.py +++ b/api/tests/src/search/backend/test_load_opportunities_to_index.py @@ -149,18 +149,27 @@ def test_opportunity_attachment_pipeline( opportunity_index_alias, search_client, ): - filename = "test_file_1.txt" - file_path = f"s3://{mock_s3_bucket}/{filename}" + filename_1 = "test_file_1.txt" + file_path_1 = f"s3://{mock_s3_bucket}/{filename_1}" content = "I am a file" - with file_util.open_stream(file_path, "w") as outfile: + + with file_util.open_stream(file_path_1, "w") as outfile: outfile.write(content) + filename_2 = "test_file_2.css" + file_path_2 = f"s3://{mock_s3_bucket}/{filename_2}" + opportunity = OpportunityFactory.create(opportunity_attachments=[]) OpportunityAttachmentFactory.create( - mime_type="text/plain", opportunity=opportunity, - file_location=file_path, - file_name=filename, + file_location=file_path_1, + file_name=filename_1, + ) + + OpportunityAttachmentFactory.create( + opportunity=opportunity, + file_location=file_path_2, + file_name=filename_2, ) load_opportunities_to_index.index_name = ( @@ -172,11 +181,14 @@ def test_opportunity_attachment_pipeline( resp = search_client.search(opportunity_index_alias, {"size": 100}) record = [d for d in resp.records if d.get("opportunity_id") == opportunity.opportunity_id] - attachment = record[0]["attachments"][0] + attachments = record[0]["attachments"] + + # assert only one (allowed) opportunity attachment was uploaded + assert len(attachments) == 1 # assert correct attachment was uploaded - assert attachment["filename"] == filename + assert attachments[0]["filename"] == filename_1 # assert data was b64encoded - assert attachment["attachment"]["content"] == content # decoded b64encoded attachment + assert attachments[0]["attachment"]["content"] == content # decoded b64encoded attachment class TestLoadOpportunitiesToIndexPartialRefresh(BaseTestClass):