Skip to content

Commit

Permalink
dataframe: adjust for anonymous sitename
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Sep 8, 2023
1 parent 3cd9d37 commit e2c0251
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 68 deletions.
10 changes: 4 additions & 6 deletions python/fate/arch/dataframe/_frame_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def __init__(
header: str = None,
delimiter: str = ",",
dtype: Union[str, dict] = "float32",
anonymous_role: str = None,
anonymous_party_id: str = None,
anonymous_site_name: str = None,
na_values: Union[str, list, dict] = None,
input_format: str = "dense",
tag_with_value: bool = False,
Expand All @@ -56,8 +55,7 @@ def __init__(
self._delimiter = delimiter
self._header = header
self._dtype = dtype
self._anonymous_role = anonymous_role
self._anonymous_party_id = anonymous_party_id
self._anonymous_site_name = anonymous_site_name
self._na_values = na_values
self._input_format = input_format
self._tag_with_value = tag_with_value
Expand Down Expand Up @@ -232,11 +230,11 @@ def to_frame(self, ctx, df: "pd.DataFrame"):
label_type=self._label_type, weight_type=self._weight_type,
dtype=self._dtype, default_type=types.DEFAULT_DATA_TYPE)

site_name = ctx.local.name
local_role = ctx.local.party[0]
local_party_id = ctx.local.party[1]

if local_role != "local":
data_manager.fill_anonymous_role_and_party_id(role=local_role, party_id=local_party_id)
data_manager.fill_anonymous_site_name(site_name=site_name)

buf = zip(df.index.tolist(), df.values.tolist())
table = ctx.computing.parallelize(
Expand Down
5 changes: 2 additions & 3 deletions python/fate/arch/dataframe/io/_json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def deserialize(ctx, data):

data_manager = DataManager.deserialize(schema_meta)

role = ctx.local.party[0]
party_id = ctx.local.party[1]
data_manager.fill_anonymous_role_and_party_id(role, party_id)
site_name = ctx.local.name
data_manager.fill_anonymous_site_name(site_name)

from ..ops._transformer import transform_list_to_block_table

Expand Down
9 changes: 4 additions & 5 deletions python/fate/arch/dataframe/manager/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,15 @@ def duplicate(self) -> "DataManager":

def init_from_local_file(self, sample_id_name, columns, match_id_list, match_id_name, label_name, weight_name,
label_type, weight_type, dtype, default_type=types.DEFAULT_DATA_TYPE,
anonymous_role=None, anonymous_party_id=None):
anonymous_site_name=None):
schema_manager = SchemaManager()
retrieval_index_dict = schema_manager.parse_local_file_schema(sample_id_name,
columns,
match_id_list,
match_id_name,
label_name,
weight_name,
anonymous_role=anonymous_role,
anonymous_party_id=anonymous_party_id)
anonymous_site_name=anonymous_site_name)
schema_manager.init_field_types(label_type, weight_type, dtype,
default_type=default_type)
block_manager = BlockManager()
Expand Down Expand Up @@ -136,8 +135,8 @@ def loc_block(self, name: Union[str, List[str]], with_offset=True):

return loc_ret

def fill_anonymous_role_and_party_id(self, role, party_id):
self._schema_manager.fill_anonymous_role_and_party_id(role, party_id)
def fill_anonymous_site_name(self, site_name):
self._schema_manager.fill_anonymous_site_name(site_name)

def get_fields_loc(self, with_sample_id=True, with_match_id=True, with_label=True, with_weight=True):
field_block_mapping = self._block_manager.field_block_mapping
Expand Down
27 changes: 12 additions & 15 deletions python/fate/arch/dataframe/manager/schema_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def weight_name(self, weight_name: str):
self._weight_name = weight_name

if self.anonymous_weight_name is None:
anonymous_generator = AnonymousGenerator(role=self._anonymous_summary["role"],
party_id=self._anonymous_summary["party_id"])
anonymous_generator = AnonymousGenerator(site_name=self._anonymous_summary["site_name"])

self._anonymous_weight_name = anonymous_generator.add_anonymous_weight()

Expand All @@ -93,8 +92,7 @@ def label_name(self, label_name: str):
self._label_name = label_name

if self._anonymous_label_name is None:
anonymous_generator = AnonymousGenerator(role=self._anonymous_summary["role"],
party_id=self._anonymous_summary["party_id"])
anonymous_generator = AnonymousGenerator(site_name=self._anonymous_summary["site_name"])
self._anonymous_label_name = anonymous_generator.add_anonymous_label()

@property
Expand Down Expand Up @@ -132,21 +130,20 @@ def anonymous_summary(self, anonymous_summary):
def append_columns(self, names):
self._columns = self._columns.append(pd.Index(names))
# TODO: extend anonymous column
anonymous_generator = AnonymousGenerator(role=self._anonymous_summary["role"],
party_id=self._anonymous_summary["party_id"])
anonymous_generator = AnonymousGenerator(site_name=self._anonymous_summary["site_name"])

anonymous_columns, anonymous_summary = anonymous_generator.add_anonymous_columns(names, self._anonymous_summary)
self._anonymous_columns = self._anonymous_columns.append(pd.Index(anonymous_columns))
self._anonymous_summary = anonymous_summary

def init_anonymous_names(self, anonymous_role, anonymous_party_id):
anonymous_generator = AnonymousGenerator(anonymous_role, anonymous_party_id)
def init_anonymous_names(self, anonymous_site_name):
anonymous_generator = AnonymousGenerator(anonymous_site_name)
anonymous_ret_dict = anonymous_generator.generate_anonymous_names(self)
self._set_anonymous_info_by_dict(anonymous_ret_dict)

def fill_anonymous_role_and_party_id(self, anonymous_role, anonymous_party_id):
anonymous_generator = AnonymousGenerator(anonymous_role, anonymous_party_id)
anonymous_ret_dict = anonymous_generator.fill_role_and_party_id(
def fill_anonymous_site_name(self, anonymous_site_name):
anonymous_generator = AnonymousGenerator(anonymous_site_name)
anonymous_ret_dict = anonymous_generator.fill_anonymous_site_name(
anonymous_label_name=self.anonymous_label_name,
anonymous_weight_name=self._anonymous_weight_name,
anonymous_columns=self._anonymous_columns,
Expand Down Expand Up @@ -368,7 +365,7 @@ def get_all_keys(self):
return list(self._name_offset_mapping.keys())

def parse_local_file_schema(self, sample_id_name, columns, match_id_list, match_id_name, label_name, weight_name,
anonymous_role=None, anonymous_party_id=None):
anonymous_site_name=None):
column_indexes = list(range(len(columns)))

match_id_index, label_index, weight_index = None, None, None
Expand Down Expand Up @@ -403,7 +400,7 @@ def parse_local_file_schema(self, sample_id_name, columns, match_id_list, match_
columns=columns
)

self._schema.init_anonymous_names(anonymous_role, anonymous_party_id)
self._schema.init_anonymous_names(anonymous_site_name)
self.init_name_mapping()

return dict(
Expand All @@ -413,8 +410,8 @@ def parse_local_file_schema(self, sample_id_name, columns, match_id_list, match_
column_indexes=column_indexes
)

def fill_anonymous_role_and_party_id(self, anonymous_role, anonymous_party_id):
self._schema.fill_anonymous_role_and_party_id(anonymous_role, anonymous_party_id)
def fill_anonymous_site_name(self, anonymous_site_name):
self._schema.fill_anonymous_site_name(anonymous_site_name)

@staticmethod
def extract_column_index_by_name(columns, column_indexes, name, drop=True):
Expand Down
67 changes: 32 additions & 35 deletions python/fate/arch/dataframe/manager/utils/_anonymous_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@
ANONYMOUS_LABEL = "y"
ANONYMOUS_WEIGHT = "weight"
SPLICES = "_"
ANONYMOUS_ROLE = "AnonymousRole"
ANONYMOUS_PARTY_ID = "AnonymousPartyId"
DEFAULT_SITE_NAME = "AnonymousRole_AnonymousPartyId"


class AnonymousGenerator(object):
def __init__(self, role=None, party_id=None):
self._role = role
self._party_id = party_id
def __init__(self, site_name=None):
self._site_name = site_name

def _generate_anonymous_column(self, suf):
if self._role and self._party_id:
return SPLICES.join([self._role, self._party_id, suf])
if self._site_name:
return SPLICES.join([self._site_name, suf])
else:
return SPLICES.join([ANONYMOUS_ROLE, ANONYMOUS_PARTY_ID, suf])
return SPLICES.join([DEFAULT_SITE_NAME, suf])

def generate_anonymous_names(self, schema):
column_len = len(schema.columns.tolist())
Expand All @@ -55,19 +53,15 @@ def generate_anonymous_names(self, schema):
anonymous_weight_name=anonymous_weight_name,
anonymous_columns=anonymous_columns,
anonymous_summary=dict(column_len=column_len,
role=self._role,
party_id=self._party_id)
site_name=self._site_name
)
)

def _check_role_party_id_consistency(self, anonymous_summary):
anonymous_role = anonymous_summary["role"]
anonymous_party_id = anonymous_summary["party_id"]
def _check_site_name_consistency(self, anonymous_summary):
anonymous_site_name = anonymous_summary["site_name"]

if anonymous_role and self._role is not None and anonymous_role != self._role:
raise ValueError(f"previous_role={anonymous_role} != current_role={self._role}")

if anonymous_party_id and self._party_id is not None and anonymous_party_id != self._party_id:
raise ValueError(f"previous_party_id={anonymous_party_id} != current_role={self._party_id}")
if anonymous_site_name and self._site_name is not None and anonymous_site_name != self._site_name:
raise ValueError(f"previous_site_name={anonymous_site_name} != current_site_name={self._site_name}")

def add_anonymous_label(self):
return self._generate_anonymous_column(ANONYMOUS_LABEL)
Expand All @@ -76,7 +70,7 @@ def add_anonymous_weight(self):
return self._generate_anonymous_column(ANONYMOUS_WEIGHT)

def add_anonymous_columns(self, columns, anonymous_summary: dict):
self._check_role_party_id_consistency(anonymous_summary)
self._check_site_name_consistency(anonymous_summary)
anonymous_summary = copy.deepcopy(anonymous_summary)

column_len = anonymous_summary["column_len"]
Expand All @@ -86,18 +80,17 @@ def add_anonymous_columns(self, columns, anonymous_summary: dict):
anonymous_summary["column_len"] = column_len + len(columns)
return anonymous_columns, anonymous_summary

def fill_role_and_party_id(self, anonymous_label_name, anonymous_weight_name,
anonymous_columns, anonymous_summary):
def fill_anonymous_site_name(self, anonymous_label_name, anonymous_weight_name,
anonymous_columns, anonymous_summary):
anonymous_summary = copy.deepcopy(anonymous_summary)

self._check_role_party_id_consistency(anonymous_summary)
self._check_site_name_consistency(anonymous_summary)

if anonymous_summary["role"] is None and anonymous_summary["party_id"] is None:
anonymous_label_name = self._fill_role_and_party_id(anonymous_label_name)
anonymous_weight_name = self._fill_role_and_party_id(anonymous_weight_name)
anonymous_columns = self._fill_role_and_party_id(anonymous_columns)
anonymous_summary["role"] = self._role
anonymous_summary["party_id"] = self._party_id
if anonymous_summary["site_name"] is None:
anonymous_label_name = self._fill_site_name(anonymous_label_name)
anonymous_weight_name = self._fill_site_name(anonymous_weight_name)
anonymous_columns = self._fill_site_name(anonymous_columns)
anonymous_summary["site_name"] = self._site_name

return dict(
anonymous_label_name=anonymous_label_name,
Expand All @@ -106,22 +99,26 @@ def fill_role_and_party_id(self, anonymous_label_name, anonymous_weight_name,
anonymous_summary=anonymous_summary
)

def _fill_role_and_party_id(self, name):
def _fill_site_name(self, name):
if name is None:
return name

if isinstance(name, str):
role, party_id, suf = name.split(SPLICES, 2)
if role != ANONYMOUS_ROLE or party_id != ANONYMOUS_PARTY_ID:
raise ValueError(f"To fill anonymous names with role and party_id, it shouldn't be fill before")
site_name_pre, site_name_suf, suf = name.split(SPLICES, 2)
site_name = SPLICES.join([site_name_pre, site_name_suf])

if site_name != DEFAULT_SITE_NAME:
raise ValueError(f"To fill anonymous names with site_name, it shouldn't be fill before")
return self._generate_anonymous_column(suf)
else:
name = list(name)
ret = []
for _name in name:
role, party_id, suf = _name.split(SPLICES, 2)
if role != ANONYMOUS_ROLE or party_id != ANONYMOUS_PARTY_ID:
raise ValueError(f"To fill anonymous names with role and party_id, it shouldn't be fill before")
site_name_pre, site_name_suf, suf = _name.split(SPLICES, 2)
site_name = SPLICES.join([site_name_pre, site_name_suf])

if site_name != DEFAULT_SITE_NAME:
raise ValueError(f"To fill anonymous names with site_name, it shouldn't be fill before")

ret.append(self._generate_anonymous_column(suf))

Expand Down
6 changes: 2 additions & 4 deletions python/fate/components/components/dataframe_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def dataframe_transformer(
dataframe_output: cpn.dataframe_output(roles=[LOCAL]),
namespace: cpn.parameter(type=str, default=None, optional=True),
name: cpn.parameter(type=str, default=None, optional=True),
anonymous_role: cpn.parameter(type=str, default=None, optional=True),
anonymous_party_id: cpn.parameter(type=str, default=None, optional=True),
anonymous_site_name: cpn.parameter(type=str, default=None, optional=True),
):
from fate.arch.dataframe import TableReader

Expand All @@ -43,8 +42,7 @@ def dataframe_transformer(
header=metadata.get("header", None),
na_values=metadata.get("na_values", None),
dtype=metadata.get("dtype", "float32"),
anonymous_role=anonymous_role,
anonymous_party_id=anonymous_party_id,
anonymous_site_name=anonymous_site_name,
delimiter=metadata.get("delimiter", ","),
input_format=metadata.get("input_format", "dense"),
tag_with_value=metadata.get("tag_with_value", False),
Expand Down

0 comments on commit e2c0251

Please sign in to comment.