Skip to content

Commit

Permalink
dataframe update: support variation/isna/na_count
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Jun 12, 2023
1 parent ea4c4f1 commit 704c5e7
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
11 changes: 11 additions & 0 deletions python/fate/arch/dataframe/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,17 @@ def fillna(self, value):
from .ops._fillna import fillna
return fillna(self, value)

def isna(self):
from .ops._missing import isna
return isna(self)

def isin(self, values):
from .ops._isin import isin
return isin(self, values)

def na_count(self):
return self.isna().sum()

def max(self) -> "pd.Series":
from .ops._stat import max
return max(self)
Expand All @@ -226,6 +233,10 @@ def var(self, ddof=1, **kwargs):
from .ops._stat import var
return var(self, ddof=ddof)

def variation(self, ddof=1):
from .ops._stat import variation
return variation(self, ddof=ddof)

def skew(self, unbiased=False):
from .ops._stat import skew
return skew(self, unbiased=unbiased)
Expand Down
58 changes: 58 additions & 0 deletions python/fate/arch/dataframe/ops/_missing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from ._compress_block import compress_blocks
from .._dataframe import DataFrame
from ..manager import BlockType


def isna(df: "DataFrame"):
data_manager = df.data_manager
block_indexes = data_manager.infer_operable_blocks()

block_table = _isna(df.block_table, block_indexes)
dst_data_manager = data_manager.duplicate()
to_promote_types = []
for bid in block_indexes:
to_promote_types.append((bid, BlockType.get_block_type(torch.bool)))

dst_data_manager.promote_types(to_promote_types)
dst_block_table, dst_data_manager = compress_blocks(block_table, dst_data_manager)

return DataFrame(
df._ctx,
dst_block_table,
df.partition_order_mappings,
dst_data_manager
)


def _isna(block_table, block_indexes):
block_index_set = set(block_indexes)

def _isna_judgement(blocks):
ret_blocks = []
for bid, block in enumerate(blocks):
if bid not in block_index_set:
ret_blocks.append(block)
else:
ret_blocks.append(torch.isnan(block) if isinstance(block, torch.Tensor) else np.isnan(block))

return ret_blocks

return block_table.mapValues(
_isna_judgement
)
13 changes: 5 additions & 8 deletions python/fate/arch/dataframe/ops/_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,24 +225,21 @@ def kurt(df: "DataFrame", unbiased=False):
return m4 / m2 ** 4 - 3


def variation(df: "DataFrame", ddof=1):
return std(df, ddof==ddof) / mean(df)


def _post_process(reduce_ret, operable_blocks, data_manager: "DataManager") -> "pd.Series":
field_names = data_manager.infer_operable_field_names()
field_indexes = [data_manager.get_field_offset(name) for name in field_names]
field_indexes_loc = dict(zip(field_indexes, range(len(field_indexes))))
ret = [[] for i in range(len(field_indexes))]

block_type = None

reduce_ret = [r.reshape(-1).tolist() for r in reduce_ret]
for idx, bid in enumerate(operable_blocks):
field_indexes = data_manager.blocks[bid].field_indexes
for offset, field_index in enumerate(field_indexes):
loc = field_indexes_loc[field_index]
ret[loc] = reduce_ret[idx][offset]

if block_type is None:
block_type = data_manager.blocks[bid].block_type
elif block_type < data_manager.blocks[bid].block_type:
block_type = data_manager.blocks[bid].block_type

return pd.Series(ret, index=field_names, dtype=block_type.value)
return pd.Series(ret, index=field_names)

0 comments on commit 704c5e7

Please sign in to comment.