Skip to content

Commit

Permalink
Merge branch 'dev-2.0.0-beta-arch-update' of https://github.com/Feder…
Browse files Browse the repository at this point in the history
…atedAI/FATE into feature-2.0.0-glm
  • Loading branch information
nemirorox committed Jul 3, 2023
2 parents 1a49e43 + 9afe162 commit 1208a56
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
38 changes: 36 additions & 2 deletions python/fate/arch/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, List, Literal, Optional, Tuple, TypeVar
from typing import Iterable, List, Literal, Optional, Tuple, TypeVar, overload

from fate.arch.abc import CSessionABC, FederationEngine, PartyMeta

Expand Down Expand Up @@ -101,11 +101,45 @@ def on_batches(self) -> "Context":
def on_cross_validations(self) -> "Context":
return self.sub_ctx("cross_validations")

@overload
def ctxs_range(self, end: int) -> Iterable[Tuple[int, "Context"]]:
...

@overload
def ctxs_range(self, start: int, end: int) -> Iterable[Tuple[int, "Context"]]:
...

def ctxs_range(self, *args, **kwargs) -> Iterable[Tuple[int, "Context"]]:

"""
create contexes with namespaces indexed from 0 to end(excluded)
"""
for i in range(end):

if "start" in kwargs:
start = kwargs["start"]
if "end" not in kwargs:
raise ValueError("End value must be provided")
end = kwargs["end"]
if len(args) > 0:
raise ValueError("Too many arguments")
else:
if "end" in kwargs:
end = kwargs["end"]
if len(args) > 1:
raise ValueError("Too many arguments")
elif len(args) == 0:
raise ValueError("Start value must be provided")
else:
start = args[0]
else:
if len(args) == 1:
start, end = 0, args[0]
elif len(args) == 2:
start, end = args
else:
raise ValueError("Too few arguments")

for i in range(start, end):
yield i, self.with_namespace(self.namespace.indexed_ns(index=i))

def ctxs_zip(self, iterable: Iterable[T]) -> Iterable[Tuple["Context", T]]:
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/tensor/distributed/_ops_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def div(input, other):
def _binary(input, other, op, swap_operad=False, dtype_promote_to=None):
# swap input and output if input is not DTensor
if not isinstance(input, DTensor):
return _binary(op, other, input, swap_operad=not swap_operad, dtype_promote_to=dtype_promote_to)
return _binary(other, input, op, swap_operad=not swap_operad, dtype_promote_to=dtype_promote_to)

if isinstance(other, DTensor):
if swap_operad:
Expand Down

0 comments on commit 1208a56

Please sign in to comment.