Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor standalone: simplify requirements #5356

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions python/fate/arch/computing/backends/standalone/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,6 @@ def __init__(
session_id = generate_computing_uuid()
if data_dir is None:
raise ValueError("data_dir is None")
# data_dir = os.environ.get(
# "STANDALONE_DATA_PATH",
# os.path.abspath(
# os.path.join(
# os.path.dirname(__file__),
# os.path.pardir,
# os.path.pardir,
# os.path.pardir,
# os.path.pardir,
# os.path.pardir,
# "data",
# )
# ),
# )
if options is None:
options = {}
max_workers = options.get("task_cores", None)
Expand Down
240 changes: 90 additions & 150 deletions python/fate/arch/computing/backends/standalone/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@
import cloudpickle as f_pickle
import lmdb

from fate.arch import trace

PartyMeta = Tuple[Literal["guest", "host", "arbiter", "local"], str]

logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)


def _watch_thread_react_to_parent_die(ppid, logger_config):
Expand Down Expand Up @@ -69,14 +66,63 @@ def f():
thread = threading.Thread(target=f, daemon=True)
thread.start()

# initialize tracer
trace.setup_tracing("standalone_computing")

# initialize loggers
if logger_config is not None:
logging.config.dictConfig(logger_config)


class BasicProcessPool:
def __init__(self, pool, log_level):
self._pool = pool
self._exception_tb = {}
self.log_level = log_level

def submit(self, func, process_infos):
features = []
outputs = {}
num_partitions = len(process_infos)

for p, process_info in enumerate(process_infos):
features.append(
self._pool.submit(
BasicProcessPool._process_wrapper,
func,
process_info,
self.log_level,
)
)

from concurrent.futures import wait, FIRST_COMPLETED

not_done = features
while not_done:
done, not_done = wait(not_done, return_when=FIRST_COMPLETED)
for f in done:
partition_id, output, e = f.result()
if e is not None:
logger.error(f"partition {partition_id} exec failed: {e}")
raise RuntimeError(f"Partition {partition_id} exec failed: {e}")
else:
outputs[partition_id] = output

outputs = [outputs[p] for p in range(num_partitions)]
return outputs

@classmethod
def _process_wrapper(cls, do_func, process_info, log_level):
try:
if log_level is not None:
pass
output = do_func(process_info)
return process_info.partition_id, output, None
except Exception as e:
logger.error(f"exception in rank {process_info.partition_id}: {e}")
return process_info.partition_id, None, e

def shutdown(self):
self._pool.shutdown()


# noinspection PyPep8Naming
class Table(object):
def __init__(
Expand Down Expand Up @@ -187,7 +233,6 @@ def collect(self, **kwargs):
else:
_, _, _, it = heappop(entries)

@trace.auto_trace
def reduce(self, func):
return self._session.submit_reduce(
func,
Expand All @@ -197,7 +242,6 @@ def reduce(self, func):
namespace=self._namespace,
)

@trace.auto_trace
def binary_sorted_map_partitions_with_index(
self,
other: "Table",
Expand Down Expand Up @@ -243,7 +287,6 @@ def binary_sorted_map_partitions_with_index(
partitioner_type=partitioner_type,
)

@trace.auto_trace
def map_reduce_partitions_with_index(
self,
map_partition_op: Callable[[int, Iterable], Iterable],
Expand Down Expand Up @@ -442,21 +485,36 @@ def delete(self, k_bytes: bytes, partitioner: Callable[[bytes, int], int]):

# noinspection PyMethodMayBeStatic
class Session(object):
def __init__(self, session_id, data_dir: str, max_workers=None, logger_config=None):
def __init__(
self,
session_id,
data_dir: str,
max_workers=None,
logger_config=None,
executor_pool_cls=BasicProcessPool,
):
self.session_id = session_id
self._data_dir = data_dir
self._max_workers = max_workers
if self._max_workers is None:
self._max_workers = os.cpu_count()
self._pool = Executor(
max_workers=max_workers,
initializer=_watch_thread_react_to_parent_die,
initargs=(
os.getpid(),
logger_config,

self._enable_process_logger = True
if self._enable_process_logger:
log_level = logging.getLevelName(logger.getEffectiveLevel())
else:
log_level = None
self._pool = executor_pool_cls(
pool=Executor(
max_workers=max_workers,
initializer=_watch_thread_react_to_parent_die,
initargs=(
os.getpid(),
logger_config,
),
),
log_level=log_level,
)
self._enable_process_logger = True

@property
def data_dir(self):
Expand Down Expand Up @@ -502,7 +560,7 @@ def parallelize(
self,
data: Iterable,
partition: int,
partitioner: Callable[[bytes], int],
partitioner: Callable[[bytes, int], int],
key_serdes_type,
value_serdes_type,
partitioner_type,
Expand Down Expand Up @@ -543,17 +601,13 @@ def kill(self):
self._pool.shutdown()

def submit_reduce(self, func, data_dir: str, num_partitions: int, name: str, namespace: str):
futures = []
for p in range(num_partitions):
futures.append(
self._pool.submit(
_do_reduce,
_ReduceProcess(
p, _TaskInputInfo(data_dir, namespace, name, num_partitions), _ReduceFunctorInfo(func)
),
)
)
rs = [r.result() for r in futures]
rs = self._pool.submit(
_do_reduce,
[
_ReduceProcess(p, _TaskInputInfo(data_dir, namespace, name, num_partitions), _ReduceFunctorInfo(func))
for p in range(num_partitions)
],
)
rs = [r for r in filter(partial(is_not, None), rs)]
if len(rs) <= 0:
return None
Expand Down Expand Up @@ -583,15 +637,15 @@ def _submit_map_reduce_partitions_with_index(
)
return self._submit_process(
_do_func,
(
[
_MapReduceProcess(
partition_id=p,
input_info=input_info,
output_info=output_info,
operator_info=_MapReduceFunctorInfo(mapper=mapper, reducer=reducer),
)
for p in range(max(input_num_partitions, output_num_partitions))
),
],
)

def _submit_sorted_binary_map_partitions_with_index(
Expand All @@ -618,7 +672,7 @@ def _submit_sorted_binary_map_partitions_with_index(
output_info = _TaskOutputInfo(output_data_dir, output_namespace, output_name, num_partitions, partitioner=None)
return self._submit_process(
do_func,
(
[
_BinarySortedMapProcess(
partition_id=p,
first_input_info=first_input_info,
Expand All @@ -627,125 +681,11 @@ def _submit_sorted_binary_map_partitions_with_index(
operator_info=_BinarySortedMapFunctorInfo(func),
)
for p in range(num_partitions)
),
],
)

def _submit_process(self, do_func, process_infos):
if self._enable_process_logger:
log_level = logging.getLevelName(logger.getEffectiveLevel())
else:
log_level = None
rich_process_pool = RichProcessPool(
self._pool,
process_infos,
log_level,
)
return rich_process_pool.submit(do_func)


class RichProcessPool:
def __init__(self, pool, process_infos, log_level):
import rich.console

self._pool = pool
self._exception_tb = {}
self.process_infos = list(process_infos)
self.log_level = log_level
self.console = rich.console.Console()
self.width = self.console.width

def submit(self, func):
features = []
outputs = {}

with tracer.start_as_current_span("submit_process"):
carrier = trace.inject_carrier()
for p in range(len(self.process_infos)):
features.append(
self._pool.submit(
RichProcessPool._process_wrapper,
carrier,
func,
self.process_infos[p],
self.log_level,
self.width,
)
)

from concurrent.futures import wait, FIRST_COMPLETED

not_done = features
while not_done:
done, not_done = wait(not_done, return_when=FIRST_COMPLETED)
for f in done:
partition_id, output, e, exc_traceback = f.result()
if e is not None:
import rich.panel

self._exception_tb[partition_id] = exc_traceback
self.console.print(
rich.panel.Panel(
exc_traceback,
title=f"partition {partition_id} exception",
expand=False,
border_style="red",
)
)
raise RuntimeError(f"Partition {partition_id} exec failed: {e}")
else:
outputs[partition_id] = output

outputs = [outputs[p] for p in range(len(self.process_infos))]
return outputs

@classmethod
def _process_wrapper(cls, carrier, do_func, process_info, log_level, width):
trace_context = trace.extract_carrier(carrier)
with tracer.start_as_current_span(f"partition:{process_info.partition_id}", context=trace_context):
if log_level is not None:
RichProcessPool._set_up_process_logger(process_info.partition_id, log_level)
try:
output = do_func(process_info)
return process_info.partition_id, output, None, None
except Exception as e:
import rich.traceback

logger.error(f"exception in rank {process_info.partition_id}: {e}", stack_info=False)
exc_traceback = rich.traceback.Traceback.from_exception(
type(e), e, traceback=e.__traceback__, width=width, show_locals=True
)
return process_info.partition_id, None, e, exc_traceback

def show_exceptions(self):
import rich.panel

console = rich.console.Console()
for rank, tb in self._exception_tb.items():
console.print(rich.panel.Panel(tb, title=f"rank {rank} exception", expand=False, border_style="red"))

@classmethod
def _set_up_process_logger(cls, rank, log_level="DEBUG"):
message_header = f"[[bold green blink]Rank:{rank}[/]]"

logging.config.dictConfig(
dict(
version=1,
formatters={"with_rank": {"format": f"{message_header} %(message)s", "datefmt": "[%X]"}},
handlers={
"base": {
"class": "rich.logging.RichHandler",
"level": log_level,
"filters": [],
"formatter": "with_rank",
"tracebacks_show_locals": True,
"markup": True,
}
},
loggers={},
root=dict(handlers=["base"], level=log_level),
disable_existing_loggers=False,
)
)
return self._pool.submit(do_func, process_infos)


class Federation(object):
Expand Down Expand Up @@ -1101,7 +1041,7 @@ def _open_env(path, write=False):
return env
except lmdb.Error as e:
if "No such file or directory" in e.args[0]:
time.sleep(0.01)
time.sleep(0.001)
t += 1
else:
raise e
Expand Down Expand Up @@ -1231,7 +1171,7 @@ def __init__(self, data_dir: str, session_id, party: Tuple[str, str]) -> None:
def wait_status_set(self, key: bytes) -> bytes:
value = self.get_status(key)
while value is None:
time.sleep(0.1)
time.sleep(0.001)
value = self.get_status(key)
return key

Expand Down