diff --git a/sdk/python/kfp/compiler/_op_to_template.py b/sdk/python/kfp/compiler/_op_to_template.py index e4bbb4a8248..01af772b4cc 100644 --- a/sdk/python/kfp/compiler/_op_to_template.py +++ b/sdk/python/kfp/compiler/_op_to_template.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import re import warnings import yaml @@ -287,7 +288,10 @@ def _op_to_template(op: BaseOp): template.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/task_display_name'] = processed_op.display_name if isinstance(op, dsl.ContainerOp) and op._metadata: - import json template.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/component_spec'] = json.dumps(op._metadata.to_dict(), sort_keys=True) + if isinstance(op, dsl.ContainerOp) and op.execution_options: + if op.execution_options.caching_strategy.max_cache_staleness: + template.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/max_cache_staleness'] = str(op.execution_options.caching_strategy.max_cache_staleness) + return template diff --git a/sdk/python/kfp/components/_structures.py b/sdk/python/kfp/components/_structures.py index a68b1cc5b3c..aba11fe0bd6 100644 --- a/sdk/python/kfp/components/_structures.py +++ b/sdk/python/kfp/components/_structures.py @@ -46,6 +46,9 @@ 'AndPredicate', 'OrPredicate', + 'RetryStrategySpec', + 'CachingStrategySpec', + 'ExecutionOptionsSpec', 'TaskSpec', 'GraphSpec', @@ -532,6 +535,17 @@ def __init__(self, super().__init__(locals()) +class CachingStrategySpec(ModelBase): + _serialized_names = { + 'max_cache_staleness': 'maxCacheStaleness', + } + + def __init__(self, + max_cache_staleness: Optional[str] = None, # RFC3339 compliant duration: P30DT1H22M3S + ): + super().__init__(locals()) + + class KubernetesExecutionOptionsSpec(ModelBase): _serialized_names = { 'main_container': 'mainContainer', @@ -549,11 +563,13 @@ def __init__(self, class ExecutionOptionsSpec(ModelBase): _serialized_names = { 'retry_strategy': 'retryStrategy', + 'caching_strategy': 'cachingStrategy', 'kubernetes_options': 'kubernetesOptions', } def __init__(self, retry_strategy: Optional[RetryStrategySpec] = None, + caching_strategy: Optional[CachingStrategySpec] = None, kubernetes_options: Optional[KubernetesExecutionOptionsSpec] = None, ): super().__init__(locals()) diff --git a/sdk/python/kfp/components/structures/components.json_schema.json b/sdk/python/kfp/components/structures/components.json_schema.json index 0a7d0b78cc3..6f788cd65cb 100644 --- a/sdk/python/kfp/components/structures/components.json_schema.json +++ b/sdk/python/kfp/components/structures/components.json_schema.json @@ -320,6 +320,15 @@ "additionalProperties": false }, + "CachingStrategySpec": { + "description": "Optional configuration that specifies how the task execution may be skipped if the output data exist in cache.", + "type": "object", + "properties": { + "maxCacheStaleness": {"type": "string", "format": "duration"} + }, + "additionalProperties": false + }, + "KubernetesExecutionOptionsSpec": { "description": "When running on Kubernetes, KubernetesExecutionOptionsSpec describes changes to the configuration of a Kubernetes Pod that will execute the task.", "type": "object", @@ -336,6 +345,7 @@ "type": "object", "properties": { "retryStrategy": {"$ref": "#/definitions/RetryStrategySpec"}, + "cachingStrategy": {"$ref": "#/definitions/CachingStrategySpec"}, "kubernetesOptions": {"$ref": "#/definitions/KubernetesExecutionOptionsSpec"} }, "additionalProperties": false diff --git a/sdk/python/kfp/components/structures/components.json_schema.outline.yaml b/sdk/python/kfp/components/structures/components.json_schema.outline.yaml index f5f92eb08f2..340a68d030f 100644 --- a/sdk/python/kfp/components/structures/components.json_schema.outline.yaml +++ b/sdk/python/kfp/components/structures/components.json_schema.outline.yaml @@ -68,6 +68,8 @@ executionOptions: #ExecutionOptionsSpec retryStrategy: #RetryStrategySpec maxRetries: integer + cachingStrategy: #CachingStrategySpec + maxCacheStaleness: string kubernetesOptions: #KubernetesExecutionOptionsSpec metadata: io.k8s.apimachinery.pkg.apis.meta.v1.ObjectMeta mainContainer: io.k8s.api.core.v1.Container diff --git a/sdk/python/kfp/dsl/_container_op.py b/sdk/python/kfp/dsl/_container_op.py index 26eed8e95d5..816341142d1 100644 --- a/sdk/python/kfp/dsl/_container_op.py +++ b/sdk/python/kfp/dsl/_container_op.py @@ -25,7 +25,7 @@ ) from . import _pipeline_param -from ..components.structures import ComponentSpec +from ..components.structures import ComponentSpec, ExecutionOptionsSpec, CachingStrategySpec # generics T = TypeVar('T') @@ -1089,6 +1089,10 @@ def _decorated(*args, **kwargs): self._metadata = None + self.execution_options = ExecutionOptionsSpec( + caching_strategy=CachingStrategySpec(), + ) + self.outputs = {} if file_outputs: self.outputs = { diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index 5f12f4c18de..a2e04697d84 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -801,3 +801,12 @@ def some_name(): template_names = set(template['name'] for template in workflow_dict['spec']['templates']) self.assertGreater(len(template_names), 1) self.assertEqual(template_names, {'some-name', 'some-name-2'}) + + def test_set_execution_options_caching_strategy(self): + def some_pipeline(): + task = some_op() + task.execution_options.caching_strategy.max_cache_staleness = "P30D" + + workflow_dict = kfp.compiler.Compiler()._compile(some_pipeline) + template = workflow_dict['spec']['templates'][0] + self.assertEqual(template['metadata']['annotations']['pipelines.kubeflow.org/max_cache_staleness'], "P30D") diff --git a/sdk/python/tests/components/test_graph_components.py b/sdk/python/tests/components/test_graph_components.py index 48fa8fed8c6..9df214c3d4a 100644 --- a/sdk/python/tests/components/test_graph_components.py +++ b/sdk/python/tests/components/test_graph_components.py @@ -128,6 +128,21 @@ def test_handle_parsing_predicates(self): struct = load_yaml(component_text) ComponentSpec.from_dict(struct) + def test_handle_parsing_task_execution_options_caching_strategy(self): + component_text = '''\ +implementation: + graph: + tasks: + task 1: + componentRef: {name: Comp 1} + executionOptions: + cachingStrategy: + maxCacheStaleness: P30D +''' + struct = load_yaml(component_text) + component_spec = ComponentSpec.from_dict(struct) + self.assertEqual(component_spec.implementation.graph.tasks['task 1'].execution_options.caching_strategy.max_cache_staleness, 'P30D') + def test_handle_parsing_task_container_spec_options(self): component_text = '''\ implementation: