Skip to content

Commit

Permalink
feat(component): impl basic mlmd with ml-metadata
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 22, 2022
1 parent 229cb29 commit 1bf0d1b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 22 deletions.
62 changes: 62 additions & 0 deletions python/fate/arch/context/_mlmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from ml_metadata import metadata_store
from ml_metadata.proto import metadata_store_pb2


class MachineLearningMetadata:
def __init__(self, backend="sqlite", metadata={}) -> None:
self.store = self.create_store(backend, metadata)
self._job_type_id = None # context type
self._task_type_id = None # execution type

def update_task_state(self, task_run, state, exception=None):
task_run.properties["state"].string_value = state
if exception is not None:
task_run.properties["exception"].string_value = exception
self.store.put_executions([task_run])

def get_task_safe_terminate_flag(self, task_run):
return task_run.properties["safe_terminate"].bool_value

def set_task_safe_terminate_flag(self, task_run):
task_run.properties["safe_terminate"].bool_value = True
self.store.put_executions([task_run])

def get_or_create_task(self, taskid):
task_run = self.store.get_execution_by_type_and_name("Task", taskid)
if task_run is None:
task_run = metadata_store_pb2.Execution()
task_run.type_id = self.task_type_id
task_run.name = taskid
task_run.properties["state"].string_value = "INIT"
task_run.properties["safe_terminate"].bool_value = False
[task_run_id] = self.store.put_executions([task_run])
task_run.id = task_run_id
return task_run

@classmethod
def create_store(cls, backend, metadata):
connection_config = metadata_store_pb2.ConnectionConfig()
if backend == "sqlite":
connection_config.sqlite.filename_uri = metadata["filename_uri"]
connection_config.sqlite.connection_mode = metadata.get("connection_mode", 3)
return metadata_store.MetadataStore(connection_config)

@property
def job_type_id(self):
if self._job_type_id is None:
job_type = metadata_store_pb2.ContextType()
job_type.name = "Job"
job_type.properties["jobid"] = metadata_store_pb2.STRING
self._job_type_id = self.store.put_context_type(job_type)
return self._job_type_id

@property
def task_type_id(self):
if self._task_type_id is None:
task_type = metadata_store_pb2.ExecutionType()
task_type.name = "Task"
task_type.properties["state"] = metadata_store_pb2.STRING
task_type.properties["exception"] = metadata_store_pb2.STRING
task_type.properties["safe_terminate"] = metadata_store_pb2.BOOLEAN
self._task_type_id = self.store.put_execution_type(task_type)
return self._task_type_id
3 changes: 2 additions & 1 deletion python/fate/components/entrypoint/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fate.arch.context import Context
from fate.components.loader import load_component
from fate.components.spec.mlmd import get_mlmd
from fate.components.spec.task import TaskConfigSpec

logger = logging.getLogger(__name__)
Expand All @@ -18,8 +19,8 @@ class ParamsValidateFailed(ComponentExecException):


def execute_component(config: TaskConfigSpec):
mlmd = config.conf.mlmd
context_name = config.execution_id
mlmd = get_mlmd(config.conf.mlmd, context_name)
computing = get_computing(config)
federation = get_federation(config, computing)
device = config.conf.get_device()
Expand Down
3 changes: 0 additions & 3 deletions python/fate/components/entrypoint/component_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def execute(process_tag, config, config_entrypoint, properties):
logger.debug("logger installed")
logger.debug(f"task config: {task_config}")

# init mlmd
task_config.conf.mlmd.init(task_config.execution_id)

from fate.components.entrypoint.component import execute_component

execute_component(task_config)
Expand Down
33 changes: 19 additions & 14 deletions python/fate/components/spec/mlmd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
from typing import Literal, Protocol

import pydantic
from pydantic import BaseModel

"""
Expand All @@ -24,16 +22,20 @@ def safe_terminate(self):
...


class PipelineMLMD(BaseModel):
class PipelineMLMDDesc(BaseModel):
class PipelineMLMDMetaData(BaseModel):
state_path: str
terminate_state_path: str
db: str

type: Literal["pipeline"]
metadata: PipelineMLMDMetaData

def init(self, execution_id: str):
...

class PipelineMLMD:
def __init__(self, mlmd: PipelineMLMDDesc, taskid) -> None:
from fate.arch.context._mlmd import MachineLearningMetadata

self._mlmd = MachineLearningMetadata(metadata=dict(filename_uri=mlmd.metadata.db))
self._task = self._mlmd.get_or_create_task(taskid)

def log_excution_start(self):
return self._log_state("running")
Expand All @@ -42,17 +44,15 @@ def log_excution_end(self):
return self._log_state("finish")

def log_excution_exception(self, message: dict):
return self._log_state("exception", message)
import json

self._log_state("exception", json.dumps(message))

def _log_state(self, state, message=None):
data = dict(state=state)
if message is not None:
data["message"] = message
with open(self.metadata.state_path, "w") as f:
json.dump(data, f)
self._mlmd.update_task_state(self._task, state, message)

def safe_terminate(self):
return True
return self._mlmd.get_task_safe_terminate_flag(self._task)


class FlowMLMD(BaseModel):
Expand Down Expand Up @@ -84,3 +84,8 @@ def _log_state(self, state, message=None):

def safe_terminate(self):
...


def get_mlmd(mlmd, taskid):
if isinstance(mlmd, PipelineMLMDDesc):
return PipelineMLMD(mlmd, taskid)
4 changes: 2 additions & 2 deletions python/fate/components/spec/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Literal, Optional, Union

import pydantic
from fate.components.spec.mlmd import FlowMLMD, PipelineMLMD
from fate.components.spec.mlmd import FlowMLMD, PipelineMLMDDesc

from .logger import CustomLogger, FlowLogger, PipelineLogger

Expand Down Expand Up @@ -42,7 +42,7 @@ class TaskConfSpec(pydantic.BaseModel):
computing: TaskDistributedComputingBackendSpec
federation: TaskFederationBackendSpec
logger: Union[PipelineLogger, FlowLogger, CustomLogger]
mlmd: Union[PipelineMLMD, FlowMLMD]
mlmd: Union[PipelineMLMDDesc, FlowMLMD]

def get_device(self):
from fate.arch.unify import device
Expand Down
3 changes: 1 addition & 2 deletions schemas/tasks/lr.train.guest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ conf:
mlmd:
type: pipeline
metadata:
state_path: /Users/sage/fate_tmp/state
terminate_state_path: /Users/sage/fate_tmp/terminate_state
db: /Users/sage/mlmd.db
logger:
type: pipeline
metadata:
Expand Down

0 comments on commit 1bf0d1b

Please sign in to comment.