Skip to content

Commit

Permalink
refactor standalone: simplify requirements
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 20, 2023
1 parent df87cb0 commit 279f7dc
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 164 deletions.
14 changes: 0 additions & 14 deletions python/fate/arch/computing/backends/standalone/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,6 @@ def __init__(
session_id = generate_computing_uuid()
if data_dir is None:
raise ValueError("data_dir is None")
# data_dir = os.environ.get(
# "STANDALONE_DATA_PATH",
# os.path.abspath(
# os.path.join(
# os.path.dirname(__file__),
# os.path.pardir,
# os.path.pardir,
# os.path.pardir,
# os.path.pardir,
# os.path.pardir,
# "data",
# )
# ),
# )
if options is None:
options = {}
max_workers = options.get("task_cores", None)
Expand Down
240 changes: 90 additions & 150 deletions python/fate/arch/computing/backends/standalone/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@
import cloudpickle as f_pickle
import lmdb

from fate.arch import trace

PartyMeta = Tuple[Literal["guest", "host", "arbiter", "local"], str]

logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)


def _watch_thread_react_to_parent_die(ppid, logger_config):
Expand Down Expand Up @@ -69,14 +66,63 @@ def f():
thread = threading.Thread(target=f, daemon=True)
thread.start()

# initialize tracer
trace.setup_tracing("standalone_computing")

# initialize loggers
if logger_config is not None:
logging.config.dictConfig(logger_config)


class BasicProcessPool:
def __init__(self, pool, log_level):
self._pool = pool
self._exception_tb = {}
self.log_level = log_level

def submit(self, func, process_infos):
features = []
outputs = {}
num_partitions = len(process_infos)

for p, process_info in enumerate(process_infos):
features.append(
self._pool.submit(
BasicProcessPool._process_wrapper,
func,
process_info,
self.log_level,
)
)

from concurrent.futures import wait, FIRST_COMPLETED

not_done = features
while not_done:
done, not_done = wait(not_done, return_when=FIRST_COMPLETED)
for f in done:
partition_id, output, e = f.result()
if e is not None:
logger.error(f"partition {partition_id} exec failed: {e}")
raise RuntimeError(f"Partition {partition_id} exec failed: {e}")
else:
outputs[partition_id] = output

outputs = [outputs[p] for p in range(num_partitions)]
return outputs

@classmethod
def _process_wrapper(cls, do_func, process_info, log_level):
try:
if log_level is not None:
pass
output = do_func(process_info)
return process_info.partition_id, output, None
except Exception as e:
logger.error(f"exception in rank {process_info.partition_id}: {e}")
return process_info.partition_id, None, e

def shutdown(self):
self._pool.shutdown()


# noinspection PyPep8Naming
class Table(object):
def __init__(
Expand Down Expand Up @@ -187,7 +233,6 @@ def collect(self, **kwargs):
else:
_, _, _, it = heappop(entries)

@trace.auto_trace
def reduce(self, func):
return self._session.submit_reduce(
func,
Expand All @@ -197,7 +242,6 @@ def reduce(self, func):
namespace=self._namespace,
)

@trace.auto_trace
def binary_sorted_map_partitions_with_index(
self,
other: "Table",
Expand Down Expand Up @@ -243,7 +287,6 @@ def binary_sorted_map_partitions_with_index(
partitioner_type=partitioner_type,
)

@trace.auto_trace
def map_reduce_partitions_with_index(
self,
map_partition_op: Callable[[int, Iterable], Iterable],
Expand Down Expand Up @@ -442,21 +485,36 @@ def delete(self, k_bytes: bytes, partitioner: Callable[[bytes, int], int]):

# noinspection PyMethodMayBeStatic
class Session(object):
def __init__(self, session_id, data_dir: str, max_workers=None, logger_config=None):
def __init__(
self,
session_id,
data_dir: str,
max_workers=None,
logger_config=None,
executor_pool_cls=BasicProcessPool,
):
self.session_id = session_id
self._data_dir = data_dir
self._max_workers = max_workers
if self._max_workers is None:
self._max_workers = os.cpu_count()
self._pool = Executor(
max_workers=max_workers,
initializer=_watch_thread_react_to_parent_die,
initargs=(
os.getpid(),
logger_config,

self._enable_process_logger = True
if self._enable_process_logger:
log_level = logging.getLevelName(logger.getEffectiveLevel())
else:
log_level = None
self._pool = executor_pool_cls(
pool=Executor(
max_workers=max_workers,
initializer=_watch_thread_react_to_parent_die,
initargs=(
os.getpid(),
logger_config,
),
),
log_level=log_level,
)
self._enable_process_logger = True

@property
def data_dir(self):
Expand Down Expand Up @@ -502,7 +560,7 @@ def parallelize(
self,
data: Iterable,
partition: int,
partitioner: Callable[[bytes], int],
partitioner: Callable[[bytes, int], int],
key_serdes_type,
value_serdes_type,
partitioner_type,
Expand Down Expand Up @@ -543,17 +601,13 @@ def kill(self):
self._pool.shutdown()

def submit_reduce(self, func, data_dir: str, num_partitions: int, name: str, namespace: str):
futures = []
for p in range(num_partitions):
futures.append(
self._pool.submit(
_do_reduce,
_ReduceProcess(
p, _TaskInputInfo(data_dir, namespace, name, num_partitions), _ReduceFunctorInfo(func)
),
)
)
rs = [r.result() for r in futures]
rs = self._pool.submit(
_do_reduce,
[
_ReduceProcess(p, _TaskInputInfo(data_dir, namespace, name, num_partitions), _ReduceFunctorInfo(func))
for p in range(num_partitions)
],
)
rs = [r for r in filter(partial(is_not, None), rs)]
if len(rs) <= 0:
return None
Expand Down Expand Up @@ -583,15 +637,15 @@ def _submit_map_reduce_partitions_with_index(
)
return self._submit_process(
_do_func,
(
[
_MapReduceProcess(
partition_id=p,
input_info=input_info,
output_info=output_info,
operator_info=_MapReduceFunctorInfo(mapper=mapper, reducer=reducer),
)
for p in range(max(input_num_partitions, output_num_partitions))
),
],
)

def _submit_sorted_binary_map_partitions_with_index(
Expand All @@ -618,7 +672,7 @@ def _submit_sorted_binary_map_partitions_with_index(
output_info = _TaskOutputInfo(output_data_dir, output_namespace, output_name, num_partitions, partitioner=None)
return self._submit_process(
do_func,
(
[
_BinarySortedMapProcess(
partition_id=p,
first_input_info=first_input_info,
Expand All @@ -627,125 +681,11 @@ def _submit_sorted_binary_map_partitions_with_index(
operator_info=_BinarySortedMapFunctorInfo(func),
)
for p in range(num_partitions)
),
],
)

def _submit_process(self, do_func, process_infos):
if self._enable_process_logger:
log_level = logging.getLevelName(logger.getEffectiveLevel())
else:
log_level = None
rich_process_pool = RichProcessPool(
self._pool,
process_infos,
log_level,
)
return rich_process_pool.submit(do_func)


class RichProcessPool:
def __init__(self, pool, process_infos, log_level):
import rich.console

self._pool = pool
self._exception_tb = {}
self.process_infos = list(process_infos)
self.log_level = log_level
self.console = rich.console.Console()
self.width = self.console.width

def submit(self, func):
features = []
outputs = {}

with tracer.start_as_current_span("submit_process"):
carrier = trace.inject_carrier()
for p in range(len(self.process_infos)):
features.append(
self._pool.submit(
RichProcessPool._process_wrapper,
carrier,
func,
self.process_infos[p],
self.log_level,
self.width,
)
)

from concurrent.futures import wait, FIRST_COMPLETED

not_done = features
while not_done:
done, not_done = wait(not_done, return_when=FIRST_COMPLETED)
for f in done:
partition_id, output, e, exc_traceback = f.result()
if e is not None:
import rich.panel

self._exception_tb[partition_id] = exc_traceback
self.console.print(
rich.panel.Panel(
exc_traceback,
title=f"partition {partition_id} exception",
expand=False,
border_style="red",
)
)
raise RuntimeError(f"Partition {partition_id} exec failed: {e}")
else:
outputs[partition_id] = output

outputs = [outputs[p] for p in range(len(self.process_infos))]
return outputs

@classmethod
def _process_wrapper(cls, carrier, do_func, process_info, log_level, width):
trace_context = trace.extract_carrier(carrier)
with tracer.start_as_current_span(f"partition:{process_info.partition_id}", context=trace_context):
if log_level is not None:
RichProcessPool._set_up_process_logger(process_info.partition_id, log_level)
try:
output = do_func(process_info)
return process_info.partition_id, output, None, None
except Exception as e:
import rich.traceback

logger.error(f"exception in rank {process_info.partition_id}: {e}", stack_info=False)
exc_traceback = rich.traceback.Traceback.from_exception(
type(e), e, traceback=e.__traceback__, width=width, show_locals=True
)
return process_info.partition_id, None, e, exc_traceback

def show_exceptions(self):
import rich.panel

console = rich.console.Console()
for rank, tb in self._exception_tb.items():
console.print(rich.panel.Panel(tb, title=f"rank {rank} exception", expand=False, border_style="red"))

@classmethod
def _set_up_process_logger(cls, rank, log_level="DEBUG"):
message_header = f"[[bold green blink]Rank:{rank}[/]]"

logging.config.dictConfig(
dict(
version=1,
formatters={"with_rank": {"format": f"{message_header} %(message)s", "datefmt": "[%X]"}},
handlers={
"base": {
"class": "rich.logging.RichHandler",
"level": log_level,
"filters": [],
"formatter": "with_rank",
"tracebacks_show_locals": True,
"markup": True,
}
},
loggers={},
root=dict(handlers=["base"], level=log_level),
disable_existing_loggers=False,
)
)
return self._pool.submit(do_func, process_infos)


class Federation(object):
Expand Down Expand Up @@ -1101,7 +1041,7 @@ def _open_env(path, write=False):
return env
except lmdb.Error as e:
if "No such file or directory" in e.args[0]:
time.sleep(0.01)
time.sleep(0.001)
t += 1
else:
raise e
Expand Down Expand Up @@ -1231,7 +1171,7 @@ def __init__(self, data_dir: str, session_id, party: Tuple[str, str]) -> None:
def wait_status_set(self, key: bytes) -> bytes:
value = self.get_status(key)
while value is None:
time.sleep(0.1)
time.sleep(0.001)
value = self.get_status(key)
return key

Expand Down

0 comments on commit 279f7dc

Please sign in to comment.