diff --git a/python/fate/components/core/__init__.py b/python/fate/components/core/__init__.py index d94395e763..5acf26ddd4 100644 --- a/python/fate/components/core/__init__.py +++ b/python/fate/components/core/__init__.py @@ -6,6 +6,7 @@ from ._load_metric_handler import load_metric_handler from .component_desc import Component, ComponentExecutionIO from .essential import ARBITER, GUEST, HOST, LOCAL, Label, Role, Stage +from ._cpn_task_mode import is_root_worker, is_deepspeed_mode, TaskMode __all__ = [ "Component", @@ -24,4 +25,7 @@ "HOST", "LOCAL", "Label", + "is_root_worker", + "is_deepspeed_mode", + "TaskMode", ] diff --git a/python/fate/components/core/_cpn_task_mode.py b/python/fate/components/core/_cpn_task_mode.py new file mode 100644 index 0000000000..b9856c50a0 --- /dev/null +++ b/python/fate/components/core/_cpn_task_mode.py @@ -0,0 +1,17 @@ +import enum +import os + + +class TaskMode(enum.StrEnum): + SIMPLE = "SIMPLE" + DEEPSPEED = "DEEPSPEED" + + +def is_deepspeed_mode(): + return os.getenv("FATE_TASK_TYPE", "").upper() == TaskMode.DEEPSPEED + + +def is_root_worker(): + if is_deepspeed_mode(): + return os.getenv("RANK", "0") == "0" + return True diff --git a/python/fate/components/entrypoint/cli/component/cleanup_cli.py b/python/fate/components/entrypoint/cli/component/cleanup_cli.py index 0a15bd890c..6c4e5792b3 100644 --- a/python/fate/components/entrypoint/cli/component/cleanup_cli.py +++ b/python/fate/components/entrypoint/cli/component/cleanup_cli.py @@ -16,6 +16,7 @@ def cleanup(process_tag, config, env_name): load_config_from_env, load_config_from_file, ) + from fate.components.core import is_root_worker configs = {} configs = load_config_from_env(configs, env_name) @@ -23,15 +24,16 @@ def cleanup(process_tag, config, env_name): config = TaskCleanupConfigSpec.parse_obj(configs) try: - print("start cleanup") - computing = load_computing(config.computing) - federation = load_federation(config.federation, computing) - ctx = Context( - computing=computing, - federation=federation, - ) - ctx.destroy() - print("cleanup done") + if is_root_worker(): + print("start cleanup") + computing = load_computing(config.computing) + federation = load_federation(config.federation, computing) + ctx = Context( + computing=computing, + federation=federation, + ) + ctx.destroy() + print("cleanup done") except Exception as e: traceback.print_exc() raise e diff --git a/python/fate/components/entrypoint/cli/component/execute_cli.py b/python/fate/components/entrypoint/cli/component/execute_cli.py index fa7e9a37b2..bbeb619977 100644 --- a/python/fate/components/entrypoint/cli/component/execute_cli.py +++ b/python/fate/components/entrypoint/cli/component/execute_cli.py @@ -98,6 +98,7 @@ def execute_component_from_config(config: "TaskConfigSpec", output_path): load_device, load_federation, load_metric_handler, + is_root_worker, ) logger = logging.getLogger(__name__) @@ -106,7 +107,11 @@ def execute_component_from_config(config: "TaskConfigSpec", output_path): party_task_id = config.party_task_id device = load_device(config.conf.device) computing = load_computing(config.conf.computing, config.conf.logger.config) - federation = load_federation(config.conf.federation, computing) + if is_root_worker(): + federation = load_federation(config.conf.federation, computing) + else: + federation = None + logger.info("skip federation initialization for non-root worker") cipher = CipherKit(device=device) ctx = Context( device=device,