diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 9f061d5f93..debb2f5a9b 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Iterable, List, Literal, Optional, Tuple, TypeVar +from typing import Iterable, List, Literal, Optional, Tuple, TypeVar, overload from fate.arch.abc import CSessionABC, FederationEngine, PartyMeta @@ -101,11 +101,45 @@ def on_batches(self) -> "Context": def on_cross_validations(self) -> "Context": return self.sub_ctx("cross_validations") + @overload def ctxs_range(self, end: int) -> Iterable[Tuple[int, "Context"]]: + ... + + @overload + def ctxs_range(self, start: int, end: int) -> Iterable[Tuple[int, "Context"]]: + ... + + def ctxs_range(self, *args, **kwargs) -> Iterable[Tuple[int, "Context"]]: + """ create contexes with namespaces indexed from 0 to end(excluded) """ - for i in range(end): + + if "start" in kwargs: + start = kwargs["start"] + if "end" not in kwargs: + raise ValueError("End value must be provided") + end = kwargs["end"] + if len(args) > 0: + raise ValueError("Too many arguments") + else: + if "end" in kwargs: + end = kwargs["end"] + if len(args) > 1: + raise ValueError("Too many arguments") + elif len(args) == 0: + raise ValueError("Start value must be provided") + else: + start = args[0] + else: + if len(args) == 1: + start, end = 0, args[0] + elif len(args) == 2: + start, end = args + else: + raise ValueError("Too few arguments") + + for i in range(start, end): yield i, self.with_namespace(self.namespace.indexed_ns(index=i)) def ctxs_zip(self, iterable: Iterable[T]) -> Iterable[Tuple["Context", T]]: diff --git a/python/fate/arch/tensor/distributed/_ops_binary.py b/python/fate/arch/tensor/distributed/_ops_binary.py index 70141f1ea2..9c61aa1455 100644 --- a/python/fate/arch/tensor/distributed/_ops_binary.py +++ b/python/fate/arch/tensor/distributed/_ops_binary.py @@ -31,7 +31,7 @@ def div(input, other): def _binary(input, other, op, swap_operad=False, dtype_promote_to=None): # swap input and output if input is not DTensor if not isinstance(input, DTensor): - return _binary(op, other, input, swap_operad=not swap_operad, dtype_promote_to=dtype_promote_to) + return _binary(other, input, op, swap_operad=not swap_operad, dtype_promote_to=dtype_promote_to) if isinstance(other, DTensor): if swap_operad: