Skip to content

Commit

Permalink
feat: implement violin plots (#900)
Browse files Browse the repository at this point in the history
Closes #867 

### Summary of Changes

This pull request applies the changes, which implement violin plots for
tables and columns, while also adding corresponding tests as well as
adding a section in data_visualization.ipynb.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people committed Jul 18, 2024
1 parent 5a0cdb3 commit 9f5992a
Show file tree
Hide file tree
Showing 23 changed files with 320 additions and 19 deletions.
70 changes: 64 additions & 6 deletions docs/tutorials/data_visualization.ipynb

Large diffs are not rendered by default.

68 changes: 67 additions & 1 deletion src/safeds/data/tabular/plotting/_column_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def box_plot(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
if self._column.row_count > 0:
_check_column_is_numeric(self._column, operation="create a box plot")

import matplotlib.pyplot as plt

def _set_boxplot_colors(box: dict, theme: str) -> None:
Expand Down Expand Up @@ -127,6 +126,73 @@ def _set_boxplot_colors(box: dict, theme: str) -> None:

return _figure_to_image(fig)

def violin_plot(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
Create a violin plot for the values in the column. This is only possible for numeric columns.
Parameters
----------
theme:
The color theme of the plot. Default is "light".
Returns
-------
plot:
The violin plot as an image.
Raises
------
TypeError
If the column is not numeric.
Examples
--------
>>> from safeds.data.tabular.containers import Column
>>> column = Column("test", [1, 2, 3])
>>> violinplot = column.plot.violin_plot()
"""
if self._column.row_count > 0:
_check_column_is_numeric(self._column, operation="create a violin plot")
from math import nan

import matplotlib.pyplot as plt

style = "dark_background" if theme == "dark" else "default"
with plt.style.context(style):
if theme == "dark":
plt.rcParams.update(
{
"text.color": "white",
"axes.labelcolor": "white",
"axes.edgecolor": "white",
"xtick.color": "white",
"ytick.color": "white",
"grid.color": "gray",
"grid.linewidth": 0.5,
},
)
else:
plt.rcParams.update(
{
"grid.linewidth": 0.5,
},
)

fig, ax = plt.subplots()
data = self._column._series.drop_nulls()
if len(data) == 0:
data = [nan, nan]
ax.violinplot(
data,
)

ax.set(title=self._column.name)

ax.yaxis.grid(visible=True)
fig.tight_layout()

return _figure_to_image(fig)

def histogram(self, *, max_bin_count: int = 10, theme: Literal["dark", "light"] = "light") -> Image:
"""
Create a histogram for the values in the column.
Expand Down
106 changes: 94 additions & 12 deletions src/safeds/data/tabular/plotting/_table_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,102 @@ def box_plots(self, *, theme: Literal["dark", "light"] = "light") -> Image:
fig.delaxes(axs[number_of_rows - 1, i])

fig.tight_layout()
return _figure_to_image(fig)

def violin_plots(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
Create a violin plot for every numerical column.
Parameters
----------
theme:
The color theme of the plot. Default is "light".
Returns
-------
plot:
The violin plot(s) as an image.
Raises
------
NonNumericColumnError
If the table contains only non-numerical columns.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"a": [1, 2], "b": [3, 42]})
>>> image = table.plot.violin_plots()
"""
numerical_table = self._table.remove_non_numeric_columns()
if numerical_table.column_count == 0:
raise NonNumericColumnError("This table contains only non-numerical columns.")
from math import ceil

import matplotlib.pyplot as plt

style = "dark_background" if theme == "dark" else "default"
with plt.style.context(style):
if theme == "dark":
plt.rcParams.update(
{
"text.color": "white",
"axes.labelcolor": "white",
"axes.edgecolor": "white",
"xtick.color": "white",
"ytick.color": "white",
"grid.color": "gray",
"grid.linewidth": 0.5,
},
)
else:
plt.rcParams.update(
{
"grid.linewidth": 0.5,
},
)

columns = numerical_table.to_columns()
columns = [column._series.drop_nulls() for column in columns]
max_width = 3
number_of_columns = len(columns) if len(columns) <= max_width else max_width
number_of_rows = ceil(len(columns) / number_of_columns)

fig, axs = plt.subplots(nrows=number_of_rows, ncols=number_of_columns)
line = 0
for i, column in enumerate(columns):
data = column.to_list()

if i % number_of_columns == 0 and i != 0:
line += 1

if number_of_columns == 1:
axs.violinplot(
data,
)
axs.set_title(numerical_table.column_names[i])
break

style = "dark_background" if theme == "dark" else "default"
with plt.style.context(style):
if theme == "dark":
plt.rcParams.update(
{
"text.color": "white",
"axes.labelcolor": "white",
"axes.edgecolor": "white",
"xtick.color": "white",
"ytick.color": "white",
},
if number_of_rows == 1:
axs[i].violinplot(
data,
)
axs[i].set_title(numerical_table.column_names[i])

else:
axs[line, i % number_of_columns].violinplot(
data,
)
return _figure_to_image(fig)
axs[line, i % number_of_columns].set_title(numerical_table.column_names[i])

# removes unused ax indices, so there wont be empty plots
last_filled_ax_index = len(columns) % number_of_columns
for i in range(last_filled_ax_index, number_of_columns):
if number_of_rows != 1 and last_filled_ax_index != 0:
fig.delaxes(axs[number_of_rows - 1, i])

fig.tight_layout()
return _figure_to_image(fig)

def correlation_heatmap(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import ColumnTypeError
from syrupy import SnapshotAssertion


@pytest.mark.parametrize(
"column",
[
Column("a", []),
Column("a", [0]),
Column("a", [0, 1]),
],
ids=[
"empty",
"one row",
"multiple rows",
],
)
def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAssertion) -> None:
violin_plot = column.plot.violin_plot()
assert violin_plot == snapshot_png_image


@pytest.mark.parametrize(
"column",
[
Column("a", []),
Column("a", [0]),
Column("a", [0, 1]),
],
ids=[
"empty",
"one row",
"multiple rows",
],
)
def test_should_match_dark_snapshot(column: Column, snapshot_png_image: SnapshotAssertion) -> None:
violin_plot = column.plot.violin_plot(theme="dark")
assert violin_plot == snapshot_png_image


def test_should_raise_if_column_contains_non_numerical_values() -> None:
column = Column("a", ["A", "B", "C"])
with pytest.raises(ColumnTypeError):
column.plot.violin_plot()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 49 additions & 0 deletions tests/safeds/data/tabular/plotting/test_plot_violin_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.exceptions import NonNumericColumnError
from syrupy import SnapshotAssertion


@pytest.mark.parametrize(
"table",
[
Table({"A": [1, 2, 3]}),
Table({"A": [1, 2, 3], "B": ["A", "A", "Bla"], "C": [True, True, False], "D": [1.0, 2.1, 4.5]}),
Table({"A": [1, 2, 3], "B": [1.0, 2.1, 4.5], "C": [1, 2, 3], "D": [1.0, 2.1, 4.5]}),
],
ids=["one column", "four columns (some non-numeric)", "four columns (all numeric)"],
)
def test_should_match_snapshot(table: Table, snapshot_png_image: SnapshotAssertion) -> None:
violinplots = table.plot.violin_plots()
assert violinplots == snapshot_png_image


@pytest.mark.parametrize(
"table",
[
Table({"A": [1, 2, 3]}),
Table({"A": [1, 2, 3], "B": ["A", "A", "Bla"], "C": [True, True, False], "D": [1.0, 2.1, 4.5]}),
Table({"A": [1, 2, 3], "B": [1.0, 2.1, 4.5], "C": [1, 2, 3], "D": [1.0, 2.1, 4.5]}),
],
ids=["one column", "four columns (some non-numeric)", "four columns (all numeric)"],
)
def test_should_match_dark_snapshot(table: Table, snapshot_png_image: SnapshotAssertion) -> None:
violinplots = table.plot.violin_plots(theme="dark")
assert violinplots == snapshot_png_image


def test_should_raise_if_column_contains_non_numerical_values() -> None:
table = Table.from_dict({"A": ["1", "2", "3.5"], "B": ["0.2", "4", "77"]})
with pytest.raises(
NonNumericColumnError,
match=(
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis table contains only"
r" non-numerical columns."
),
):
table.plot.violin_plots()


def test_should_fail_on_empty_table() -> None:
with pytest.raises(NonNumericColumnError):
Table().plot.violin_plots()

0 comments on commit 9f5992a

Please sign in to comment.