Skip to content

Commit

Permalink
fix union artifact desc
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jun 25, 2023
1 parent 31ff940 commit 6140e22
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 42 deletions.
2 changes: 2 additions & 0 deletions python/fate/components/components/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
def run(
ctx: "Context",
role: Role,
parameter: cpn.parameter(type=str, default="default", desc="parameter"),
mix_input: cpn.dataframe_input(roles=[GUEST, HOST]) | cpn.data_directory_input(),
dataframe_inputs: cpn.dataframe_inputs(roles=[GUEST, HOST]),
dataframe_input: cpn.dataframe_input(roles=[GUEST, HOST]),
dataset_inputs: cpn.data_directory_inputs(roles=[GUEST, HOST]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def is_active_for(self, stage: "Stage", role: "Role"):

def get_correct_arti(self, apply_spec) -> T:
for t in self.types:
if t.is_correct_arti(apply_spec):
if apply_spec.type_name is None or t.get_type().type_name == apply_spec.type_name:
return t(
name=self.name,
roles=self.roles,
Expand Down Expand Up @@ -86,6 +86,9 @@ def merge(self, a: "AllowArtifactDescribes"):
is_multi=self.is_multi,
)

def __str__(self):
return f"AllowArtifactDescribes(name={self.name}, types={self.types}, roles={self.roles}, stages={self.stages}, desc={self.desc}, optional={self.optional}, is_multi={self.is_multi})"


class ArtifactDescribeAnnotation:
def __init__(
Expand Down
73 changes: 37 additions & 36 deletions python/fate/components/core/component_desc/_component_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,11 @@

from ..spec.artifact import ArtifactInputApplySpec, ArtifactOutputApplySpec
from ..spec.task import TaskConfigSpec
from .artifacts._base_type import (
AT,
ArtifactDescribe,
M,
_ArtifactType,
_ArtifactTypeReader,
)
from .artifacts._base_type import AT, ArtifactDescribe, M

logger = logging.getLogger(__name__)


class ComponentInputDataApplied:
def __init__(
self, artifact_desc: "ArtifactDescribe", artifact_type: "_ArtifactType", reader: "_ArtifactTypeReader"
):
self.artifact_desc = artifact_desc
self.artifact_type = artifact_type
self.reader = reader


class ComponentExecutionIO:
def __init__(self, ctx: "Context", component: Component, role: Role, stage: Stage, config):
self.parameter_artifacts_desc = {}
Expand Down Expand Up @@ -67,41 +52,46 @@ def _handle_input(self, ctx, component, arg, stage, role, config):
data=component.artifacts.data_inputs,
model=component.artifacts.model_inputs,
).items():
if arti := artifacts.get(arg):
if arti.is_active_for(stage, role):
if allow_artifacts := artifacts.get(arg):
if allow_artifacts.is_active_for(stage, role):
apply_spec: Union[
ArtifactInputApplySpec, List[ArtifactInputApplySpec]
] = config.input_artifacts.get(arg)
if apply_spec is not None:
arti = arti.get_correct_arti(apply_spec)
try:
if arti.multi:
if allow_artifacts.is_multi:
readers = []
for c in apply_spec:
uri = URI.from_string(c.uri)
arti = allow_artifacts.get_correct_arti(c)
readers.append(arti.get_reader(ctx, uri, c.metadata))
self.input_artifacts[input_type][arg] = _ArtifactsType([r.artifact for r in readers])
self.input_artifacts_reader[input_type][arg] = readers
else:
uri = URI.from_string(apply_spec.uri)
arti = allow_artifacts.get_correct_arti(apply_spec)
reader = arti.get_reader(ctx, uri, apply_spec.metadata)
self.input_artifacts[input_type][arg] = reader.artifact
self.input_artifacts_reader[input_type][arg] = reader
except Exception as e:
raise ComponentArtifactApplyError(f"load as input artifact({arti}) error: {e}") from e
elif arti.optional:
raise ComponentArtifactApplyError(
f"load as input artifact({allow_artifacts}) error: {e}"
) from e
elif allow_artifacts.optional:
self.input_artifacts_reader[input_type][arg] = None
self.input_artifacts[input_type][arg] = None
else:
raise ComponentArtifactApplyError(
f"load as input artifact({arti}) error: apply_config is None but not optional"
f"load as input artifact({allow_artifacts}) error: `{arg}` is not optional but None got"
)
logging.debug(
f"apply {input_type} artifact `{arti.name}`: {apply_spec} -> {self.input_artifacts_reader[input_type][arg]}"
logger.debug(
f"apply {input_type} artifact `{allow_artifacts.name}`: {apply_spec} -> {self.input_artifacts_reader[input_type][arg]}"
)
return True
else:
logging.debug(f"skip {input_type} artifact `{arti.name}` for stage `{stage}` and role `{role}`")
logger.debug(
f"skip {input_type} artifact `{allow_artifacts.name}` for stage `{stage}` and role `{role}`"
)
return False

def _handle_output(self, ctx, component, arg, stage, role, config):
Expand All @@ -112,17 +102,17 @@ def _handle_output(self, ctx, component, arg, stage, role, config):
model=component.artifacts.model_outputs,
metric=component.artifacts.metric_outputs,
).items():
if arti := artifacts.get(arg):
if arti.is_active_for(stage, role):
if allowed_artifacts := artifacts.get(arg):
if allowed_artifacts.is_active_for(stage, role):
apply_spec: ArtifactOutputApplySpec = config.output_artifacts.get(arg)
if apply_spec is not None:
arti = arti.get_correct_arti(apply_spec)
try:
if arti.multi:
if allowed_artifacts.is_multi:
if not apply_spec.is_template():
raise ComponentArtifactApplyError(
"template uri required for multiple output artifact"
)
arti = allowed_artifacts.get_correct_arti(apply_spec)
writers = WriterGenerator(ctx, arti, apply_spec)
self.output_artifacts[output_type][arg] = writers.recorder
self.output_artifacts_writer[output_type][arg] = writers
Expand All @@ -132,24 +122,29 @@ def _handle_output(self, ctx, component, arg, stage, role, config):
raise ComponentArtifactApplyError(
"template uri is not supported for non-multiple output artifact"
)
arti = allowed_artifacts.get_correct_arti(apply_spec)
writer = arti.get_writer(ctx, URI.from_string(apply_spec.uri))
self.output_artifacts[output_type][arg] = writer.artifact
self.output_artifacts_writer[output_type][arg] = writer
except Exception as e:
raise ComponentArtifactApplyError(f"load as output artifact({arti}) error: {e}") from e
elif arti.optional:
raise ComponentArtifactApplyError(
f"load as output artifact({allowed_artifacts}) error: {e}"
) from e
elif allowed_artifacts.optional:
self.output_artifacts_writer[output_type][arg] = None
self.output_artifacts[output_type][arg] = None
else:
raise ComponentArtifactApplyError(
f"load as output artifact({arti}) error: apply_config is None but not optional"
f"load as output artifact({allowed_artifacts}) error: apply_config is None but not optional"
)
logging.debug(
f"apply {output_type} artifact `{arti.name}`: {apply_spec} -> {self.output_artifacts_writer[output_type][arg]}"
logger.debug(
f"apply {output_type} artifact `{allowed_artifacts.name}`: {apply_spec} -> {self.output_artifacts_writer[output_type][arg]}"
)
return True
else:
logging.debug(f"skip {output_type} artifact `{arti.name}` for stage `{stage}` and role `{role}`")
logger.debug(
f"skip {output_type} artifact `{allowed_artifacts.name}` for stage `{stage}` and role `{role}`"
)
return False

def get_kwargs(self):
Expand Down Expand Up @@ -216,6 +211,12 @@ def __next__(self):
self.current += 1
return writer

def __str__(self):
return f"{self.__class__.__name__}({self.artifact_describe}, index={self.current}>"

def __repr__(self):
return str(self)


class ComponentArtifactApplyError(RuntimeError):
...
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(self, name: str, roles: List[Role], stages: List[Stage], desc: str,
self.multi = multi

def __str__(self) -> str:
return f"ArtifactDeclare<name={self.name}, type={self.get_type()}, roles={self.roles}, stages={self.stages}, optional={self.optional}>"
return f"{self.__class__.__name__}(name={self.name}, type={self.get_type()}, roles={self.roles}, stages={self.stages}, optional={self.optional})"

def dict(self):
return ArtifactSpec(
Expand All @@ -135,10 +135,6 @@ def dict(self):
def get_type(cls) -> AT:
raise NotImplementedError()

@classmethod
def is_correct_arti(cls, uri: URI):
return True

def get_writer(self, ctx: "Context", uri: "URI") -> _ArtifactTypeWriter[M]:
raise NotImplementedError()

Expand Down
2 changes: 2 additions & 0 deletions python/fate/components/core/spec/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ class Config:
class ArtifactInputApplySpec(pydantic.BaseModel):
uri: str
metadata: Metadata
type_name: Optional[str] = None


class ArtifactOutputApplySpec(pydantic.BaseModel):
uri: str
_is_template: Optional[bool] = None
type_name: Optional[str] = None

def is_template(self) -> bool:
return "{index}" in self.uri
Expand Down

0 comments on commit 6140e22

Please sign in to comment.