From 616eb38f457341fd4d0e9a745ddbff5b43a5f84d Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 11 Jan 2023 00:00:59 -0800 Subject: [PATCH] fix(mlmd): add is_input properties Signed-off-by: weiwee --- python/fate/arch/context/_mlmd.py | 83 ++++++++++--------- .../fate/components/entrypoint/component.py | 12 ++- .../fate/components/loader/mlmd/pipeline.py | 24 ++++-- 3 files changed, 71 insertions(+), 48 deletions(-) diff --git a/python/fate/arch/context/_mlmd.py b/python/fate/arch/context/_mlmd.py index f59ba0a39d..aa05eed8a4 100644 --- a/python/fate/arch/context/_mlmd.py +++ b/python/fate/arch/context/_mlmd.py @@ -41,9 +41,17 @@ def get_artifacts(self, taskid): artifacts = self.store.get_artifacts_by_context(context_id) # parameters parameters = [] - data = [] - model = [] - metric = [] + input_data, output_data = [], [] + input_model, output_model = [], [] + input_metric, output_metric = [], [] + + def _to_dict(artifact): + return dict( + uri=artifact.uri, + name=artifact.properties["name"].string_value, + metadata=json.loads(artifact.properties["metadata"].string_value), + ) + for artifact in artifacts: if self.parameter_type_id == artifact.type_id: parameters.append( @@ -53,34 +61,31 @@ def get_artifacts(self, taskid): type=artifact.properties["type"].string_value, ) ) - if self.data_type_id == artifact.type_id: - data.append( - dict( - uri=artifact.uri, - name=artifact.properties["name"].string_value, - metadata=json.loads(artifact.properties["metadata"].string_value), - ) - ) - - if self.model_type_id == artifact.type_id: - model.append( - dict( - uri=artifact.uri, - name=artifact.properties["name"].string_value, - metadata=json.loads(artifact.properties["metadata"].string_value), - ) - ) - - if self.metric_type_id == artifact.type_id: - metric.append( - dict( - uri=artifact.uri, - name=artifact.properties["name"].string_value, - metadata=json.loads(artifact.properties["metadata"].string_value), - ) - ) - - return dict(parameters=parameters, data=data, model=model, metric=metric) + if artifact.type_id in {self.data_type_id, self.model_type_id, self.metric_type_id}: + is_input = artifact.properties["is_input"].bool_value + + if self.data_type_id == artifact.type_id: + if is_input: + input_data.append(_to_dict(artifact)) + else: + output_data.append(_to_dict(artifact)) + + if self.model_type_id == artifact.type_id: + if is_input: + input_model.append(_to_dict(artifact)) + else: + output_model.append(_to_dict(artifact)) + + if self.metric_type_id == artifact.type_id: + if is_input: + input_metric.append(_to_dict(artifact)) + else: + output_metric.append(_to_dict(artifact)) + return dict( + parameters=parameters, + input=dict(data=input_data, model=input_model, metric=input_metric), + output=dict(data=output_data, model=output_model, metric=output_metric), + ) def get_or_create_task_context(self, taskid): task_context_run = self.store.get_context_by_type_and_name("TaskContext", taskid) @@ -155,19 +160,20 @@ def add_parameter(self, name: str, value): [artifact_id] = self.store.put_artifacts([artifact]) return artifact_id - def add_data_artifact(self, name: str, uri: str, metadata: dict): - return self.add_artifact(self.data_type_id, name, uri, metadata) + def add_data_artifact(self, name: str, uri: str, metadata: dict, is_input): + return self.add_artifact(self.data_type_id, name, uri, metadata, is_input) - def add_model_artifact(self, name: str, uri: str, metadata: dict): - return self.add_artifact(self.model_type_id, name, uri, metadata) + def add_model_artifact(self, name: str, uri: str, metadata: dict, is_input): + return self.add_artifact(self.model_type_id, name, uri, metadata, is_input) - def add_metric_artifact(self, name: str, uri: str, metadata: dict): - return self.add_artifact(self.metric_type_id, name, uri, metadata) + def add_metric_artifact(self, name: str, uri: str, metadata: dict, is_input): + return self.add_artifact(self.metric_type_id, name, uri, metadata, is_input) - def add_artifact(self, type_id: int, name: str, uri: str, metadata: dict): + def add_artifact(self, type_id: int, name: str, uri: str, metadata: dict, is_input): artifact = metadata_store_pb2.Artifact() artifact.uri = uri artifact.properties["name"].string_value = name + artifact.properties["is_input"].bool_value = is_input artifact.properties["metadata"].string_value = json.dumps(metadata) artifact.type_id = type_id [artifact_id] = self.store.put_artifacts([artifact]) @@ -227,6 +233,7 @@ def create_artifact_type(self, name): artifact_type.name = name artifact_type.properties["uri"] = metadata_store_pb2.STRING artifact_type.properties["name"] = metadata_store_pb2.STRING + artifact_type.properties["is_input"] = metadata_store_pb2.BOOLEAN artifact_type.properties["metadata"] = metadata_store_pb2.STRING artifact_type_id = self.store.put_artifact_type(artifact_type) return artifact_type_id diff --git a/python/fate/components/entrypoint/component.py b/python/fate/components/entrypoint/component.py index bbeb79dac6..bb97d9d95e 100644 --- a/python/fate/components/entrypoint/component.py +++ b/python/fate/components/entrypoint/component.py @@ -83,9 +83,11 @@ def execute_component(config: TaskConfigSpec): input_metric_artifacts = parse_input_metric(component, stage, role, config.inputs.artifacts) # log output artifacts for name, artifact in input_data_artifacts.items(): - mlmd.io.log_input_artifact(name, artifact) + if artifact is not None: + mlmd.io.log_input_artifact(name, artifact) for name, artifact in input_metric_artifacts.items(): - mlmd.io.log_input_artifact(name, artifact) + if artifact is not None: + mlmd.io.log_input_artifact(name, artifact) # wrap model artifact input_model_artifacts = { @@ -119,9 +121,11 @@ def execute_component(config: TaskConfigSpec): # log output artifacts for name, artifact in output_data_artifacts.items(): - mlmd.io.log_output_data(name, artifact) + if artifact is not None: + mlmd.io.log_output_data(name, artifact) for name, artifact in output_metric_artifacts.items(): - mlmd.io.log_output_metric(name, artifact) + if artifact is not None: + mlmd.io.log_output_metric(name, artifact) except Exception as e: tb = traceback.format_exc() diff --git a/python/fate/components/loader/mlmd/pipeline.py b/python/fate/components/loader/mlmd/pipeline.py index be5382a91e..982e2aee58 100644 --- a/python/fate/components/loader/mlmd/pipeline.py +++ b/python/fate/components/loader/mlmd/pipeline.py @@ -68,37 +68,49 @@ def log_input_parameter(self, key, value): self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) def log_input_data(self, key, value): - artifact_id = self._mlmd.add_data_artifact(name=value.name, uri=value.uri, metadata=value.metadata) + artifact_id = self._mlmd.add_data_artifact( + name=value.name, uri=value.uri, metadata=value.metadata, is_input=True + ) execution_id = self._mlmd.get_or_create_task(self._taskid).id self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id) self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) def log_input_model(self, key, value): - artifact_id = self._mlmd.add_model_artifact(name=value.name, uri=value.uri, metadata=value.metadata) + artifact_id = self._mlmd.add_model_artifact( + name=value.name, uri=value.uri, metadata=value.metadata, is_input=True + ) execution_id = self._mlmd.get_or_create_task(self._taskid).id self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id) self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) def log_input_metric(self, key, value): - artifact_id = self._mlmd.add_metric_artifact(name=value.name, uri=value.uri, metadata=value.metadata) + artifact_id = self._mlmd.add_metric_artifact( + name=value.name, uri=value.uri, metadata=value.metadata, is_input=True + ) execution_id = self._mlmd.get_or_create_task(self._taskid).id self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id) self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) def log_output_data(self, key, value): - artifact_id = self._mlmd.add_data_artifact(name=value.name, uri=value.uri, metadata=value.metadata) + artifact_id = self._mlmd.add_data_artifact( + name=value.name, uri=value.uri, metadata=value.metadata, is_input=False + ) execution_id = self._mlmd.get_or_create_task(self._taskid).id self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id) self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) def log_output_model(self, key, value, metadata={}): - artifact_id = self._mlmd.add_model_artifact(name=value.name, uri=value.uri, metadata=value.metadata) + artifact_id = self._mlmd.add_model_artifact( + name=value.name, uri=value.uri, metadata=value.metadata, is_input=False + ) execution_id = self._mlmd.get_or_create_task(self._taskid).id self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id) self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) def log_output_metric(self, key, value): - artifact_id = self._mlmd.add_metric_artifact(name=value.name, uri=value.uri, metadata=value.metadata) + artifact_id = self._mlmd.add_metric_artifact( + name=value.name, uri=value.uri, metadata=value.metadata, is_input=False + ) execution_id = self._mlmd.get_or_create_task(self._taskid).id self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id) self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)