Skip to content

Commit

Permalink
dataframe: support kfold splits
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Jul 11, 2023
1 parent e356046 commit fa70dd1
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/fate/arch/dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from .io import build_schema, deserialize, parse_schema, serialize
from .utils import DataLoader
from .utils import KFold

__all__ = [
"PandasReader",
Expand All @@ -34,4 +35,6 @@
"serialize",
"deserialize",
"DataFrame",
"KFold",
"DataLoader",
]
1 change: 1 addition & 0 deletions python/fate/arch/dataframe/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._dataloader import DataLoader
from ._k_fold import KFold
86 changes: 86 additions & 0 deletions python/fate/arch/dataframe/utils/_k_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .._dataframe import DataFrame
from sklearn.model_selection import KFold as sk_KFold


class KFold(object):
def __init__(self,
ctx,
mode="hetero",
role="guest",
n_splits=5,
shuffle=False,
random_state=None):
self._ctx = ctx
self._mode = mode
self._role = role
self._n_splits = n_splits
self._shuffle = shuffle
self._random_state = random_state

self._check_param()

def split(self, df: DataFrame):
if self._mode == "hetero":
return self._hetero_split(df)
else:
return self._homo_split(df, return_indexer=False)

def _hetero_split(self, df: DataFrame):
if self._role == "guest":
homo_splits = self._homo_split(df, return_indexer=True)
for _, iter_ctx in self._ctx.sub_ctx("KFold").ctxs_range(self._n_splits):
train_frame, test_frame, train_indexer, test_indexer = next(homo_splits)

iter_ctx.hosts.put("fold_indexes", (train_indexer, test_indexer))

yield train_frame, test_frame
else:
for _, iter_ctx in self._ctx.sub_ctx("KFold").ctxs_range(self._n_splits):
train_indexer, test_indexer = iter_ctx.guest.get("fold_indexes")
train_frame = df.loc(train_indexer)
test_frame = df.loc(test_indexer)

yield train_frame, test_frame

def _homo_split(self, df: DataFrame, return_indexer):
kf = sk_KFold(n_splits=self._n_splits, shuffle=self._shuffle, random_state=self._random_state)
indexer = list(df.get_indexer(target="sample_id").collect())

for train, test in kf.split(indexer):
train_indexer = [indexer[idx] for idx in train]
test_indexer = [indexer[idx] for idx in test]

train_indexer = self._ctx.computing.parallelize(train_indexer,
include_key=True,
partition=df.block_table.partitions)

test_indexer = self._ctx.computing.parallelize(test_indexer,
include_key=True,
partition=df.block_table.partitions)

train_frame = df.loc(train_indexer)
test_frame = df.loc(test_indexer)

if return_indexer:
yield train_frame, test_frame, train_indexer, test_indexer
else:
yield train_frame, test_frame

def _check_param(self):
if not isinstance(self._n_splits, int) or self._n_splits < 2:
raise ValueError("n_splits should be positive integer >= 2")

0 comments on commit fa70dd1

Please sign in to comment.