-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update pipeline: support get_output_model/metrics
Signed-off-by: mgqa34 <mgq3374541@163.com>
- Loading branch information
Showing
19 changed files
with
482 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
from .dag import DAG | ||
from .task_info import FateFlowTaskInfo, StandaloneTaskInfo | ||
|
||
|
||
__all__ = [ | ||
"DAG" | ||
"DAG", | ||
"FateFlowTaskInfo", | ||
"StandaloneTaskInfo" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from datetime import datetime | ||
from typing import List | ||
|
||
import pydantic | ||
|
||
|
||
class MLModelComponentSpec(pydantic.BaseModel): | ||
name: str | ||
provider: str | ||
version: str | ||
metadata: dict | ||
|
||
|
||
class MLModelPartiesSpec(pydantic.BaseModel): | ||
guest: List[str] | ||
host: List[str] | ||
arbiter: List[str] | ||
|
||
|
||
class MLModelFederatedSpec(pydantic.BaseModel): | ||
task_id: str | ||
parties: MLModelPartiesSpec | ||
component: MLModelComponentSpec | ||
|
||
|
||
class MLModelModelSpec(pydantic.BaseModel): | ||
name: str | ||
created_time: datetime | ||
file_format: str | ||
metadata: dict | ||
|
||
|
||
class MLModelPartySpec(pydantic.BaseModel): | ||
party_task_id: str | ||
role: str | ||
partyid: str | ||
models: List[MLModelModelSpec] | ||
|
||
|
||
class MLModelSpec(pydantic.BaseModel): | ||
federated: MLModelFederatedSpec | ||
party: MLModelPartySpec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import abc | ||
import typing | ||
from .model_info import StandaloneModelInfo, FateFlowModelInfo | ||
from ..utils.fateflow.fate_flow_job_invoker import FATEFlowJobInvoker | ||
|
||
|
||
class TaskInfo(object): | ||
def __init__(self, task_name: str, model_info: typing.Union[StandaloneModelInfo, FateFlowModelInfo]): | ||
self._model_info = model_info | ||
self._task_name = task_name | ||
|
||
@abc.abstractmethod | ||
def get_output_data(self, *args, **kwargs): | ||
... | ||
|
||
@abc.abstractmethod | ||
def get_output_model(self, *args, **kwargs): | ||
... | ||
|
||
@abc.abstractmethod | ||
def get_output_metrics(self, *args, **kwargs): | ||
... | ||
|
||
|
||
class StandaloneTaskInfo(TaskInfo): | ||
def get_output_data(self): | ||
... | ||
|
||
def get_output_model(self, role=None, party_id=None): | ||
party_id = party_id if role else self._model_info.local_party_id | ||
role = role if role else self._model_info.local_role | ||
return self._model_info.task_info[self._task_name].get_output_model(role, party_id) | ||
|
||
def get_output_metrics(self, role=None, party_id=None): | ||
party_id = party_id if role else self._model_info.local_party_id | ||
role = role if role else self._model_info.local_role | ||
return self._model_info.task_info[self._task_name].get_output_metrics(role, party_id) | ||
|
||
|
||
class FateFlowTaskInfo(TaskInfo): | ||
def get_output_model(self): | ||
return FATEFlowJobInvoker().get_output_model(job_id=self._model_info.job_id, | ||
role=self._model_info.local_role, | ||
party_id=self._model_info.local_party_id, | ||
task_name=self._task_name) | ||
|
||
def get_output_data(self, limits=None, ): | ||
... | ||
|
||
def get_output_metrics(self): | ||
return FATEFlowJobInvoker().get_output_metrics(job_id=self._model_info.job_id, | ||
role=self._model_info.local_role, | ||
party_id=self._model_info.local_party_id, | ||
task_name=self._task_name) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,52 @@ | ||
import json | ||
import os | ||
import tarfile | ||
import tempfile | ||
import yaml | ||
from ..utils.uri_tools import parse_uri, replace_uri_path, get_schema_from_uri | ||
from ..utils.file_utils import construct_local_dir | ||
from ..conf.types import UriTypes | ||
from ..entity.model_structure import MLModelSpec | ||
|
||
|
||
class LocalFSModelManager(object): | ||
@classmethod | ||
def generate_output_model_uri(cls, output_dir_uri: str, job_id: str, task_name: str, | ||
role: str, party_id: str): | ||
model_id = "_".join([job_id, task_name, role, str(party_id)]) | ||
model_version = "v0" | ||
model_version = "0" | ||
uri_obj = parse_uri(output_dir_uri) | ||
local_path = construct_local_dir(uri_obj.path, *[model_id, model_version]) | ||
uri_obj = replace_uri_path(uri_obj, str(local_path)) | ||
return uri_obj.geturl() | ||
|
||
|
||
class LMDBModelManager(object): | ||
@classmethod | ||
def generate_output_model_uri(cls, uri_obj, session_id: str, role: str, party_id: str, namespace: str, name: str): | ||
... | ||
def get_output_model(cls, output_dir_uri): | ||
uri_obj = parse_uri(output_dir_uri) | ||
models = dict() | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
tar = tarfile.open(uri_obj.path, "r:") | ||
tar.extractall(path=temp_dir) | ||
tar.close() | ||
for file_name in os.listdir(temp_dir): | ||
if file_name.endswith("FMLModel.yaml"): | ||
with open(os.path.join(temp_dir, file_name), "r") as fp: | ||
model_meta = yaml.safe_load(fp) | ||
model_spec = MLModelSpec.parse_obj(model_meta) | ||
|
||
for model in model_spec.party.models: | ||
file_format = model.file_format | ||
model_name = model.name | ||
|
||
if file_format == "json": | ||
with open(os.path.join(temp_dir, model_name), "r") as fp: | ||
models[model_name] = json.loads(fp.read()) | ||
|
||
return models | ||
|
||
|
||
def get_model_manager(model_uri: str): | ||
uri_type = get_schema_from_uri(model_uri) | ||
if uri_type == UriTypes.LOCAL: | ||
return LocalFSModelManager | ||
else: | ||
return LMDBModelManager | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.