Skip to content

Commit

Permalink
add label info to data split when using stratified mode(#4779)
Browse files Browse the repository at this point in the history
edit examples

Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Dec 20, 2023
1 parent d235322 commit 7f91a5e
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 17 deletions.
6 changes: 3 additions & 3 deletions examples/benchmark_quality/lr/epsilon_5k_sshe_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ data_host: "epsilon_5k_hetero_host"
idx: "id"
label_name: "y"
epochs: 5
batch_size: 80
batch_size: 20
init_param:
fit_intercept: True
method: "zeros"
random_state: 42
random_state: 102
early_stop: "diff"
learning_rate: 0.15
learning_rate: 1.0
55 changes: 55 additions & 0 deletions examples/pipeline/data_split/data_split_lr_testsuite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,65 @@ data:
table_name: breast_hetero_host
namespace: experiment
role: host_1
- file: examples/data/breast_homo_guest.csv
meta:
delimiter: ","
dtype: float64
input_format: dense
label_type: int64
label_name: y
match_id_name: id
match_id_range: 0
tag_value_delimiter: ":"
tag_with_value: false
weight_type: float64
partitions: 4
head: true
extend_sid: true
table_name: breast_homo_guest
namespace: experiment
role: guest_0
- file: examples/data/breast_hetero_host.csv
meta:
delimiter: ","
dtype: float64
input_format: dense
match_id_name: id
match_id_range: 0
tag_value_delimiter: ":"
tag_with_value: false
weight_type: float64
partitions: 4
head: true
extend_sid: true
table_name: breast_hetero_host
namespace: experiment
role: host_1
- file: examples/data/breast_homo_host.csv
meta:
delimiter: ","
dtype: float64
input_format: dense
label_type: int64
label_name: y
match_id_name: id
match_id_range: 0
tag_value_delimiter: ":"
tag_with_value: false
weight_type: float64
partitions: 4
head: true
extend_sid: true
table_name: breast_homo_host
namespace: experiment
role: host_0
tasks:
data-split:
script: test_data_split.py
data-split-stratified:
script: test_data_split_stratified.py
data-split-multi-host:
script: test_data_split_multi_host.py
data-split-homo:
script: test_data_split_stratified_homo.py

88 changes: 88 additions & 0 deletions examples/pipeline/data_split/test_data_split_stratified_homo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate import DataSplit, PSI, Reader
from fate_client.pipeline.utils import test_utils


def main(config="../config.yaml", namespace=""):
if isinstance(config, str):
config = test_utils.load_job_config(config)
parties = config.parties
guest = parties.guest[0]
host = parties.host[0]

pipeline = FateFlowPipeline().set_parties(guest=guest, host=host)
if config.task_cores:
pipeline.conf.set("task_cores", config.task_cores)
if config.timeout:
pipeline.conf.set("timeout", config.timeout)

reader_0 = Reader("reader_0")
reader_0.guest.task_parameters(
namespace=f"experiment{namespace}",
name="breast_homo_guest"
)
reader_0.hosts[0].task_parameters(
namespace=f"experiment{namespace}",
name="breast_homo_host"
)

reader_1 = Reader("reader_1")
reader_1.guest.task_parameters(
namespace=f"experiment{namespace}",
name="breast_homo_guest"
)
reader_1.hosts[0].task_parameters(
namespace=f"experiment{namespace}",
name="breast_homo_host"
)

psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"])
psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"])

data_split_0 = DataSplit("data_split_0",
train_size=0.6,
validate_size=0.0,
test_size=0.4,
stratified=False,
hetero_sync=False,
input_data=psi_0.outputs["output_data"])

data_split_1 = DataSplit("data_split_1",
train_size=100,
test_size=30,
stratified=True,
hetero_sync=False,
input_data=psi_1.outputs["output_data"]
)

pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, data_split_0, data_split_1])
pipeline.compile()
# print(pipeline.get_dag())
pipeline.fit()


if __name__ == "__main__":
parser = argparse.ArgumentParser("PIPELINE DEMO")
parser.add_argument("--config", type=str, default="../config.yaml",
help="config file")
parser.add_argument("--namespace", type=str, default="",
help="namespace for data stored in FATE")
args = parser.parse_args()
main(config=args.config, namespace=args.namespace)
9 changes: 4 additions & 5 deletions python/fate/components/components/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ def data_split(
train_data_set, validate_data_set, test_data_set = module.fit(sub_ctx, input_data)
# train_data_set, validate_data_set, test_data_set = module.split_data(train_data)
data_split_summary = {'original_count': input_data.shape[0],
'train_count': train_data_set.shape[0] if train_data_set else None,
'validate_count': validate_data_set.shape[0] if validate_data_set else None,
'test_count': test_data_set.shape[0] if test_data_set else None,
'stratified': stratified}
ctx.metrics.log_metrics(data_split_summary, "summary")
'train_count': train_data_set.shape[0] if train_data_set else 0,
'validate_count': validate_data_set.shape[0] if validate_data_set else 0,
'test_count': test_data_set.shape[0] if test_data_set else 0}
ctx.metrics.log_metrics(data_split_summary, name="summary", type='data_split')
if train_data_set:
train_output_data.write(train_data_set)
if validate_data_set:
Expand Down
67 changes: 58 additions & 9 deletions python/fate/ml/model_selection/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ def fit(self, ctx: Context, train_data, validate_data=None):
self.validate_size,
self.test_size,
data_count)

if self.stratified:
train_data_set = sample_per_label(train_data, sample_count=train_size, random_state=self.random_state)
train_data_set, train_sample_n_per_label, labels = sample_per_label(train_data, sample_count=train_size,
random_state=self.random_state)
if len(train_sample_n_per_label) == 0:
train_sample_n_per_label = {label: 0 for label in labels}
else:
train_data_set = sample_data(df=train_data, n=train_size, random_state=self.random_state)
if train_data_set is not None:
Expand All @@ -58,8 +62,11 @@ def fit(self, ctx: Context, train_data, validate_data=None):
validate_test_data_set = train_data

if self.stratified:
validate_data_set = sample_per_label(validate_test_data_set, sample_count=validate_size,
random_state=self.random_state)
validate_data_set, valid_sample_n_per_label, _ = sample_per_label(validate_test_data_set,
sample_count=validate_size,
random_state=self.random_state)
if len(valid_sample_n_per_label) == 0:
valid_sample_n_per_label = {label: 0 for label in labels}
else:
validate_data_set = sample_data(df=validate_test_data_set, n=validate_size, random_state=self.random_state)
if validate_data_set is not None:
Expand All @@ -84,6 +91,22 @@ def fit(self, ctx: Context, train_data, validate_data=None):
ctx.hosts.put("validate_data_sid", validate_sid)
ctx.hosts.put("test_data_sid", test_sid)

if self.stratified:
if test_data_set:
test_sample_n_per_label = {}
for label in labels:
test_sample_n_per_label[label] = int((test_data_set.label == label).sum().values[0])
else:
test_sample_n_per_label = {label: 0 for label in labels}
for label in labels:
label_summary = {}
label_summary['original_count'] = int((train_data.label == label).sum().values[0])
label_summary['train_count'] = train_sample_n_per_label[label]
label_summary['validate_count'] = valid_sample_n_per_label[label]
label_summary['test_count'] = test_sample_n_per_label[label]

ctx.metrics.log_metrics(label_summary, name=f"{label}_summary", type='data_split')

return train_data_set, validate_data_set, test_data_set


Expand Down Expand Up @@ -124,7 +147,10 @@ def fit(self, ctx: Context, train_data, validate_data=None):
data_count)

if self.stratified:
train_data_set = sample_per_label(train_data, sample_count=train_size, random_state=self.random_state)
train_data_set, train_sample_n_per_label, labels = sample_per_label(train_data, sample_count=train_size,
random_state=self.random_state)
if len(train_sample_n_per_label) == 0:
train_sample_n_per_label = {label: 0 for label in labels}
else:
train_data_set = sample_data(df=train_data, n=train_size, random_state=self.random_state)
if train_data_set is not None:
Expand All @@ -134,8 +160,11 @@ def fit(self, ctx: Context, train_data, validate_data=None):
validate_test_data_set = train_data

if self.stratified:
validate_data_set = sample_per_label(validate_test_data_set, sample_count=validate_size,
random_state=self.random_state)
validate_data_set, valid_sample_n_per_label, _ = sample_per_label(validate_test_data_set,
sample_count=validate_size,
random_state=self.random_state)
if len(valid_sample_n_per_label) == 0:
valid_sample_n_per_label = {label: 0 for label in labels}
else:
validate_data_set = sample_data(df=validate_test_data_set, n=validate_size,
random_state=self.random_state)
Expand All @@ -149,6 +178,21 @@ def fit(self, ctx: Context, train_data, validate_data=None):
test_data_set = None
else:
test_data_set = validate_test_data_set
if self.stratified:
if test_data_set:
test_sample_n_per_label = {}
for label in labels:
test_sample_n_per_label[label] = int((test_data_set.label == label).sum().values[0])
else:
test_sample_n_per_label = {label: 0 for label in labels}
for label in labels:
label_summary = {}
label_summary['original_count'] = int((train_data.label == label).sum().values[0])
label_summary['train_count'] = train_sample_n_per_label[label]
label_summary['validate_count'] = valid_sample_n_per_label[label]
label_summary['test_count'] = test_sample_n_per_label[label]

ctx.metrics.log_metrics(label_summary, name=f"{label}_summary", type='data_split')

return train_data_set, validate_data_set, test_data_set

Expand All @@ -162,25 +206,30 @@ def sample_data(df, n, random_state):

def sample_per_label(train_data, sample_count=None, random_state=None):
train_data_binarized_label = train_data.label.get_dummies()
labels = [label_name.split("_")[1] for label_name in train_data_binarized_label.columns]
labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns]
sampled_data_df = []
sampled_n = 0
data_n = train_data.shape[0]
sample_n_per_label = {}
for i, label in enumerate(labels):
label_data = train_data.iloc(train_data.label == int(label))
label_data = train_data.iloc(train_data.label == label)
if i == len(labels) - 1:
# last label:
to_sample_n = sample_count - sampled_n
else:
to_sample_n = round(label_data.shape[0] / data_n * sample_count)
label_sampled_data = sample_data(df=label_data, n=to_sample_n, random_state=random_state)
if label_sampled_data:
sample_n_per_label[label] = label_sampled_data.shape[0]
else:
sample_n_per_label[label] = 0
if label_sampled_data is not None:
sampled_data_df.append(label_sampled_data)
sampled_n += label_sampled_data.shape[0]
sampled_data = None
if sampled_data_df:
sampled_data = DataFrame.vstack(sampled_data_df)
return sampled_data
return sampled_data, sample_n_per_label, labels


def get_split_data_size(train_size, validate_size, test_size, data_count):
Expand Down

0 comments on commit 7f91a5e

Please sign in to comment.