Skip to content

Commit

Permalink
Update pipeline
Browse files Browse the repository at this point in the history
Signed-off-by: weijingchen <talkingwallace@sohu.com>

Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Oct 26, 2023
1 parent 62d8b0f commit 0f42e47
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions examples/pipeline/hetero_nn/test_nn_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
from fate_test.utils import parse_summary_result
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.interface import DataWarehouseChannel
from fate_client.pipeline.utils import test_utils
from fate_client.pipeline.components.fate.evaluation import Evaluation
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.interface import DataWarehouseChannel
from fate_client.pipeline.components.fate.nn.torch import nn, optim
from fate_client.pipeline.components.fate.nn.torch.base import Sequential
from fate_client.pipeline.components.fate.hetero_nn import HeteroNN, get_config_of_default_runner
from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments


def main(config="../../config.yaml", namespace=""):
# obtain config
if isinstance(config, str):
config = test_utils.load_job_config(config)
parties = config.parties
guest = parties.guest[0]
host = parties.host[0]
arbiter = parties.arbiter[0]

epochs = 10
batch_size = 64
in_feat = 30
out_feat = 16
lr = 0.01

guest_train_data = {"name": "breast_homo_guest", "namespace": f"experiment{namespace}"}
host_train_data = {"name": "breast_homo_host", "namespace": f"experiment{namespace}"}
pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter)

conf = get_config_of_default_runner(
model=Sequential(
nn.Linear(in_feat, out_feat),
nn.ReLU(),
nn.Linear(out_feat ,1),
nn.Sigmoid()
),
loss=nn.BCELoss(),
optimizer=optim.Adam(lr=lr),
training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514, logging_strategy='steps'),
task_type='binary'
)


hetero_nn_0 = HeteroNN(
'hetero_nn_0'
)

hetero_nn_0.guest.component_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"]))
hetero_nn_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"]))

pipeline.add_task(hetero_nn_0)
pipeline.compile()
pipeline.fit()


if __name__ == "__main__":
parser = argparse.ArgumentParser("PIPELINE DEMO")
parser.add_argument("--config", type=str, default="../config.yaml",
help="config file")
parser.add_argument("--namespace", type=str, default="",
help="namespace for data stored in FATE")
args = parser.parse_args()
main(config=args.config, namespace=args.namespace)

0 comments on commit 0f42e47

Please sign in to comment.