diff --git a/connectors/sources/abs.py b/connectors/sources/abs.py index c505bf09b..b4e7197f6 100644 --- a/connectors/sources/abs.py +++ b/connectors/sources/abs.py @@ -26,6 +26,7 @@ DEFAULT_CONTENT_EXTRACTION = True DEFAULT_FILE_SIZE_LIMIT = 10485760 DEFAULT_RETRY_COUNT = 3 +MAX_CONCURRENT_DOWNLOADS = 100 # Max concurrent download supported by abs class AzureBlobStorageDataSource(BaseDataSource): @@ -35,16 +36,28 @@ def __init__(self, configuration): """Setup the connection to the azure base client Args: - connector (BYOConnector): Object of the BYOConnector class + configuration (DataSourceConfiguration): Object of DataSourceConfiguration class. """ super().__init__(configuration=configuration) self.connection_string = None - self.enable_content_extraction = self.configuration.get( - "enable_content_extraction", DEFAULT_CONTENT_EXTRACTION - ) - self.retry_count = int( - self.configuration.get("retry_count", DEFAULT_RETRY_COUNT) - ) + self.enable_content_extraction = self.configuration["enable_content_extraction"] + self.retry_count = self.configuration["retry_count"] + self.concurrent_downloads = self.configuration["concurrent_downloads"] + + def tweak_bulk_options(self, options): + """Tweak bulk options as per concurrent downloads support by azure blob storage + + Args: + options (dictionary): Config bulker options + + Raises: + Exception: Invalid configured concurrent_downloads + """ + if self.concurrent_downloads > MAX_CONCURRENT_DOWNLOADS: + raise Exception( + f"Configured concurrent downloads can't be set more than {MAX_CONCURRENT_DOWNLOADS}." + ) + options["concurrent_downloads"] = self.concurrent_downloads @classmethod def get_default_configuration(cls): @@ -84,6 +97,11 @@ def get_default_configuration(cls): "label": "How many retry count for fetching rows on each call", "type": "int", }, + "concurrent_downloads": { + "value": MAX_CONCURRENT_DOWNLOADS, + "label": "How many concurrent downloads for fetching blob content", + "type": "int", + }, } def _configure_connection_string(self): @@ -164,6 +182,10 @@ async def get_content(self, blob, timestamp=None, doit=None): logger.warning(f"{blob_name} can't be extracted") return + if blob["tier"] == "Archive": + logger.warning(f"{blob_name} can't be downloaded as blob tier is archive") + return + if blob_size > DEFAULT_FILE_SIZE_LIMIT: logger.warning( f"File size {blob_size} of file {blob_name} is larger than {DEFAULT_FILE_SIZE_LIMIT} bytes. Discarding the file content" diff --git a/connectors/sources/tests/test_abs.py b/connectors/sources/tests/test_abs.py index 859337471..382d0d733 100644 --- a/connectors/sources/tests/test_abs.py +++ b/connectors/sources/tests/test_abs.py @@ -439,3 +439,56 @@ def test_configure_connection_string(): with pytest.raises(Exception): # Execute source._configure_connection_string() + + +def test_tweak_bulk_options(): + """Test tweak_bulk_options method of BaseDataSource class""" + + # Setup + source = create_source(AzureBlobStorageDataSource) + options = {} + options["concurrent_downloads"] = 10 + + # Execute + source.tweak_bulk_options(options) + + +def test_tweak_bulk_options_with_invalid(): + """Test tweak_bulk_options method of BaseDataSource class with invalid concurrent downloads""" + + # Setup + source = create_source(AzureBlobStorageDataSource) + options = {} + source.concurrent_downloads = 1000 + + with pytest.raises(Exception): + # Execute + source.tweak_bulk_options(options) + + +@pytest.mark.asyncio +async def test_get_content_when_blob_tier_archive(): + """Test get_content method when the blob tier is archive""" + + # Setup + source = create_source(AzureBlobStorageDataSource) + mock_response = { + "type": "blob", + "id": "container1/blob1", + "timestamp": "2022-04-21T12:12:30", + "created at": "2022-04-21T12:12:30", + "content type": "plain/text", + "container metadata": "{'key1': 'value1'}", + "metadata": "{'key1': 'value1', 'key2': 'value2'}", + "leasedata": "{'status': 'Locked', 'state': 'Leased', 'duration': 'Infinite'}", + "title": "blob1.pdf", + "tier": "Archive", + "size": 10, + "container": "container1", + } + + # Execute + actual_response = await source.get_content(mock_response, doit=True) + + # Assert + assert actual_response is None