diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 6b0470b82361e..22a33fa13c63b 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -43,10 +43,10 @@ from apache_beam import version as beam_version from apache_beam.internal.gcp import auth +from apache_beam.io.gcp import gcsio_retry from apache_beam.metrics.metric import Metrics from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.options.pipeline_options import PipelineOptions -from apache_beam.utils import retry from apache_beam.utils.annotations import deprecated __all__ = ['GcsIO', 'create_storage_client'] @@ -155,6 +155,9 @@ def __init__(self, storage_client=None, pipeline_options=None): self.client = storage_client self._rewrite_cb = None self.bucket_to_project_number = {} + self._storage_client_retry = gcsio_retry.get_retry(pipeline_options) + self._use_blob_generation = getattr( + google_cloud_options, 'enable_gcsio_blob_generation', False) def get_project_number(self, bucket): if bucket not in self.bucket_to_project_number: @@ -167,7 +170,8 @@ def get_project_number(self, bucket): def get_bucket(self, bucket_name, **kwargs): """Returns an object bucket from its name, or None if it does not exist.""" try: - return self.client.lookup_bucket(bucket_name, **kwargs) + return self.client.lookup_bucket( + bucket_name, retry=self._storage_client_retry, **kwargs) except NotFound: return None @@ -188,7 +192,7 @@ def create_bucket( bucket_or_name=bucket, project=project, location=location, - ) + retry=self._storage_client_retry) if kms_key: bucket.default_kms_key_name(kms_key) bucket.patch() @@ -224,18 +228,18 @@ def open( return BeamBlobReader( blob, chunk_size=read_buffer_size, - enable_read_bucket_metric=self.enable_read_bucket_metric) + enable_read_bucket_metric=self.enable_read_bucket_metric, + retry=self._storage_client_retry) elif mode == 'w' or mode == 'wb': blob = bucket.blob(blob_name) return BeamBlobWriter( blob, mime_type, - enable_write_bucket_metric=self.enable_write_bucket_metric) + enable_write_bucket_metric=self.enable_write_bucket_metric, + retry=self._storage_client_retry) else: raise ValueError('Invalid file open mode: %s.' % mode) - @retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter) def delete(self, path): """Deletes the object at the given GCS path. @@ -243,14 +247,24 @@ def delete(self, path): path: GCS file path pattern in the form gs:///. """ bucket_name, blob_name = parse_gcs_path(path) + bucket = self.client.bucket(bucket_name) + if self._use_blob_generation: + # blob can be None if not found + blob = bucket.get_blob(blob_name, retry=self._storage_client_retry) + generation = getattr(blob, "generation", None) + else: + generation = None try: - bucket = self.client.bucket(bucket_name) - bucket.delete_blob(blob_name) + bucket.delete_blob( + blob_name, + if_generation_match=generation, + retry=self._storage_client_retry) except NotFound: return def delete_batch(self, paths): """Deletes the objects at the given GCS paths. + Warning: any exception during batch delete will NOT be retried. Args: paths: List of GCS file path patterns or Dict with GCS file path patterns @@ -287,8 +301,6 @@ def delete_batch(self, paths): return final_results - @retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter) def copy(self, src, dest): """Copies the given GCS object from src to dest. @@ -297,19 +309,32 @@ def copy(self, src, dest): dest: GCS file path pattern in the form gs:///. Raises: - TimeoutError: on timeout. + Any exceptions during copying """ src_bucket_name, src_blob_name = parse_gcs_path(src) dest_bucket_name, dest_blob_name= parse_gcs_path(dest, object_optional=True) src_bucket = self.client.bucket(src_bucket_name) - src_blob = src_bucket.blob(src_blob_name) + if self._use_blob_generation: + src_blob = src_bucket.get_blob(src_blob_name) + if src_blob is None: + raise NotFound("source blob %s not found during copying" % src) + src_generation = src_blob.generation + else: + src_blob = src_bucket.blob(src_blob_name) + src_generation = None dest_bucket = self.client.bucket(dest_bucket_name) if not dest_blob_name: dest_blob_name = None - src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name) + src_bucket.copy_blob( + src_blob, + dest_bucket, + new_name=dest_blob_name, + source_generation=src_generation, + retry=self._storage_client_retry) def copy_batch(self, src_dest_pairs): """Copies the given GCS objects from src to dest. + Warning: any exception during batch copy will NOT be retried. Args: src_dest_pairs: list of (src, dest) tuples of gs:/// files @@ -450,8 +475,6 @@ def _status(self, path): file_status['size'] = gcs_object.size return file_status - @retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter) def _gcs_object(self, path): """Returns a gcs object for the given path @@ -462,7 +485,7 @@ def _gcs_object(self, path): """ bucket_name, blob_name = parse_gcs_path(path) bucket = self.client.bucket(bucket_name) - blob = bucket.get_blob(blob_name) + blob = bucket.get_blob(blob_name, retry=self._storage_client_retry) if blob: return blob else: @@ -510,7 +533,8 @@ def list_files(self, path, with_metadata=False): else: _LOGGER.debug("Starting the size estimation of the input") bucket = self.client.bucket(bucket_name) - response = self.client.list_blobs(bucket, prefix=prefix) + response = self.client.list_blobs( + bucket, prefix=prefix, retry=self._storage_client_retry) for item in response: file_name = 'gs://%s/%s' % (item.bucket.name, item.name) if file_name not in file_info: @@ -546,8 +570,7 @@ def _updated_to_seconds(updated): def is_soft_delete_enabled(self, gcs_path): try: bucket_name, _ = parse_gcs_path(gcs_path) - # set retry timeout to 5 seconds when checking soft delete policy - bucket = self.get_bucket(bucket_name, retry=DEFAULT_RETRY.with_timeout(5)) + bucket = self.get_bucket(bucket_name) if (bucket.soft_delete_policy is not None and bucket.soft_delete_policy.retention_duration_seconds > 0): return True @@ -563,8 +586,9 @@ def __init__( self, blob, chunk_size=DEFAULT_READ_BUFFER_SIZE, - enable_read_bucket_metric=False): - super().__init__(blob, chunk_size=chunk_size) + enable_read_bucket_metric=False, + retry=DEFAULT_RETRY): + super().__init__(blob, chunk_size=chunk_size, retry=retry) self.enable_read_bucket_metric = enable_read_bucket_metric self.mode = "r" @@ -585,13 +609,14 @@ def __init__( content_type, chunk_size=16 * 1024 * 1024, ignore_flush=True, - enable_write_bucket_metric=False): + enable_write_bucket_metric=False, + retry=DEFAULT_RETRY): super().__init__( blob, content_type=content_type, chunk_size=chunk_size, ignore_flush=ignore_flush, - retry=DEFAULT_RETRY) + retry=retry) self.mode = "w" self.enable_write_bucket_metric = enable_write_bucket_metric diff --git a/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py b/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py index fad638136804e..07a5fb5df5535 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_integration_test.py @@ -34,6 +34,7 @@ import mock import pytest +from parameterized import parameterized_class from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import GoogleCloudOptions @@ -51,6 +52,9 @@ @unittest.skipIf(gcsio is None, 'GCP dependencies are not installed') +@parameterized_class( + ('no_gcsio_throttling_counter', 'enable_gcsio_blob_generation'), + [(False, False), (False, True), (True, False), (True, True)]) class GcsIOIntegrationTest(unittest.TestCase): INPUT_FILE = 'gs://dataflow-samples/shakespeare/kinglear.txt' @@ -67,7 +71,6 @@ def setUp(self): self.gcs_tempdir = ( self.test_pipeline.get_option('temp_location') + '/gcs_it-' + str(uuid.uuid4())) - self.gcsio = gcsio.GcsIO() def tearDown(self): FileSystems.delete([self.gcs_tempdir + '/']) @@ -92,14 +95,47 @@ def _verify_copy(self, src, dest, dest_kms_key_name=None): @pytest.mark.it_postcommit def test_copy(self): + self.gcsio = gcsio.GcsIO( + pipeline_options={ + "no_gcsio_throttling_counter": self.no_gcsio_throttling_counter, + "enable_gcsio_blob_generation": self.enable_gcsio_blob_generation + }) + src = self.INPUT_FILE + dest = self.gcs_tempdir + '/test_copy' + + self.gcsio.copy(src, dest) + self._verify_copy(src, dest) + + unknown_src = self.test_pipeline.get_option('temp_location') + \ + '/gcs_it-' + str(uuid.uuid4()) + with self.assertRaises(NotFound): + self.gcsio.copy(unknown_src, dest) + + @pytest.mark.it_postcommit + def test_copy_and_delete(self): + self.gcsio = gcsio.GcsIO( + pipeline_options={ + "no_gcsio_throttling_counter": self.no_gcsio_throttling_counter, + "enable_gcsio_blob_generation": self.enable_gcsio_blob_generation + }) src = self.INPUT_FILE dest = self.gcs_tempdir + '/test_copy' self.gcsio.copy(src, dest) self._verify_copy(src, dest) + self.gcsio.delete(dest) + + # no exception if we delete an nonexistent file. + self.gcsio.delete(dest) + @pytest.mark.it_postcommit def test_batch_copy_and_delete(self): + self.gcsio = gcsio.GcsIO( + pipeline_options={ + "no_gcsio_throttling_counter": self.no_gcsio_throttling_counter, + "enable_gcsio_blob_generation": self.enable_gcsio_blob_generation + }) num_copies = 10 srcs = [self.INPUT_FILE] * num_copies dests = [ @@ -152,6 +188,7 @@ def test_batch_copy_and_delete(self): @mock.patch('apache_beam.io.gcp.gcsio.default_gcs_bucket_name') @unittest.skipIf(NotFound is None, 'GCP dependencies are not installed') def test_create_default_bucket(self, mock_default_gcs_bucket_name): + self.gcsio = gcsio.GcsIO() google_cloud_options = self.test_pipeline.options.view_as( GoogleCloudOptions) # overwrite kms option here, because get_or_create_default_gcs_bucket() diff --git a/sdks/python/apache_beam/io/gcp/gcsio_retry.py b/sdks/python/apache_beam/io/gcp/gcsio_retry.py new file mode 100644 index 0000000000000..29fd71c5195b4 --- /dev/null +++ b/sdks/python/apache_beam/io/gcp/gcsio_retry.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Throttling Handler for GCSIO +""" + +import inspect +import logging +import math + +from google.api_core import exceptions as api_exceptions +from google.api_core import retry +from google.cloud.storage.retry import DEFAULT_RETRY +from google.cloud.storage.retry import _should_retry # pylint: disable=protected-access + +from apache_beam.metrics.metric import Metrics +from apache_beam.options.pipeline_options import GoogleCloudOptions + +_LOGGER = logging.getLogger(__name__) + +__all__ = ['DEFAULT_RETRY_WITH_THROTTLING_COUNTER'] + + +class ThrottlingHandler(object): + _THROTTLED_SECS = Metrics.counter('gcsio', "cumulativeThrottlingSeconds") + + def __call__(self, exc): + if isinstance(exc, api_exceptions.TooManyRequests): + _LOGGER.debug('Caught GCS quota error (%s), retrying.', exc.reason) + # TODO: revisit the logic here when gcs client library supports error + # callbacks + frame = inspect.currentframe() + if frame is None: + _LOGGER.warning('cannot inspect the current stack frame') + return + + prev_frame = frame.f_back + if prev_frame is None: + _LOGGER.warning('cannot inspect the caller stack frame') + return + + # next_sleep is one of the arguments in the caller + # i.e. _retry_error_helper() in google/api_core/retry/retry_base.py + sleep_seconds = prev_frame.f_locals.get("next_sleep", 0) + ThrottlingHandler._THROTTLED_SECS.inc(math.ceil(sleep_seconds)) + + +DEFAULT_RETRY_WITH_THROTTLING_COUNTER = retry.Retry( + predicate=_should_retry, on_error=ThrottlingHandler()) + + +def get_retry(pipeline_options): + if pipeline_options.view_as(GoogleCloudOptions).no_gcsio_throttling_counter: + return DEFAULT_RETRY + else: + return DEFAULT_RETRY_WITH_THROTTLING_COUNTER diff --git a/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py b/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py new file mode 100644 index 0000000000000..750879ae0284c --- /dev/null +++ b/sdks/python/apache_beam/io/gcp/gcsio_retry_test.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for Throttling Handler of GCSIO.""" + +import unittest +from unittest.mock import Mock + +from apache_beam.metrics.execution import MetricsContainer +from apache_beam.metrics.execution import MetricsEnvironment +from apache_beam.metrics.metricbase import MetricName +from apache_beam.runners.worker import statesampler +from apache_beam.utils import counters + +try: + from apache_beam.io.gcp import gcsio_retry + from google.api_core import exceptions as api_exceptions +except ImportError: + gcsio_retry = None + api_exceptions = None + + +@unittest.skipIf((gcsio_retry is None or api_exceptions is None), + 'GCP dependencies are not installed') +class TestGCSIORetry(unittest.TestCase): + def test_retry_on_non_retriable(self): + mock = Mock(side_effect=[ + Exception('Something wrong!'), + ]) + retry = gcsio_retry.DEFAULT_RETRY_WITH_THROTTLING_COUNTER + with self.assertRaises(Exception): + retry(mock)() + + def test_retry_on_throttling(self): + mock = Mock( + side_effect=[ + api_exceptions.TooManyRequests("Slow down!"), + api_exceptions.TooManyRequests("Slow down again!"), + 12345 + ]) + retry = gcsio_retry.DEFAULT_RETRY_WITH_THROTTLING_COUNTER + + sampler = statesampler.StateSampler('', counters.CounterFactory()) + statesampler.set_current_tracker(sampler) + state = sampler.scoped_state( + 'my_step', 'my_state', metrics_container=MetricsContainer('my_step')) + try: + sampler.start() + with state: + container = MetricsEnvironment.current_container() + + self.assertEqual( + container.get_counter( + MetricName('gcsio', + "cumulativeThrottlingSeconds")).get_cumulative(), + 0) + + self.assertEqual(12345, retry(mock)()) + + self.assertGreater( + container.get_counter( + MetricName('gcsio', + "cumulativeThrottlingSeconds")).get_cumulative(), + 1) + finally: + sampler.stop() + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 407295f2fb301..19df15dcf7fab 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -20,6 +20,7 @@ import logging import os +import random import unittest from datetime import datetime @@ -36,6 +37,7 @@ try: from apache_beam.io.gcp import gcsio + from apache_beam.io.gcp.gcsio_retry import DEFAULT_RETRY_WITH_THROTTLING_COUNTER from google.cloud.exceptions import BadRequest, NotFound except ImportError: NotFound = None @@ -85,7 +87,7 @@ def get_file(self, bucket, blob): holder = folder.get_blob(blob.name) return holder - def list_blobs(self, bucket_or_path, prefix=None): + def list_blobs(self, bucket_or_path, prefix=None, **unused_kwargs): bucket = self.get_bucket(bucket_or_path.name) if not prefix: return list(bucket.blobs.values()) @@ -120,7 +122,7 @@ def add_blob(self, blob): def blob(self, name): return self._create_blob(name) - def copy_blob(self, blob, dest, new_name=None): + def copy_blob(self, blob, dest, new_name=None, **kwargs): if self.get_blob(blob.name) is None: raise NotFound("source blob not found") if not new_name: @@ -129,7 +131,7 @@ def copy_blob(self, blob, dest, new_name=None): dest.add_blob(new_blob) return new_blob - def get_blob(self, blob_name): + def get_blob(self, blob_name, **unused_kwargs): bucket = self._get_canonical_bucket() if blob_name in bucket.blobs: return bucket.blobs[blob_name] @@ -146,7 +148,7 @@ def lookup_blob(self, name): def set_default_kms_key_name(self, name): self.default_kms_key_name = name - def delete_blob(self, name): + def delete_blob(self, name, **kwargs): bucket = self._get_canonical_bucket() if name in bucket.blobs: del bucket.blobs[name] @@ -175,6 +177,7 @@ def __init__( self.updated = updated self._fail_when_getting_metadata = fail_when_getting_metadata self._fail_when_reading = fail_when_reading + self.generation = random.randint(0, (1 << 63) - 1) def delete(self): self.bucket.delete_blob(self.name) @@ -532,7 +535,10 @@ def test_file_buffered_read_call(self): with mock.patch('apache_beam.io.gcp.gcsio.BeamBlobReader') as reader: self.gcs.open(file_name, read_buffer_size=read_buffer_size) reader.assert_called_with( - blob, chunk_size=read_buffer_size, enable_read_bucket_metric=False) + blob, + chunk_size=read_buffer_size, + enable_read_bucket_metric=False, + retry=DEFAULT_RETRY_WITH_THROTTLING_COUNTER) def test_file_write_call(self): file_name = 'gs://gcsio-test/write_file' diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index e20e0e8ca046e..0c5a2f961a468 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -939,6 +939,18 @@ def _add_argparse_args(cls, parser): help= 'Create metrics reporting the approximate number of bytes written per ' 'bucket.') + parser.add_argument( + '--no_gcsio_throttling_counter', + default=False, + action='store_true', + help='Throttling counter in GcsIO is enabled by default. Set ' + '--no_gcsio_throttling_counter to avoid it.') + parser.add_argument( + '--enable_gcsio_blob_generation', + default=False, + action='store_true', + help='Use blob generation when mutating blobs in GCSIO to ' + 'mitigate race conditions at the cost of more HTTP requests.') def _create_default_gcs_bucket(self): try: