Skip to content

Commit

Permalink
Merge branch 'dev-2.0.0-beta-arch-update' of https://github.com/Feder…
Browse files Browse the repository at this point in the history
…atedAI/FATE into feature-2.0.0-glm
  • Loading branch information
nemirorox committed Jul 3, 2023
2 parents ee1bb53 + 3ce1676 commit e5ad936
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
7 changes: 6 additions & 1 deletion python/fate/arch/tensor/distributed/_ops_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def _binary(input, other, op, swap_operad=False, dtype_promote_to=None):

# other is local tensor, broadcast to partitions
else:
shapes = input.shardings.shapes.bc_shapes(other.shape)
if isinstance(other, torch.Tensor):
shapes = input.shardings.shapes.bc_shapes(other.shape)
else:
# other is scalar
shapes = input.shardings.shapes.bc_shapes(torch.Size([]))

if swap_operad:
return DTensor(
input.shardings.map_shard(
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/tensor/distributed/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def bc_shapes(self, other: "_ShardingShapes") -> "_ShardingShapes":
assert other[other_align_axis] == 1, f"shape in distributed axis should be 1: {self} vs {other}"
self_axis = len(_bc_shapes[0]) - len(self.shapes[0]) + self.axis

return _ShardingShapes(_bc_shapes, self.axis)
return _ShardingShapes(_bc_shapes, self_axis)
else:
raise NotImplementedError(f"type `{other}`")

Expand Down
3 changes: 2 additions & 1 deletion python/fate/components/core/component_desc/_parameter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from typing import Dict, TypeVar

import pydantic
Expand Down Expand Up @@ -34,7 +35,7 @@ def get_parameter_spec(self):
from fate.components.core.spec.component import ParameterSpec

default = self.default if self.default is not ... else None
if issubclass(self.type, Parameter): # recommended
if not typing.get_origin(self.type) and issubclass(self.type, Parameter): # recommended
type_name = type(self.type).__name__
if (schema := self.type.schema()) != NotImplemented:
type_meta = schema
Expand Down

0 comments on commit e5ad936

Please sign in to comment.