diff --git a/examples/pipeline/test_single_linr.py b/examples/pipeline/test_single_linr.py new file mode 100644 index 0000000000..0828cf1671 --- /dev/null +++ b/examples/pipeline/test_single_linr.py @@ -0,0 +1,72 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel + +pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") + +"""feature_scale_0 = FeatureScale(name="feature_scale_0", + method="min_max", + train_data=intersection_0.outputs["output_data"]) + +feature_scale_1 = FeatureScale(name="feature_scale_1", + test_data=intersection_1.outputs["output_data"], + input_model=feature_scale_0.outputs["output_model"])""" + +linr_0 = CoordinatedLinR("linr_0", + max_iter=10, + batch_size=-1, + init_param={"fit_intercept": False}) + +linr_0.guest.component_setting(train_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace="experiment")) +linr_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name="motor_hetero_host", + namespace="experiment")) + +evaluation_0 = Evaluation("evaluation_0", + runtime_roles="guest", + input_data=linr_0.outputs["train_output_data"]) + +# pipeline.add_task(feature_scale_0) +# pipeline.add_task(feature_scale_1) +pipeline.add_task(linr_0) +pipeline.add_task(evaluation_0) +# pipeline.add_task(hetero_feature_binning_0) +pipeline.compile() +print(pipeline.get_dag()) +pipeline.fit() + +# print(pipeline.get_task_info("statistics_0").get_output_model()) +print(pipeline.get_task_info("linr_0").get_output_model()) +print(pipeline.get_task_info("linr_0").get_output_data()) +print(pipeline.get_task_info("evaluation_0").get_output_metrics()) + +pipeline.deploy([linr_0]) + +predict_pipeline = FateFlowPipeline() + +deployed_pipeline = pipeline.get_deployed_pipeline() +linr_0.guest.component_setting(test_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) +linr_0.hosts[0].component_setting(test_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment")) + +predict_pipeline.add_task(deployed_pipeline) +predict_pipeline.compile() +# print("\n\n\n") +# print(predict_pipeline.compile().get_dag()) +predict_pipeline.predict()