Skip to content

Commit

Permalink
dataframe: add local sample interface
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
mgqa34 authored and sagewe committed Jul 21, 2023
1 parent 7c57ab1 commit 6a78969
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
4 changes: 4 additions & 0 deletions python/fate/arch/dataframe/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,10 @@ def vstack(cls, stacks: List["DataFrame"]) -> "DataFrame":

return vstack(stacks)

def sample(self, n: int=None, frac: float=None, random_state=None) -> "DataFrame":
from .ops._dimension_scaling import sample
return sample(self, n, frac, random_state)

def __extract_fields(
self,
with_sample_id=True,
Expand Down
33 changes: 29 additions & 4 deletions python/fate/arch/dataframe/ops/_dimension_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List
import pandas as pd
import torch
from sklearn.utils import resample
from .._dataframe import DataFrame
from ..manager.data_manager import DataManager
from ._compress_block import compress_blocks
Expand Down Expand Up @@ -124,7 +125,7 @@ def _align_blocks(blocks, src_fields_loc=None, src_dm: DataManager=None, dst_dm:
)


def drop(df: "DataFrame", index: "DataFrame"=None) -> "DataFrame":
def drop(df: "DataFrame", index: "DataFrame" = None) -> "DataFrame":
data_manager = df.data_manager.duplicate()
l_flatten_func = functools.partial(
_flatten_partition,
Expand Down Expand Up @@ -156,11 +157,35 @@ def drop(df: "DataFrame", index: "DataFrame"=None) -> "DataFrame":
)


def sample(df: "DataFrame", n=None, frac=None, **kwargs):
def sample(df: "DataFrame", n=None, frac: float =None, random_state=None) -> "DataFrame":
"""
下采样:
only support down sample, n should <= df.shape, or fact = 1
"""
...

if n is not None and frac is not None:
raise ValueError("sample's parameters n and frac should not be set in the same time.")

if frac is not None:
if frac > 1:
raise ValueError(f"sample's parameter frac={frac} should <= 1.0")
n = max(1, int(df.shape[0] * frac))

if n > df.shape[0]:
raise ValueError(f"sample's parameter n={n} > data size={df.shape[0]}")

if n == 0:
raise ValueError(f"sample's parameter n={n} should >= 1")

indexer = list(df.get_indexer(target="sample_id").collect())
sample_indexer = resample(indexer, replace=False, n_samples=n, random_state=random_state)

sample_indexer = df._ctx.computing.parallelize(sample_indexer,
include_key=True,
partition=df.block_table.partitions)

sample_frame = df.loc(sample_indexer)

return sample_frame


def _flatten_partition(kvs, block_num=0):
Expand Down

0 comments on commit 6a78969

Please sign in to comment.