Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Oct 19, 2023
1 parent 31e25eb commit 1439468
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 219 deletions.
23 changes: 11 additions & 12 deletions python/fate/arch/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def destroy(self):
txn.drop(db)
_TableMetaManager.destroy_table(self._namespace, self._name)

def take(self, n, **kwargs):
if n <= 0:
raise ValueError(f"{n} <= 0")
return list(itertools.islice(self.collect(**kwargs), n))
def take(self, num, **kwargs):
if num <= 0:
raise ValueError(f"{num} <= 0")
return list(itertools.islice(self.collect(**kwargs), num))

def count(self):
cnt = 0
Expand Down Expand Up @@ -320,7 +320,7 @@ def map_reduce_partitions_with_index(
session=self._session,
name=result.name,
namespace=result.namespace,
partitions=self.num_partitions,
partitions=output_num_partitions,
need_cleanup=need_cleanup,
key_serdes_type=output_key_serdes_type,
value_serdes_type=output_value_serdes_type,
Expand All @@ -329,7 +329,7 @@ def map_reduce_partitions_with_index(

if reduce_partition_op is None:
# noinspection PyProtectedMember
results = self._session._submit_map_reduce_partitions_with_index(
self._session._submit_map_reduce_partitions_with_index(
_do_mrwi_shuffle_no_reduce,
map_partition_op,
reduce_partition_op,
Expand All @@ -341,13 +341,12 @@ def map_reduce_partitions_with_index(
output_namespace=output_namespace,
output_partitioner=output_partitioner,
)
result = results[0]
# noinspection PyProtectedMember
return _create_table(
session=self._session,
name=result.name,
namespace=result.namespace,
partitions=self.num_partitions,
name=output_name,
namespace=output_namespace,
partitions=output_num_partitions,
need_cleanup=need_cleanup,
key_serdes_type=output_key_serdes_type,
value_serdes_type=output_value_serdes_type,
Expand All @@ -365,7 +364,7 @@ def map_reduce_partitions_with_index(
input_num_partitions=self.num_partitions,
input_name=self._name,
input_namespace=self._namespace,
output_num_partitions=self.num_partitions,
output_num_partitions=output_num_partitions,
output_name=intermediate_name,
output_namespace=intermediate_namespace,
output_partitioner=output_partitioner,
Expand All @@ -387,7 +386,7 @@ def map_reduce_partitions_with_index(
session=self._session,
name=output_name,
namespace=output_namespace,
partitions=self.num_partitions,
partitions=output_num_partitions,
need_cleanup=need_cleanup,
key_serdes_type=output_key_serdes_type,
value_serdes_type=output_value_serdes_type,
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/computing/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _pretty_table_str(v):
from ..computing import is_table

if is_table(v):
return f"Table(partition={v.partitions})"
return f"Table(partition={v.num_partitions})"
else:
return f"{type(v).__name__}"

Expand Down
7 changes: 1 addition & 6 deletions python/fate/arch/computing/eggroll/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,7 @@ def _parallelize(
key_serdes_type=key_serdes_type,
value_serdes_type=value_serdes_type,
)
return Table(
rp,
key_serdes_type=key_serdes_type,
value_serdes_type=value_serdes_type,
partitioner_type=partitioner_type,
)
return Table(rp)

def cleanup(self, name, namespace):
self._rpc.cleanup(name=name, namespace=namespace)
Expand Down
100 changes: 24 additions & 76 deletions python/fate/arch/computing/eggroll/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,20 @@


class Table(KVTable):
def destroy(self):
self._rp.destroy()
def __init__(self, rp: RollPair):
self._rp = rp
self._engine = ComputingEngine.EGGROLL

super().__init__(
key_serdes_type=self._rp.get_store().key_serdes_type,
value_serdes_type=self._rp.get_store().value_serdes_type,
partitioner_type=self._rp.get_store().partitioner_type,
num_partitions=rp.get_partitions(),
)

@property
def engine(self):
return self._engine

def _map_reduce_partitions_with_index(
self,
Expand Down Expand Up @@ -69,12 +81,7 @@ def _map_reduce_partitions_with_index(
output_partitioner_type=output_partitioner_type,
output_num_partitions=output_num_partitions,
)
return Table(
rp,
key_serdes_type=output_key_serdes_type,
value_serdes_type=output_value_serdes_type,
partitioner_type=output_partitioner_type,
)
return Table(rp)

def _binary_sorted_map_partitions_with_index(
self,
Expand Down Expand Up @@ -105,12 +112,7 @@ def _binary_sorted_map_partitions_with_index(
output_value_serdes=output_value_serdes,
output_value_serdes_type=output_value_serdes_type,
)
return Table(
rp,
key_serdes_type=key_serdes_type,
value_serdes_type=output_value_serdes_type,
partitioner_type=partitioner_type,
)
return Table(rp)

def _take(self, n=1, **kwargs):
return self._rp.take(n=n, **kwargs)
Expand All @@ -124,30 +126,7 @@ def _collect(self):
def _reduce(self, func: Callable[[bytes, bytes], bytes]):
return self._rp.reduce(func=func)

def __init__(self, rp: RollPair, key_serdes_type, value_serdes_type, partitioner_type):
self._rp = rp
self._engine = ComputingEngine.EGGROLL

super().__init__(
key_serdes_type=key_serdes_type,
value_serdes_type=value_serdes_type,
partitioner_type=partitioner_type,
num_partitions=rp.get_partitions(),
)

@property
def engine(self):
return self._engine

@property
def partitions(self):
return self._rp.get_partitions()

@computing_profile
def save(self, uri: URI, schema: dict, options: dict = None):
if options is None:
options = {}

def _save(self, uri: URI, schema: dict, options: dict):
from ._type import EggRollStoreType

if uri.scheme != "eggroll":
Expand All @@ -159,47 +138,16 @@ def save(self, uri: URI, schema: dict, options: dict = None):

if "store_type" not in options:
options["store_type"] = EggRollStoreType.ROLLPAIR_LMDB

partitions = options.get("partitions", self.partitions)
self._rp.save_as(
name=name,
namespace=namespace,
partition=partitions,
options=options,
)
schema.update(self.schema)
return

@computing_profile
def sample(
self,
*,
fraction: typing.Optional[float] = None,
num: typing.Optional[int] = None,
seed=None,
):
if fraction is not None:
return Table(self._rp.sample(fraction=fraction, seed=seed))

if num is not None:
total = self._rp.count()
if num > total:
raise ValueError(f"not enough data to sample, own {total} but required {num}")

frac = num / float(total)
while True:
sampled_table = self._rp.sample(fraction=frac, seed=seed)
sampled_count = sampled_table.count()
if sampled_count < num:
frac *= 1.1
else:
break

if sampled_count > num:
drops = sampled_table.take(sampled_count - num)
for k, v in drops:
sampled_table.delete(k)

return Table(sampled_table)

raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}")
def _drop_num(self, num: int, partitioner):
for k, v in self._rp.take(num=num):
self._rp.delete(k, partitioner=partitioner)
return self

def _destroy(self):
self._rp.destroy()
55 changes: 8 additions & 47 deletions python/fate/arch/computing/standalone/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import logging
import typing
from typing import Callable, Iterable, Any, Tuple

from ...unify import URI
Expand All @@ -41,19 +40,17 @@ def __init__(self, table: StandaloneTable):
def table(self):
return self._table

@property
def partitions(self):
return self._table.partitions

@property
def engine(self):
return self._engine

def __getstate__(self):
def _destroy(self):
pass

def __reduce__(self):
raise NotImplementedError("Table is not picklable, please don't do this or it may cause unexpected error")
def _drop_num(self, num: int, partitioner):
for k, v in self._table.take(num=num):
self._table.delete(k, partitioner=partitioner)
return Table(table=self._table)

def _map_reduce_partitions_with_index(
self,
Expand Down Expand Up @@ -115,8 +112,8 @@ def _binary_sorted_map_partitions_with_index(
def _collect(self, **kwargs):
return self._table.collect(**kwargs)

def _take(self, n=1, **kwargs):
return self._table.take(n=n, **kwargs)
def _take(self, num=1, **kwargs):
return self._table.take(num=num, **kwargs)

def _count(self):
return self._table.count()
Expand All @@ -125,7 +122,7 @@ def _reduce(self, func, **kwargs):
return self._table.reduce(func)

@computing_profile
def _save(self, uri: URI, schema, options: dict = None):
def _save(self, uri: URI, schema, options: dict):
if uri.scheme != "standalone":
raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend")
try:
Expand All @@ -137,39 +134,3 @@ def _save(self, uri: URI, schema, options: dict = None):
namespace=namespace,
need_cleanup=False,
)
# TODO: self.schema is a bit confusing here, it set by property assignment directly, not by constructor
schema.update(self.schema)

@computing_profile
def sample(
self,
*,
fraction: typing.Optional[float] = None,
num: typing.Optional[int] = None,
seed=None,
):
if fraction is not None:
return Table(self._sample(fraction=fraction, seed=seed))

if num is not None:
total = self._table.count()
if num > total:
raise ValueError(f"not enough data to sample, own {total} but required {num}")

frac = num / float(total)
while True:
sampled_table = self._sample(fraction=frac, seed=seed)
sampled_count = sampled_table.count()
if sampled_count < num:
frac += 0.1
else:
break

if sampled_count > num:
drops = sampled_table.take(sampled_count - num)
for k, v in drops:
sampled_table.delete(k)

return Table(sampled_table)

raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}")
Loading

0 comments on commit 1439468

Please sign in to comment.