From 80959d2b281eea99a1c974c941b33d48c972427b Mon Sep 17 00:00:00 2001 From: sagewe Date: Wed, 27 Dec 2023 15:23:20 +0800 Subject: [PATCH] feat: add context create helper for ml use Signed-off-by: sagewe --- python/fate/arch/computing/__init__.py | 2 +- python/fate/arch/computing/_builder.py | 2 +- python/fate/arch/context/__init__.py | 3 +- python/fate/arch/context/_context_helper.py | 35 +++++++++++++++++++++ python/fate/arch/federation/__init__.py | 4 +-- 5 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 python/fate/arch/context/_context_helper.py diff --git a/python/fate/arch/computing/__init__.py b/python/fate/arch/computing/__init__.py index 4424dae7e2..034fcc25d5 100644 --- a/python/fate/arch/computing/__init__.py +++ b/python/fate/arch/computing/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._builder import ComputingBuilder +from ._builder import ComputingBuilder, ComputingEngine diff --git a/python/fate/arch/computing/_builder.py b/python/fate/arch/computing/_builder.py index b6ad567141..7f7a18d817 100644 --- a/python/fate/arch/computing/_builder.py +++ b/python/fate/arch/computing/_builder.py @@ -28,7 +28,7 @@ def __init__( def build(self, t: ComputingEngine, conf: dict): if t == ComputingEngine.STANDALONE: data_dir = cfg.get_option(conf, "computing.standalone.data_dir") - options = cfg.get_option(conf, "computing.standalone.options") + options = cfg.get_option(conf, "computing.standalone.options", None) return self.build_standalone(data_dir=data_dir, options=options) elif t == ComputingEngine.EGGROLL: host = cfg.get_option(conf, "computing.eggroll.host") diff --git a/python/fate/arch/context/__init__.py b/python/fate/arch/context/__init__.py index 27a9e0f346..5415a04abd 100644 --- a/python/fate/arch/context/__init__.py +++ b/python/fate/arch/context/__init__.py @@ -16,5 +16,6 @@ from ._context import Context from ._namespace import NS from ._parties import Parties +from ._context_helper import create_context -__all__ = ["Context", "CipherKit", "PHECipher", "PHECipherPublic", "NS", "Parties"] +__all__ = ["Context", "CipherKit", "PHECipher", "PHECipherPublic", "NS", "Parties", "create_context"] diff --git a/python/fate/arch/context/_context_helper.py b/python/fate/arch/context/_context_helper.py new file mode 100644 index 0000000000..6fde2abbeb --- /dev/null +++ b/python/fate/arch/context/_context_helper.py @@ -0,0 +1,35 @@ +from typing import Tuple, List + +from fate.arch.computing import ComputingBuilder, ComputingEngine +from fate.arch.context import Context +from fate.arch.federation import FederationBuilder, FederationType + + +def create_context( + local_party: Tuple[str, str], + parties: List[Tuple[str, str]], + federation_session_id, + federation_engine=FederationType.STANDALONE, + federation_conf: dict = None, + computing_session_id=None, + computing_engine=ComputingEngine.STANDALONE, + computing_conf=None, +): + if federation_conf is None: + federation_conf = {} + if computing_conf is None: + computing_conf = {} + if ComputingEngine.STANDALONE == computing_engine: + if "computing.standalone.data_dir" not in computing_conf: + computing_conf["computing.standalone.data_dir"] = "/tmp" + if computing_session_id is None: + computing_session_id = f"{federation_session_id}_{local_party[0]}_{local_party[1]}" + computing_session = ComputingBuilder(computing_session_id=computing_session_id).build( + computing_engine, computing_conf + ) + federation_session = FederationBuilder( + federation_id=federation_session_id, + party=local_party, + parties=parties, + ).build(computing_session, federation_engine, federation_conf) + return Context(computing=computing_session, federation=federation_session) diff --git a/python/fate/arch/federation/__init__.py b/python/fate/arch/federation/__init__.py index dbfeb41f63..11457c8740 100644 --- a/python/fate/arch/federation/__init__.py +++ b/python/fate/arch/federation/__init__.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._builder import FederationBuilder, FederationMode +from ._builder import FederationBuilder, FederationMode, FederationType from .api import Federation, FederationDataType, TableMeta -__all__ = ["Federation", "FederationDataType", "FederationBuilder", "FederationMode"] +__all__ = ["Federation", "FederationDataType", "FederationBuilder", "FederationMode", "FederationType", "TableMeta"]