Skip to content

Commit

Permalink
Merge pull request #4630 from FederatedAI/feature/2.0.0-beta/params
Browse files Browse the repository at this point in the history
(beta) enhance params
  • Loading branch information
nemirorox authored Mar 10, 2023
2 parents 9734331 + c369035 commit 7629740
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 65 deletions.
6 changes: 3 additions & 3 deletions python/fate/components/components/hetero_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def hetero_lr(ctx, role):
@hetero_lr.train()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST], desc="training data")
@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[GUEST, HOST], desc="validation data")
@cpn.parameter("learning_rate", type=params.ConFloat(gt=0.0), default=0.1, desc="learning rate")
@cpn.parameter("max_iter", type=params.ConInt(gt=0), default=100, desc="max iteration num")
@cpn.parameter("learning_rate", type=params.learning_rate_param(), default=0.1, desc="learning rate")
@cpn.parameter("max_iter", type=params.conint(gt=0), default=100, desc="max iteration num")
@cpn.parameter(
"batch_size", type=params.ConInt(), default=100, desc="batch size, value less or equals to 0 means full batch"
"batch_size", type=params.conint(gt=0), default=100, desc="batch size, value less or equals to 0 means full batch"
)
@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST])
@cpn.artifact("train_output_metric", type=Output[LossMetrics], roles=[ARBITER])
Expand Down
4 changes: 3 additions & 1 deletion python/fate/components/entrypoint/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def parse_input_parameters(mlmd: MLMD, cpn: _Component, input_parameters: Dict[s
try:
value = params.parse(parameter.type, parameter_apply)
except Exception as e:
raise ComponentApplyError(f"apply value `{parameter_apply}` to parameter `{arg}` failed:\n{e}")
raise ComponentApplyError(
f"apply value `{parameter_apply}` to parameter `{arg}` failed:\n{e}"
) from e
execute_parameters[parameter.name] = value
mlmd.io.log_input_parameter(parameter.name, parameter_apply)
return execute_parameters
Expand Down
68 changes: 7 additions & 61 deletions python/fate/components/params/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,65 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pydantic


class Parameter:
def parse(self, obj):
raise NotImplementedError()

def dict(self):
raise NotImplementedError()


class ConInt(Parameter):
def __init__(self, gt: int = None, ge: int = None, lt: int = None, le: int = None) -> None:
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le

def parse(self, obj):
return pydantic.parse_obj_as(pydantic.conint(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj)

def dict(self):
meta = {}
if self.gt is not None:
meta["gt"] = self.gt
if self.ge is not None:
meta["ge"] = self.ge
if self.lt is not None:
meta["lt"] = self.lt
if self.le is not None:
meta["le"] = self.le
return meta


class ConFloat(Parameter):
def __init__(self, gt: float = None, ge: float = None, lt: float = None, le: float = None) -> None:
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le

def parse(self, obj):
return pydantic.parse_obj_as(pydantic.confloat(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj)

def dict(self):
meta = {}
if self.gt is not None:
meta["gt"] = self.gt
if self.ge is not None:
meta["ge"] = self.ge
if self.lt is not None:
meta["lt"] = self.lt
if self.le is not None:
meta["le"] = self.le
return meta

#
from pydantic import validate_arguments

def parse(parameter_type, obj):
if isinstance(parameter_type, Parameter):
return parameter_type.parse(obj)
else:
return pydantic.parse_obj_as(parameter_type, obj)
from ._cipher import CipherParamType, PaillierCipherParam
from ._fields import confloat, conint, jsonschema, parse, string_choice
from ._learning_rate import learning_rate_param
from ._optimizer import optimizer_param
from ._penalty import penalty_param
15 changes: 15 additions & 0 deletions python/fate/components/params/_cipher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Literal, Union

import pydantic


class PaillierCipherParam(pydantic.BaseModel):
method: Literal["paillier"] = "paillier"
key_length: pydantic.conint(gt=1024) = 1024


class NoopCipher(pydantic.BaseModel):
method: Literal[None]


CipherParamType = Union[PaillierCipherParam, NoopCipher]
96 changes: 96 additions & 0 deletions python/fate/components/params/_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import typing
from typing import Any, Optional, Type, TypeVar

import pydantic


class Parameter:
@classmethod
def parse(cls, obj: Any):
return pydantic.parse_obj_as(cls, obj)

@classmethod
def dict(cls):
raise NotImplementedError()


T = TypeVar("T")


def parse(type_: Type[T], obj: Any) -> T:
if not isinstance(type_, typing._GenericAlias) and issubclass(type_, Parameter):
return type_.parse(obj)
else:
return pydantic.parse_obj_as(type_, obj)


def jsonschema(type_: Type[T]):
return pydantic.schema_json_of(type_, indent=2)


class ConstrainedInt(pydantic.ConstrainedInt, Parameter):
...


def conint(
*,
strict: bool = False,
gt: int = None,
ge: int = None,
lt: int = None,
le: int = None,
multiple_of: int = None,
) -> Type[int]:
namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of)
return type("ConstrainedIntValue", (ConstrainedInt,), namespace)


class ConstrainedFloat(pydantic.ConstrainedFloat, Parameter):
...


def confloat(
*,
strict: bool = False,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
allow_inf_nan: Optional[bool] = None,
) -> Type[float]:
namespace = dict(
strict=strict,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
)
return type("ConstrainedFloatValue", (ConstrainedFloat,), namespace)


class StringChoice(str, Parameter):
choice = set()
lower = True

@classmethod
def __get_validators__(cls) -> "CallableGenerator":
yield cls.string_choice_validator

@classmethod
def string_choice_validator(cls, v):
allowed = {c.lower() for c in cls.choice} if cls.lower else cls.choice
provided = v.lower() if cls.lower else v
if provided in allowed:
return provided
raise ValueError(f"provided `{provided}` not in `{allowed}`")


def string_choice(choice, lower=True) -> Type[str]:
namespace = dict(
choice=choice,
lower=lower,
)
return type("StringChoice", (StringChoice,), namespace)
13 changes: 13 additions & 0 deletions python/fate/components/params/_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ._fields import ConstrainedFloat


class LearningRate(ConstrainedFloat):
gt = 0.0

@classmethod
def dict(cls):
return {"name": cls.__name__}


def learning_rate_param():
return LearningRate
13 changes: 13 additions & 0 deletions python/fate/components/params/_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import enum
from typing import Type


class Optimizer(str, enum.Enum):
@classmethod
def __modify_schema__(cls, field_schema: dict):
field_schema["description"] = "optimizer params"


def optimizer_param(rmsprop=True, sgd=True, adam=True, nesterov_momentum_sgd=True, adagrad=True) -> Type[str]:
choice = dict(rmsprop=rmsprop, sgd=sgd, adam=adam, nesterov_momentum_sgd=nesterov_momentum_sgd, adagrad=adagrad)
return Optimizer("OptimizerParam", {k: k for k, v in choice.items() if v})
15 changes: 15 additions & 0 deletions python/fate/components/params/_penalty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Type

from ._fields import StringChoice


class Penalty(StringChoice):
chooice = {}


def penalty_param(l1=True, l2=True) -> Type[str]:
choice = {"L1": l1, "L2": l2}
namespace = dict(
chooice={k for k, v in choice.items() if v},
)
return type("PenaltyValue", (Penalty,), namespace)

0 comments on commit 7629740

Please sign in to comment.