Skip to content

Commit

Permalink
update pipeline: support get_output_model/metrics
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Jan 11, 2023
1 parent f58cb80 commit 3f8f0c4
Show file tree
Hide file tree
Showing 19 changed files with 482 additions and 96 deletions.
4 changes: 2 additions & 2 deletions python/fate_client/pipeline/components/component_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import copy
from ..conf.types import SupportRole, PlaceHolder, ArtifactSourceType
from ..conf.job_configuration import TaskConf
from python.fate_client.pipeline.utils.standalone.id_gen import get_uuid
from pipeline.entity.component_structures import load_component_spec
from ..utils.standalone.id_gen import get_uuid
from ..entity.component_structures import load_component_spec
from ..interface import ArtifactChannel
from ..entity.dag_structures import RuntimeTaskOutputChannelSpec, ModelWarehouseChannelSpec

Expand Down
5 changes: 4 additions & 1 deletion python/fate_client/pipeline/entity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .dag import DAG
from .task_info import FateFlowTaskInfo, StandaloneTaskInfo


__all__ = [
"DAG"
"DAG",
"FateFlowTaskInfo",
"StandaloneTaskInfo"
]
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Union, Dict
from ..scheduler.runtime_constructor import RuntimeConstructor
from typing import Dict


class StandaloneModelInfo(object):
def __init__(self, job_id: str, task_info: Dict[str, RuntimeConstructor],
def __init__(self, job_id: str, task_info, local_role: str, local_party_id: str,
model_id: str = None, model_version: int = None):
self._job_id = job_id
self._task_info = task_info
self._model_id = model_id
self._model_version = model_version
self._local_role = local_role
self._local_party_id = local_party_id

@property
def job_id(self):
Expand All @@ -26,6 +27,14 @@ def model_id(self):
def model_version(self):
return self._model_version

@property
def local_role(self):
return self._local_role

@property
def local_party_id(self):
return self._local_party_id


class FateFlowModelInfo(object):
def __init__(self, job_id: str, local_role: str, local_party_id: str,
Expand Down
56 changes: 56 additions & 0 deletions python/fate_client/pipeline/entity/model_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
from typing import List

import pydantic


class MLModelComponentSpec(pydantic.BaseModel):
name: str
provider: str
version: str
metadata: dict


class MLModelPartiesSpec(pydantic.BaseModel):
guest: List[str]
host: List[str]
arbiter: List[str]


class MLModelFederatedSpec(pydantic.BaseModel):
task_id: str
parties: MLModelPartiesSpec
component: MLModelComponentSpec


class MLModelModelSpec(pydantic.BaseModel):
name: str
created_time: datetime
file_format: str
metadata: dict


class MLModelPartySpec(pydantic.BaseModel):
party_task_id: str
role: str
partyid: str
models: List[MLModelModelSpec]


class MLModelSpec(pydantic.BaseModel):
federated: MLModelFederatedSpec
party: MLModelPartySpec
57 changes: 57 additions & 0 deletions python/fate_client/pipeline/entity/task_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import abc
import typing
from .model_info import StandaloneModelInfo, FateFlowModelInfo
from ..utils.fateflow.fate_flow_job_invoker import FATEFlowJobInvoker


class TaskInfo(object):
def __init__(self, task_name: str, model_info: typing.Union[StandaloneModelInfo, FateFlowModelInfo]):
self._model_info = model_info
self._task_name = task_name

@abc.abstractmethod
def get_output_data(self, *args, **kwargs):
...

@abc.abstractmethod
def get_output_model(self, *args, **kwargs):
...

@abc.abstractmethod
def get_output_metrics(self, *args, **kwargs):
...


class StandaloneTaskInfo(TaskInfo):
def get_output_data(self):
...

def get_output_model(self, role=None, party_id=None):
party_id = party_id if role else self._model_info.local_party_id
role = role if role else self._model_info.local_role
return self._model_info.task_info[self._task_name].get_output_model(role, party_id)

def get_output_metrics(self, role=None, party_id=None):
party_id = party_id if role else self._model_info.local_party_id
role = role if role else self._model_info.local_role
return self._model_info.task_info[self._task_name].get_output_metrics(role, party_id)


class FateFlowTaskInfo(TaskInfo):
def get_output_model(self):
return FATEFlowJobInvoker().get_output_model(job_id=self._model_info.job_id,
role=self._model_info.local_role,
party_id=self._model_info.local_party_id,
task_name=self._task_name)

def get_output_data(self, limits=None, ):
...

def get_output_metrics(self):
return FATEFlowJobInvoker().get_output_metrics(job_id=self._model_info.job_id,
role=self._model_info.local_role,
party_id=self._model_info.local_party_id,
task_name=self._task_name)



26 changes: 22 additions & 4 deletions python/fate_client/pipeline/executor/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..scheduler.dag_parser import DagParser
from ..scheduler.runtime_constructor import RuntimeConstructor
from ..utils.fateflow.fate_flow_job_invoker import FATEFlowJobInvoker
from .model_info import StandaloneModelInfo, FateFlowModelInfo
from python.fate_client.pipeline.entity.model_info import StandaloneModelInfo, FateFlowModelInfo


class StandaloneExecutor(object):
Expand All @@ -23,9 +23,12 @@ def fit(self, dag_schema: DAGSchema, component_specs: Dict[str, ComponentSpec],
self._dag_parser.parse_dag(dag_schema, component_specs)
self._run()

local_party_id = self.get_site_party_id(dag_schema, local_role, local_party_id)
return StandaloneModelInfo(
job_id=self._job_id,
task_info=self._runtime_constructor_dict,
local_role=local_role,
local_party_id=local_party_id,
model_id=self._job_id,
model_version=0
)
Expand All @@ -38,7 +41,9 @@ def predict(self,
self._run(fit_model_info)
return StandaloneModelInfo(
job_id=self._job_id,
task_info=self._runtime_constructor_dict
task_info=self._runtime_constructor_dict,
local_role=fit_model_info.local_role,
local_party_id=fit_model_info.local_party_id
)

def _run(self, fit_model_info: StandaloneModelInfo = None):
Expand All @@ -62,12 +67,12 @@ def _run(self, fit_model_info: StandaloneModelInfo = None):
job_id=self._job_id,
task_name=task_name,
component_ref=task_node.component_ref,
component_spec=task_node.component_spec,
stage=stage,
runtime_parameters=runtime_parameters,
log_dir=log_dir)
runtime_constructor.construct_input_artifacts(upstream_inputs,
runtime_constructor_dict,
component_spec,
fit_model_info)
runtime_constructor.construct_outputs()
# runtime_constructor.construct_output_artifacts(output_definitions)
Expand Down Expand Up @@ -103,6 +108,18 @@ def _exec_task(task_type, task_name, runtime_constructor):

return ret_msg

@staticmethod
def get_site_party_id(dag_schema, role, party_id):
if party_id:
return party_id

if party_id is None:
for party in dag_schema.dag.parties:
if role == party.role:
return party.party_id[0]

raise ValueError(f"Can not retrieval site's party_id from site's role {role}")


class FateFlowExecutor(object):
def __init__(self):
Expand Down Expand Up @@ -163,7 +180,8 @@ def get_site_party_id(flow_job_invoker, dag_schema, role, party_id):

raise ValueError(f"Can not retrieval site's party_id from site's role {role}")

def upload(self, file: str, head: int,
@staticmethod
def upload(file: str, head: int,
namespace: str, name: str, meta: dict,
partitions=4, storage_engine=None, **kwargs):
flow_job_invoker = FATEFlowJobInvoker()
Expand Down
12 changes: 5 additions & 7 deletions python/fate_client/pipeline/manager/metric_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Union
from ..utils.uri_tools import parse_uri, replace_uri_path, get_schema_from_uri
from ..utils.file_utils import construct_local_dir
Expand All @@ -14,17 +15,14 @@ def generate_output_metric_uri(cls, output_dir_uri: str, job_id: str, task_name:
uri_obj = replace_uri_path(uri_obj, str(local_path))
return uri_obj.geturl()


class LMDBMetricManager(object):
@classmethod
def generate_output_metric_uri(cls, output_dir_uri: str, job_id: str, task_name: str,
role: str, party_id: Union[str, int]):
...
def get_output_metrics(cls, uri):
uri_obj = parse_uri(uri)
with open(uri_obj.path, "r") as fin:
return json.loads(fin.read())


def get_metric_manager(metric_uri: str):
uri_type = get_schema_from_uri(metric_uri)
if uri_type == UriTypes.LOCAL:
return LocalFSMetricManager
else:
return LMDBMetricManager
36 changes: 29 additions & 7 deletions python/fate_client/pipeline/manager/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,52 @@
import json
import os
import tarfile
import tempfile
import yaml
from ..utils.uri_tools import parse_uri, replace_uri_path, get_schema_from_uri
from ..utils.file_utils import construct_local_dir
from ..conf.types import UriTypes
from ..entity.model_structure import MLModelSpec


class LocalFSModelManager(object):
@classmethod
def generate_output_model_uri(cls, output_dir_uri: str, job_id: str, task_name: str,
role: str, party_id: str):
model_id = "_".join([job_id, task_name, role, str(party_id)])
model_version = "v0"
model_version = "0"
uri_obj = parse_uri(output_dir_uri)
local_path = construct_local_dir(uri_obj.path, *[model_id, model_version])
uri_obj = replace_uri_path(uri_obj, str(local_path))
return uri_obj.geturl()


class LMDBModelManager(object):
@classmethod
def generate_output_model_uri(cls, uri_obj, session_id: str, role: str, party_id: str, namespace: str, name: str):
...
def get_output_model(cls, output_dir_uri):
uri_obj = parse_uri(output_dir_uri)
models = dict()
with tempfile.TemporaryDirectory() as temp_dir:
tar = tarfile.open(uri_obj.path, "r:")
tar.extractall(path=temp_dir)
tar.close()
for file_name in os.listdir(temp_dir):
if file_name.endswith("FMLModel.yaml"):
with open(os.path.join(temp_dir, file_name), "r") as fp:
model_meta = yaml.safe_load(fp)
model_spec = MLModelSpec.parse_obj(model_meta)

for model in model_spec.party.models:
file_format = model.file_format
model_name = model.name

if file_format == "json":
with open(os.path.join(temp_dir, model_name), "r") as fp:
models[model_name] = json.loads(fp.read())

return models


def get_model_manager(model_uri: str):
uri_type = get_schema_from_uri(model_uri)
if uri_type == UriTypes.LOCAL:
return LocalFSModelManager
else:
return LMDBModelManager

8 changes: 7 additions & 1 deletion python/fate_client/pipeline/manager/resource_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from python.fate_client.pipeline.utils.standalone.id_gen import get_uuid
from ..utils.standalone.id_gen import get_uuid
from ..utils.file_utils import construct_local_dir
from ..conf.env_config import StandaloneConfig
from ..entity.task_structure import OutputArtifact
Expand Down Expand Up @@ -93,6 +93,12 @@ def generate_output_terminate_status_uri(self, job_id, task_name, role, party_id
role,
party_id)

def get_output_model(self, uri):
return self._model_manager.get_output_model(uri)

def get_output_metrics(self, uri):
return self._metric_manager.get_output_metrics(uri)

@staticmethod
def generate_log_uri(log_dir_prefix, role, party_id):
return str(construct_local_dir(log_dir_prefix, *[role, str(party_id)]))
Expand Down
Loading

0 comments on commit 3f8f0c4

Please sign in to comment.