Skip to content

Commit

Permalink
Make get_job() retry on transient errors for CAIP/uCAIP training.
Browse files Browse the repository at this point in the history
Unfortunately, this cannot completely replace the outer retry loop, which addresses an issue for long running client due to
googleapis/google-api-python-client#218 .
Updated comments for that instead.

PiperOrigin-RevId: 381029373
  • Loading branch information
zhitaoli authored and tfx-copybara committed Jun 23, 2021
1 parent 44051cd commit 810226a
Show file tree
Hide file tree
Showing 20 changed files with 160 additions and 350 deletions.
2 changes: 0 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

## Bug Fixes and Other Changes

* Fixed issue where passing `analyzer_cache` to `tfx.components.Transform`
before there are any Transform cache artifacts published would fail.
* Depends on `protobuf>=3.13,<4`.

## Documentation Updates
Expand Down
3 changes: 1 addition & 2 deletions tfx/components/transform/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,7 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
os.path.join(output_uri, _DEFAULT_TRANSFORMED_EXAMPLES_PREFIX))

def _GetCachePath(label, params_dict):
# Covers the cases: path wasn't provided, or was provided an empty list.
if not params_dict.get(label):
if params_dict.get(label) is None:
return None
else:
return artifact_utils.get_single_uri(params_dict[label])
Expand Down
7 changes: 1 addition & 6 deletions tfx/components/transform/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import json
import os
import tempfile
from absl.testing import parameterized

import tensorflow as tf
import tensorflow_transform as tft
Expand Down Expand Up @@ -525,14 +524,10 @@ def test_counters(self):
# Output materialization is enabled.
self.assertMetricsCounterEqual(metrics, 'materialize', 1)

@parameterized.named_parameters([('no_1st_input_cache', False),
('empty_1st_input_cache', True)])
def test_do_with_cache(self, provide_first_input_cache):
def test_do_with_cache(self):
# First run that creates cache.
self._exec_properties[
standard_component_specs.MODULE_FILE_KEY] = self._module_file
if provide_first_input_cache:
self._input_dict[standard_component_specs.ANALYZER_CACHE_KEY] = []
metrics = self._run_pipeline_get_metrics()

# The test data has 9909 instances in the train dataset, and 5091 instances
Expand Down
58 changes: 27 additions & 31 deletions tfx/dsl/placeholder/placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@

from google.protobuf import message

# To resolve circular dependency caused by type annotations.
# TODO(b/191610358): Reduce the number of circular type-dependencies.
types = Any # tfx.types imports channel.py, which in turn imports this module.
types = Any # Avoid circular dependency between placeholder.py and channel.py.

# TODO(b/190409099): Support RuntimeParameter.
_ValueLikeType = Union[int, float, str, 'ChannelWrappedPlaceholder']
_ValueLikeTypes = Union[int, float, str, 'ChannelWrappedPlaceholder']


class _PlaceholderOperator(json_utils.Jsonable):
Expand Down Expand Up @@ -323,9 +321,7 @@ def encode(
return result


# To ensure that ArtifactPlaceholder operations on a ChannelWrappedPlaceholder
# still returns a ChannelWrappedPlaceholder.
_T = TypeVar('_T')
T = TypeVar('T')


class ArtifactPlaceholder(Placeholder):
Expand All @@ -335,23 +331,23 @@ class ArtifactPlaceholder(Placeholder):
"""

@property
def uri(self: _T) -> _T:
def uri(self: T) -> T:
self._try_inject_index_operator()
self._operators.append(_ArtifactUriOperator())
return self

def split_uri(self: _T, split: str) -> _T:
def split_uri(self: T, split: str) -> T:
self._try_inject_index_operator()
self._operators.append(_ArtifactUriOperator(split))
return self

@property
def value(self: _T) -> _T:
def value(self: T) -> T:
self._try_inject_index_operator()
self._operators.append(_ArtifactValueOperator())
return self

def __getitem__(self: _T, key: int) -> _T:
def __getitem__(self: T, key: int) -> T:
self._operators.append(_IndexOperator(key))
return self

Expand Down Expand Up @@ -452,29 +448,29 @@ def __init__(self, channel: 'types.Channel'):
super().__init__(placeholder_pb2.Placeholder.Type.INPUT_ARTIFACT)
self.channel = channel

def __eq__(self, other: _ValueLikeType) -> 'Predicate':
def __eq__(self, other: _ValueLikeTypes) -> 'Predicate':
return Predicate.from_comparison(_CompareOp.EQUAL, left=self, right=other)

def __ne__(self, other: _ValueLikeType) -> 'Predicate':
def __ne__(self, other: _ValueLikeTypes) -> 'Predicate':
return logical_not(self == other)

def __lt__(self, other: _ValueLikeType) -> 'Predicate':
def __lt__(self, other: _ValueLikeTypes) -> 'Predicate':
return Predicate.from_comparison(
_CompareOp.LESS_THAN, left=self, right=other)

def __le__(self, other: _ValueLikeType) -> 'Predicate':
def __le__(self, other: _ValueLikeTypes) -> 'Predicate':
return logical_not(self > other)

def __gt__(self, other: _ValueLikeType) -> 'Predicate':
def __gt__(self, other: _ValueLikeTypes) -> 'Predicate':
return Predicate.from_comparison(
_CompareOp.GREATER_THAN, left=self, right=other)

def __ge__(self, other: _ValueLikeType) -> 'Predicate':
def __ge__(self, other: _ValueLikeTypes) -> 'Predicate':
return logical_not(self < other)


def _encode_value_like(
x: _ValueLikeType,
x: _ValueLikeTypes,
channel_to_key_fn: Optional[Callable[['types.Channel'], str]] = None
) -> placeholder_pb2.PlaceholderExpression:
"""Encodes x to a placeholder expression proto."""
Expand Down Expand Up @@ -503,17 +499,17 @@ def _encode_value_like(
return result


_PredicateSubtype = Union['_Comparison', '_NotExpression',
'_BinaryLogicalExpression']
_PredicateSubtypes = Union['_Comparison', '_NotExpression',
'_BinaryLogicalExpression']


@attr.s
class _Comparison:
"""Represents a comparison between two placeholders."""

compare_op = attr.ib(type=_CompareOp)
left = attr.ib(type=_ValueLikeType)
right = attr.ib(type=_ValueLikeType)
left = attr.ib(type=_ValueLikeTypes)
right = attr.ib(type=_ValueLikeTypes)

def encode_with_keys(
self,
Expand Down Expand Up @@ -542,7 +538,7 @@ class _LogicalOp(enum.Enum):
class _NotExpression:
"""Represents a logical negation."""

pred_dataclass = attr.ib(type=_PredicateSubtype)
pred_dataclass = attr.ib(type=_PredicateSubtypes)

def encode_with_keys(
self,
Expand All @@ -566,8 +562,8 @@ class _BinaryLogicalExpression:
"""Represents a boolean logical expression with exactly two arguments."""

logical_op = attr.ib(type=_LogicalOp)
left = attr.ib(type=_PredicateSubtype)
right = attr.ib(type=_PredicateSubtype)
left = attr.ib(type=_PredicateSubtypes)
right = attr.ib(type=_PredicateSubtypes)

def encode_with_keys(
self,
Expand Down Expand Up @@ -596,7 +592,7 @@ class Predicate(Placeholder):
Prefer to use syntax like `<channel>.future() > 5` to create a Predicate.
"""

def __init__(self, pred_dataclass: _PredicateSubtype):
def __init__(self, pred_dataclass: _PredicateSubtypes):
"""NOT INTENDED TO BE USED DIRECTLY BY PIPELINE AUTHORS."""

super().__init__(placeholder_pb2.Placeholder.Type.INPUT_ARTIFACT)
Expand All @@ -605,7 +601,7 @@ def __init__(self, pred_dataclass: _PredicateSubtype):
@classmethod
def from_comparison(cls, compare_op: _CompareOp,
left: ChannelWrappedPlaceholder,
right: _ValueLikeType) -> 'Predicate':
right: _ValueLikeTypes) -> 'Predicate':
"""Creates a Predicate instance.
Note that even though the `left` argument is assumed to be a
Expand Down Expand Up @@ -706,14 +702,14 @@ def logical_not(pred: Predicate) -> Predicate:


def logical_and(left: Predicate, right: Predicate) -> Predicate:
"""Applies the AND boolean operation on two Predicates.
"""Applies the AND boolean operation on a Predicate.
Args:
left: The first argument of the AND operation.
right: The second argument of the AND operation.
Returns:
The Predicate resulting from the AND operation.
A Predicate.
"""

return Predicate(
Expand All @@ -722,14 +718,14 @@ def logical_and(left: Predicate, right: Predicate) -> Predicate:


def logical_or(left: Predicate, right: Predicate) -> Predicate:
"""Applies the OR boolean operation on two Predicates.
"""Applies the OR boolean operation on a Predicate.
Args:
left: The first argument of the OR operation.
right: The second argument of the OR operation.
Returns:
The Predicate resulting from the OR operation.
A Predicate.
"""

return Predicate(
Expand Down
40 changes: 14 additions & 26 deletions tfx/dsl/placeholder/placeholder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,36 +419,24 @@ class PredicateTest(parameterized.TestCase, tf.test.TestCase):
'expected_rhs_value_type': 'string_value',
},
{
'testcase_name':
'right_side_placeholder_left_side_int',
'left':
1,
'right':
Channel(type=_MyType).future().value,
'testcase_name': 'right_side_placeholder_left_side_int',
'left': 1,
'right': Channel(type=_MyType).future().value,
'expected_op':
placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN,
'expected_lhs_field':
'operator',
'expected_rhs_field':
'value',
'expected_rhs_value_type':
'int_value',
(placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN),
'expected_lhs_field': 'operator',
'expected_rhs_field': 'value',
'expected_rhs_value_type': 'int_value',
},
{
'testcase_name':
'right_side_placeholder_left_side_float',
'left':
1.1,
'right':
Channel(type=_MyType).future().value,
'testcase_name': 'right_side_placeholder_left_side_float',
'left': 1.1,
'right': Channel(type=_MyType).future().value,
'expected_op':
placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN,
'expected_lhs_field':
'operator',
'expected_rhs_field':
'value',
'expected_rhs_value_type':
'double_value',
(placeholder_pb2.ComparisonOperator.Operation.GREATER_THAN),
'expected_lhs_field': 'operator',
'expected_rhs_field': 'value',
'expected_rhs_value_type': 'double_value',
},
)
def testComparison(self,
Expand Down
2 changes: 1 addition & 1 deletion tfx/examples/ranking/ranking_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
_data_root = os.path.join(_ranking_root, 'data')
# Python module file to inject customized logic into the TFX components. The
# Transform and Trainer both require user-defined functions to run successfully.
_module_file = os.path.join(_ranking_root, 'ranking_utils.py')
_module_file = os.path.join(_ranking_root, 'taxi_utils_native_keras.py')
# Path which can be listened to by the model server. Pusher will output the
# trained model here.
_serving_model_dir = os.path.join(
Expand Down
12 changes: 6 additions & 6 deletions tfx/extensions/google_cloud_ai_platform/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Lint as: python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -94,11 +93,12 @@ def _launch_aip_training(
# The following logic will keep polling the state of the job until the job
# enters a final state.
#
# During the polling, if a connection error was encountered, the GET request
# will be retried by recreating the Python API client to refresh the lifecycle
# of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# Note that even if client.get_job() retries on transient errors, it could
# still raise `ConnectionError` due to
# https://github.com/googleapis/google-api-python-client/issues/218.
# We need to recreate the Python API client to refresh the lifecycle
# of the connection being used, then retry again.
# If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function will raise
# ConnectionError.
while client.get_job_state(response) not in client.JOB_STATES_COMPLETED:
Expand Down
4 changes: 2 additions & 2 deletions tfx/extensions/google_cloud_ai_platform/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def testStartAIPTraining_uCAIP(self, mock_gapic):
],
}, body['job_spec'])
self.assertStartsWith(body['display_name'], 'tfx_')
self._mock_get.assert_called_with(name='ucaip_job_study_id')
self._mock_get.assert_called_with(name='ucaip_job_study_id', retry=mock.ANY)

@mock.patch(
'tfx.extensions.google_cloud_ai_platform.training_clients.gapic')
Expand Down Expand Up @@ -269,7 +269,7 @@ def testStartAIPTrainingWithUserContainer_uCAIP(self, mock_gapic):
],
}, body['job_spec'])
self.assertEqual(body['display_name'], 'my_jobid')
self._mock_get.assert_called_with(name='ucaip_job_study_id')
self._mock_get.assert_called_with(name='ucaip_job_study_id', retry=mock.ANY)

def _setUpPredictionMocks(self):
self._serving_path = os.path.join(self._output_data_dir, 'serving_path')
Expand Down
10 changes: 8 additions & 2 deletions tfx/extensions/google_cloud_ai_platform/training_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, List, Optional, Text, Union

from absl import logging
from google.api_core import retry
from google.cloud.aiplatform import gapic
from google.cloud.aiplatform_v1beta1.types.custom_job import CustomJob
from google.cloud.aiplatform_v1beta1.types.job_state import JobState
Expand Down Expand Up @@ -150,6 +151,9 @@ def get_job_name(self) -> Text:
return self._job_name


_get_job_retry = retry.Retry(deadline=150)


class CAIPJobClient(AbstractJobClient):
"""Class for interacting with CAIP CMLE job."""

Expand Down Expand Up @@ -308,7 +312,8 @@ def launch_job(self,

def get_job(self) -> Dict[Text, Text]:
"""Gets the long-running job."""
request = self._client.projects().jobs().get(name=self._job_name)
request = self._client.projects().jobs().get(
name=self._job_name, retry=_get_job_retry)
return request.execute()

def get_job_state(self, response) -> Text:
Expand Down Expand Up @@ -496,7 +501,8 @@ def launch_job(self,

def get_job(self) -> CustomJob:
"""Gets the long-running job."""
return self._client.get_custom_job(name=self._job_name)
return self._client.get_custom_job(
name=self._job_name, retry=_get_job_retry)

def get_job_state(self, response) -> JobState:
"""Gets the state of the long-running job.
Expand Down
4 changes: 2 additions & 2 deletions tfx/orchestration/kubeflow/kubeflow_dag_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def __init__(
kubeflow_pb2.KubeflowMetadataConfig] = None,
# TODO(b/143883035): Figure out the best practice to put the
# SUPPORTED_LAUNCHER_CLASSES
supported_launcher_classes: Optional[List[Type[
base_component_launcher.BaseComponentLauncher]]] = None,
supported_launcher_classes: List[Type[
base_component_launcher.BaseComponentLauncher]] = None,
**kwargs):
"""Creates a KubeflowDagRunnerConfig object.
Expand Down
4 changes: 2 additions & 2 deletions tfx/orchestration/local/local_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def read(self) -> List[Any]:
return json.load(f)

def write(self, data: List[Any]):
with fileio.open(os.path.join(self.uri, 'dataset.json'), 'w+') as f:
with fileio.open(os.path.join(self.uri, 'dataset.json'), 'w') as f:
json.dump(data, f)


Expand Down Expand Up @@ -88,7 +88,7 @@ def read_from(cls, model_uri: Text) -> 'SimpleModel':

def write_to(self, model_uri: Text) -> None:
data = {'prediction': self.always_predict}
with fileio.open(os.path.join(model_uri, 'model_data.json'), 'w+') as f:
with fileio.open(os.path.join(model_uri, 'model_data.json'), 'w') as f:
json.dump(data, f)


Expand Down
Loading

0 comments on commit 810226a

Please sign in to comment.