Skip to content

Commit

Permalink
Merge pull request #4913 from FederatedAI/feature-2.0.0-beta_datafram…
Browse files Browse the repository at this point in the history
…e_update

dataframe api update
  • Loading branch information
mgqa34 authored Jun 13, 2023
2 parents 064b14b + ddb1b88 commit 139ac1e
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 6 deletions.
12 changes: 8 additions & 4 deletions python/fate/arch/dataframe/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def apply_row(self, func, columns=None, with_label=False,

def create_frame(self, with_label=False, with_weight=False, columns: list = None) -> "DataFrame":
return self.__extract_fields(with_sample_id=True,
with_match_id=True,
with_label=with_label,
with_weight=with_weight,
columns=columns)
with_match_id=True,
with_label=with_label,
with_weight=with_weight,
columns=columns)

def drop(self, index) -> "DataFrame":
from .ops._dimension_scaling import drop
Expand All @@ -198,6 +198,10 @@ def fillna(self, value):
from .ops._fillna import fillna
return fillna(self, value)

def get_dummies(self, dtype="int32"):
from .ops._encoder import get_dummies
return get_dummies(self, dtype=dtype)

def isna(self):
from .ops._missing import isna
return isna(self)
Expand Down
25 changes: 23 additions & 2 deletions python/fate/arch/dataframe/manager/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pandas as pd
import torch
from enum import Enum
from typing import Union, Tuple, List
from typing import Union, Tuple, List, Dict
from .schema_manager import SchemaManager


Expand Down Expand Up @@ -125,6 +125,10 @@ def is_single(self):
def get_field_offset(self, idx):
return self._field_index_mapping[idx]

def reset_field_indexes(self, dst_field_indexes):
field_indexes = [dst_field_indexes[src_field_index] for src_field_index in self._field_indexes]
self._field_index_mapping = dict(zip(field_indexes, range(len(field_indexes))))

def derive_block(self, field_indexes) -> Tuple["Block", bool, list]:
"""
assume that sub field indexes always in self._field_indexes
Expand Down Expand Up @@ -245,7 +249,7 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
return torch.tensor(block, dtype=torch.float64)
return torch.tensor(block, dtype=torch.float64)


class BoolBlock(Block):
Expand Down Expand Up @@ -345,6 +349,23 @@ def append_fields(self, field_indexes, block_types, should_compress=True):

return block_ids

def pop_blocks(self, block_indexes: List[int], field_index_changes: Dict[int, int]):
block_index_set = set(block_indexes)
blocks = []
field_block_mapping = dict()

for bid, block in enumerate(self._blocks):
if bid not in block_index_set:
block.reset_field_indexes(field_index_changes)
blocks.append(block)

for bid, block in enumerate(blocks):
for offset, field_index in enumerate(block.field_indexes):
field_block_mapping[field_index] = (bid, offset)

self._blocks = blocks
self._field_block_mapping = field_block_mapping

def split_fields(self, field_indexes, block_types):
field_sets = set(field_indexes)
block_field_maps = dict()
Expand Down
10 changes: 10 additions & 0 deletions python/fate/arch/dataframe/manager/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def append_columns(self, columns: List[str], block_types: Union["BlockType", Lis

return block_indexes

def pop_blocks(self, block_indexes: List[int]):
"""
"""
field_indexes = []
for block_index in block_indexes:
field_indexes.extend(self._block_manager.blocks[block_index].field_indexes)

field_index_changes = self._schema_manager.pop_fields(field_indexes)
self._block_manager.pop_blocks(block_indexes, field_index_changes)

def split_columns(self, columns: List[str], block_types: Union["BlockType", List["BlockType"]]):
field_indexes = self._schema_manager.split_columns(columns, block_types)
narrow_blocks, dst_blocks = self._block_manager.split_fields(field_indexes, block_types)
Expand Down
42 changes: 42 additions & 0 deletions python/fate/arch/dataframe/manager/schema_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ def append_columns(self, names):
self._columns = self._columns.append(pd.Index(names))
# TODO: extend anonymous column

def pop_columns(self, names):
names = set(names)
if self._label_name in names:
names.remove(self._label_name)
self._label_name = None
if self._weight_name in names:
names.remove(self._weight_name)
self._weight_name = None

columns = []
for name in self._columns:
if name not in names:
columns.append(name)
self._columns = pd.Index(columns)

# TODO: pop anonymous columns

def __eq__(self, other: "Schema"):
return self.label_name == other.label_name and self.weight_name == other.weight_name \
and self.sample_id_name == other.sample_id_name and self.match_id_name == other.match_id_name \
Expand Down Expand Up @@ -161,6 +178,31 @@ def append_columns(self, names, block_types):

return [field_index + offset for offset in range(len(names))]

def pop_fields(self, field_indexes):
field_names = [self._offset_name_mapping[field_id] for field_id in field_indexes]
self._schema = copy.deepcopy(self._schema)
self._schema.pop_columns(field_names)

field_index_set = set(field_indexes)
left_field_indexes = []
for i in range(len(self._offset_name_mapping)):
if i not in field_index_set:
left_field_indexes.append(i)

name_offset_mapping = dict()
offset_name_mapping = dict()
field_index_changes = dict()
for dst_field_id, src_field_id in enumerate(left_field_indexes):
name = self._offset_name_mapping[src_field_id]
name_offset_mapping[name] = dst_field_id
offset_name_mapping[dst_field_id] = name
field_index_changes[src_field_id] = dst_field_id

self._name_offset_mapping = name_offset_mapping
self._offset_name_mapping = offset_name_mapping

return field_index_changes

def split_columns(self, names, block_types):
field_indexes = [self._name_offset_mapping[name] for name in names]
for offset, name in enumerate(names):
Expand Down
100 changes: 100 additions & 0 deletions python/fate/arch/dataframe/ops/_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from .._dataframe import DataFrame
from ..manager import BlockType


def get_dummies(df: "DataFrame", dtype="int32"):
data_manager = df.data_manager
block_indexes = data_manager.infer_operable_blocks()
field_names = data_manager.infer_operable_field_names()

if len(field_names) != 1:
raise ValueError(f"get_dummies only support single column, but {len(field_names)} columns are found.")

categories = _get_categories(df.block_table, block_indexes)[0][0]
dst_field_names = ["_".join(map(str, [field_names[0], c])) for c in categories]
dst_data_manager = data_manager.duplicate()
dst_data_manager.pop_blocks(block_indexes)
dst_data_manager.append_columns(dst_field_names, block_types=BlockType.get_block_type(dtype))

block_table = _one_hot_encode(df.block_table, block_indexes, dst_data_manager, [[categories]], dtype=dtype)

return DataFrame(
df._ctx,
block_table,
partition_order_mappings=df.partition_order_mappings,
data_manager=dst_data_manager
)


def _get_categories(block_table, block_indexes):
block_index_set = set(block_indexes)

def _mapper(blocks):
categories_ = []
for bid, block in enumerate(blocks):
if bid not in block_index_set:
continue

enc = OneHotEncoder()
cate_block = enc.fit(block).categories_
categories_.append([set(cate) for cate in cate_block])

return categories_

def _reducer(categories1_, categories2_):
categories_ = []
for cate_block1, cate_block2 in zip(categories1_, categories2_):
cate_block = [cate1 | cate2 for cate1, cate2 in zip(cate_block1, cate_block2)]
categories_.append(cate_block)

return categories_

categories = block_table.mapValues(_mapper).reduce(_reducer)

categories = [[sorted(cate) for cate in cate_block]
for cate_block in categories]

return categories


def _one_hot_encode(block_table, block_indexes, data_manager, categories, dtype):
categories = [np.array(category) for category in categories]
block_index_set = set(block_indexes)

def _encode(blocks):
ret_blocks = []
enc_blocks = []
idx = 0
for bid, block in enumerate(blocks):
if bid not in block_index_set:
ret_blocks.append(block)
continue

enc = OneHotEncoder(dtype=dtype)
enc.fit([[1]]) # one hot encoder need to fit first.
enc.categories_ = categories[idx]
idx += 1
enc_blocks.append(enc.transform(block).toarray())

ret_blocks.append(data_manager.blocks[-1].convert_block(np.hstack(enc_blocks)))

return ret_blocks

return block_table.mapValues(_encode)

0 comments on commit 139ac1e

Please sign in to comment.