Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev-2.0.0-beta-debugging' into d…
Browse files Browse the repository at this point in the history
…ev-2.0.0-beta-debugging

# Conflicts:
#	python/fate/ml/glm/coordinated_linr/arbiter.py
  • Loading branch information
nemirorox committed Jul 11, 2023
2 parents 5dba9ad + 759b762 commit 05fd550
Show file tree
Hide file tree
Showing 31 changed files with 696 additions and 408 deletions.
4 changes: 2 additions & 2 deletions python/fate/arch/computing/eggroll/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def session_id(self):
def load(self, uri: URI, schema: dict, options: dict = None) -> Table:
from ._type import EggRollStoreType

if uri.schema != "eggroll":
raise ValueError(f"uri scheme {uri.schema} not supported with eggroll backend")
if uri.scheme != "eggroll":
raise ValueError(f"uri scheme {uri.scheme} not supported with eggroll backend")
try:
_, namespace, name = uri.path_splits()
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions python/fate/arch/computing/eggroll/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def save(self, uri: URI, schema: dict, options: dict = None):

from ._type import EggRollStoreType

if uri.schema != "eggroll":
raise ValueError(f"uri scheme {uri.schema} not supported with eggroll backend")
if uri.scheme != "eggroll":
raise ValueError(f"uri scheme {uri.scheme} not supported with eggroll backend")
try:
_, namespace, name = uri.path_splits()
except Exception as e:
Expand Down
12 changes: 6 additions & 6 deletions python/fate/arch/computing/spark/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, session_id):
def load(self, uri: URI, schema, options: dict = None) -> "Table":
partitions = options.get("partitions", None)

if uri.schema == "hdfs":
if uri.scheme == "hdfs":
in_serialized = (options.get("in_serialized", True),)
id_delimiter = (options.get("id_delimiter", ","),)
table = from_hdfs(
Expand All @@ -47,10 +47,10 @@ def load(self, uri: URI, schema, options: dict = None) -> "Table":
in_serialized=in_serialized,
id_delimiter=id_delimiter,
)
table.schema = schema
table.scheme = schema
return table

if uri.schema == "hive":
if uri.scheme == "hive":
try:
(path,) = uri.path_splits()
database_name, table_name = path.split(".")
Expand All @@ -61,10 +61,10 @@ def load(self, uri: URI, schema, options: dict = None) -> "Table":
db_name=database_name,
partitions=partitions,
)
table.schema = schema
table.scheme = schema
return table

if uri.schema == "file":
if uri.scheme == "file":
in_serialized = (options.get("in_serialized", True),)
id_delimiter = (options.get("id_delimiter", ","),)
table = from_localfs(
Expand All @@ -73,7 +73,7 @@ def load(self, uri: URI, schema, options: dict = None) -> "Table":
in_serialized=in_serialized,
id_delimiter=id_delimiter,
)
table.schema = schema
table.scheme = schema
return table

raise NotImplementedError(f"uri type {uri} not supported with spark backend")
Expand Down
6 changes: 3 additions & 3 deletions python/fate/arch/computing/spark/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,23 @@ def save(self, uri, schema: dict, options: dict = None):
if options is None:
options = {}
partitions = options.get("partitions")
if uri.schema == "hdfs":
if uri.scheme == "hdfs":
table = self._rdd.map(lambda x: hdfs_serialize(x[0], x[1]))
if partitions:
table = table.repartition(partitions)
table.saveAsTextFile(uri.original_uri)
schema.update(self.schema)
return

if uri.schema == "hive":
if uri.scheme == "hive":
table = self._rdd.map(lambda x: hive_to_row(x[0], x[1]))
if partitions:
table = table.repartition(partitions)
table.toDF().write.saveAsTable(uri.original_uri)
schema.update(self.schema)
return

if uri.schema == "file":
if uri.scheme == "file":
table = self._rdd.map(lambda x: hdfs_serialize(x[0], x[1]))
if partitions:
table = table.repartition(partitions)
Expand Down
4 changes: 2 additions & 2 deletions python/fate/arch/computing/standalone/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def session_id(self):
return self._session.session_id

def load(self, uri: URI, schema: dict, options: dict = None):
if uri.schema != "standalone":
raise ValueError(f"uri scheme `{uri.schema}` not supported with standalone backend")
if uri.scheme != "standalone":
raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend")
try:
*database, namespace, name = uri.path_splits()
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions python/fate/arch/computing/standalone/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def save(self, uri: URI, schema, options: dict = None):
if options is None:
options = {}

if uri.schema != "standalone":
raise ValueError(f"uri scheme `{uri.schema}` not supported with standalone backend")
if uri.scheme != "standalone":
raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend")
try:
*database, namespace, name = uri.path_splits()
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion python/fate/arch/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._cipher import CipherKit
from ._context import Context

__all__ = ["Context"]
__all__ = ["Context", "CipherKit"]
32 changes: 25 additions & 7 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from ..unify import device

logger = logging.getLogger(__name__)


class CipherKit:
def __init__(self, device: device) -> None:
self.device = device
def __init__(self, device: device, cipher_mapping=None) -> None:
self._device = device
self._cipher_mapping = cipher_mapping

@property
def phe(self):
return PHECipher(self.device)
if self._cipher_mapping is None:
if self._device == device.CPU:
return PHECipher("paillier")
else:
logger.warning(f"no impl exists for device {self._device}, fallback to CPU")
return PHECipher("paillier")

if "phe" not in self._cipher_mapping:
raise ValueError("phe is not set")

if self._device not in self._cipher_mapping["phe"]:
raise ValueError(f"phe is not set for device {self._device}")

return PHECipher(self._cipher_mapping["phe"])


class PHECipher:
def __init__(self, _device) -> None:
self.device = _device
def __init__(self, kind) -> None:
self.kind = kind

def keygen(self, **kwargs):
from fate.arch.tensor import keygen
from fate.arch.tensor import phe_keygen

return keygen(self.device, **kwargs)
return phe_keygen(self.device, **kwargs)
43 changes: 29 additions & 14 deletions python/fate/arch/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..unify import device
from ._cipher import CipherKit
from ._federation import Parties, Party
from ._metrics import MetricsWrap, NoopMetricsHandler
from ._metrics import InMemoryMetricsHandler, MetricsWrap
from ._namespace import NS, default_ns

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,25 +48,40 @@ def __init__(
self._computing = computing
self._federation = federation
self._metrics_handler = metrics_handler
self.namespace = namespace
self.cipher = cipher
self._namespace = namespace
self._cipher = cipher

if self.namespace is None:
self.namespace = default_ns
if self.cipher is None:
self.cipher: CipherKit = CipherKit(device)
if self._namespace is None:
self._namespace = default_ns
if self._cipher is None:
self._cipher: CipherKit = CipherKit(device)

self._role_to_parties = None
self._is_destroyed = False

def register_metric_handler(self, metrics_handler):
@property
def device(self):
return self._device

@property
def namespace(self):
return self._namespace

@property
def cipher(self):
return self._cipher

def set_cipher(self, cipher: CipherKit):
self._cipher = cipher

def set_metric_handler(self, metrics_handler):
self._metrics_handler = metrics_handler

@property
def metrics(self):
if self._metrics_handler is None:
self._metrics_handler = NoopMetricsHandler()
return MetricsWrap(self._metrics_handler, self.namespace)
self._metrics_handler = InMemoryMetricsHandler()
return MetricsWrap(self._metrics_handler, self._namespace)

def with_namespace(self, namespace: NS):
return Context(
Expand All @@ -87,7 +102,7 @@ def federation(self) -> "FederationEngine":
return self._get_federation()

def sub_ctx(self, name: str, is_special=False) -> "Context":
return self.with_namespace(self.namespace.sub_ns(name=name, is_special=is_special))
return self.with_namespace(self._namespace.sub_ns(name=name, is_special=is_special))

@property
def on_iterations(self) -> "Context":
Expand Down Expand Up @@ -140,14 +155,14 @@ def ctxs_range(self, *args, **kwargs) -> Iterable[Tuple[int, "Context"]]:
raise ValueError("Too few arguments")

for i in range(start, end):
yield i, self.with_namespace(self.namespace.indexed_ns(index=i))
yield i, self.with_namespace(self._namespace.indexed_ns(index=i))

def ctxs_zip(self, iterable: Iterable[T]) -> Iterable[Tuple["Context", T]]:
"""
zip contexts with iterable with namespaces indexed from 0
"""
for i, it in enumerate(iterable):
yield self.with_namespace(self.namespace.indexed_ns(index=i)), it
yield self.with_namespace(self._namespace.indexed_ns(index=i)), it

def set_federation(self, federation: "FederationEngine"):
self._federation = federation
Expand Down Expand Up @@ -208,7 +223,7 @@ def _get_parties(self, role: Optional[Literal["guest", "host", "arbiter"]] = Non
return Parties(
self._get_federation(),
parties,
self.namespace,
self._namespace,
)

def _get_federation(self):
Expand Down
4 changes: 2 additions & 2 deletions python/fate/arch/context/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _push(
parties: List[PartyMeta],
value,
):
tag = namespace.get_federation_tag()
tag = namespace.federation_tag
_TableRemotePersistentPickler.push(value, federation, name, tag, parties)


Expand All @@ -138,7 +138,7 @@ def _pull(
namespace: NS,
parties: List[PartyMeta],
):
tag = namespace.get_federation_tag()
tag = namespace.federation_tag
raw_values = federation.pull(
name=name,
tag=tag,
Expand Down
Loading

0 comments on commit 05fd550

Please sign in to comment.