Skip to content

Commit

Permalink
add conditional federation init for deepspeed mode
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 8, 2023
1 parent 1e47e79 commit 5db3b3a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
4 changes: 4 additions & 0 deletions python/fate/components/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ._load_metric_handler import load_metric_handler
from .component_desc import Component, ComponentExecutionIO
from .essential import ARBITER, GUEST, HOST, LOCAL, Label, Role, Stage
from ._cpn_task_mode import is_root_worker, is_deepspeed_mode, TaskMode

__all__ = [
"Component",
Expand All @@ -24,4 +25,7 @@
"HOST",
"LOCAL",
"Label",
"is_root_worker",
"is_deepspeed_mode",
"TaskMode",
]
17 changes: 17 additions & 0 deletions python/fate/components/core/_cpn_task_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import enum
import os


class TaskMode(enum.StrEnum):
SIMPLE = "SIMPLE"
DEEPSPEED = "DEEPSPEED"


def is_deepspeed_mode():
return os.getenv("FATE_TASK_TYPE", "").upper() == TaskMode.DEEPSPEED


def is_root_worker():
if is_deepspeed_mode():
return os.getenv("RANK", "0") == "0"
return True
20 changes: 11 additions & 9 deletions python/fate/components/entrypoint/cli/component/cleanup_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@ def cleanup(process_tag, config, env_name):
load_config_from_env,
load_config_from_file,
)
from fate.components.core import is_root_worker

configs = {}
configs = load_config_from_env(configs, env_name)
load_config_from_file(configs, config)
config = TaskCleanupConfigSpec.parse_obj(configs)

try:
print("start cleanup")
computing = load_computing(config.computing)
federation = load_federation(config.federation, computing)
ctx = Context(
computing=computing,
federation=federation,
)
ctx.destroy()
print("cleanup done")
if is_root_worker():
print("start cleanup")
computing = load_computing(config.computing)
federation = load_federation(config.federation, computing)
ctx = Context(
computing=computing,
federation=federation,
)
ctx.destroy()
print("cleanup done")
except Exception as e:
traceback.print_exc()
raise e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def execute_component_from_config(config: "TaskConfigSpec", output_path):
load_device,
load_federation,
load_metric_handler,
is_root_worker,
)

logger = logging.getLogger(__name__)
Expand All @@ -106,7 +107,11 @@ def execute_component_from_config(config: "TaskConfigSpec", output_path):
party_task_id = config.party_task_id
device = load_device(config.conf.device)
computing = load_computing(config.conf.computing, config.conf.logger.config)
federation = load_federation(config.conf.federation, computing)
if is_root_worker():
federation = load_federation(config.conf.federation, computing)
else:
federation = None
logger.info("skip federation initialization for non-root worker")
cipher = CipherKit(device=device)
ctx = Context(
device=device,
Expand Down

0 comments on commit 5db3b3a

Please sign in to comment.