Skip to content

Commit

Permalink
refactor underlying data for computing and federation backend
Browse files Browse the repository at this point in the history
- Allow complete control over data serialization and partitioning from the higher layers.
- Complete basic validation in standalone mode.

Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Oct 12, 2023
1 parent 37a8a63 commit 20b5506
Show file tree
Hide file tree
Showing 9 changed files with 1,293 additions and 525 deletions.
960 changes: 540 additions & 420 deletions python/fate/arch/_standalone.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion python/fate/arch/computing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

def is_table(v):
from fate.arch.abc import CTableABC
from fate.arch.computing.table import KVTable

return isinstance(v, CTableABC)
return isinstance(v, CTableABC) or isinstance(v, KVTable)


__all__ = ["is_table", "ComputingEngine", "profile_start", "profile_ends"]
35 changes: 28 additions & 7 deletions python/fate/arch/computing/standalone/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# limitations under the License.

import logging
from collections.abc import Iterable
from typing import Optional

from fate.arch.abc import CSessionABC
from ..table import KVTableContext

from ..._standalone import Session
from ...unify import URI, generate_computing_uuid, uuid
Expand All @@ -26,7 +25,7 @@
LOGGER = logging.getLogger(__name__)


class CSession(CSessionABC):
class CSession(KVTableContext):
def __init__(
self, session_id: Optional[str] = None, logger_config: Optional[dict] = None, options: Optional[dict] = None
):
Expand All @@ -44,7 +43,12 @@ def get_standalone_session(self):
def session_id(self):
return self._session.session_id

def load(self, uri: URI, schema: dict, options: dict = None):
def _load(
self,
uri: URI,
schema: dict,
options: dict,
):
if uri.scheme != "standalone":
raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend")
try:
Expand All @@ -57,15 +61,32 @@ def load(self, uri: URI, schema: dict, options: dict = None):
raw_table = raw_table.save_as(
name=f"{name}_{uuid()}",
namespace=namespace,
partition=partitions,
partitions=partitions,
need_cleanup=True,
)
table = Table(raw_table)
table.schema = schema
return table

def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs):
table = self._session.parallelize(data=data, partition=partition, include_key=include_key, **kwargs)
def _parallelize(
self,
data,
total_partitions,
key_serdes,
key_serdes_type,
value_serdes,
value_serdes_type,
partitioner,
partitioner_type,
):
table = self._session.parallelize(
data=data,
partition=total_partitions,
partitioner=partitioner,
key_serdes_type=key_serdes_type,
value_serdes_type=value_serdes_type,
partitioner_type=partitioner_type,
)
return Table(table)

def cleanup(self, name, namespace):
Expand Down
190 changes: 97 additions & 93 deletions python/fate/arch/computing/standalone/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import logging
import typing

from fate.arch.abc import CTableABC
from typing import Callable, Iterable, Any

from ...unify import URI
from .._profile import computing_profile
from .._type import ComputingEngine
from ..table import KVTable, V
from ..._standalone import Table as StandaloneTable

LOGGER = logging.getLogger(__name__)


class Table(CTableABC):
def __init__(self, table):
class Table(KVTable):
def __init__(self, table: StandaloneTable):
self._table = table
self._engine = ComputingEngine.STANDALONE

self._count = None
super().__init__(
key_serdes_type=table.key_serdes_type,
value_serdes_type=table.value_serdes_type,
partitioner_type=table.partitioner_type,
)

@property
def engine(self):
Expand All @@ -40,13 +43,93 @@ def engine(self):
def __getstate__(self):
pass

def __reduce__(self):
raise NotImplementedError("Table is not picklable, please don't do this or it may cause unexpected error")

def _map_reduce_partitions_with_index(
self,
map_partition_op: Callable[[int, Iterable], Iterable],
reduce_partition_op: Callable[[Any, Any], Any],
shuffle,
output_key_serdes,
output_key_serdes_type,
output_value_serdes,
output_value_serdes_type,
output_partitioner,
output_partitioner_type,
):
return Table(
table=self._table.map_reduce_partitions_with_index(
map_partition_op=map_partition_op,
reduce_partition_op=reduce_partition_op,
output_partitioner=output_partitioner,
shuffle=shuffle,
output_key_serdes_type=output_key_serdes_type,
output_value_serdes_type=output_value_serdes_type,
output_partitioner_type=output_partitioner_type,
),
)

def _collect(self, **kwargs):
return self._table.collect(**kwargs)

def _take(self, n=1, **kwargs):
return self._table.take(n=n, **kwargs)

def _count(self):
return self._table.count()

def _join(
self,
other: "Table",
merge_op: Callable[[V, V], V],
key_serdes,
key_serdes_type,
value_serdes,
value_serdes_type,
partitioner,
partitioner_type,
):
return Table(
table=self._table.join(other._table, merge_op=merge_op),
)

def _union(
self,
other: "Table",
merge_op: Callable[[V, V], V],
key_serdes,
key_serdes_type,
value_serdes,
value_serdes_type,
partitioner,
partitioner_type,
):
return Table(
table=self._table.union(other._table, merge_op=merge_op),
)

def _subtract_by_key(
self,
other: "Table",
key_serdes,
key_serdes_type,
value_serdes,
value_serdes_type,
partitioner,
partitioner_type,
):
return Table(
table=self._table.subtract_by_key(other._table),
)

def _reduce(self, func, **kwargs):
return self._table.reduce(func)

@property
def partitions(self):
return self._table.partitions

def copy(self):
return Table(self._table.mapValues(lambda x: x))

@computing_profile
def save(self, uri: URI, schema, options: dict = None):
if options is None:
Expand All @@ -61,75 +144,12 @@ def save(self, uri: URI, schema, options: dict = None):
self._table.save_as(
name=name,
namespace=namespace,
partition=options.get("partitions", self.partitions),
partitions=options.get("partitions", self.partitions),
need_cleanup=False,
)
# TODO: self.schema is a bit confusing here, it set by property assignment directly, not by constructor
schema.update(self.schema)

@computing_profile
def count(self) -> int:
if self._count is None:
self._count = self._table.count()
return self._count

@computing_profile
def collect(self, **kwargs):
return self._table.collect(**kwargs)

@computing_profile
def take(self, n=1, **kwargs):
return self._table.take(n=n, **kwargs)

@computing_profile
def first(self, **kwargs):
resp = list(itertools.islice(self._table.collect(**kwargs), 1))
if len(resp) < 1:
raise RuntimeError("table is empty")
return resp[0]

@computing_profile
def reduce(self, func, **kwargs):
return self._table.reduce(func)

@computing_profile
def map(self, func):
return Table(self._table.map(func))

@computing_profile
def mapValues(self, func):
return Table(self._table.mapValues(func))

@computing_profile
def flatMap(self, func):
return Table(self._table.flatMap(func))

@computing_profile
def applyPartitions(self, func):
return Table(self._table.applyPartitions(func))

@computing_profile
def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning=False):
if use_previous_behavior is True:
LOGGER.warning(
"please use `applyPartitions` instead of `mapPartitions` "
"if the previous behavior was expected. "
"The previous behavior will not work in future"
)
return Table(self._table.applyPartitions(func))
return Table(self._table.mapPartitions(func, preserves_partitioning=preserves_partitioning))

@computing_profile
def mapReducePartitions(self, mapper, reducer, **kwargs):
return Table(self._table.mapReducePartitions(mapper, reducer))

@computing_profile
def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs):
return Table(self._table.mapPartitionsWithIndex(func, preserves_partitioning=preserves_partitioning))

@computing_profile
def glom(self):
return Table(self._table.glom())

@computing_profile
def sample(
self,
Expand All @@ -139,7 +159,7 @@ def sample(
seed=None,
):
if fraction is not None:
return Table(self._table.sample(fraction=fraction, seed=seed))
return Table(self._sample(fraction=fraction, seed=seed))

if num is not None:
total = self._table.count()
Expand All @@ -148,7 +168,7 @@ def sample(

frac = num / float(total)
while True:
sampled_table = self._table.sample(fraction=frac, seed=seed)
sampled_table = self._sample(fraction=frac, seed=seed)
sampled_count = sampled_table.count()
if sampled_count < num:
frac += 0.1
Expand All @@ -163,19 +183,3 @@ def sample(
return Table(sampled_table)

raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}")

@computing_profile
def filter(self, func):
return Table(self._table.filter(func))

@computing_profile
def join(self, other: "Table", func):
return Table(self._table.join(other._table, func))

@computing_profile
def subtractByKey(self, other: "Table"):
return Table(self._table.subtractByKey(other._table))

@computing_profile
def union(self, other: "Table", func=lambda v1, v2: v1):
return Table(self._table.union(other._table, func))
Loading

0 comments on commit 20b5506

Please sign in to comment.