Skip to content

Commit

Permalink
add unresolved model type
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 8, 2023
1 parent 71e5ade commit 1e47e79
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 11 deletions.
4 changes: 4 additions & 0 deletions python/fate/components/core/_cpn_reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
parameter,
table_input,
table_inputs,
model_unresolved_output,
model_unresolved_outputs,
)
from .essential import Role

Expand Down Expand Up @@ -75,4 +77,6 @@ def wrapper(roles: Optional[List[Role]] = None, desc="", optional=False) -> "Typ
"model_directory_outputs",
"model_directory_output",
"model_directory_input",
"model_unresolved_output",
"model_unresolved_outputs",
]
8 changes: 6 additions & 2 deletions python/fate/components/core/_load_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ def load_federation(federation, computing):
if isinstance(federation, (OSXFederationSpec, RollSiteFederationSpec)):
if isinstance(federation, OSXFederationSpec):
mode = FederationMode.from_str(federation.metadata.osx_config.mode)
host = federation.metadata.osx_config.host
port = federation.metadata.osx_config.port
options = dict(max_message_size=federation.metadata.osx_config.max_message_size)
else:
mode = FederationMode.STREAM
host = federation.metadata.rollsite_config.host
port = federation.metadata.rollsite_config.port
options = {}
return builder.build_osx(
computing_session=computing,
host=federation.metadata.osx_config.host,
port=federation.metadata.osx_config.port,
host=host,
port=port,
mode=mode,
options=options,
)
Expand Down
4 changes: 4 additions & 0 deletions python/fate/components/core/component_desc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
model_directory_outputs,
table_input,
table_inputs,
model_unresolved_output,
model_unresolved_outputs,
)

__all__ = [
Expand Down Expand Up @@ -69,4 +71,6 @@
"model_directory_input",
"json_metric_output",
"json_metric_outputs",
"model_unresolved_output",
"model_unresolved_outputs",
]
4 changes: 0 additions & 4 deletions python/fate/components/core/component_desc/_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def dump_yaml(self, stream=None):
def predict(
self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None
):

if roles is None:
roles = []

Expand All @@ -241,7 +240,6 @@ def predict(
def train(
self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None
):

if roles is None:
roles = []

Expand All @@ -250,7 +248,6 @@ def train(
def cross_validation(
self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None
):

if roles is None:
roles = []

Expand Down Expand Up @@ -326,7 +323,6 @@ def component(

def _component(name, roles, provider, version, description, is_subcomponent):
def decorator(f):

cpn_name = name or f.__name__.lower()
if isinstance(f, Component):
raise TypeError("Attempted to convert a callback into a component_desc twice.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def _handle_output(self, ctx, component, arg, stage, role, config):
(self.output_model, component.artifacts.model_outputs),
(self.output_metric, component.artifacts.metric_outputs),
]:

if allowed_artifacts := artifacts.get(arg):
if allowed_artifacts.is_active_for(stage, role):
apply_spec: ArtifactOutputApplySpec = config.output_artifacts.get(arg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
model_directory_inputs,
model_directory_output,
model_directory_outputs,
model_unresolved_output,
model_unresolved_outputs,
)

__all__ = [
Expand Down Expand Up @@ -56,4 +58,6 @@
"data_unresolved_outputs",
"json_metric_output",
"json_metric_outputs",
"model_unresolved_output",
"model_unresolved_outputs",
]
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def data_directory_outputs(
) -> Type[Iterator[DataDirectoryWriter]]:
return _create_artifact_annotation(False, True, DataDirectoryArtifactDescribe, "data")(roles, desc, optional)


def data_unresolved_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataUnresolvedWriter]:
return _create_artifact_annotation(False, False, DataUnresolvedArtifactDescribe, "data")(roles, desc, optional)


def data_unresolved_outputs(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[Iterator[DataUnresolvedWriter]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Iterator, List, Optional, Type

from .._base_type import Role, _create_artifact_annotation
from ._directory import (
ModelDirectoryArtifactDescribe,
ModelDirectoryReader,
ModelDirectoryWriter,
)
from ._json import JsonModelArtifactDescribe, JsonModelReader, JsonModelWriter
from ._unresolved import ModelUnresolvedArtifactDescribe, ModelUnresolvedReader, ModelUnresolvedWriter
from .._base_type import Role, _create_artifact_annotation


def json_model_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[JsonModelReader]:
Expand Down Expand Up @@ -43,3 +44,15 @@ def model_directory_outputs(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[Iterator[ModelDirectoryWriter]]:
return _create_artifact_annotation(False, True, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional)


def model_unresolved_output(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[ModelUnresolvedWriter]:
return _create_artifact_annotation(False, False, ModelUnresolvedArtifactDescribe, "model")(roles, desc, optional)


def model_unresolved_outputs(
roles: Optional[List[Role]] = None, desc="", optional=False
) -> Type[Iterator[ModelUnresolvedWriter]]:
return _create_artifact_annotation(False, True, ModelUnresolvedArtifactDescribe, "model")(roles, desc, optional)
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from fate.components.core.essential import ModelUnresolvedArtifactType
from .._base_type import (
URI,
ArtifactDescribe,
ModelOutputMetadata,
Metadata,
_ArtifactType,
_ArtifactTypeReader,
_ArtifactTypeWriter,
)


class ModelUnresolvedWriter(_ArtifactTypeWriter[ModelUnresolvedArtifactType]):
def write_metadata(self, metadata: dict, name=None, namespace=None):
self.artifact.metadata.metadata.update(metadata)
if name is not None:
self.artifact.metadata.name = name
if namespace is not None:
self.artifact.metadata.namespace = namespace


class ModelUnresolvedReader(_ArtifactTypeReader):
def get_metadata(self):
return self.artifact.metadata.metadata


class ModelUnresolvedArtifactDescribe(ArtifactDescribe[ModelUnresolvedArtifactType, ModelOutputMetadata]):
@classmethod
def get_type(cls):
return ModelUnresolvedArtifactType

def get_writer(self, config, ctx, uri: URI, type_name: str) -> ModelUnresolvedWriter:
return ModelUnresolvedWriter(ctx, _ArtifactType(uri=uri, metadata=ModelOutputMetadata(), type_name=type_name))

def get_reader(self, ctx, uri: "URI", metadata: "Metadata", type_name: str) -> ModelUnresolvedReader:
return ModelUnresolvedReader(ctx, _ArtifactType(uri=uri, metadata=metadata, type_name=type_name))
1 change: 1 addition & 0 deletions python/fate/components/core/essential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
JsonMetricArtifactType,
JsonModelArtifactType,
ModelDirectoryArtifactType,
ModelUnresolvedArtifactType,
TableArtifactType,
)
from ._label import Label
Expand Down
6 changes: 6 additions & 0 deletions python/fate/components/core/essential/_artifact_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,9 @@ class JsonMetricArtifactType(ArtifactType):
type_name = "json_metric"
path_type = "file"
uri_types = ["file"]


class ModelUnresolvedArtifactType(ArtifactType):
type_name = "model_unresolved"
path_type = "unresolved"
uri_types = ["unresolved"]
3 changes: 0 additions & 3 deletions python/fate/components/core/spec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class MLModelPartiesSpec(pydantic.BaseModel):


class MLModelFederatedSpec(pydantic.BaseModel):

task_id: str
parties: MLModelPartiesSpec
component: MLModelComponentSpec
Expand All @@ -46,14 +45,12 @@ class MLModelModelSpec(pydantic.BaseModel):


class MLModelPartySpec(pydantic.BaseModel):

party_task_id: str
role: str
partyid: str
models: List[MLModelModelSpec]


class MLModelSpec(pydantic.BaseModel):

federated: MLModelFederatedSpec
party: MLModelPartySpec

0 comments on commit 1e47e79

Please sign in to comment.