Skip to content

Commit

Permalink
Merge pull request #5027 from FederatedAI/feature-2.0.0-beta-datafram…
Browse files Browse the repository at this point in the history
…e_update

dataloader: ensure order
  • Loading branch information
mgqa34 authored Aug 1, 2023
2 parents 10fec3c + c2ba4c3 commit be96985
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
14 changes: 7 additions & 7 deletions python/fate/arch/dataframe/utils/_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
batch_size=-1,
shuffle=False,
batch_strategy="full",
random_seed=None,
random_state=None,
):
self._ctx = ctx
self._dataset = dataset
Expand All @@ -39,7 +39,7 @@ def __init__(
self._batch_size = min(batch_size, len(dataset))
self._shuffle = shuffle
self._batch_strategy = batch_strategy
self._random_seed = random_seed
self._random_state = random_state
self._need_align = need_align
self._mode = mode
self._role = role
Expand All @@ -56,7 +56,7 @@ def _init_settings(self):
role=self._role,
batch_size=self._batch_size,
shuffle=self._shuffle,
random_seed=self._random_seed,
random_state=self._random_state,
need_align=self._need_align,
sync_arbiter=self._sync_arbiter,
)
Expand All @@ -77,7 +77,7 @@ def __iter__(self):


class FullBatchDataLoader(object):
def __init__(self, dataset, ctx, mode, role, batch_size, shuffle, random_seed, need_align, sync_arbiter):
def __init__(self, dataset, ctx, mode, role, batch_size, shuffle, random_state, need_align, sync_arbiter):
self._dataset = dataset
self._ctx = ctx
self._mode = mode
Expand All @@ -86,7 +86,7 @@ def __init__(self, dataset, ctx, mode, role, batch_size, shuffle, random_seed, n
if self._batch_size is None and self._role != "arbiter":
self._batch_size = len(self._dataset)
self._shuffle = shuffle
self._random_seed = random_seed
self._random_state = random_state
self._need_align = need_align
self._sync_arbiter = sync_arbiter

Expand Down Expand Up @@ -121,9 +121,9 @@ def _prepare(self):
self._batch_splits.append(BatchEncoding(self._dataset, batch_id=0))
else:
if self._mode in ["homo", "local"] or self._role == "guest":
indexer = list(self._dataset.get_indexer(target="sample_id").collect())
indexer = sorted(list(self._dataset.get_indexer(target="sample_id").collect()))
if self._shuffle:
random.seed = self._random_seed
random.seed = self._random_state
random.shuffle(indexer)

for i, iter_ctx in self._ctx.sub_ctx("dataloader_batch").ctxs_range(self._batch_num):
Expand Down
16 changes: 13 additions & 3 deletions python/fate/ml/intersection/raw_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@ def __init__(self):
def fit(self, ctx: Context, train_data, validate_data=None):
# ctx.hosts.put("raw_index", train_data.index.tolist())
ctx.hosts.put("raw_index", train_data.get_indexer(target="sample_id"))
intersect_indexes = ctx.hosts.get("intersect_index")
intersect_indexes = ctx.hosts.get("host_intersect_index")
intersect_data = train_data
for intersect_index in intersect_indexes:
intersect_data = intersect_data.loc(intersect_index, preserve_order=True)
intersect_data = intersect_data.loc(intersect_index)

ctx.hosts.put("final_intersect_index", intersect_data.get_indexer(target="sample_id"))

intersect_count = intersect_data.count()
ctx.hosts.put("intersect_count", intersect_count)

logger.info(f"intersect count={intersect_count}")
data = sorted(intersect_data.block_table.collect())
logger.info(f"mgq-debug, data={data}")
return intersect_data


Expand All @@ -48,8 +52,14 @@ def fit(self, ctx: Context, train_data, validate_data=None):
guest_index = ctx.guest.get("raw_index")
intersect_data = train_data.loc(guest_index)
# ctx.guest.put("intersect_index", intersect_data.index.tolist())
ctx.guest.put("intersect_index", intersect_data.get_indexer(target="sample_id"))
ctx.guest.put("host_intersect_index", intersect_data.get_indexer(target="sample_id"))

final_intersect_index = ctx.guest.get("final_intersect_index")
intersect_data = intersect_data.loc(final_intersect_index, preserve_order=True)

intersect_count = ctx.guest.get("intersect_count")

logger.info(f"intersect count={intersect_count}")
data = sorted(intersect_data.block_table.collect())
logger.info(f"mgq-debug, data={data}")
return intersect_data

0 comments on commit be96985

Please sign in to comment.