Skip to content

Commit

Permalink
add stat (#4663)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Jun 13, 2023
1 parent 139ac1e commit 46f9b14
Show file tree
Hide file tree
Showing 9 changed files with 414 additions and 18 deletions.
118 changes: 118 additions & 0 deletions examples/pipeline/test_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json

from fate_client.pipeline import StandalonePipeline, FateFlowPipeline
from fate_client.pipeline.components.fate import FeatureScale
from fate_client.pipeline.components.fate import Intersection
from fate_client.pipeline.components.fate import Reader
from fate_client.pipeline.components.fate import Statistics
from fate_client.pipeline.utils import test_utils


def main(config="./config.yaml", namespace=""):
if isinstance(config, str):
config = test_utils.load_job_config(config)

parties = config.parties
guest = parties.guest[0]
host = parties.host[0]
arbiter = parties.arbiter[0]

if config.work_mode == 0:
pipeline = StandalonePipeline().set_roles(guest=guest, host=host, arbiter=arbiter)
else:
pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter)
reader_0 = Reader(name="reader_0")
cluster = config.work_mode

if cluster:
reader_0.guest.component_param(table_name="breast_hetero_guest",
namespace=f"{namespace}experiment",
# path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv",
# format="csv",
# match_id_name="id",
# delimiter=",",
label_name="y",
label_type="float32",
dtype="float32")

reader_0.hosts[0].component_param(table_name="breast_hetero_host",
namespace=f"{namespace}experiment",
# path="file:///data/projects/fate/examples/data/breast_hetero_host.csv",
# match_id_name="id",
# delimiter=",",
label_name=None,
dtype="float32")
else:
data_base = config.data_base_dir

reader_0.guest.component_param(path=f"file://{data_base}/examples/data/breast_hetero_guest.csv",
# path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv",
format="csv",
match_id_name="id",
delimiter=",",
label_name="y",
label_type="float32",
dtype="float32")

reader_0.hosts[0].component_param(path=f"file://{data_base}/examples/data/breast_hetero_host.csv",
# path="file:///data/projects/fate/examples/data/breast_hetero_host.csv",
format="csv",
match_id_name="id",
delimiter=",",
label_name=None,
dtype="float32")

intersection_0 = Intersection(name="intersection_0",
method="raw",
input_data=reader_0.outputs["output_data"])

intersection_1 = Intersection(name="intersection_1",
method="raw",
input_data=reader_0.outputs["output_data"])

feature_scale_0 = FeatureScale(name="feature_scale_0",
method="standard",
train_data=intersection_0.outputs["output_data"])

feature_scale_1 = FeatureScale(name="feature_scale_1",
test_data=intersection_1.outputs["output_data"],
input_model=feature_scale_0.outputs["output_model"])

statistics_0 = Statistics(name="statistics_0", train_data=feature_scale_1.outputs["test_output_data"],
metrics=["mean", "max", "std", "var", "kurtosis", "skewness"])

pipeline.add_task(reader_0)
pipeline.add_task(feature_scale_0)
pipeline.add_task(feature_scale_1)
pipeline.add_task(intersection_0)
pipeline.add_task(intersection_1)
pipeline.add_task(statistics_0)
pipeline.compile()
print(pipeline.get_dag())
pipeline.fit()
print(json.dumps(pipeline.get_task_info("statistics_0").get_output_model(), indent=4))


if __name__ == "__main__":
parser = argparse.ArgumentParser("PIPELINE DEMO")
parser.add_argument("-config", type=str, default="/Users/yuwu/PycharmProjects/FATE/examples/config.yaml",
help="config file")
parser.add_argument("-namespace", type=str, default="",
help="namespace for data stored in FATE")
args = parser.parse_args()
main(config=args.config, namespace=args.namespace)
14 changes: 6 additions & 8 deletions python/fate/arch/dataframe/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import operator
import pandas as pd

from typing import List, Union

import numpy as np
import pandas as pd

from .manager import DataManager, Schema
from .ops import (
aggregate_indexer,
transform_to_table,
get_partition_order_mappings
)
from .manager import DataManager, Schema


class DataFrame(object):
Expand Down Expand Up @@ -256,11 +256,9 @@ def sigmoid(self) -> "DataFrame":
def count(self) -> "int":
return self.shape[0]

def describe(self, metric_kwargs=None):
def describe(self, ddof=1, unbiased=False):
from .ops._stat import describe
if metric_kwargs is None:
metric_kwargs = dict()
return describe(self, metric_kwargs)
return describe(self, ddof=ddof, unbiased=unbiased)

def quantile(self, q, axis=0, method="quantile", ):
...
Expand Down
15 changes: 7 additions & 8 deletions python/fate/arch/dataframe/ops/_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import numpy as np
import pandas as pd
import torch

from .._dataframe import DataFrame
from ..manager import DataManager


FLOATING_POINT_ZERO = 1e-14


Expand Down Expand Up @@ -229,17 +229,17 @@ def variation(df: "DataFrame", ddof=1):
return std(df, ddof==ddof) / mean(df)


def describe(df: "DataFrame", metric_kwargs):
def describe(df: "DataFrame", ddof=1, unbiased=False):
stat_metrics = dict()
stat_metrics["sum"] = sum(df)
stat_metrics["min"] = min(df)
stat_metrics["max"] = max(df)
stat_metrics["mean"] = mean(df)
stat_metrics["std"] = std(df) if "std" not in metric_kwargs else std(df, ddof=metric_kwargs["std"])
stat_metrics["var"] = var(df) if "var" not in metric_kwargs else var(df, ddof=metric_kwargs["var"])
stat_metrics["variation"] = variation(df) if "variation" not in metric_kwargs else variation(df, ddof=metric_kwargs["variation"])
stat_metrics["skew"] = skew(df) if "skew" not in metric_kwargs else skew(df, unbiased=metric_kwargs["unbiased"])
stat_metrics["kurt"] = kurt(df) if "kurt" not in metric_kwargs else kurt(df, unbiased=metric_kwargs["unbiased"])
stat_metrics["std"] = std(df, ddof=ddof)
stat_metrics["var"] = var(df, ddof=ddof)
stat_metrics["variation"] = variation(df, ddof=ddof)
stat_metrics["skew"] = skew(df, unbiased=unbiased)
stat_metrics["kurt"] = kurt(df, unbiased=unbiased)
stat_metrics["na_count"] = df.isna().sum()

return pd.DataFrame(stat_metrics)
Expand All @@ -259,4 +259,3 @@ def _post_process(reduce_ret, operable_blocks, data_manager: "DataManager") -> "
ret[loc] = reduce_ret[idx][offset]

return pd.Series(ret, index=field_names)

3 changes: 2 additions & 1 deletion python/fate/components/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from .hetero_lr import hetero_lr
from .intersection import intersection
from .reader import reader
from .statistics import statistics

BUILDIN_COMPONENTS = [hetero_lr, reader, feature_scale, intersection, evaluation]
BUILDIN_COMPONENTS = [hetero_lr, reader, feature_scale, intersection, evaluation, statistics]
94 changes: 94 additions & 0 deletions python/fate/components/components/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# Copyright 2023 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, List

from fate.components import (
GUEST,
HOST,
DatasetArtifact,
Input,
ModelArtifact,
Output,
Role,
cpn,
params
)


@cpn.component(roles=[GUEST, HOST])
def statistics(ctx, role):
...


@statistics.train()
@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST])
@cpn.parameter("metrics",
type=Union[List[params.statistic_metrics_param()], params.statistic_metrics_param()],
default=['mean', 'std', 'min', 'max'],
desc="metrics to be computed, default ['count', 'mean', 'std', 'min', 'max']")
@cpn.parameter("ddof", type=params.conint(ge=0), default=1, desc="Delta Degrees of Freedom for std and var, default 1")
@cpn.parameter("bias", type=bool, default=True,
desc="If False, the calculations of skewness and kurtosis are corrected for statistical bias.")
@cpn.parameter("skip_col", type=List[str], default=None, optional=True,
desc="columns to be skipped, default None; if None, statistics will be computed over all columns")
@cpn.parameter("use_anonymous", type=bool, default=False,
desc="bool, whether interpret `skip_col` as anonymous column names")
@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST])
def statistics_train(
ctx,
role: Role,
train_data,
metrics,
ddof,
bias,
skip_col,
use_anonymous,
output_model,
):
train(ctx, train_data, output_model, metrics, ddof, bias, skip_col, use_anonymous)


def train(ctx, train_data, output_model, metrics, ddof, bias, skip_col, use_anonymous):
from fate.ml.statistics.statistics import FeatureStatistics

with ctx.sub_ctx("train") as sub_ctx:
train_data = sub_ctx.reader(train_data).read_dataframe().data
select_cols = get_to_compute_cols(train_data.schema.columns, train_data.schema.anonymous_columns,
skip_col, use_anonymous)
if isinstance(metrics, str):
metrics = [metrics]
if len(metrics) > 1:
for metric in metrics:
if metric == "describe":
raise ValueError(f"'describe' should not be combined with additional metric names.")
stat_computer = FeatureStatistics(list(set(metrics)), ddof, bias)
train_data = train_data[select_cols]
stat_computer.fit(sub_ctx, train_data)

model = stat_computer.to_model()
with output_model as model_writer:
model_writer.write_model("statistics", model, metadata={"model_type": "statistic"})


def get_to_compute_cols(columns, anonymous_columns, skip_columns, use_anonymous):
if skip_columns is None:
skip_columns = []
if use_anonymous and skip_columns is not None:
skip_columns = [anonymous_columns[columns.index(col)] for col in skip_columns]
skip_col_set = set(skip_columns)
select_columns = [col for col in columns if col not in skip_col_set]

return select_columns
2 changes: 1 addition & 1 deletion python/fate/components/params/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pydantic import validate_arguments

from ._cipher import CipherParamType, PaillierCipherParam
from ._fields import Parameter, confloat, conint, jsonschema, parse, string_choice
from ._learning_rate import learning_rate_param
from ._metrics import metrics_param, statistic_metrics_param
from ._optimizer import optimizer_param
from ._penalty import penalty_param
50 changes: 50 additions & 0 deletions python/fate/components/params/_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Copyright 2023 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Type

from ._fields import StringChoice


class Metrics(StringChoice):
choice = {}


class StatisticMetrics(StringChoice):
choice = {}


def statistic_metrics_param(count=True, sum=True, min=True, max=True, mean=True, median=True,
std=True, var=True, coe=True,
missing_count=True, missing_ratio=True,
skewness=True, kurtosis=True,
describe=True) -> Type[str]:
choice = {"count": count, "sum": sum, "max": max, "min": min, "mean": mean, "median": median,
"std": std, "var": var, "coefficient_of_variation": coe,
"missing_count": missing_count, "missing_ratio": missing_ratio,
"skewness": skewness, "kurtosis": kurtosis}
namespace = dict(
choice={k for k, v in choice.items() if v},
)
return type("StatisticMetrics", (StatisticMetrics,), namespace)


def metrics_param(auc=True, ks=True, accuracy=True, mse=True) -> Type[str]:
choice = {"auc": auc, "ks": ks, "accuracy": accuracy,
"mse": mse}
namespace = dict(
choice={k for k, v in choice.items() if v},
)
return type("Metrics", (Metrics,), namespace)
16 changes: 16 additions & 0 deletions python/fate/ml/statistics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Copyright 2023 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .statistics import FeatureStatistics
Loading

0 comments on commit 46f9b14

Please sign in to comment.