Skip to content

Commit

Permalink
include axes as optional argument (#41)
Browse files Browse the repository at this point in the history
* include axes as optional argument
  • Loading branch information
pacanada authored Nov 21, 2022
1 parent d19dd16 commit e83165a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ repos:
hooks:
- id: black

- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
hooks:
- id: flake8
additional_dependencies:
- id: flake8
additional_dependencies:
- flake8-unused-arguments
- flake8-bugbear
- pep8-naming
Expand Down
41 changes: 33 additions & 8 deletions dfds_ds_toolbox/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
)


def plot_classification_proba_histogram(y_true: Sequence[int], y_pred: Sequence[float]) -> Figure:
def plot_classification_proba_histogram(
y_true: Sequence[int], y_pred: Sequence[float], ax: Axes = None
) -> Figure:
"""Plot histogram of predictions for binary classifiers.
Args:
y_true: 1D array of binary target values, 0 or 1.
y_pred: 1D array of predicted target values, probability of class 1.
ax: Optional pre-existing axis to plot on
"""
fig, ax = plt.subplots()
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
bins = np.linspace(0, 1, 11)
df = pd.DataFrame()
df["Actual class"] = y_true
Expand Down Expand Up @@ -144,19 +150,25 @@ def get_trend_stats(


def plot_regression_predicted_vs_actual(
y_true: Sequence[float], y_pred: Sequence[float], alpha: float = 0.2
y_true: Sequence[float], y_pred: Sequence[float], alpha: float = 0.2, ax: Axes = None
) -> Figure:
"""Scatter plot of the predicted vs true targets for regression problems.
Args:
y_true: array with observed values
y_pred: array with predicted values
alpha: transparency of the dots on the scatter plot
ax: Optional pre-existing axis to plot on
Returns:
Figure
"""
fig, ax = plt.subplots()

if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
min_val = min(min(y_true), min(y_pred))
max_val = max(max(y_true), max(y_pred))
ax.plot([min_val, max_val], [min_val, max_val])
Expand Down Expand Up @@ -198,7 +210,9 @@ def plot_roc_curve(
return fig


def plot_lift_curve(y_true: Sequence[int], y_pred: Sequence[float], n_bins: int = 10) -> Figure:
def plot_lift_curve(
y_true: Sequence[int], y_pred: Sequence[float], n_bins: int = 10, ax: Axes = None
) -> Figure:
"""Plot lift curve, i.e. how much better than baserate is the model at different thresholds.
Lift of 1 corresponds to predicting the baserate for the whole sample.
Expand All @@ -207,11 +221,16 @@ def plot_lift_curve(y_true: Sequence[int], y_pred: Sequence[float], n_bins: int
y_true: array with observed values, either 0 or 1.
y_pred: array with predicted probabilities, float between 0 and 1.
n_bins: number of bins to use
ax: Optional pre-existing axis to plot on
Returns:
matplotlib Figure
"""
fig, ax = plt.subplots()
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# Ensure numpy arrays. Save to new variable to avoid redefining a type. Mypy doesn't like that.
y_true_array = np.array(y_true)
y_pred_array = np.array(y_pred)
Expand All @@ -233,19 +252,25 @@ def plot_lift_curve(y_true: Sequence[int], y_pred: Sequence[float], n_bins: int
return fig


def plot_gain_chart(y_true: Sequence[int], y_pred: Sequence[float], n_bins: int = 10) -> Figure:
def plot_gain_chart(
y_true: Sequence[int], y_pred: Sequence[float], n_bins: int = 10, ax: Axes = None
) -> Figure:
"""The cumulative gains chart shows the percentage of the overall number of cases in a given
category "gained" by targeting a percentage of the total number of cases.
Args:
y_true: array with observed values, either 0 or 1.
y_pred: array with predicted probabilities, float between 0 and 1.
n_bins: number of bins to use
ax: Optional pre-existing axis to plot on
Returns:
matplotlib Figure
"""
fig, ax = plt.subplots()
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# Ensure numpy arrays. Save to new variable to avoid redefining a type. Mypy doesn't like that.
y_true_array = np.array(y_true)
y_pred_array = np.array(y_pred)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dfds_ds_toolbox"
version = "0.10.1"
version = "0.10.2"
description = "A toolbox for data science"
license = "MIT"
authors = ["Data Science Chapter at DFDS <urcha@dfds.com>"]
Expand Down

0 comments on commit e83165a

Please sign in to comment.