Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(beta) enhance params #4630

Merged
merged 1 commit into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/fate/components/components/hetero_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def hetero_lr(ctx, role):
@hetero_lr.train()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST], desc="training data")
@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[GUEST, HOST], desc="validation data")
@cpn.parameter("learning_rate", type=params.ConFloat(gt=0.0), default=0.1, desc="learning rate")
@cpn.parameter("max_iter", type=params.ConInt(gt=0), default=100, desc="max iteration num")
@cpn.parameter("learning_rate", type=params.learning_rate_param(), default=0.1, desc="learning rate")
@cpn.parameter("max_iter", type=params.conint(gt=0), default=100, desc="max iteration num")
@cpn.parameter(
"batch_size", type=params.ConInt(), default=100, desc="batch size, value less or equals to 0 means full batch"
"batch_size", type=params.conint(gt=0), default=100, desc="batch size, value less or equals to 0 means full batch"
)
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("train_output_metric", type=Output[LossMetrics], roles=[ARBITER])
Expand Down
4 changes: 3 additions & 1 deletion python/fate/components/entrypoint/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def parse_input_parameters(mlmd: MLMD, cpn: _Component, input_parameters: Dict[s
try:
value = params.parse(parameter.type, parameter_apply)
except Exception as e:
raise ComponentApplyError(f"apply value `{parameter_apply}` to parameter `{arg}` failed:\n{e}")
raise ComponentApplyError(
f"apply value `{parameter_apply}` to parameter `{arg}` failed:\n{e}"
) from e
execute_parameters[parameter.name] = value
mlmd.io.log_input_parameter(parameter.name, parameter_apply)
return execute_parameters
Expand Down
68 changes: 7 additions & 61 deletions python/fate/components/params/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,65 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pydantic


class Parameter:
def parse(self, obj):
raise NotImplementedError()

def dict(self):
raise NotImplementedError()


class ConInt(Parameter):
def __init__(self, gt: int = None, ge: int = None, lt: int = None, le: int = None) -> None:
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le

def parse(self, obj):
return pydantic.parse_obj_as(pydantic.conint(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj)

def dict(self):
meta = {}
if self.gt is not None:
meta["gt"] = self.gt
if self.ge is not None:
meta["ge"] = self.ge
if self.lt is not None:
meta["lt"] = self.lt
if self.le is not None:
meta["le"] = self.le
return meta


class ConFloat(Parameter):
def __init__(self, gt: float = None, ge: float = None, lt: float = None, le: float = None) -> None:
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le

def parse(self, obj):
return pydantic.parse_obj_as(pydantic.confloat(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj)

def dict(self):
meta = {}
if self.gt is not None:
meta["gt"] = self.gt
if self.ge is not None:
meta["ge"] = self.ge
if self.lt is not None:
meta["lt"] = self.lt
if self.le is not None:
meta["le"] = self.le
return meta

#
from pydantic import validate_arguments

def parse(parameter_type, obj):
if isinstance(parameter_type, Parameter):
return parameter_type.parse(obj)
else:
return pydantic.parse_obj_as(parameter_type, obj)
from ._cipher import CipherParamType, PaillierCipherParam
from ._fields import confloat, conint, jsonschema, parse, string_choice
from ._learning_rate import learning_rate_param
from ._optimizer import optimizer_param
from ._penalty import penalty_param
15 changes: 15 additions & 0 deletions python/fate/components/params/_cipher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Literal, Union

import pydantic


class PaillierCipherParam(pydantic.BaseModel):
method: Literal["paillier"] = "paillier"
key_length: pydantic.conint(gt=1024) = 1024


class NoopCipher(pydantic.BaseModel):
method: Literal[None]


CipherParamType = Union[PaillierCipherParam, NoopCipher]
96 changes: 96 additions & 0 deletions python/fate/components/params/_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import typing
from typing import Any, Optional, Type, TypeVar

import pydantic


class Parameter:
@classmethod
def parse(cls, obj: Any):
return pydantic.parse_obj_as(cls, obj)

@classmethod
def dict(cls):
raise NotImplementedError()


T = TypeVar("T")


def parse(type_: Type[T], obj: Any) -> T:
if not isinstance(type_, typing._GenericAlias) and issubclass(type_, Parameter):
return type_.parse(obj)
else:
return pydantic.parse_obj_as(type_, obj)


def jsonschema(type_: Type[T]):
return pydantic.schema_json_of(type_, indent=2)


class ConstrainedInt(pydantic.ConstrainedInt, Parameter):
...


def conint(
*,
strict: bool = False,
gt: int = None,
ge: int = None,
lt: int = None,
le: int = None,
multiple_of: int = None,
) -> Type[int]:
namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of)
return type("ConstrainedIntValue", (ConstrainedInt,), namespace)


class ConstrainedFloat(pydantic.ConstrainedFloat, Parameter):
...


def confloat(
*,
strict: bool = False,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
allow_inf_nan: Optional[bool] = None,
) -> Type[float]:
namespace = dict(
strict=strict,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
)
return type("ConstrainedFloatValue", (ConstrainedFloat,), namespace)


class StringChoice(str, Parameter):
choice = set()
lower = True

@classmethod
def __get_validators__(cls) -> "CallableGenerator":
yield cls.string_choice_validator

@classmethod
def string_choice_validator(cls, v):
allowed = {c.lower() for c in cls.choice} if cls.lower else cls.choice
provided = v.lower() if cls.lower else v
if provided in allowed:
return provided
raise ValueError(f"provided `{provided}` not in `{allowed}`")


def string_choice(choice, lower=True) -> Type[str]:
namespace = dict(
choice=choice,
lower=lower,
)
return type("StringChoice", (StringChoice,), namespace)
13 changes: 13 additions & 0 deletions python/fate/components/params/_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ._fields import ConstrainedFloat


class LearningRate(ConstrainedFloat):
gt = 0.0

@classmethod
def dict(cls):
return {"name": cls.__name__}


def learning_rate_param():
return LearningRate
13 changes: 13 additions & 0 deletions python/fate/components/params/_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import enum
from typing import Type


class Optimizer(str, enum.Enum):
@classmethod
def __modify_schema__(cls, field_schema: dict):
field_schema["description"] = "optimizer params"


def optimizer_param(rmsprop=True, sgd=True, adam=True, nesterov_momentum_sgd=True, adagrad=True) -> Type[str]:
choice = dict(rmsprop=rmsprop, sgd=sgd, adam=adam, nesterov_momentum_sgd=nesterov_momentum_sgd, adagrad=adagrad)
return Optimizer("OptimizerParam", {k: k for k, v in choice.items() if v})
15 changes: 15 additions & 0 deletions python/fate/components/params/_penalty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Type

from ._fields import StringChoice


class Penalty(StringChoice):
chooice = {}


def penalty_param(l1=True, l2=True) -> Type[str]:
choice = {"L1": l1, "L2": l2}
namespace = dict(
chooice={k for k, v in choice.items() if v},
)
return type("PenaltyValue", (Penalty,), namespace)