Skip to content

Commit

Permalink
fix(mlmd): add is_input properties
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jan 11, 2023
1 parent a9b9944 commit 616eb38
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 48 deletions.
83 changes: 45 additions & 38 deletions python/fate/arch/context/_mlmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,17 @@ def get_artifacts(self, taskid):
artifacts = self.store.get_artifacts_by_context(context_id)
# parameters
parameters = []
data = []
model = []
metric = []
input_data, output_data = [], []
input_model, output_model = [], []
input_metric, output_metric = [], []

def _to_dict(artifact):
return dict(
uri=artifact.uri,
name=artifact.properties["name"].string_value,
metadata=json.loads(artifact.properties["metadata"].string_value),
)

for artifact in artifacts:
if self.parameter_type_id == artifact.type_id:
parameters.append(
Expand All @@ -53,34 +61,31 @@ def get_artifacts(self, taskid):
type=artifact.properties["type"].string_value,
)
)
if self.data_type_id == artifact.type_id:
data.append(
dict(
uri=artifact.uri,
name=artifact.properties["name"].string_value,
metadata=json.loads(artifact.properties["metadata"].string_value),
)
)

if self.model_type_id == artifact.type_id:
model.append(
dict(
uri=artifact.uri,
name=artifact.properties["name"].string_value,
metadata=json.loads(artifact.properties["metadata"].string_value),
)
)

if self.metric_type_id == artifact.type_id:
metric.append(
dict(
uri=artifact.uri,
name=artifact.properties["name"].string_value,
metadata=json.loads(artifact.properties["metadata"].string_value),
)
)

return dict(parameters=parameters, data=data, model=model, metric=metric)
if artifact.type_id in {self.data_type_id, self.model_type_id, self.metric_type_id}:
is_input = artifact.properties["is_input"].bool_value

if self.data_type_id == artifact.type_id:
if is_input:
input_data.append(_to_dict(artifact))
else:
output_data.append(_to_dict(artifact))

if self.model_type_id == artifact.type_id:
if is_input:
input_model.append(_to_dict(artifact))
else:
output_model.append(_to_dict(artifact))

if self.metric_type_id == artifact.type_id:
if is_input:
input_metric.append(_to_dict(artifact))
else:
output_metric.append(_to_dict(artifact))
return dict(
parameters=parameters,
input=dict(data=input_data, model=input_model, metric=input_metric),
output=dict(data=output_data, model=output_model, metric=output_metric),
)

def get_or_create_task_context(self, taskid):
task_context_run = self.store.get_context_by_type_and_name("TaskContext", taskid)
Expand Down Expand Up @@ -155,19 +160,20 @@ def add_parameter(self, name: str, value):
[artifact_id] = self.store.put_artifacts([artifact])
return artifact_id

def add_data_artifact(self, name: str, uri: str, metadata: dict):
return self.add_artifact(self.data_type_id, name, uri, metadata)
def add_data_artifact(self, name: str, uri: str, metadata: dict, is_input):
return self.add_artifact(self.data_type_id, name, uri, metadata, is_input)

def add_model_artifact(self, name: str, uri: str, metadata: dict):
return self.add_artifact(self.model_type_id, name, uri, metadata)
def add_model_artifact(self, name: str, uri: str, metadata: dict, is_input):
return self.add_artifact(self.model_type_id, name, uri, metadata, is_input)

def add_metric_artifact(self, name: str, uri: str, metadata: dict):
return self.add_artifact(self.metric_type_id, name, uri, metadata)
def add_metric_artifact(self, name: str, uri: str, metadata: dict, is_input):
return self.add_artifact(self.metric_type_id, name, uri, metadata, is_input)

def add_artifact(self, type_id: int, name: str, uri: str, metadata: dict):
def add_artifact(self, type_id: int, name: str, uri: str, metadata: dict, is_input):
artifact = metadata_store_pb2.Artifact()
artifact.uri = uri
artifact.properties["name"].string_value = name
artifact.properties["is_input"].bool_value = is_input
artifact.properties["metadata"].string_value = json.dumps(metadata)
artifact.type_id = type_id
[artifact_id] = self.store.put_artifacts([artifact])
Expand Down Expand Up @@ -227,6 +233,7 @@ def create_artifact_type(self, name):
artifact_type.name = name
artifact_type.properties["uri"] = metadata_store_pb2.STRING
artifact_type.properties["name"] = metadata_store_pb2.STRING
artifact_type.properties["is_input"] = metadata_store_pb2.BOOLEAN
artifact_type.properties["metadata"] = metadata_store_pb2.STRING
artifact_type_id = self.store.put_artifact_type(artifact_type)
return artifact_type_id
12 changes: 8 additions & 4 deletions python/fate/components/entrypoint/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def execute_component(config: TaskConfigSpec):
input_metric_artifacts = parse_input_metric(component, stage, role, config.inputs.artifacts)
# log output artifacts
for name, artifact in input_data_artifacts.items():
mlmd.io.log_input_artifact(name, artifact)
if artifact is not None:
mlmd.io.log_input_artifact(name, artifact)
for name, artifact in input_metric_artifacts.items():
mlmd.io.log_input_artifact(name, artifact)
if artifact is not None:
mlmd.io.log_input_artifact(name, artifact)

# wrap model artifact
input_model_artifacts = {
Expand Down Expand Up @@ -119,9 +121,11 @@ def execute_component(config: TaskConfigSpec):

# log output artifacts
for name, artifact in output_data_artifacts.items():
mlmd.io.log_output_data(name, artifact)
if artifact is not None:
mlmd.io.log_output_data(name, artifact)
for name, artifact in output_metric_artifacts.items():
mlmd.io.log_output_metric(name, artifact)
if artifact is not None:
mlmd.io.log_output_metric(name, artifact)

except Exception as e:
tb = traceback.format_exc()
Expand Down
24 changes: 18 additions & 6 deletions python/fate/components/loader/mlmd/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,37 +68,49 @@ def log_input_parameter(self, key, value):
self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)

def log_input_data(self, key, value):
artifact_id = self._mlmd.add_data_artifact(name=value.name, uri=value.uri, metadata=value.metadata)
artifact_id = self._mlmd.add_data_artifact(
name=value.name, uri=value.uri, metadata=value.metadata, is_input=True
)
execution_id = self._mlmd.get_or_create_task(self._taskid).id
self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id)
self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)

def log_input_model(self, key, value):
artifact_id = self._mlmd.add_model_artifact(name=value.name, uri=value.uri, metadata=value.metadata)
artifact_id = self._mlmd.add_model_artifact(
name=value.name, uri=value.uri, metadata=value.metadata, is_input=True
)
execution_id = self._mlmd.get_or_create_task(self._taskid).id
self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id)
self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)

def log_input_metric(self, key, value):
artifact_id = self._mlmd.add_metric_artifact(name=value.name, uri=value.uri, metadata=value.metadata)
artifact_id = self._mlmd.add_metric_artifact(
name=value.name, uri=value.uri, metadata=value.metadata, is_input=True
)
execution_id = self._mlmd.get_or_create_task(self._taskid).id
self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id)
self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)

def log_output_data(self, key, value):
artifact_id = self._mlmd.add_data_artifact(name=value.name, uri=value.uri, metadata=value.metadata)
artifact_id = self._mlmd.add_data_artifact(
name=value.name, uri=value.uri, metadata=value.metadata, is_input=False
)
execution_id = self._mlmd.get_or_create_task(self._taskid).id
self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id)
self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)

def log_output_model(self, key, value, metadata={}):
artifact_id = self._mlmd.add_model_artifact(name=value.name, uri=value.uri, metadata=value.metadata)
artifact_id = self._mlmd.add_model_artifact(
name=value.name, uri=value.uri, metadata=value.metadata, is_input=False
)
execution_id = self._mlmd.get_or_create_task(self._taskid).id
self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id)
self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)

def log_output_metric(self, key, value):
artifact_id = self._mlmd.add_metric_artifact(name=value.name, uri=value.uri, metadata=value.metadata)
artifact_id = self._mlmd.add_metric_artifact(
name=value.name, uri=value.uri, metadata=value.metadata, is_input=False
)
execution_id = self._mlmd.get_or_create_task(self._taskid).id
self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id)
self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id)
Expand Down

0 comments on commit 616eb38

Please sign in to comment.