Skip to content

Commit

Permalink
Fix univariate plots
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jan 11, 2024
1 parent 2f8f705 commit c9e24bf
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 13 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ train_data = feat_type_detector.fit_transform_feature_types(train_data)
# show univariate plots
univariate_plots(
train_data.loc[:, feat_type_detector.num_columns], # here the target column EC1 is already included
"EC1",
)

# show bi-variate plots
Expand Down
12 changes: 3 additions & 9 deletions bluecast/eda/analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,14 @@
from sklearn.manifold import TSNE


def univariate_plots(df: pd.DataFrame, target: str) -> None:
def univariate_plots(df: pd.DataFrame) -> None:
"""
Plots univariate plots for all the columns in the dataframe.
The target column must be part of the provided DataFrame.
Plots univariate plots for all the columns in the dataframe. Only numerical columns are expected.
The target column does not need to be part of the provided DataFrame.
Expects numeric columns only.
"""
if target not in df.columns.to_list():
raise ValueError("Target column must be part of the provided DataFrame")
for col in df.columns:
# Check if the col is the target column (EC1 or EC2)
if col == target:
continue # Skip target columns in univariate analysis

plt.figure(figsize=(8, 4))

# Histogram
Expand Down
3 changes: 1 addition & 2 deletions bluecast/tests/test_analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_univariate_plots(synthetic_train_test_data):
univariate_plots(
synthetic_train_test_data[0].loc[
:, ["numerical_feature_1", "numerical_feature_2", "numerical_feature_3"]
],
"target",
]
)
assert True

Expand Down
1 change: 0 additions & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ train_data = feat_type_detector.fit_transform_feature_types(train_data)
# show univariate plots
univariate_plots(
train_data.loc[:, feat_type_detector.num_columns], # here the target column EC1 is already included
"EC1",
)

# show bi-variate plots
Expand Down

0 comments on commit c9e24bf

Please sign in to comment.