Skip to content

Commit

Permalink
Merge pull request #5002 from FederatedAI/feature-2.0.0-sample
Browse files Browse the repository at this point in the history
Feature 2.0.0 sample
  • Loading branch information
mgqa34 authored Jul 24, 2023
2 parents 8b0a723 + e7ccc3a commit a8fb563
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 116 deletions.
68 changes: 68 additions & 0 deletions examples/pipeline/test_data_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate import DataSplit
from fate_client.pipeline.components.fate import Intersection
from fate_client.pipeline.interface import DataWarehouseChannel

pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998")

intersection_0 = Intersection("intersection_0",
method="raw")
intersection_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_sid",
namespace="experiment"))
intersection_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host_sid",
namespace="experiment"))

intersection_1 = Intersection("intersection_1",
method="raw")
intersection_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_sid",
namespace="experiment"))
intersection_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host_sid",
namespace="experiment"))

data_split_0 = DataSplit("data_split_0",
train_size=0.6,
validate_size=0.1,
test_size=None,
input_data=intersection_0.outputs["output_data"])

data_split_1 = DataSplit("data_split_1",
train_size=200,
test_size=50,
input_data=intersection_0.outputs["output_data"]
)

pipeline.add_task(intersection_0)
pipeline.add_task(intersection_1)
pipeline.add_task(data_split_0)
pipeline.add_task(data_split_1)

# pipeline.add_task(hetero_feature_binning_0)
pipeline.compile()
print(pipeline.get_dag())
pipeline.fit()

# print(pipeline.get_task_info("data_split_0").get_output_data())
output_data = pipeline.get_task_info("data_split_0").get_output_data()
import pandas as pd

print(f"data split 0 train size: {pd.DataFrame(output_data['train_output_data']).shape};"
f"validate size: {pd.DataFrame(output_data['validate_output_data']).shape}"
f"test size: {pd.DataFrame(output_data['test_output_data']).shape}")
output_data = pipeline.get_task_info("data_split_1").get_output_data()
print(f"data split 1train size: {pd.DataFrame(output_data['train_output_data']).shape};"
f"validate size: {pd.DataFrame(output_data['validate_output_data']).shape}"
f"test size: {pd.DataFrame(output_data['test_output_data']).shape}")
62 changes: 62 additions & 0 deletions examples/pipeline/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate import Intersection
from fate_client.pipeline.components.fate import Sample
from fate_client.pipeline.interface import DataWarehouseChannel

pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998")

intersection_0 = Intersection("intersection_0",
method="raw")
intersection_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_sid",
namespace="experiment"))
intersection_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host_sid",
namespace="experiment"))

intersection_1 = Intersection("intersection_1",
method="raw")
intersection_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_sid",
namespace="experiment"))
intersection_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host_sid",
namespace="experiment"))

sample_0 = Sample("sample_0",
frac=0.5,
replace=False,
federated_sample=False,
input_data=intersection_0.outputs["output_data"])

sample_1 = Sample("sample_1",
runtime_roles=["guest"],
n=1000,
replace=True,
federated_sample=False,
input_data=intersection_0.outputs["output_data"]
)

pipeline.add_task(intersection_0)
pipeline.add_task(intersection_1)
pipeline.add_task(sample_0)
pipeline.add_task(sample_1)

# pipeline.add_task(hetero_feature_binning_0)
pipeline.compile()
print(pipeline.get_dag())
pipeline.fit()
output_data_0 = pipeline.get_task_info("sample_0").get_output_data()
output_data_1 = pipeline.get_task_info("sample_1").get_output_data()
print(f"sample 0: {output_data_0};"
f"sample 1: {output_data_1}")
35 changes: 22 additions & 13 deletions python/fate/components/components/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Union

from fate.arch import Context
from fate.components.core import GUEST, HOST, Role, cpn, params
from fate.ml.model_selection.data_split import DataSplitModuleGuest, DataSplitModuleHost

logger = logging.getLogger(__name__)


@cpn.component(roles=[GUEST, HOST], provider="fate")
def data_split(
ctx: Context,
role: Role,
input_data: cpn.dataframe_input(roles=[GUEST, HOST]),
train_size: cpn.parameter(type=Union[params.conint(ge=0), params.confloat(ge=0.0)], default=None,
desc="size of output training data, should be either int for exact sample size or float for fraction"),
validate_size: cpn.parameter(type=Union[params.conint(ge=0), params.confloat(ge=0.0)], default=None,
desc="size of output validation data, should be either int for exact sample size or float for fraction"),
test_size: cpn.parameter(type=Union[params.conint(ge=0), params.confloat(ge=0.0)], default=None,
desc="size of output test data, should be either int for exact sample size or float for fraction"),
train_size: cpn.parameter(type=Union[params.confloat(ge=0.0, le=1.0), params.conint(ge=0)], default=None,
desc="size of output training data, "
"should be either int for exact sample size or float for fraction"),
validate_size: cpn.parameter(type=Union[params.confloat(ge=0.0, le=1.0), params.conint(ge=0)], default=None,
desc="size of output validation data, "
"should be either int for exact sample size or float for fraction"),
test_size: cpn.parameter(type=Union[params.confloat(ge=0.0, le=1.0), params.conint(ge=0)], default=None,
desc="size of output test data, "
"should be either int for exact sample size or float for fraction"),
stratified: cpn.parameter(type=bool, default=False,
desc="whether sample with stratification, "
"should not use this for data with continuous label values"),
random_state: cpn.parameter(type=params.conint(ge=0), default=None, desc="random state"),
ctx_mode: cpn.parameter(type=params.string_choice(["hetero", "homo", "local"]), default="hetero",
desc="sampling mode, 'homo' & 'local' will both sample locally"),
federated_sample: cpn.parameter(type=bool, default=True,
desc="sampling mode, 'homo' & 'local' scenario should sample locally, "
"default True for hetero scenario"),
train_output_data: cpn.dataframe_output(roles=[GUEST, HOST], optional=True),
validate_output_data: cpn.dataframe_output(roles=[GUEST, HOST], optional=True),
test_output_data: cpn.dataframe_output(roles=[GUEST, HOST], optional=True),
):
if isinstance(train_size, float) or isinstance(validate_size, float) or isinstance(test_size, float):
if train_size + validate_size + test_size > 1:
raise ValueError("(train_size + validate_size + test_size) should be less than or equal to 1.0")
if train_size is None and validate_size is None and test_size is None:
train_size = 0.8
validate_size = 0.2
test_size = 0.0

# logger.info(f"in cpn received train_size: {train_size}, validate_size: {validate_size}, test_size: {test_size}")
# check if local but federated sample
if federated_sample and len(ctx.parties.ranks) < 2:
raise ValueError(f"federated sample can only be called when both 'guest' and 'host' present. Please check")

sub_ctx = ctx.sub_ctx("train")
if role.is_guest:
module = DataSplitModuleGuest(train_size, validate_size, test_size, stratified, random_state, ctx_mode)
module = DataSplitModuleGuest(train_size, validate_size, test_size, stratified, random_state, federated_sample)
elif role.is_host:
module = DataSplitModuleHost(train_size, validate_size, test_size, stratified, random_state, ctx_mode)
module = DataSplitModuleHost(train_size, validate_size, test_size, stratified, random_state, federated_sample)
input_data = input_data.read()

train_data_set, validate_data_set, test_data_set = module.fit(sub_ctx, input_data)
Expand Down
18 changes: 9 additions & 9 deletions python/fate/components/components/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ def sample(
"otherwise a dict of pairs like [label_i, sample_rate_i],"
"e.g. {0: 0.5, 1: 0.8, 2: 0.3}, any label unspecified in dict will not be sampled,"
"default: 1.0, cannot be used with n"),
n: cpn.parameter(type=Union[params.conint(gt=0),
Mapping[Union[params.conint(), params.confloat()], params.conint(gt=0)]], default=None, optional=True,
n: cpn.parameter(type=params.conint(gt=0), default=None, optional=True,
desc="exact sample size, it should be an int greater than 0, "
"otherwise a dict of pairs like [label_i, sample_count_i],"
"e.g. {0: 50, 1: 20, 2: 30}, any label unspecified in dict will not be sampled,"
"default: None, cannot be used with frac"),
random_state: cpn.parameter(type=params.conint(ge=0), default=None,
desc="random state"),
ctx_mode: cpn.parameter(type=params.string_choice(["hetero", "homo", "local"]), default="hetero",
desc="sampling mode, 'homo' & 'local' will both sample locally"),
federated_sample: cpn.parameter(type=bool, default=True,
desc="sampling mode, 'homo' & 'local' scenario should sample locally,"
"default True for 'hetero' federation scenario"),
output_data: cpn.dataframe_output(roles=[GUEST, HOST])
):
if frac is not None and n is not None:
Expand All @@ -58,14 +56,16 @@ def sample(
raise ValueError(f"replace has to be set to True when sampling frac greater than 1.")
if n is None and frac is None:
frac = 1.0

# check if local but federated sample
if federated_sample and len(ctx.parties.ranks) < 2:
raise ValueError(f"federated sample can only be called when both 'guest' and 'host' present. Please check")
sub_ctx = ctx.sub_ctx("train")
if role.is_guest:
module = SampleModuleGuest(mode=mode, replace=replace, frac=frac, n=n,
random_state=random_state, ctx_mode=ctx_mode)
random_state=random_state, federated_sample=federated_sample)
elif role.is_host:
module = SampleModuleHost(mode=mode, replace=replace, frac=frac, n=n,
random_state=random_state, ctx_mode=ctx_mode)
random_state=random_state, federated_sample=federated_sample)
else:
raise ValueError(f"unknown role")
input_data = input_data.read()
Expand Down
Loading

0 comments on commit a8fb563

Please sign in to comment.