Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add throttling counter in gcsio and refactor retrying #32428

Merged
merged 20 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@

from apache_beam import version as beam_version
from apache_beam.internal.gcp import auth
from apache_beam.io.gcp import gcsio_retry
from apache_beam.metrics.metric import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.utils import retry
from apache_beam.utils.annotations import deprecated

__all__ = ['GcsIO', 'create_storage_client']
Expand Down Expand Up @@ -155,6 +155,9 @@ def __init__(self, storage_client=None, pipeline_options=None):
self.client = storage_client
self._rewrite_cb = None
self.bucket_to_project_number = {}
self._storage_client_retry = gcsio_retry.get_retry(pipeline_options)
self._use_blob_generation = getattr(
google_cloud_options, 'enable_gcsio_blob_generation', False)

def get_project_number(self, bucket):
if bucket not in self.bucket_to_project_number:
Expand All @@ -167,7 +170,8 @@ def get_project_number(self, bucket):
def get_bucket(self, bucket_name, **kwargs):
"""Returns an object bucket from its name, or None if it does not exist."""
try:
return self.client.lookup_bucket(bucket_name, **kwargs)
return self.client.lookup_bucket(
bucket_name, retry=self._storage_client_retry, **kwargs)
except NotFound:
return None

Expand All @@ -188,7 +192,7 @@ def create_bucket(
bucket_or_name=bucket,
project=project,
location=location,
)
retry=self._storage_client_retry)
if kms_key:
bucket.default_kms_key_name(kms_key)
bucket.patch()
Expand Down Expand Up @@ -224,18 +228,18 @@ def open(
return BeamBlobReader(
blob,
chunk_size=read_buffer_size,
enable_read_bucket_metric=self.enable_read_bucket_metric)
enable_read_bucket_metric=self.enable_read_bucket_metric,
retry=self._storage_client_retry)
elif mode == 'w' or mode == 'wb':
blob = bucket.blob(blob_name)
return BeamBlobWriter(
blob,
mime_type,
enable_write_bucket_metric=self.enable_write_bucket_metric)
enable_write_bucket_metric=self.enable_write_bucket_metric,
retry=self._storage_client_retry)
else:
raise ValueError('Invalid file open mode: %s.' % mode)

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def delete(self, path):
"""Deletes the object at the given GCS path.

Expand All @@ -245,12 +249,21 @@ def delete(self, path):
bucket_name, blob_name = parse_gcs_path(path)
try:
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)
if self._use_blob_generation:
shunping marked this conversation as resolved.
Show resolved Hide resolved
blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
generation = getattr(blob, "generation", None)
else:
generation = None
bucket.delete_blob(
blob_name,
if_generation_match=generation,
retry=self._storage_client_retry)
except NotFound:
return

def delete_batch(self, paths):
"""Deletes the objects at the given GCS paths.
Warning: any exception during batch delete will NOT be retried.

Args:
paths: List of GCS file path patterns or Dict with GCS file path patterns
Expand Down Expand Up @@ -287,8 +300,6 @@ def delete_batch(self, paths):

return final_results

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def copy(self, src, dest):
"""Copies the given GCS object from src to dest.

Expand All @@ -297,19 +308,32 @@ def copy(self, src, dest):
dest: GCS file path pattern in the form gs://<bucket>/<name>.

Raises:
TimeoutError: on timeout.
Any exceptions during copying
"""
src_bucket_name, src_blob_name = parse_gcs_path(src)
dest_bucket_name, dest_blob_name= parse_gcs_path(dest, object_optional=True)
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
if self._use_blob_generation:
shunping marked this conversation as resolved.
Show resolved Hide resolved
src_blob = src_bucket.get_blob(src_blob_name)
if src_blob is None:
raise NotFound("source blob %s not found during copying" % src)
src_generation = getattr(src_blob, "generation", None)
shunping marked this conversation as resolved.
Show resolved Hide resolved
else:
src_blob = src_bucket.blob(src_blob_name)
src_generation = None
dest_bucket = self.client.bucket(dest_bucket_name)
if not dest_blob_name:
dest_blob_name = None
src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name)
src_bucket.copy_blob(
src_blob,
dest_bucket,
new_name=dest_blob_name,
source_generation=src_generation,
retry=self._storage_client_retry)

def copy_batch(self, src_dest_pairs):
"""Copies the given GCS objects from src to dest.
Warning: any exception during batch copy will NOT be retried.

Args:
src_dest_pairs: list of (src, dest) tuples of gs://<bucket>/<name> files
Expand Down Expand Up @@ -450,8 +474,6 @@ def _status(self, path):
file_status['size'] = gcs_object.size
return file_status

@retry.with_exponential_backoff(
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
def _gcs_object(self, path):
"""Returns a gcs object for the given path

Expand All @@ -462,7 +484,7 @@ def _gcs_object(self, path):
"""
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
blob = bucket.get_blob(blob_name)
blob = bucket.get_blob(blob_name, retry=self._storage_client_retry)
if blob:
return blob
else:
Expand Down Expand Up @@ -510,7 +532,8 @@ def list_files(self, path, with_metadata=False):
else:
_LOGGER.debug("Starting the size estimation of the input")
bucket = self.client.bucket(bucket_name)
response = self.client.list_blobs(bucket, prefix=prefix)
response = self.client.list_blobs(
bucket, prefix=prefix, retry=self._storage_client_retry)
for item in response:
file_name = 'gs://%s/%s' % (item.bucket.name, item.name)
if file_name not in file_info:
Expand Down Expand Up @@ -546,8 +569,7 @@ def _updated_to_seconds(updated):
def is_soft_delete_enabled(self, gcs_path):
try:
bucket_name, _ = parse_gcs_path(gcs_path)
# set retry timeout to 5 seconds when checking soft delete policy
bucket = self.get_bucket(bucket_name, retry=DEFAULT_RETRY.with_timeout(5))
bucket = self.get_bucket(bucket_name)
if (bucket.soft_delete_policy is not None and
bucket.soft_delete_policy.retention_duration_seconds > 0):
return True
Expand All @@ -563,8 +585,9 @@ def __init__(
self,
blob,
chunk_size=DEFAULT_READ_BUFFER_SIZE,
enable_read_bucket_metric=False):
super().__init__(blob, chunk_size=chunk_size)
enable_read_bucket_metric=False,
retry=DEFAULT_RETRY):
super().__init__(blob, chunk_size=chunk_size, retry=retry)
self.enable_read_bucket_metric = enable_read_bucket_metric
self.mode = "r"

Expand All @@ -585,13 +608,14 @@ def __init__(
content_type,
chunk_size=16 * 1024 * 1024,
ignore_flush=True,
enable_write_bucket_metric=False):
enable_write_bucket_metric=False,
retry=DEFAULT_RETRY):
super().__init__(
blob,
content_type=content_type,
chunk_size=chunk_size,
ignore_flush=ignore_flush,
retry=DEFAULT_RETRY)
retry=retry)
self.mode = "w"
self.enable_write_bucket_metric = enable_write_bucket_metric

Expand Down
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/io/gcp/gcsio_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import mock
import pytest
from parameterized import parameterized_class

from apache_beam.io.filesystems import FileSystems
from apache_beam.options.pipeline_options import GoogleCloudOptions
Expand All @@ -51,6 +52,9 @@


@unittest.skipIf(gcsio is None, 'GCP dependencies are not installed')
@parameterized_class(
('no_gcsio_throttling_counter', 'enable_gcsio_blob_generation'),
[(False, False), (False, True), (True, False), (True, True)])
class GcsIOIntegrationTest(unittest.TestCase):

INPUT_FILE = 'gs://dataflow-samples/shakespeare/kinglear.txt'
Expand All @@ -67,7 +71,6 @@ def setUp(self):
self.gcs_tempdir = (
self.test_pipeline.get_option('temp_location') + '/gcs_it-' +
str(uuid.uuid4()))
self.gcsio = gcsio.GcsIO()

def tearDown(self):
FileSystems.delete([self.gcs_tempdir + '/'])
Expand All @@ -92,14 +95,47 @@ def _verify_copy(self, src, dest, dest_kms_key_name=None):

@pytest.mark.it_postcommit
def test_copy(self):
self.gcsio = gcsio.GcsIO(
pipeline_options={
"no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
"enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
})
src = self.INPUT_FILE
dest = self.gcs_tempdir + '/test_copy'

self.gcsio.copy(src, dest)
self._verify_copy(src, dest)

unknown_src = self.test_pipeline.get_option('temp_location') + \
'/gcs_it-' + str(uuid.uuid4())
with self.assertRaises(NotFound):
self.gcsio.copy(unknown_src, dest)

@pytest.mark.it_postcommit
def test_copy_and_delete(self):
self.gcsio = gcsio.GcsIO(
pipeline_options={
"no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
"enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
})
src = self.INPUT_FILE
dest = self.gcs_tempdir + '/test_copy'

self.gcsio.copy(src, dest)
self._verify_copy(src, dest)

self.gcsio.delete(dest)

# no exception if we delete an nonexistent file.
self.gcsio.delete(dest)

@pytest.mark.it_postcommit
def test_batch_copy_and_delete(self):
self.gcsio = gcsio.GcsIO(
pipeline_options={
"no_gcsio_throttling_counter": self.no_gcsio_throttling_counter,
"enable_gcsio_blob_generation": self.enable_gcsio_blob_generation
})
num_copies = 10
srcs = [self.INPUT_FILE] * num_copies
dests = [
Expand Down Expand Up @@ -152,6 +188,7 @@ def test_batch_copy_and_delete(self):
@mock.patch('apache_beam.io.gcp.gcsio.default_gcs_bucket_name')
@unittest.skipIf(NotFound is None, 'GCP dependencies are not installed')
def test_create_default_bucket(self, mock_default_gcs_bucket_name):
self.gcsio = gcsio.GcsIO()
google_cloud_options = self.test_pipeline.options.view_as(
GoogleCloudOptions)
# overwrite kms option here, because get_or_create_default_gcs_bucket()
Expand Down
71 changes: 71 additions & 0 deletions sdks/python/apache_beam/io/gcp/gcsio_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Throttling Handler for GCSIO
"""

import inspect
import logging
import math

from google.api_core import exceptions as api_exceptions
from google.api_core import retry
from google.cloud.storage.retry import DEFAULT_RETRY
from google.cloud.storage.retry import _should_retry # pylint: disable=protected-access

from apache_beam.metrics.metric import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions

_LOGGER = logging.getLogger(__name__)

__all__ = ['DEFAULT_RETRY_WITH_THROTTLING_COUNTER']


class ThrottlingHandler(object):
_THROTTLED_SECS = Metrics.counter('gcsio', "cumulativeThrottlingSeconds")

def __call__(self, exc):
if isinstance(exc, api_exceptions.TooManyRequests):
_LOGGER.debug('Caught GCS quota error (%s), retrying.', exc.reason)
# TODO: revist the logic here when gcs client library supports error
# callbacks
frame = inspect.currentframe()
if frame is None:
_LOGGER.warning('cannot inspect the current stack frame')
return

prev_frame = frame.f_back
if prev_frame is None:
_LOGGER.warning('cannot inspect the caller stack frame')
return

# next_sleep is one of the arguments in the caller
# i.e. _retry_error_helper() in google/api_core/retry/retry_base.py
sleep_seconds = prev_frame.f_locals.get("next_sleep", 0)
ThrottlingHandler._THROTTLED_SECS.inc(math.ceil(sleep_seconds))
shunping marked this conversation as resolved.
Show resolved Hide resolved


DEFAULT_RETRY_WITH_THROTTLING_COUNTER = retry.Retry(
predicate=_should_retry, on_error=ThrottlingHandler())


def get_retry(pipeline_options):
if pipeline_options.view_as(GoogleCloudOptions).no_gcsio_throttling_counter:
return DEFAULT_RETRY
else:
return DEFAULT_RETRY_WITH_THROTTLING_COUNTER
Loading
Loading