From c9e24bfac4ef0778b96b97adf57b581193d52c07 Mon Sep 17 00:00:00 2001 From: ThomasMeissnerDS Date: Thu, 11 Jan 2024 06:10:39 +0100 Subject: [PATCH] Fix univariate plots --- README.md | 1 - bluecast/eda/analyse.py | 12 +++--------- bluecast/tests/test_analyse.py | 3 +-- docs/source/index.md | 1 - 4 files changed, 4 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 64f2ff4f..138538f3 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,6 @@ train_data = feat_type_detector.fit_transform_feature_types(train_data) # show univariate plots univariate_plots( train_data.loc[:, feat_type_detector.num_columns], # here the target column EC1 is already included - "EC1", ) # show bi-variate plots diff --git a/bluecast/eda/analyse.py b/bluecast/eda/analyse.py index 197b6761..68739cc6 100644 --- a/bluecast/eda/analyse.py +++ b/bluecast/eda/analyse.py @@ -12,20 +12,14 @@ from sklearn.manifold import TSNE -def univariate_plots(df: pd.DataFrame, target: str) -> None: +def univariate_plots(df: pd.DataFrame) -> None: """ - Plots univariate plots for all the columns in the dataframe. - The target column must be part of the provided DataFrame. + Plots univariate plots for all the columns in the dataframe. Only numerical columns are expected. + The target column does not need to be part of the provided DataFrame. Expects numeric columns only. """ - if target not in df.columns.to_list(): - raise ValueError("Target column must be part of the provided DataFrame") for col in df.columns: - # Check if the col is the target column (EC1 or EC2) - if col == target: - continue # Skip target columns in univariate analysis - plt.figure(figsize=(8, 4)) # Histogram diff --git a/bluecast/tests/test_analyse.py b/bluecast/tests/test_analyse.py index b223197d..1fae9ac9 100644 --- a/bluecast/tests/test_analyse.py +++ b/bluecast/tests/test_analyse.py @@ -80,8 +80,7 @@ def test_univariate_plots(synthetic_train_test_data): univariate_plots( synthetic_train_test_data[0].loc[ :, ["numerical_feature_1", "numerical_feature_2", "numerical_feature_3"] - ], - "target", + ] ) assert True diff --git a/docs/source/index.md b/docs/source/index.md index 64f2ff4f..138538f3 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -148,7 +148,6 @@ train_data = feat_type_detector.fit_transform_feature_types(train_data) # show univariate plots univariate_plots( train_data.loc[:, feat_type_detector.num_columns], # here the target column EC1 is already included - "EC1", ) # show bi-variate plots