Skip to content

Commit

Permalink
support mlmd in pipeline
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Nov 23, 2022
1 parent 79ce2c4 commit c94822a
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 136 deletions.
4 changes: 4 additions & 0 deletions python/fate_client/pipeline/conf/env_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class StandaloneConfig(object):
COMPUTING_ENGINE = conf.get("computing_engine")
FEDERATION_ENGINE = conf.get("federation_engine")

SQLITE_DB = conf.get("sqlite_db")
if not SQLITE_DB:
SQLITE_DB = default_path.joinpath("pipeline_sqlite.db").as_uri()


class LogPath(object):
@classmethod
Expand Down
10 changes: 7 additions & 3 deletions python/fate_client/pipeline/entity/task_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class TaskRuntimeInputSpec(BaseModel):
artifacts: Optional[Dict[str, IOArtifact]]


class TaskRuntimeOutputSpec(BaseModel):
artifacts: Dict[str, IOArtifact]


class MLMDSpec(BaseModel):
type: str
metadata: Dict[str, Any]
Expand Down Expand Up @@ -48,8 +52,8 @@ class RuntimeEnvSpec(BaseModel):
mlmd: MLMDSpec
logger: LOGGERSpec
device: str
distributed_computing_backend: ComputingBackendSpec
federation_backend: FederationBackendSpec
computing: ComputingBackendSpec
federation: FederationBackendSpec


class TaskScheduleSpec(BaseModel):
Expand All @@ -58,6 +62,6 @@ class TaskScheduleSpec(BaseModel):
role: str
stage: str
party_id: Optional[Union[str, int]]
inputs: Optional[Dict[str, IOArtifact]]
inputs: Optional[TaskRuntimeInputSpec]
outputs: Optional[Dict[str, IOArtifact]]
env: RuntimeEnvSpec
4 changes: 2 additions & 2 deletions python/fate_client/pipeline/executor/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def _run(self):
status = self._exec_task("run_component",
task_name,
runtime_constructor=runtime_constructor)
if status["summary_status"] != "SUCCESS":
raise ValueError(f"run task {task_name} is failed, status is {status}")
# if status["summary_status"] != "SUCCESS":
# raise ValueError(f"run task {task_name} is failed, status is {status}")

self._runtime_constructor_dict = runtime_constructor_dict
print("Job Finish Successfully!!!")
Expand Down
3 changes: 2 additions & 1 deletion python/fate_client/pipeline/manager/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ def generate_output_model_uri(cls, output_dir_uri: str, job_id: str, task_name:
role: str, party_id: str, model_suffix: str):
model_id = "_".join([job_id, task_name, role, str(party_id), model_suffix])
model_version = "v0"
suffix = "model.json"
uri_obj = parse_uri(output_dir_uri)
local_path = construct_local_dir(uri_obj.path, *[model_id, model_version])
local_path = construct_local_dir(uri_obj.path, *[model_id, model_version, suffix])
uri_obj = replace_uri_path(uri_obj, str(local_path))
return uri_obj.geturl()

Expand Down
15 changes: 5 additions & 10 deletions python/fate_client/pipeline/manager/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, conf: 'StandaloneConfig'):
self._data_manager = get_data_manager(conf.OUTPUT_DATA_DIR)
self._model_manager = get_model_manager(conf.OUTPUT_MODEL_DIR)
self._metric_manager = get_metric_manager(conf.OUTPUT_METRIC_DIR)
self._status_manager = get_status_manager(conf.OUTPUT_STATUS_DIR)
self._status_manager = get_status_manager().create_status_manager(conf.SQLITE_DB)
self._task_conf_manager = get_task_conf_manager(conf.JOB_DIR)

def generate_output_artifact(self, job_id, task_name, role, party_id, output_key, artifact_type):
Expand All @@ -28,7 +28,8 @@ def generate_output_artifact(self, job_id, task_name, role, party_id, output_key

return IOArtifact(
name=output_key,
uri=model_uri
uri=model_uri,
metadata=dict(format="json")
)
elif artifact_type in ["dataset", "datasets"]:
data_uri = self._generate_output_data_uri(
Expand All @@ -41,7 +42,8 @@ def generate_output_artifact(self, job_id, task_name, role, party_id, output_key

return IOArtifact(
name=output_key,
uri=data_uri
uri=data_uri,
metadata=dict(format="json")
)

def _generate_output_data_uri(self, job_id, task_name, role, party_id, output_key):
Expand Down Expand Up @@ -73,13 +75,6 @@ def _generate_output_metric_uri(self, job_id, role, party):
role,
party)

def generate_output_status_uri(self, job_id, task_name, role, party_id):
return self._status_manager.generate_output_status_uri(self._conf.OUTPUT_STATUS_DIR,
job_id,
task_name,
role,
party_id)

def generate_output_terminate_status_uri(self, job_id, task_name, role, party_id):
return self._status_manager.generate_output_terminate_status_uri(self._conf.OUTPUT_STATUS_DIR,
job_id,
Expand Down
174 changes: 96 additions & 78 deletions python/fate_client/pipeline/manager/status_manager.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,115 @@
import json
import os
from pathlib import Path
from ..utils.uri_tools import parse_uri, replace_uri_path, get_schema_from_uri
from ..utils.file_utils import construct_local_dir, write_json_file
from ..conf.types import UriTypes
from ml_metadata import metadata_store
from ml_metadata.proto import metadata_store_pb2


class LocalFSStatusManager(object):
@classmethod
def generate_output_status_uri(cls, output_dir_uri: str, job_id: str, task_name: str,
role: str, party_id: str):
uri_obj = parse_uri(output_dir_uri)
local_path = construct_local_dir(uri_obj.path, *[job_id, task_name, role, party_id, "status.log"])
uri_obj = replace_uri_path(uri_obj, str(local_path))
return uri_obj.geturl()

@classmethod
def generate_output_terminate_status_uri(cls, output_dir_uri: str, job_id: str, task_name: str,
role: str, party_id: str):
uri_obj = parse_uri(output_dir_uri)
local_path = construct_local_dir(uri_obj.path, *[job_id, task_name, role, party_id, "terminate_status.log"])
uri_obj = replace_uri_path(uri_obj, str(local_path))
return uri_obj.geturl()
class SQLiteStatusManager(object):
def __init__(self, status_uri: str):
self._meta_manager = MachineLearningMetadata(metadata=dict(filename_uri=status_uri))

@classmethod
def monitor_status(cls, status_uris):
for status_uri in status_uris:
uri_obj = parse_uri(status_uri.status_uri)
if not os.path.exists(uri_obj.path):
def create_status_manager(cls, status_uri):
return SQLiteStatusManager(status_uri)

def monitor_finish_status(self, execution_ids: list):
for execution_id in execution_ids:
task_run = self._meta_manager.get_or_create_task(execution_id)
state = task_run.properties["state"].string_value
if state == "running":
return False

return True

@classmethod
def record_finish_status(cls, status_uri):
uri_obj = parse_uri(status_uri)
path = Path(uri_obj.path).parent.joinpath("done")
buf = dict(job_status="done")

write_json_file(str(path), buf)
def record_terminate_status(self, execution_ids):
for execution_id in execution_ids:
task_run = self._meta_manager.get_or_create_task(execution_id)
self._meta_manager.set_task_safe_terminate_flag(task_run)

@classmethod
def get_tasks_status(cls, task_status_uris):
def get_task_results(self, tasks_info):
"""
running/finish/exception
"""
summary_msg = dict()
summary_status = "SUCCESS"
for obj in task_status_uris:
try:
path = parse_uri(obj.task_terminate_status_uri).path
with open(path, "r") as fin:
party_status = json.loads(fin.read())

if party_status["status"]["status"] != "SUCCESS":
summary_status = "FAIL"
except FileNotFoundError:
party_status = dict(
status=dict(
status="FAIL",
extras="can not start task"
)
)
summary_status = "FAIL"

if obj.role not in summary_msg:
summary_msg[obj.role] = dict()
summary_msg[obj.role][obj.party_id] = party_status
summary_status = "success"

for task_info in tasks_info:
role = task_info.role
party_id = task_info.party_id
if role not in summary_msg:
summary_msg[role] = dict()

task_run = self._meta_manager.get_or_create_task(task_info.execution_id)
status = task_run.properties["state"].string_value

summary_msg[role][party_id] = status
if status != "finish":
summary_status = "fail"

ret = dict(summary_status=summary_status,
retmsg=summary_msg)

return ret


class LMDBStatusManager(object):
@classmethod
def generate_output_status_uri(cls, uri_obj, session_id: str, role: str, party_id: str):
...

@classmethod
def record_finish_status(cls, status_uri):
...

@classmethod
def get_task_status(cls, status_uris):
...
class MachineLearningMetadata:
def __init__(self, backend="sqlite", metadata={}) -> None:
self.store = self.create_store(backend, metadata)
self._job_type_id = None # context type
self._task_type_id = None # execution type

def update_task_state(self, task_run, state, exception=None):
task_run.properties["state"].string_value = state
if exception is not None:
task_run.properties["exception"].string_value = exception
self.store.put_executions([task_run])

def get_task_safe_terminate_flag(self, task_run):
task_run = self.get_or_create_task(task_run.name)
return task_run.properties["safe_terminate"].bool_value

def set_task_safe_terminate_flag(self, task_run):
task_run.properties["safe_terminate"].bool_value = True
self.store.put_executions([task_run])

def get_or_create_task(self, taskid):
task_run = self.store.get_execution_by_type_and_name("Task", taskid)
if task_run is None:
task_run = metadata_store_pb2.Execution()
task_run.type_id = self.task_type_id
task_run.name = taskid
task_run.properties["state"].string_value = "INIT"
task_run.properties["safe_terminate"].bool_value = False
[task_run_id] = self.store.put_executions([task_run])
task_run.id = task_run_id
return task_run

@classmethod
def monitor_status(cls, ):
...


def get_status_manager(model_uri: str):
uri_type = get_schema_from_uri(model_uri)
if uri_type == UriTypes.LOCAL:
return LocalFSStatusManager
else:
return LMDBStatusManager
def create_store(cls, backend, metadata):
connection_config = metadata_store_pb2.ConnectionConfig()
if backend == "sqlite":
connection_config.sqlite.filename_uri = metadata["filename_uri"]
connection_config.sqlite.connection_mode = metadata.get("connection_mode", 3)
return metadata_store.MetadataStore(connection_config)

@property
def job_type_id(self):
if self._job_type_id is None:
job_type = metadata_store_pb2.ContextType()
job_type.name = "Job"
job_type.properties["jobid"] = metadata_store_pb2.STRING
self._job_type_id = self.store.put_context_type(job_type)
return self._job_type_id

@property
def task_type_id(self):
if self._task_type_id is None:
task_type = metadata_store_pb2.ExecutionType()
task_type.name = "Task"
task_type.properties["state"] = metadata_store_pb2.STRING
task_type.properties["exception"] = metadata_store_pb2.STRING
task_type.properties["safe_terminate"] = metadata_store_pb2.BOOLEAN
self._task_type_id = self.store.put_execution_type(task_type)
return self._task_type_id


def get_status_manager():
return SQLiteStatusManager
2 changes: 1 addition & 1 deletion python/fate_client/pipeline/pipeline_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ log_directory:
console_display_log:

standalone:
sqlite_db:
job_directory:
output_data_dir:
output_model_dir:
output_metric_dir:
output_status_dir:
computing_engine: "standalone"
federation_engine: "standalone"
logger:
Expand Down
28 changes: 17 additions & 11 deletions python/fate_client/pipeline/scheduler/runtime_constructor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ..conf.env_config import StandaloneConfig
from ..entity.component_structures import OutputDefinitionsSpec
from ..entity.task_structure import TaskScheduleSpec, MLMDSpec, LOGGERSpec, \
RuntimeEnvSpec, ComputingBackendSpec, FederationBackendSpec, FederationPartySpec
from ..entity.task_structure import TaskScheduleSpec, LOGGERSpec, TaskRuntimeInputSpec, TaskRuntimeOutputSpec, \
MLMDSpec, RuntimeEnvSpec, ComputingBackendSpec, FederationBackendSpec, FederationPartySpec
from ..manager.resource_manager import StandaloneResourceManager
from ..utils.id_gen import gen_computing_id, gen_federation_id, gen_execution_id

Expand Down Expand Up @@ -112,12 +112,8 @@ def construct_input_artifacts(self, upstream_inputs, runtime_constructor_dict, c
self._input_artifacts[party.role][party.party_id].update({input_key: output_artifacts})

def _construct_mlmd(self, role, party_id):
status_path = self._resource_manager.generate_output_status_uri(self._job_id, self._task_name, role, party_id)
terminate_status_path = self._resource_manager.generate_output_terminate_status_uri(
self._job_id, self._task_name, role, party_id)
metadata = {
"state_path": status_path,
"terminate_state_path": terminate_status_path
"db": self._conf.SQLITE_DB
}
return MLMDSpec(type="pipeline",
metadata=metadata)
Expand Down Expand Up @@ -160,8 +156,8 @@ def _construct_runtime_env(self, role, party_id):
mlmd=mlmd,
logger=logger,
device=self._conf.DEVICE,
distributed_computing_backend=computing_backend,
federation_backend=federation_backend
computing=computing_backend,
federation=federation_backend
)

def construct_task_schedule_spec(self):
Expand All @@ -176,12 +172,19 @@ def construct_task_schedule_spec(self):
)

input_artifact = self._input_artifacts[party.role][party.party_id]
task_input_spec = TaskRuntimeInputSpec()
if input_artifact:
party_task_spec.inputs = input_artifact
task_input_spec.artifacts = input_artifact
parameters = self._runtime_parameters.get(party.role, {}).get(party.party_id, {})
if parameters:
task_input_spec.parameters = parameters

if task_input_spec.dict(exclude_defaults=True):
party_task_spec.inputs = task_input_spec

output_artifact = self._output_artifacts[party.role][party.party_id]
if output_artifact:
party_task_spec.outputs = output_artifact
party_task_spec.outputs = TaskRuntimeOutputSpec(artifacts=output_artifact)

self._task_schedule_spec[party.role][party.party_id] = party_task_spec
conf_uri = self._resource_manager.write_out_task_conf(self._job_id,
Expand All @@ -208,3 +211,6 @@ def execution_id(self, role, party_id):
@property
def status_manager(self):
return self._resource_manager.status_manager

def log_path(self, role, party_id):
return self._task_schedule_spec[role][party_id].env.logger.metadata["base_path"]
Loading

0 comments on commit c94822a

Please sign in to comment.