Skip to content

Commit

Permalink
split output metadata according to type
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jun 21, 2023
1 parent e77b143 commit 9ba7ab9
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 30 deletions.
35 changes: 24 additions & 11 deletions python/fate/components/core/component_desc/artifacts/_base_type.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import typing
from typing import Generic, List, TypeVar
from typing import Generic, List, TypeVar, Union

from fate.arch import URI
from fate.components.core.essential import Role, Stage
from fate.components.core.spec.artifact import Metadata
from fate.components.core.spec.artifact import (
DataOutputMetadata,
Metadata,
MetricOutputMetadata,
ModelOutputMetadata,
)
from fate.components.core.spec.component import ArtifactSpec
from fate.components.core.spec.task import (
ArtifactInputApplySpec,
ArtifactOutputApplySpec,
)

if typing.TYPE_CHECKING:
from fate.arch import URI, Context
from fate.arch import Context


class _ArtifactTypeWriter:
Expand Down Expand Up @@ -38,7 +44,9 @@ def __repr__(self):


class _ArtifactType:
def __init__(self, uri: "URI", metadata: Metadata) -> None:
def __init__(
self, uri: "URI", metadata: Union[Metadata, DataOutputMetadata, ModelOutputMetadata, MetricOutputMetadata]
) -> None:
self.uri = uri
self.metadata = metadata

Expand Down Expand Up @@ -112,7 +120,7 @@ def dict(self):
def get_type(self) -> AT:
raise NotImplementedError()

def get_writer(self, ctx: "Context", artifact_type: _ArtifactType) -> _ArtifactTypeWriter:
def get_writer(self, ctx: "Context", uri: "URI") -> _ArtifactTypeWriter:
raise NotImplementedError()

def get_reader(self, ctx: "Context", artifact_type: _ArtifactType) -> _ArtifactTypeReader:
Expand Down Expand Up @@ -143,8 +151,9 @@ def load_as_output_slot(self, ctx: "Context", apply_config):
if self.multi:
return self._generator_recorder(ctx, output_artifact_iter)
else:
artifact = next(output_artifact_iter)
return artifact, self.get_writer(ctx, artifact)
uri = next(output_artifact_iter)
writer = self.get_writer(ctx, uri)
return writer.artifact, writer
except Exception as e:
raise ComponentArtifactApplyError(f"load as output artifact({self}) slot error: {e}") from e
if not self.optional:
Expand All @@ -164,16 +173,20 @@ def load_output(self, spec: ArtifactOutputApplySpec):
if i != 0:
raise ValueError(f"index should be 0, but got {i}")
uri = URI.from_string(spec.uri)
yield _ArtifactType(uri, Metadata())
yield uri
i += 1

def _generator_recorder(self, ctx: "Context", generator: typing.Generator[_ArtifactType, None, None]):
def create_metadata(self) -> Metadata:
raise NotImplementedError()

def _generator_recorder(self, ctx: "Context", generator: typing.Generator["URI", None, None]):
recorder = []

def _generator():
for item in generator:
recorder.append(item)
yield self.get_writer(ctx, item)
writer = self.get_writer(ctx, item)
recorder.append(writer.artifact)
yield writer

return recorder, _generator()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

if typing.TYPE_CHECKING:
from fate.arch import URI
from fate.arch.dataframe import DataFrame

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,7 +40,7 @@ def write(self, df: "DataFrame", name=None, namespace=None):
samples = df.data_overview()
from fate.components.core.spec.artifact import DataOverview

self.artifact.metadata.overview = DataOverview(count=count, samples=samples)
self.artifact.metadata.data_overview = DataOverview(count=count, samples=samples)

logger.debug(f"write dataframe to artifact: {self.artifact}")

Expand Down Expand Up @@ -77,8 +78,10 @@ class DataframeArtifactDescribe(ArtifactDescribe):
def get_type(self):
return DataframeArtifactType

def get_writer(self, ctx, artifact_type: _ArtifactType) -> _ArtifactTypeWriter:
return DataframeWriter(ctx, artifact_type)
def get_writer(self, ctx, uri: "URI") -> _ArtifactTypeWriter:
from fate.components.core.spec.artifact import DataOutputMetadata

return DataframeWriter(ctx, _ArtifactType(uri=uri, metadata=DataOutputMetadata()))

def get_reader(self, ctx, artifact_type: _ArtifactType) -> _ArtifactTypeReader:
return DataframeReader(ctx, artifact_type)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from fate.components.core.essential import DataDirectoryArtifactType

from .._base_type import (
URI,
ArtifactDescribe,
DataOutputMetadata,
_ArtifactType,
_ArtifactTypeReader,
_ArtifactTypeWriter,
Expand Down Expand Up @@ -37,8 +39,8 @@ class DataDirectoryArtifactDescribe(ArtifactDescribe):
def get_type(self):
return DataDirectoryArtifactType

def get_writer(self, ctx, artifact_type: _ArtifactType) -> DataDirectoryWriter:
return DataDirectoryWriter(ctx, artifact_type)
def get_writer(self, ctx, uri: URI) -> DataDirectoryWriter:
return DataDirectoryWriter(ctx, _ArtifactType(uri=uri, metadata=DataOutputMetadata()))

def get_reader(self, ctx, artifact_type: _ArtifactType) -> DataDirectoryReader:
return DataDirectoryReader(ctx, artifact_type)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from fate.components.core.essential import TableArtifactType

from .._base_type import (
URI,
ArtifactDescribe,
DataOutputMetadata,
_ArtifactType,
_ArtifactTypeReader,
_ArtifactTypeWriter,
Expand Down Expand Up @@ -37,8 +39,8 @@ class TableArtifactDescribe(ArtifactDescribe):
def get_type(self):
return TableArtifactType

def get_writer(self, ctx, artifact_type: _ArtifactType) -> TableWriter:
return TableWriter(ctx, artifact_type)
def get_writer(self, ctx, uri: URI) -> TableWriter:
return TableWriter(ctx, _ArtifactType(uri, DataOutputMetadata()))

def get_reader(self, ctx: "Context", artifact_type: _ArtifactType) -> TableReader:
return TableReader(ctx, artifact_type)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from fate.components.core.essential import JsonMetricArtifactType

from .._base_type import (
URI,
ArtifactDescribe,
MetricOutputMetadata,
_ArtifactType,
_ArtifactTypeReader,
_ArtifactTypeWriter,
Expand Down Expand Up @@ -40,8 +42,8 @@ class JsonMetricArtifactDescribe(ArtifactDescribe[_ArtifactType]):
def get_type(self):
return JsonMetricArtifactType

def get_writer(self, ctx: "Context", artifact_type: _ArtifactType) -> JsonMetricWriter:
return JsonMetricWriter(ctx, artifact_type)
def get_writer(self, ctx: "Context", uri: URI) -> JsonMetricWriter:
return JsonMetricWriter(ctx, _ArtifactType(uri=uri, metadata=MetricOutputMetadata()))

def get_reader(self, ctx: "Context", artifact_type: _ArtifactType) -> JsonMetricReader:
return JsonMetricReader(ctx, artifact_type)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from fate.components.core.essential import ModelDirectoryArtifactType

from .._base_type import (
URI,
ArtifactDescribe,
ModelOutputMetadata,
_ArtifactType,
_ArtifactTypeReader,
_ArtifactTypeWriter,
Expand Down Expand Up @@ -37,8 +39,8 @@ class ModelDirectoryArtifactDescribe(ArtifactDescribe[_ArtifactType]):
def get_type(self):
return ModelDirectoryArtifactType

def get_writer(self, ctx: "Context", artifact_type: _ArtifactType) -> ModelDirectoryWriter:
return ModelDirectoryWriter(ctx, artifact_type)
def get_writer(self, ctx: "Context", uri: URI) -> ModelDirectoryWriter:
return ModelDirectoryWriter(ctx, _ArtifactType(uri=uri, metadata=ModelOutputMetadata()))

def get_reader(self, ctx: "Context", artifact_type: _ArtifactType) -> ModelDirectoryReader:
return ModelDirectoryReader(ctx, artifact_type)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from fate.components.core.essential import JsonModelArtifactType

from .._base_type import (
URI,
ArtifactDescribe,
ModelOutputMetadata,
_ArtifactType,
_ArtifactTypeReader,
_ArtifactTypeWriter,
Expand Down Expand Up @@ -38,8 +40,8 @@ class JsonModelArtifactDescribe(ArtifactDescribe[_ArtifactType]):
def get_type(self):
return JsonModelArtifactType

def get_writer(self, ctx: "Context", artifact_type: _ArtifactType) -> JsonModelWriter:
return JsonModelWriter(ctx, artifact_type)
def get_writer(self, ctx: "Context", uri: URI) -> JsonModelWriter:
return JsonModelWriter(ctx, _ArtifactType(uri=uri, metadata=ModelOutputMetadata()))

def get_reader(self, ctx: "Context", artifact_type: _ArtifactType) -> JsonModelReader:
return JsonModelReader(ctx, artifact_type)
60 changes: 55 additions & 5 deletions python/fate/components/core/spec/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import re
import typing
from typing import List, Optional

if typing.TYPE_CHECKING:
from fate.arch import URI

import pydantic

from .model import (
MLModelComponentSpec,
MLModelFederatedSpec,
MLModelModelSpec,
MLModelPartiesSpec,
MLModelPartySpec,
MLModelSpec,
)

# see https://www.rfc-editor.org/rfc/rfc3986#appendix-B
# scheme = $2
# authority = $4
Expand Down Expand Up @@ -48,7 +54,51 @@ class Metadata(pydantic.BaseModel):
metadata: dict = pydantic.Field(default_factory=dict)
name: Optional[str] = None
namespace: Optional[str] = None
overview: Optional[DataOverview] = None
source: Optional[ArtifactSource] = None


class ModelOutputMetadata(pydantic.BaseModel):
metadata: dict = pydantic.Field(default_factory=dict)
name: Optional[str] = None
namespace: Optional[str] = None
source: Optional[ArtifactSource] = None
model_overview: MLModelSpec = MLModelSpec(
federated=MLModelFederatedSpec(
task_id="",
parties=MLModelPartiesSpec(guest=[], host=[], arbiter=[]),
component=MLModelComponentSpec(name="", provider="", version="", metadata={}),
),
party=MLModelPartySpec(
party_task_id="",
role="",
partyid="",
models=[
MLModelModelSpec(
name="", created_time=datetime.datetime.now().isoformat(), file_format="", metadata={}
)
],
),
)

class Config:
extra = "forbid"


class DataOutputMetadata(pydantic.BaseModel):
metadata: dict = pydantic.Field(default_factory=dict)
name: Optional[str] = None
namespace: Optional[str] = None
source: Optional[ArtifactSource] = None
data_overview: Optional[DataOverview] = None

class Config:
extra = "forbid"


class MetricOutputMetadata(pydantic.BaseModel):
metadata: dict = pydantic.Field(default_factory=dict)
name: Optional[str] = None
namespace: Optional[str] = None
source: Optional[ArtifactSource] = None

class Config:
Expand Down
2 changes: 1 addition & 1 deletion python/fate/components/core/spec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MLModelFederatedSpec(pydantic.BaseModel):

class MLModelModelSpec(pydantic.BaseModel):
name: str
created_time: datetime
created_time: str
file_format: str
metadata: dict

Expand Down

0 comments on commit 9ba7ab9

Please sign in to comment.