Skip to content

Commit

Permalink
Add tag to ExternalPipelineChannel so we can get artifacts by tags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630465771
  • Loading branch information
tfx-copybara committed May 23, 2024
1 parent a4d4cbe commit 40c3a82
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 4 deletions.
76 changes: 73 additions & 3 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
"""Compiler submodule specialized for NodeInputs."""

from collections.abc import Iterable
from typing import Type, cast
from collections.abc import Iterable, Sequence
from typing import Optional, Type, cast

from tfx import types
from tfx.dsl.compiler import compiler_context
Expand All @@ -41,6 +41,8 @@

from ml_metadata.proto import metadata_store_pb2

PropertyPredicate = pipeline_pb2.PropertyPredicate


def _get_tfx_value(value: str) -> pipeline_pb2.Value:
"""Returns a TFX Value containing the provided string."""
Expand Down Expand Up @@ -137,12 +139,16 @@ def compile_op_node(op_node: resolver_op.OpNode):
def _compile_channel_pb_contexts(
context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]],
result: pipeline_pb2.InputSpec.Channel,
property_predicate: Optional[pipeline_pb2.PropertyPredicate] = None,
):
"""Adds contexts to the channel."""
for context_type, context_value in context_types_and_names:
ctx = result.context_queries.add()
ctx.type.name = context_type
ctx.name.CopyFrom(context_value)
if context_value:
ctx.name.CopyFrom(context_value)
if property_predicate:
ctx.property_predicate.CopyFrom(property_predicate)


def _compile_channel_pb(
Expand Down Expand Up @@ -175,6 +181,65 @@ def _compile_channel_pb(
result.output_key = output_key


def _complie_run_context_predicate(
run_context_predicates: Sequence[tuple[str, metadata_store_pb2.Value]],
result_input_channel: pipeline_pb2.InputSpec.Channel,
):
"""Compile run context property predicates into InputSpec.Channel."""
if not run_context_predicates:
return

predicates = []
for run_context_predicate in run_context_predicates:
predicates.append(
PropertyPredicate(
value_comparator=PropertyPredicate.ValueComparator(
property_name=run_context_predicate[0],
op=PropertyPredicate.ValueComparator.Op.EQ,
target_value=pipeline_pb2.Value(
field_value=run_context_predicate[1]
),
is_custom_property=True,
)
)
)

if len(predicates) == 1:
_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(''),
)],
result_input_channel,
predicates[0],
)
else:
binary_operator_predicate = PropertyPredicate(
binary_logical_operator=PropertyPredicate.BinaryLogicalOperator(
op=PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND,
lhs=predicates[0],
rhs=predicates[1],
)
)
for i in range(2, len(predicates)):
binary_operator_predicate = PropertyPredicate(
binary_logical_operator=PropertyPredicate.BinaryLogicalOperator(
op=PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND,
lhs=binary_operator_predicate,
rhs=predicates[i],
)
)

_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(''),
)],
result_input_channel,
binary_operator_predicate,
)


def _compile_input_spec(
*,
pipeline_ctx: compiler_context.PipelineContext,
Expand Down Expand Up @@ -249,6 +314,11 @@ def _compile_input_spec(
result_input_channel,
)

if channel.run_context_predicates:
_complie_run_context_predicate(
channel.run_context_predicates, result_input_channel
)

if pipeline_ctx.pipeline.platform_config:
project_config = (
pipeline_ctx.pipeline.platform_config.project_platform_config
Expand Down
209 changes: 209 additions & 0 deletions tfx/dsl/compiler/node_inputs_compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tfx.types import standard_artifacts

from google.protobuf import text_format
from ml_metadata.proto import metadata_store_pb2


class DummyArtifact(types.Artifact):
Expand Down Expand Up @@ -292,6 +293,214 @@ def testCompileInputGraph(self):
ctx, node, channel, result)
self.assertEqual(input_graph_id, second_input_graph_id)

def testCompilePropertyPredicateForTags(self):
with self.subTest('zero tag'):
consumer = DummyNode(
'MyConsumer',
inputs={
'input_key': channel_types.ExternalPipelineChannel(
artifact_type=DummyArtifact,
owner='MyProducer',
pipeline_name='pipeline_name',
producer_component_id='producer_component_id',
output_key='z',
run_context_predicates=[],
)
},
)
result = self._compile_node_inputs(consumer, components=[consumer])
self.assertLen(result.inputs['input_key'].channels, 1)
context_queries = result.inputs['input_key'].channels[0].context_queries
self.assertLen(context_queries, 2)
self.assertProtoEquals(
"""
type {
name: "pipeline"
}
name {
field_value {
string_value: "pipeline_name"
}
}
""",
context_queries[0],
)

self.assertProtoEquals(
"""
type {
name: "node"
}
name {
field_value {
string_value: "pipeline_name.producer_component_id"
}
}
""",
context_queries[1],
)

with self.subTest('one tag'):
consumer = DummyNode(
'MyConsumer',
inputs={
'input_key': channel_types.ExternalPipelineChannel(
artifact_type=DummyArtifact,
owner='MyProducer',
pipeline_name='pipeline_name',
producer_component_id='producer_component_id',
output_key='z',
run_context_predicates=[
('tag_1', metadata_store_pb2.Value(bool_value=True))
],
)
},
)

result = self._compile_node_inputs(consumer, components=[consumer])

self.assertLen(result.inputs['input_key'].channels, 1)
context_queries = result.inputs['input_key'].channels[0].context_queries
self.assertLen(context_queries, 3)
self.assertProtoEquals(
"""
type {
name: "pipeline"
}
name {
field_value {
string_value: "pipeline_name"
}
}
""",
context_queries[0],
)

self.assertProtoEquals(
"""
type {
name: "node"
}
name {
field_value {
string_value: "pipeline_name.producer_component_id"
}
}
""",
context_queries[1],
)

self.assertProtoEquals(
"""
type {
name: "pipeline_run"
}
name {
field_value {
string_value: ""
}
}
property_predicate {
value_comparator {
property_name: "tag_1"
target_value {
field_value {
bool_value: true
}
}
op: EQ
is_custom_property: true
}
}
""",
context_queries[2],
)

with self.subTest('three tags'):
consumer = DummyNode(
'MyConsumer',
inputs={
'input_key': channel_types.ExternalPipelineChannel(
artifact_type=DummyArtifact,
owner='MyProducer',
pipeline_name='pipeline_name',
producer_component_id='producer_component_id',
output_key='z',
run_context_predicates=[
('tag_1', metadata_store_pb2.Value(bool_value=True)),
('tag_2', metadata_store_pb2.Value(bool_value=True)),
('tag_3', metadata_store_pb2.Value(bool_value=True)),
],
)
},
)

result = self._compile_node_inputs(consumer, components=[consumer])

self.assertLen(result.inputs['input_key'].channels, 1)
context_queries = result.inputs['input_key'].channels[0].context_queries
self.assertLen(context_queries, 3)

self.assertProtoEquals(
"""
type {
name: "pipeline_run"
}
name {
field_value {
string_value: ""
}
}
property_predicate {
binary_logical_operator {
op: AND
lhs {
binary_logical_operator {
op: AND
lhs {
value_comparator {
property_name: "tag_1"
target_value {
field_value {
bool_value: true
}
}
op: EQ
is_custom_property: true
}
}
rhs {
value_comparator {
property_name: "tag_2"
target_value {
field_value {
bool_value: true
}
}
op: EQ
is_custom_property: true
}
}
}
}
rhs {
value_comparator {
property_name: "tag_3"
target_value {
field_value {
bool_value: true
}
}
op: EQ
is_custom_property: true
}
}
}
}
""",
context_queries[2],
)

def testCompileInputGraphRef(self):
with dummy_artifact_list.given_output_type(DummyArtifact):
x1 = dummy_artifact_list()
Expand Down
9 changes: 8 additions & 1 deletion tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,9 @@ def __init__(
producer_component_id: str,
output_key: str,
pipeline_run_id: str = '',
run_context_predicates: Sequence[
tuple[str, metadata_store_pb2.Value]
] = (),
):
"""Initialization of ExternalPipelineChannel.
Expand All @@ -733,13 +736,16 @@ def __init__(
output_key: The output key when producer component produces the artifacts
in this Channel.
pipeline_run_id: (Optional) Pipeline run id the artifacts belong to.
run_context_predicates: (Optional) A list of run context property
predicates to filter run contexts.
"""
super().__init__(type=artifact_type)
self.owner = owner
self.pipeline_name = pipeline_name
self.producer_component_id = producer_component_id
self.output_key = output_key
self.pipeline_run_id = pipeline_run_id
self.run_context_predicates = run_context_predicates

def get_data_dependent_node_ids(self) -> Set[str]:
return set()
Expand All @@ -751,7 +757,8 @@ def __repr__(self) -> str:
f'pipeline_name={self.pipeline_name}, '
f'producer_component_id={self.producer_component_id}, '
f'output_key={self.output_key}, '
f'pipeline_run_id={self.pipeline_run_id})'
f'pipeline_run_id={self.pipeline_run_id}), '
f'run_context_predicates={self.run_context_predicates}'
)


Expand Down
Loading

0 comments on commit 40c3a82

Please sign in to comment.