Skip to content

Commit

Permalink
adjust pipeline
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Jan 5, 2023
1 parent 92290b7 commit e275ce9
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 28 deletions.
2 changes: 1 addition & 1 deletion python/fate/arch/context/io/data/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def read_dataframe(self):
data = json.load(fin)

table = self.ctx.computing.parallelize(data, include_key=True, partition=1)
data.schema = schema
table.schema = schema
df = dataframe.deserialize(self.ctx, table)

return Dataframe(df, df.shape[1], df.shape[0])
Expand Down
3 changes: 2 additions & 1 deletion python/fate_client/pipeline/entity/task_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class RuntimeConfSpec(BaseModel):


class TaskScheduleSpec(BaseModel):
taskid: str
task_id: str
party_task_id: str
component: str
role: str
stage: str
Expand Down
18 changes: 9 additions & 9 deletions python/fate_client/pipeline/manager/status_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ def __init__(self, status_uri: str):
def create_status_manager(cls, status_uri):
return SQLiteStatusManager(status_uri)

def monitor_finish_status(self, task_ids: list):
for task_id in task_ids:
task_run = self._meta_manager.get_or_create_task(task_id)
def monitor_finish_status(self, party_task_ids: list):
for party_task_id in party_task_ids:
task_run = self._meta_manager.get_or_create_task(party_task_id)
state = task_run.properties["state"].string_value
if state in ["INIT", "running"]:
return False

return True

def record_task_status(self, task_id, status):
self._meta_manager.update_task_state(task_id, status)
def record_task_status(self, party_task_id, status):
self._meta_manager.update_task_state(party_task_id, status)

def record_terminate_status(self, task_ids):
for task_id in task_ids:
def record_terminate_status(self, party_task_ids):
for party_task_id in party_task_ids:
# task_run = self._meta_manager.get_or_create_task(execution_id)
self._meta_manager.set_task_safe_terminate_flag(task_id)
self._meta_manager.set_task_safe_terminate_flag(party_task_id)

def get_task_results(self, tasks_info):
"""
Expand All @@ -41,7 +41,7 @@ def get_task_results(self, tasks_info):
if role not in summary_msg:
summary_msg[role] = dict()

task_run = self._meta_manager.get_or_create_task(task_info.task_id)
task_run = self._meta_manager.get_or_create_task(task_info.party_task_id)
status = task_run.properties["state"].string_value

summary_msg[role][party_id] = status
Expand Down
12 changes: 7 additions & 5 deletions python/fate_client/pipeline/scheduler/runtime_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,11 @@ def construct_task_schedule_spec(self):
for party in self._runtime_parties:
conf = self._construct_runtime_conf(party.role, party.party_id)
party_task_spec = TaskScheduleSpec(
taskid=gen_task_id(self._job_id, self._task_name, party.role, party.party_id),
task_id=self._federation_id,
party_task_id=gen_task_id(self._job_id, self._task_name, party.role, party.party_id),
component=self._component_ref,
role=party.role,
party_id=party.party_id,
stage=self._stage,
conf=conf
)
Expand Down Expand Up @@ -212,8 +214,8 @@ def mlmd(self, role, party_id):
def task_conf_path(self, role, party_id):
return self._task_conf_path[role][party_id]

def task_id(self, role, party_id):
return self._task_schedule_spec[role][party_id].taskid
def party_task_id(self, role, party_id):
return self._task_schedule_spec[role][party_id].party_task_id

@property
def status_manager(self):
Expand All @@ -224,8 +226,8 @@ def log_path(self, role, party_id):

def retrieval_task_outputs(self):
for party in self._runtime_parties:
task_id = self._task_schedule_spec[party.role][party.party_id].taskid
outputs = self._resource_manager.status_manager.get_task_outputs(task_id)
party_task_id = self._task_schedule_spec[party.role][party.party_id].party_task_id
outputs = self._resource_manager.status_manager.get_task_outputs(party_task_id)

for output_key in self.OUTPUT_KEYS:
output_list = outputs.get(output_key)
Expand Down
24 changes: 12 additions & 12 deletions python/fate_client/pipeline/utils/job_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ def run_subprocess(exec_cmd, std_log_fd):
return process


def run_task_in_party(exec_cmd, std_log_fd, status_manager, task_id):
def run_task_in_party(exec_cmd, std_log_fd, status_manager, party_task_id):
process = run_subprocess(exec_cmd, std_log_fd)
process.communicate()
process.terminate()
if process.returncode != 0:
"""
subprocess fail, record the fail status to MLMD
"""
status_manager.record_task_status(task_id, "exception")
status_manager.record_task_status(party_task_id, "exception")

try:
os.kill(process.pid, 0)
except ProcessLookupError:
pass


def run_detect_task(status_manager, task_ids):
def run_detect_task(status_manager, party_task_ids):
while True:
is_finish = status_manager.monitor_finish_status(task_ids)
is_finish = status_manager.monitor_finish_status(party_task_ids)
if is_finish:
status_manager.record_terminate_status(task_ids)
status_manager.record_terminate_status(party_task_ids)
break

time.sleep(0.5)
Expand All @@ -51,16 +51,16 @@ def process_task(task_type: str, task_name: str, exec_cmd_prefix: list, runtime_
# task_done_tag_paths = list()
mp_ctx = multiprocessing.get_context("fork")
std_log_fds = []
task_ids = []
party_task_ids = []
task_infos = []
for party in parties:
role = party.role
party_id = party.party_id

conf_path = runtime_constructor.task_conf_path(role, party_id)
task_id = runtime_constructor.task_id(role, party_id)
task_ids.append(task_id)
task_infos.append(SimpleNamespace(task_id=task_id, role=role, party_id=party_id))
party_task_id = runtime_constructor.party_task_id(role, party_id)
party_task_ids.append(party_task_id)
task_infos.append(SimpleNamespace(party_task_id=party_task_id, role=role, party_id=party_id))

log_path = runtime_constructor.log_path(role, party_id)
std_log_path = Path(log_path).joinpath("std.log").resolve()
Expand All @@ -72,7 +72,7 @@ def process_task(task_type: str, task_name: str, exec_cmd_prefix: list, runtime_
exec_cmd.extend(
[
"--process-tag",
task_id,
party_task_id,
"--config",
conf_path
]
Expand All @@ -81,14 +81,14 @@ def process_task(task_type: str, task_name: str, exec_cmd_prefix: list, runtime_
exec_cmd=exec_cmd,
std_log_fd=std_log_fd,
status_manager=status_manager,
task_id=task_id
party_task_id=party_task_id
)))

task_pools[-1].start()

detect_task = mp_ctx.Process(target=run_detect_task,
kwargs=dict(status_manager=status_manager,
task_ids=task_ids))
party_task_ids=party_task_ids))

detect_task.start()

Expand Down

0 comments on commit e275ce9

Please sign in to comment.