Skip to content

Commit

Permalink
feat: suggest similar columns if column gets accessed that doesnt exi…
Browse files Browse the repository at this point in the history
…st (#385)

Closes #203

### Summary of Changes

* Added a method `_get_similar_columns` to find columns with a name
similar to a given name.
* Updated `UnknownColumnNameError` to allow for passing name
suggestions.
* Used these suggestions in `get_column`, `keep_only_columns`,
`remove_columns`, `rename_column`, `replace_column`, `transform_column`,
`plot_lineplot`, `plot_scatterplot`.

Co-authored-by: jxnior01 <129027012+jxnior01@users.noreply.github.com>
  • Loading branch information
robmeth and jxnior01 authored Jul 13, 2023
1 parent 5cfba79 commit 6a097a4
Show file tree
Hide file tree
Showing 6 changed files with 374 additions and 9 deletions.
222 changes: 221 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ scikit-learn = "^1.2.0"
seaborn = "^0.12.2"
openpyxl = "^3.1.2"
scikit-image = "^0.21.0"
levenshtein = "^0.21.1"

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.1"
Expand Down
57 changes: 51 additions & 6 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar

import Levenshtein
import matplotlib.pyplot as plt
import numpy as np
import openpyxl
Expand Down Expand Up @@ -597,7 +598,8 @@ def get_column(self, column_name: str) -> Column:
Column('b', [2])
"""
if not self.has_column(column_name):
raise UnknownColumnNameError([column_name])
similar_columns = self._get_similar_columns(column_name)
raise UnknownColumnNameError([column_name], similar_columns)

return Column._from_pandas_series(
self._data[column_name],
Expand Down Expand Up @@ -695,6 +697,34 @@ def get_row(self, index: int) -> Row:

return Row._from_pandas_dataframe(self._data.iloc[[index]], self._schema)

def _get_similar_columns(self, column_name: str) -> list[str]:
"""
Get all the column names in a Table that are similar to a given name.
Parameters
----------
column_name : str
The name to compare the Table's column names to.
Returns
-------
similar_columns: list[str]
A list of all column names in the Table that are similar or equal to the given column name.
"""
similar_columns = []
similarity = 0.6
i = 0
while i < len(self.column_names):
if Levenshtein.jaro_winkler(self.column_names[i], column_name) >= similarity:
similar_columns.append(self.column_names[i])
i += 1
if len(similar_columns) == 4 and similarity < 0.9:
similarity += 0.1
similar_columns = []
i = 0

return similar_columns

# ------------------------------------------------------------------------------------------------------------------
# Information
# ------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1106,11 +1136,13 @@ def keep_only_columns(self, column_names: list[str]) -> Table:
1 4
"""
invalid_columns = []
similar_columns: list[str] = []
for name in column_names:
if not self._schema.has_column(name):
similar_columns = similar_columns + self._get_similar_columns(name)
invalid_columns.append(name)
if len(invalid_columns) != 0:
raise UnknownColumnNameError(invalid_columns)
raise UnknownColumnNameError(invalid_columns, similar_columns)

clone = self._copy()
clone = clone.remove_columns(list(set(self.column_names) - set(column_names)))
Expand Down Expand Up @@ -1151,11 +1183,13 @@ def remove_columns(self, column_names: list[str]) -> Table:
1 3
"""
invalid_columns = []
similar_columns: list[str] = []
for name in column_names:
if not self._schema.has_column(name):
similar_columns = similar_columns + self._get_similar_columns(name)
invalid_columns.append(name)
if len(invalid_columns) != 0:
raise UnknownColumnNameError(invalid_columns)
raise UnknownColumnNameError(invalid_columns, similar_columns)

transformed_data = self._data.drop(labels=column_names, axis="columns")
transformed_data.columns = [name for name in self._schema.column_names if name not in column_names]
Expand Down Expand Up @@ -1349,7 +1383,8 @@ def rename_column(self, old_name: str, new_name: str) -> Table:
0 1 2
"""
if old_name not in self._schema.column_names:
raise UnknownColumnNameError([old_name])
similar_columns = self._get_similar_columns(old_name)
raise UnknownColumnNameError([old_name], similar_columns)
if old_name == new_name:
return self
if new_name in self._schema.column_names:
Expand Down Expand Up @@ -1401,7 +1436,8 @@ def replace_column(self, old_column_name: str, new_columns: list[Column]) -> Tab
0 1 3
"""
if old_column_name not in self._schema.column_names:
raise UnknownColumnNameError([old_column_name])
similar_columns = self._get_similar_columns(old_column_name)
raise UnknownColumnNameError([old_column_name], similar_columns)

columns = list[Column]()
for old_column in self.column_names:
Expand Down Expand Up @@ -1705,7 +1741,8 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl
items: list = [transformer(item) for item in self.to_rows()]
result: list[Column] = [Column(name, items)]
return self.replace_column(name, result)
raise UnknownColumnNameError([name])
similar_columns = self._get_similar_columns(name)
raise UnknownColumnNameError([name], similar_columns)

def transform_table(self, transformer: TableTransformer) -> Table:
"""
Expand Down Expand Up @@ -1881,9 +1918,13 @@ def plot_lineplot(self, x_column_name: str, y_column_name: str) -> Image:
>>> image = table.plot_lineplot("temperature", "sales")
"""
if not self.has_column(x_column_name) or not self.has_column(y_column_name):
similar_columns_x = self._get_similar_columns(x_column_name)
similar_columns_y = self._get_similar_columns(y_column_name)
raise UnknownColumnNameError(
([x_column_name] if not self.has_column(x_column_name) else [])
+ ([y_column_name] if not self.has_column(y_column_name) else []),
(similar_columns_x if not self.has_column(x_column_name) else [])
+ (similar_columns_y if not self.has_column(y_column_name) else []),
)

fig = plt.figure()
Expand Down Expand Up @@ -1935,9 +1976,13 @@ def plot_scatterplot(self, x_column_name: str, y_column_name: str) -> Image:
>>> image = table.plot_scatterplot("temperature", "sales")
"""
if not self.has_column(x_column_name) or not self.has_column(y_column_name):
similar_columns_x = self._get_similar_columns(x_column_name)
similar_columns_y = self._get_similar_columns(y_column_name)
raise UnknownColumnNameError(
([x_column_name] if not self.has_column(x_column_name) else [])
+ ([y_column_name] if not self.has_column(y_column_name) else []),
(similar_columns_x if not self.has_column(x_column_name) else [])
+ (similar_columns_y if not self.has_column(y_column_name) else []),
)

fig = plt.figure()
Expand Down
15 changes: 13 additions & 2 deletions src/safeds/exceptions/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,19 @@ class UnknownColumnNameError(KeyError):
The name of the column that was tried to be accessed.
"""

def __init__(self, column_names: list[str]):
super().__init__(f"Could not find column(s) '{', '.join(column_names)}'")
def __init__(self, column_names: list[str], similar_columns: list[str] | None = None):
class _UnknownColumnNameErrorMessage(
str,
): # This class is necessary for the newline character in a KeyError exception. See https://stackoverflow.com/a/70114007
def __repr__(self) -> str:
return str(self)

error_message = f"Could not find column(s) '{', '.join(column_names)}'."

if similar_columns is not None and len(similar_columns) > 0:
error_message += f"\nDid you mean '{similar_columns}'?"

super().__init__(_UnknownColumnNameErrorMessage(error_message))


class NonNumericColumnError(Exception):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.exceptions._data import UnknownColumnNameError


@pytest.mark.parametrize(
("table", "column_name", "expected"),
[
(Table({"column1": ["col1_1"], "x": ["y"], "cilumn2": ["cil2_1"]}), "col1", ["column1"]),
(
Table(
{
"column1": ["col1_1"],
"col2": ["col2_1"],
"col3": ["col2_1"],
"col4": ["col2_1"],
"cilumn2": ["cil2_1"],
},
),
"clumn1",
["column1", "cilumn2"],
),
(
Table({"column1": ["a"], "column2": ["b"], "column3": ["c"]}),
"notexisting",
[],
),
(
Table({"column1": ["col1_1"], "x": ["y"], "cilumn2": ["cil2_1"]}),
"x",
["x"],
),
(Table({}), "column1", []),
],
ids=["one similar", "two similar/ dynamic increase", "no similar", "exact match", "empty table"],
)
def test_should_get_similar_column_names(table: Table, column_name: str, expected: list[str]) -> None:
assert table._get_similar_columns(column_name) == expected


def test_should_raise_error_if_column_name_unknown() -> None:
with pytest.raises(
UnknownColumnNameError,
match=r"Could not find column\(s\) 'col3'.\nDid you mean '\['col1', 'col2'\]'?",
):
raise UnknownColumnNameError(["col3"], ["col1", "col2"])
42 changes: 42 additions & 0 deletions tests/safeds/exceptions/test_unknown_column_name_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
from safeds.exceptions import UnknownColumnNameError


@pytest.mark.parametrize(
("column_names", "similar_columns", "expected_error_message"),
[
(["column1"], [], r"Could not find column\(s\) 'column1'\."),
(["column1", "column2"], [], r"Could not find column\(s\) 'column1, column2'\."),
(["column1"], ["column_a"], r"Could not find column\(s\) 'column1'\.\nDid you mean '\['column_a'\]'\?"),
(
["column1", "column2"],
["column_a"],
r"Could not find column\(s\) 'column1, column2'\.\nDid you mean '\['column_a'\]'\?",
),
(
["column1"],
["column_a", "column_b"],
r"Could not find column\(s\) 'column1'\.\nDid you mean '\['column_a', 'column_b'\]'\?",
),
(
["column1", "column2"],
["column_a", "column_b"],
r"Could not find column\(s\) 'column1, column2'\.\nDid you mean '\['column_a', 'column_b'\]'\?",
),
],
ids=[
"one_unknown_no_suggestions",
"two_unknown_no_suggestions",
"one_unknown_one_suggestion",
"two_unknown_one_suggestion",
"one_unknown_two_suggestions",
"two_unknown_two_suggestions",
],
)
def test_empty_similar_columns(
column_names: list[str],
similar_columns: list[str],
expected_error_message: str,
) -> None:
with pytest.raises(UnknownColumnNameError, match=expected_error_message):
raise UnknownColumnNameError(column_names, similar_columns)

0 comments on commit 6a097a4

Please sign in to comment.