diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index c35d08bf0..9cfdbe402 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -76,7 +76,7 @@ def _from_pandas_series(data: pd.Series, type_: ColumnType | None = None) -> Col result._name = data.name result._data = data # noinspection PyProtectedMember - result._type = type_ if type_ is not None else ColumnType._from_numpy_data_type(data.dtype) + result._type = type_ if type_ is not None else ColumnType._data_type(data) return result @@ -106,7 +106,7 @@ def __init__(self, name: str, data: Sequence[T] | None = None) -> None: self._name: str = name self._data: pd.Series = data.rename(name) if isinstance(data, pd.Series) else pd.Series(data, name=name) # noinspection PyProtectedMember - self._type: ColumnType = ColumnType._from_numpy_data_type(self._data.dtype) + self._type: ColumnType = ColumnType._data_type(self._data) def __contains__(self, item: Any) -> bool: return item in self._data diff --git a/src/safeds/data/tabular/containers/_row.py b/src/safeds/data/tabular/containers/_row.py index 141f0503c..5f16ff810 100644 --- a/src/safeds/data/tabular/containers/_row.py +++ b/src/safeds/data/tabular/containers/_row.py @@ -1,7 +1,8 @@ from __future__ import annotations import copy -from collections.abc import Mapping +import functools +from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any import pandas as pd @@ -441,6 +442,39 @@ def get_column_type(self, column_name: str) -> ColumnType: """ return self._schema.get_column_type(column_name) + # ------------------------------------------------------------------------------------------------------------------ + # Transformations + # ------------------------------------------------------------------------------------------------------------------ + + def sort_columns( + self, + comparator: Callable[[tuple, tuple], int] = lambda col1, col2: (col1[0] > col2[0]) - (col1[0] < col2[0]), + ) -> Row: + """ + Sort the columns of a `Row` with the given comparator and return a new `Row`. + + The original row is not modified. The comparator is a function that takes two tuples of (ColumnName, Value) `col1` and `col2` and + returns an integer: + + * If `col1` should be ordered before `col2`, the function should return a negative number. + * If `col1` should be ordered after `col2`, the function should return a positive number. + * If the original order of `col1` and `col2` should be kept, the function should return 0. + + If no comparator is given, the columns will be sorted alphabetically by their name. + + Parameters + ---------- + comparator : Callable[[tuple, tuple], int] + The function used to compare two tuples of (ColumnName, Value). + + Returns + ------- + new_row : Row + A new row with sorted columns. + """ + sorted_row_dict = dict(sorted(self.to_dict().items(), key=functools.cmp_to_key(comparator))) + return Row.from_dict(sorted_row_dict) + # ------------------------------------------------------------------------------------------------------------------ # Conversion # ------------------------------------------------------------------------------------------------------------------ diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index c7ce1142c..ae26a62bc 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -24,7 +24,6 @@ DuplicateColumnNameError, IndexOutOfBoundsError, NonNumericColumnError, - SchemaMismatchError, UnknownColumnNameError, WrongFileExtensionError, ) @@ -302,8 +301,8 @@ def from_rows(rows: list[Row]) -> Table: Raises ------ - SchemaMismatchError - If any of the row schemas does not match with the others. + UnknownColumnNameError + If any of the row column names does not match with the first row. Examples -------- @@ -318,17 +317,22 @@ def from_rows(rows: list[Row]) -> Table: if len(rows) == 0: return Table._from_pandas_dataframe(pd.DataFrame()) - schema_compare: Schema = rows[0]._schema + column_names_compare: list = list(rows[0].column_names) + unknown_column_names = set() row_array: list[pd.DataFrame] = [] for row in rows: - if schema_compare != row._schema: - raise SchemaMismatchError + unknown_column_names.update(set(column_names_compare) - set(row.column_names)) row_array.append(row._data) + if len(unknown_column_names) > 0: + raise UnknownColumnNameError(list(unknown_column_names)) dataframe: DataFrame = pd.concat(row_array, ignore_index=True) - dataframe.columns = schema_compare.column_names - return Table._from_pandas_dataframe(dataframe) + dataframe.columns = column_names_compare + + schema = Schema.merge_multiple_schemas([row.schema for row in rows]) + + return Table._from_pandas_dataframe(dataframe, schema) @staticmethod def _from_pandas_dataframe(data: pd.DataFrame, schema: Schema | None = None) -> Table: @@ -906,6 +910,9 @@ def add_row(self, row: Row) -> Table: If the table happens to be empty beforehand, respective columns will be added automatically. + The order of columns of the new row will be adjusted to the order of columns in the table. + The new table will contain the merged schema. + This table is not modified. Parameters @@ -920,8 +927,8 @@ def add_row(self, row: Row) -> Table: Raises ------ - SchemaMismatchError - If the schema of the row does not match the table schema. + UnknownColumnNameError + If the row has different column names than the table. Examples -------- @@ -935,20 +942,18 @@ def add_row(self, row: Row) -> Table: """ int_columns = [] result = self._copy() + if self.number_of_columns == 0: + return Table.from_rows([row]) + if len(set(self.column_names) - set(row.column_names)) > 0: + raise UnknownColumnNameError(list(set(self.column_names) - set(row.column_names))) + if result.number_of_rows == 0: - int_columns = list(filter(lambda name: isinstance(row[name], int | np.int64), row.column_names)) - if result.number_of_columns == 0: - for column in row.column_names: - result._data[column] = Column(column, []) - result._schema = Schema._from_pandas_dataframe(result._data) - elif result.column_names != row.column_names: - raise SchemaMismatchError - elif result._schema != row.schema: - raise SchemaMismatchError + int_columns = list(filter(lambda name: isinstance(row[name], int | np.int64 | np.int32), row.column_names)) new_df = pd.concat([result._data, row._data]).infer_objects() new_df.columns = result.column_names - result = Table._from_pandas_dataframe(new_df) + schema = Schema.merge_multiple_schemas([result.schema, row.schema]) + result = Table._from_pandas_dataframe(new_df, schema) for column in int_columns: result = result.replace_column(column, [result.get_column(column).transform(lambda it: int(it))]) @@ -959,6 +964,9 @@ def add_rows(self, rows: list[Row] | Table) -> Table: """ Add multiple rows to a table. + The order of columns of the new rows will be adjusted to the order of columns in the table. + The new table will contain the merged schema. + This table is not modified. Parameters @@ -973,8 +981,8 @@ def add_rows(self, rows: list[Row] | Table) -> Table: Raises ------ - SchemaMismatchError - If the schema of one of the rows does not match the table schema. + UnknownColumnNameError + If at least one of the rows have different column names than the table. Examples -------- @@ -990,28 +998,21 @@ def add_rows(self, rows: list[Row] | Table) -> Table: """ if isinstance(rows, Table): rows = rows.to_rows() - int_columns = [] result = self._copy() + + if len(rows) == 0: + return self._copy() + + different_column_names = set() for row in rows: - if result.number_of_rows == 0: - int_columns = list(filter(lambda name: isinstance(row[name], int | np.int64), row.column_names)) - if result.number_of_columns == 0: - for column in row.column_names: - result._data[column] = Column(column, []) - result._schema = Schema._from_pandas_dataframe(result._data) - elif result.column_names != row.column_names: - raise SchemaMismatchError - elif result._schema != row.schema: - raise SchemaMismatchError - - row_frames = (row._data for row in rows) - - new_df = pd.concat([result._data, *row_frames]).infer_objects() - new_df.columns = result.column_names - result = Table._from_pandas_dataframe(new_df) + different_column_names.update(set(rows[0].column_names) - set(row.column_names)) + if len(different_column_names) > 0: + raise UnknownColumnNameError(list(different_column_names)) - for column in int_columns: - result = result.replace_column(column, [result.get_column(column).transform(lambda it: int(it))]) + result = self._copy() + + for row in rows: + result = result.add_row(row) return result @@ -1269,7 +1270,7 @@ def remove_rows_with_missing_values(self) -> Table: """ result = self._data.copy(deep=True) result = result.dropna(axis="index") - return Table._from_pandas_dataframe(result, self._schema) + return Table._from_pandas_dataframe(result) def remove_rows_with_outliers(self) -> Table: """ diff --git a/src/safeds/data/tabular/transformation/_label_encoder.py b/src/safeds/data/tabular/transformation/_label_encoder.py index 20fec0436..089cdee4f 100644 --- a/src/safeds/data/tabular/transformation/_label_encoder.py +++ b/src/safeds/data/tabular/transformation/_label_encoder.py @@ -152,6 +152,9 @@ def inverse_transform(self, transformed_table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if transformed_table.number_of_rows == 0: + raise ValueError("The LabelEncoder cannot inverse transform the table because it contains 0 rows") + if transformed_table.keep_only_columns( self._column_names, ).remove_columns_with_non_numerical_values().number_of_columns < len(self._column_names): @@ -168,9 +171,6 @@ def inverse_transform(self, transformed_table: Table) -> Table: ), ) - if transformed_table.number_of_rows == 0: - raise ValueError("The LabelEncoder cannot inverse transform the table because it contains 0 rows") - data = transformed_table._data.copy() data.columns = transformed_table.column_names data[self._column_names] = self._wrapped_transformer.inverse_transform(data[self._column_names]) diff --git a/src/safeds/data/tabular/transformation/_one_hot_encoder.py b/src/safeds/data/tabular/transformation/_one_hot_encoder.py index 9a198ea0e..cfad36889 100644 --- a/src/safeds/data/tabular/transformation/_one_hot_encoder.py +++ b/src/safeds/data/tabular/transformation/_one_hot_encoder.py @@ -277,6 +277,9 @@ def inverse_transform(self, transformed_table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if transformed_table.number_of_rows == 0: + raise ValueError("The OneHotEncoder cannot inverse transform the table because it contains 0 rows") + if transformed_table._as_table().keep_only_columns( _transformed_column_names, ).remove_columns_with_non_numerical_values().number_of_columns < len(_transformed_column_names): @@ -293,9 +296,6 @@ def inverse_transform(self, transformed_table: Table) -> Table: ), ) - if transformed_table.number_of_rows == 0: - raise ValueError("The OneHotEncoder cannot inverse transform the table because it contains 0 rows") - original_columns = {} for original_column_name in self._column_names: original_columns[original_column_name] = [None for _ in range(transformed_table.number_of_rows)] @@ -306,6 +306,12 @@ def inverse_transform(self, transformed_table: Table) -> Table: if transformed_table.get_column(constructed_column)[i] == 1.0: original_columns[original_column_name][i] = value + for original_column_name in self._value_to_column_nans: + constructed_column = self._value_to_column_nans[original_column_name] + for i in range(transformed_table.number_of_rows): + if transformed_table.get_column(constructed_column)[i] == 1.0: + original_columns[original_column_name][i] = np.nan + table = transformed_table for column_name, encoded_column in original_columns.items(): diff --git a/src/safeds/data/tabular/transformation/_range_scaler.py b/src/safeds/data/tabular/transformation/_range_scaler.py index 91173ea49..d61175296 100644 --- a/src/safeds/data/tabular/transformation/_range_scaler.py +++ b/src/safeds/data/tabular/transformation/_range_scaler.py @@ -66,6 +66,9 @@ def fit(self, table: Table, column_names: list[str] | None) -> RangeScaler: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if table.number_of_rows == 0: + raise ValueError("The RangeScaler cannot be fitted because the table contains 0 rows") + if ( table.keep_only_columns(column_names).remove_columns_with_non_numerical_values().number_of_columns < table.keep_only_columns(column_names).number_of_columns @@ -83,9 +86,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> RangeScaler: ), ) - if table.number_of_rows == 0: - raise ValueError("The RangeScaler cannot be fitted because the table contains 0 rows") - wrapped_transformer = sk_MinMaxScaler((self._minimum, self._maximum)) wrapped_transformer.fit(table._data[column_names]) @@ -131,6 +131,9 @@ def transform(self, table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if table.number_of_rows == 0: + raise ValueError("The RangeScaler cannot transform the table because it contains 0 rows") + if ( table.keep_only_columns(self._column_names).remove_columns_with_non_numerical_values().number_of_columns < table.keep_only_columns(self._column_names).number_of_columns @@ -148,9 +151,6 @@ def transform(self, table: Table) -> Table: ), ) - if table.number_of_rows == 0: - raise ValueError("The RangeScaler cannot transform the table because it contains 0 rows") - data = table._data.copy() data.columns = table.column_names data[self._column_names] = self._wrapped_transformer.transform(data[self._column_names]) @@ -191,6 +191,9 @@ def inverse_transform(self, transformed_table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if transformed_table.number_of_rows == 0: + raise ValueError("The RangeScaler cannot transform the table because it contains 0 rows") + if ( transformed_table.keep_only_columns(self._column_names) .remove_columns_with_non_numerical_values() @@ -210,9 +213,6 @@ def inverse_transform(self, transformed_table: Table) -> Table: ), ) - if transformed_table.number_of_rows == 0: - raise ValueError("The RangeScaler cannot transform the table because it contains 0 rows") - data = transformed_table._data.copy() data.columns = transformed_table.column_names data[self._column_names] = self._wrapped_transformer.inverse_transform(data[self._column_names]) diff --git a/src/safeds/data/tabular/transformation/_standard_scaler.py b/src/safeds/data/tabular/transformation/_standard_scaler.py index 3c190c58f..a3b213a11 100644 --- a/src/safeds/data/tabular/transformation/_standard_scaler.py +++ b/src/safeds/data/tabular/transformation/_standard_scaler.py @@ -48,6 +48,9 @@ def fit(self, table: Table, column_names: list[str] | None) -> StandardScaler: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if table.number_of_rows == 0: + raise ValueError("The StandardScaler cannot be fitted because the table contains 0 rows") + if ( table.keep_only_columns(column_names).remove_columns_with_non_numerical_values().number_of_columns < table.keep_only_columns(column_names).number_of_columns @@ -65,9 +68,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> StandardScaler: ), ) - if table.number_of_rows == 0: - raise ValueError("The StandardScaler cannot be fitted because the table contains 0 rows") - wrapped_transformer = sk_StandardScaler() wrapped_transformer.fit(table._data[column_names]) @@ -113,6 +113,9 @@ def transform(self, table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if table.number_of_rows == 0: + raise ValueError("The StandardScaler cannot transform the table because it contains 0 rows") + if ( table.keep_only_columns(self._column_names).remove_columns_with_non_numerical_values().number_of_columns < table.keep_only_columns(self._column_names).number_of_columns @@ -130,9 +133,6 @@ def transform(self, table: Table) -> Table: ), ) - if table.number_of_rows == 0: - raise ValueError("The StandardScaler cannot transform the table because it contains 0 rows") - data = table._data.copy() data.columns = table.column_names data[self._column_names] = self._wrapped_transformer.transform(data[self._column_names]) @@ -173,6 +173,9 @@ def inverse_transform(self, transformed_table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(missing_columns) + if transformed_table.number_of_rows == 0: + raise ValueError("The StandardScaler cannot transform the table because it contains 0 rows") + if ( transformed_table.keep_only_columns(self._column_names) .remove_columns_with_non_numerical_values() @@ -192,9 +195,6 @@ def inverse_transform(self, transformed_table: Table) -> Table: ), ) - if transformed_table.number_of_rows == 0: - raise ValueError("The StandardScaler cannot transform the table because it contains 0 rows") - data = transformed_table._data.copy() data.columns = transformed_table.column_names data[self._column_names] = self._wrapped_transformer.inverse_transform(data[self._column_names]) diff --git a/src/safeds/data/tabular/typing/__init__.py b/src/safeds/data/tabular/typing/__init__.py index 9a19c2b5d..8b9b4a849 100644 --- a/src/safeds/data/tabular/typing/__init__.py +++ b/src/safeds/data/tabular/typing/__init__.py @@ -1,6 +1,6 @@ """Types used to define the schema of a tabular dataset.""" -from ._column_type import Anything, Boolean, ColumnType, Integer, RealNumber, String +from ._column_type import Anything, Boolean, ColumnType, Integer, Nothing, RealNumber, String from ._imputer_strategy import ImputerStrategy from ._schema import Schema @@ -10,6 +10,7 @@ "ColumnType", "ImputerStrategy", "Integer", + "Nothing", "RealNumber", "Schema", "String", diff --git a/src/safeds/data/tabular/typing/_column_type.py b/src/safeds/data/tabular/typing/_column_type.py index 82396f334..2c4ca4f57 100644 --- a/src/safeds/data/tabular/typing/_column_type.py +++ b/src/safeds/data/tabular/typing/_column_type.py @@ -2,24 +2,40 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING +from types import NoneType +from typing import TYPE_CHECKING, Any + +import numpy as np if TYPE_CHECKING: - import numpy as np + import pandas as pd class ColumnType(ABC): """Abstract base class for column types.""" + _is_nullable: bool # This line is just here so the linter doesn't throw an error + + @abstractmethod + def __init__(self, is_nullable: bool = False) -> None: + """ + Abstract initializer for ColumnType. + + Parameters + ---------- + is_nullable + Whether the columntype is nullable. + """ + @staticmethod - def _from_numpy_data_type(data_type: np.dtype) -> ColumnType: + def _data_type(data: pd.Series) -> ColumnType: """ - Return the column type for a given `numpy` data type. + Return the column type for a given `Series` from `pandas`. Parameters ---------- - data_type : numpy.dtype - The `numpy` data type. + data : pd.Series + The data to be checked. Returns ------- @@ -31,17 +47,50 @@ def _from_numpy_data_type(data_type: np.dtype) -> ColumnType: NotImplementedError If the given data type is not supported. """ - if data_type.kind in ("u", "i"): - return Integer() - if data_type.kind == "b": - return Boolean() - if data_type.kind == "f": - return RealNumber() - if data_type.kind in ("S", "U", "O", "M", "m"): - return String() - message = f"Unsupported numpy data type '{data_type}'." - raise NotImplementedError(message) + def column_type_of_type(cell_type: Any) -> ColumnType: + if cell_type == int or cell_type == np.int64 or cell_type == np.int32: + return Integer(is_nullable) + if cell_type == float or cell_type == np.float64 or cell_type == np.float32: + return RealNumber(is_nullable) + if cell_type == bool: + return Boolean(is_nullable) + if cell_type == str: + return String(is_nullable) + if cell_type is NoneType: + return Nothing() + else: + message = f"Unsupported numpy data type '{cell_type}'." + raise NotImplementedError(message) + + result: ColumnType = Nothing() + is_nullable = False + for cell in data: + if result == Nothing(): + result = column_type_of_type(type(cell)) + if type(cell) is NoneType: + is_nullable = True + result._is_nullable = is_nullable + if result != column_type_of_type(type(cell)): + if type(cell) is NoneType: + is_nullable = True + result._is_nullable = is_nullable + elif (isinstance(result, Integer) and isinstance(column_type_of_type(type(cell)), RealNumber)) or ( + isinstance(result, RealNumber) and isinstance(column_type_of_type(type(cell)), Integer) + ): + result = RealNumber(is_nullable) + else: + result = Anything(is_nullable) + if isinstance(cell, float) and np.isnan(cell): + is_nullable = True + result._is_nullable = is_nullable + + if isinstance(result, RealNumber) and all( + data.apply(lambda c: bool(isinstance(c, float) and np.isnan(c) or c == float(int(c)))), + ): + result = Integer(is_nullable) + + return result @abstractmethod def is_nullable(self) -> bool: @@ -79,7 +128,7 @@ class Anything(ColumnType): _is_nullable: bool - def __init__(self, is_nullable: bool = False): + def __init__(self, is_nullable: bool = False) -> None: self._is_nullable = is_nullable def __repr__(self) -> str: @@ -124,7 +173,7 @@ class Boolean(ColumnType): _is_nullable: bool - def __init__(self, is_nullable: bool = False): + def __init__(self, is_nullable: bool = False) -> None: self._is_nullable = is_nullable def __repr__(self) -> str: @@ -169,7 +218,7 @@ class RealNumber(ColumnType): _is_nullable: bool - def __init__(self, is_nullable: bool = False): + def __init__(self, is_nullable: bool = False) -> None: self._is_nullable = is_nullable def __repr__(self) -> str: @@ -214,7 +263,7 @@ class Integer(ColumnType): _is_nullable: bool - def __init__(self, is_nullable: bool = False): + def __init__(self, is_nullable: bool = False) -> None: self._is_nullable = is_nullable def __repr__(self) -> str: @@ -259,7 +308,7 @@ class String(ColumnType): _is_nullable: bool - def __init__(self, is_nullable: bool = False): + def __init__(self, is_nullable: bool = False) -> None: self._is_nullable = is_nullable def __repr__(self) -> str: @@ -289,3 +338,41 @@ def is_numeric(self) -> bool: True if the column is numeric. """ return False + + +@dataclass +class Nothing(ColumnType): + """Type for a column that contains None Values only.""" + + _is_nullable: bool + + def __init__(self) -> None: + self._is_nullable = True + + def __repr__(self) -> str: + result = "Nothing" + if self._is_nullable: + result += "?" + return result + + def is_nullable(self) -> bool: + """ + Return whether the given column type is nullable. + + Returns + ------- + is_nullable : bool + True if the column is nullable. + """ + return True + + def is_numeric(self) -> bool: + """ + Return whether the given column type is numeric. + + Returns + ------- + is_numeric : bool + True if the column is numeric. + """ + return False diff --git a/src/safeds/data/tabular/typing/_schema.py b/src/safeds/data/tabular/typing/_schema.py index a75a87241..32a6bc3af 100644 --- a/src/safeds/data/tabular/typing/_schema.py +++ b/src/safeds/data/tabular/typing/_schema.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from safeds.data.tabular.typing import Anything, Integer, Nothing, RealNumber from safeds.data.tabular.typing._column_type import ColumnType from safeds.exceptions import UnknownColumnNameError @@ -49,7 +50,9 @@ def _from_pandas_dataframe(dataframe: pd.DataFrame) -> Schema: """ names = dataframe.columns # noinspection PyProtectedMember - types = (ColumnType._from_numpy_data_type(data_type) for data_type in dataframe.dtypes) + types = [] + for col in dataframe: + types.append(ColumnType._data_type(dataframe[col])) return Schema(dict(zip(names, types, strict=True))) @@ -95,7 +98,7 @@ def __repr__(self) -> str: >>> repr(schema) "Schema({'A': Integer})" """ - return f"Schema({str(self)})" + return f"Schema({self!s})" def __str__(self) -> str: """ @@ -220,6 +223,69 @@ def to_dict(self) -> dict[str, ColumnType]: """ return dict(self._schema) # defensive copy + @staticmethod + def merge_multiple_schemas(schemas: list[Schema]) -> Schema: + """ + Merge multiple schemas into one. + + For each type missmatch the new schema will have the least common supertype. + + The type hierarchy is as follows: + * Anything + * RealNumber + * Integer + * Boolean + * String + + Parameters + ---------- + schemas : list[Schema] + the list of schemas you want to merge + + Returns + ------- + schema : Schema + the new merged schema + + Raises + ------ + UnknownColumnNameError + if not all schemas have the same column names + """ + schema_dict = schemas[0]._schema + missing_col_names = set() + for schema in schemas: + missing_col_names.update(set(schema.column_names) - set(schema_dict.keys())) + if len(missing_col_names) > 0: + raise UnknownColumnNameError(list(missing_col_names)) + for schema in schemas: + if schema_dict != schema._schema: + for col_name in schema_dict: + nullable = False + if schema_dict[col_name].is_nullable() or schema.get_column_type(col_name).is_nullable(): + nullable = True + if isinstance(schema_dict[col_name], type(schema.get_column_type(col_name))): + if schema.get_column_type(col_name).is_nullable() and not schema_dict[col_name].is_nullable(): + schema_dict[col_name] = type(schema.get_column_type(col_name))(nullable) + continue + if ( + isinstance(schema_dict[col_name], RealNumber) + and isinstance(schema.get_column_type(col_name), Integer) + ) or ( + isinstance(schema_dict[col_name], Integer) + and isinstance(schema.get_column_type(col_name), RealNumber) + ): + schema_dict[col_name] = RealNumber(nullable) + continue + if isinstance(schema_dict[col_name], Nothing): + schema_dict[col_name] = type(schema.get_column_type(col_name))(nullable) + continue + if isinstance(schema.get_column_type(col_name), Nothing): + schema_dict[col_name] = type(schema_dict[col_name])(nullable) + continue + schema_dict[col_name] = Anything(nullable) + return Schema(schema_dict) + # ------------------------------------------------------------------------------------------------------------------ # IPython Integration # ------------------------------------------------------------------------------------------------------------------ diff --git a/src/safeds/exceptions/__init__.py b/src/safeds/exceptions/__init__.py index 021736287..352334cf5 100644 --- a/src/safeds/exceptions/__init__.py +++ b/src/safeds/exceptions/__init__.py @@ -9,7 +9,6 @@ IndexOutOfBoundsError, MissingValuesColumnError, NonNumericColumnError, - SchemaMismatchError, TransformerNotFittedError, UnknownColumnNameError, ValueNotPresentWhenFittedError, @@ -43,7 +42,6 @@ "IndexOutOfBoundsError", "MissingValuesColumnError", "NonNumericColumnError", - "SchemaMismatchError", "TransformerNotFittedError", "UnknownColumnNameError", "ValueNotPresentWhenFittedError", diff --git a/src/safeds/exceptions/_data.py b/src/safeds/exceptions/_data.py index f11c7a334..2d2fb7880 100644 --- a/src/safeds/exceptions/_data.py +++ b/src/safeds/exceptions/_data.py @@ -93,13 +93,6 @@ def __init__(self, expected_size: str, actual_size: str): super().__init__(f"Expected a column of size {expected_size} but got column of size {actual_size}.") -class SchemaMismatchError(Exception): - """Exception raised when schemas are unequal.""" - - def __init__(self) -> None: - super().__init__("Failed because at least two schemas didn't match.") - - class ColumnLengthMismatchError(Exception): """Exception raised when the lengths of two or more columns do not match.""" diff --git a/tests/safeds/data/tabular/containers/_column/test_from_pandas_series.py b/tests/safeds/data/tabular/containers/_column/test_from_pandas_series.py index 9946120b2..b62e3d1af 100644 --- a/tests/safeds/data/tabular/containers/_column/test_from_pandas_series.py +++ b/tests/safeds/data/tabular/containers/_column/test_from_pandas_series.py @@ -1,7 +1,7 @@ import pandas as pd import pytest from safeds.data.tabular.containers import Column -from safeds.data.tabular.typing import Boolean, ColumnType, Integer, RealNumber, String +from safeds.data.tabular.typing import Anything, Boolean, ColumnType, Integer, Nothing, RealNumber, String @pytest.mark.parametrize( @@ -35,14 +35,15 @@ def test_should_use_type_if_passed(series: pd.Series, type_: ColumnType) -> None @pytest.mark.parametrize( ("series", "expected"), [ - (pd.Series([]), String()), + (pd.Series([]), Nothing()), (pd.Series([True, False, True]), Boolean()), (pd.Series([1, 2, 3]), Integer()), - (pd.Series([1.0, 2.0, 3.0]), RealNumber()), + (pd.Series([1.0, 2.0, 3.0]), Integer()), + (pd.Series([1.0, 2.5, 3.0]), RealNumber()), (pd.Series(["a", "b", "c"]), String()), - (pd.Series([1, 2.0, "a", True]), String()), + (pd.Series([1, 2.0, "a", True]), Anything(is_nullable=False)), ], - ids=["empty", "boolean", "integer", "real number", "string", "mixed"], + ids=["empty", "boolean", "integer", "real number .0", "real number", "string", "mixed"], ) def test_should_infer_type_if_not_passed(series: pd.Series, expected: ColumnType) -> None: assert Column._from_pandas_series(series).type == expected diff --git a/tests/safeds/data/tabular/containers/_column/test_init.py b/tests/safeds/data/tabular/containers/_column/test_init.py index 2966b2970..d7015b3ff 100644 --- a/tests/safeds/data/tabular/containers/_column/test_init.py +++ b/tests/safeds/data/tabular/containers/_column/test_init.py @@ -3,7 +3,7 @@ import pandas as pd import pytest from safeds.data.tabular.containers import Column -from safeds.data.tabular.typing import Boolean, ColumnType, Integer, RealNumber, String +from safeds.data.tabular.typing import Anything, Boolean, ColumnType, Integer, Nothing, RealNumber, String def test_should_store_the_name() -> None: @@ -43,14 +43,15 @@ def test_should_store_the_data(column: Column, expected: list) -> None: @pytest.mark.parametrize( ("column", "expected"), [ - (Column("A", []), String()), + (Column("A", []), Nothing()), (Column("A", [True, False, True]), Boolean()), (Column("A", [1, 2, 3]), Integer()), - (Column("A", [1.0, 2.0, 3.0]), RealNumber()), + (Column("A", [1.0, 2.0, 3.0]), Integer()), + (Column("A", [1.0, 2.5, 3.0]), RealNumber()), (Column("A", ["a", "b", "c"]), String()), - (Column("A", [1, 2.0, "a", True]), String()), + (Column("A", [1, 2.0, "a", True]), Anything()), ], - ids=["empty", "boolean", "integer", "real number", "string", "mixed"], + ids=["empty", "boolean", "integer", "real number .0", "real number", "string", "mixed"], ) def test_should_infer_type(column: Column, expected: ColumnType) -> None: assert column.type == expected diff --git a/tests/safeds/data/tabular/containers/_table/test_add_row.py b/tests/safeds/data/tabular/containers/_table/test_add_row.py index d99792a43..d64e47f41 100644 --- a/tests/safeds/data/tabular/containers/_table/test_add_row.py +++ b/tests/safeds/data/tabular/containers/_table/test_add_row.py @@ -1,34 +1,70 @@ import pytest from _pytest.python_api import raises from safeds.data.tabular.containers import Row, Table -from safeds.exceptions import SchemaMismatchError +from safeds.data.tabular.typing import Anything, Integer, Schema +from safeds.exceptions import UnknownColumnNameError @pytest.mark.parametrize( - ("table", "row", "expected"), + ("table", "row", "expected", "expected_schema"), [ ( Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Row({"col1": 5, "col2": 6}), Table({"col1": [1, 2, 1, 5], "col2": [1, 2, 4, 6]}), + Schema({"col1": Integer(), "col2": Integer()}), ), - (Table({"col2": [], "col4": []}), Row({"col2": 5, "col4": 6}), Table({"col2": [5], "col4": [6]})), - (Table(), Row({"col2": 5, "col4": 6}), Table({"col2": [5], "col4": [6]})), + ( + Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), + Row({"col1": "5", "col2": 6}), + Table({"col1": [1, 2, 1, "5"], "col2": [1, 2, 4, 6]}), + Schema({"col1": Anything(), "col2": Integer()}), + ), + ( + Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), + Row({"col1": "5", "col2": None}), + Table({"col1": [1, 2, 1, "5"], "col2": [1, 2, 4, None]}), + Schema({"col1": Anything(), "col2": Integer(is_nullable=True)}), + ), + ( + Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), + Row({"col1": 5, "col2": 6}), + Table({"col1": [1, 2, 1, 5], "col2": [1, 2, 4, 6]}), + Schema({"col1": Integer(), "col2": Integer()}), + ), + ( + Table({"col1": [], "col2": []}), + Row({"col1": 5, "col2": 6}), + Table({"col1": [5], "col2": [6]}), + Schema({"col1": Integer(), "col2": Integer()}), + ), + ], + ids=[ + "added row", + "different schemas", + "different schemas and nullable", + "add row to rowless table", + "add row to empty table", ], - ids=["add row", "add row to rowless table", "add row to empty table"], ) -def test_should_add_row(table: Table, row: Row, expected: Table) -> None: - table = table.add_row(row) - assert table == expected +def test_should_add_row(table: Table, row: Row, expected: Table, expected_schema: Schema) -> None: + result = table.add_row(row) + assert result.number_of_rows - 1 == table.number_of_rows + assert result.schema == expected_schema + assert result == expected -def test_should_raise_error_if_row_schema_invalid() -> None: - table1 = Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}) - row = Row({"col1": 5, "col2": "Hallo"}) - with raises(SchemaMismatchError, match=r"Failed because at least two schemas didn't match."): - table1.add_row(row) - - -def test_should_raise_schema_mismatch() -> None: - with raises(SchemaMismatchError, match=r"Failed because at least two schemas didn't match."): - Table({"a": [], "b": []}).add_row(Row({"beer": None, "rips": None})) +@pytest.mark.parametrize( + ("table", "row", "expected_error_msg"), + [ + ( + Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), + Row({"col1": 5, "col3": "Hallo"}), + r"Could not find column\(s\) 'col2'", + ), + ], + ids=["unknown column col2 in row"], +) +def test_should_raise_error_if_row_column_names_invalid(table: Table, row: Row, expected_error_msg: str) -> None: + with raises(UnknownColumnNameError, match=expected_error_msg): + table.add_row(row) diff --git a/tests/safeds/data/tabular/containers/_table/test_add_rows.py b/tests/safeds/data/tabular/containers/_table/test_add_rows.py index 2bdf52624..b5cd8742a 100644 --- a/tests/safeds/data/tabular/containers/_table/test_add_rows.py +++ b/tests/safeds/data/tabular/containers/_table/test_add_rows.py @@ -1,7 +1,6 @@ import pytest -from _pytest.python_api import raises from safeds.data.tabular.containers import Row, Table -from safeds.exceptions import SchemaMismatchError +from safeds.exceptions import UnknownColumnNameError @pytest.mark.parametrize( @@ -12,13 +11,18 @@ [Row({"col1": "d", "col2": 6}), Row({"col1": "e", "col2": 8})], Table({"col1": ["a", "b", "c", "d", "e"], "col2": [1, 2, 4, 6, 8]}), ), + ( + Table({"col1": ["a", "b", "c"], "col2": [1, 2, 4]}), + [Row({"col1": "d", "col2": 6}), Row({"col1": "e", "col2": "f"})], + Table({"col1": ["a", "b", "c", "d", "e"], "col2": [1, 2, 4, 6, "f"]}), + ), ( Table(), [Row({"col1": "d", "col2": 6}), Row({"col1": "e", "col2": 8})], Table({"col1": ["d", "e"], "col2": [6, 8]}), ), ], - ids=["Rows with string and integer values", "empty"], + ids=["Rows with string and integer values", "different schema", "empty"], ) def test_should_add_rows(table1: Table, rows: list[Row], table2: Table) -> None: table1 = table1.add_rows(rows) @@ -49,8 +53,13 @@ def test_should_add_rows(table1: Table, rows: list[Row], table2: Table) -> None: Table({"col1": [], "yikes": []}), Table({"col1": [], "yikes": []}), ), + ( + Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), + Table({"col1": [5, "7"], "col2": [6, None]}), + Table({"col1": [1, 2, 1, 5, "7"], "col2": [1, 2, 4, 6, None]}), + ), ], - ids=["Rows from table", "add empty to table", "add on empty table", "rowless"], + ids=["Rows from table", "add empty to table", "add on empty table", "rowless", "different schema"], ) def test_should_add_rows_from_table(table1: Table, table2: Table, expected: Table) -> None: table1 = table1.add_rows(table2) @@ -58,15 +67,17 @@ def test_should_add_rows_from_table(table1: Table, table2: Table, expected: Tabl assert table1 == expected -def test_should_raise_error_if_row_schema_invalid() -> None: - table1 = Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}) - row = [Row({"col1": 2, "col2": 4}), Row({"col1": 5, "col2": "Hallo"})] - with pytest.raises(SchemaMismatchError, match=r"Failed because at least two schemas didn't match."): - table1.add_rows(row) - - -def test_should_raise_schema_mismatch() -> None: - with raises(SchemaMismatchError, match=r"Failed because at least two schemas didn't match."): - Table({"a": [], "b": []}).add_rows([Row({"a": None, "b": None}), Row({"beer": None, "rips": None})]) - with raises(SchemaMismatchError, match=r"Failed because at least two schemas didn't match."): - Table({"a": [], "b": []}).add_rows([Row({"beer": None, "rips": None}), Row({"a": None, "b": None})]) +@pytest.mark.parametrize( + ("table", "rows", "expected_error_msg"), + [ + ( + Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), + [Row({"col1": 2, "col3": 4}), Row({"col1": 5, "col2": "Hallo"})], + r"Could not find column\(s\) 'col3'", + ), + ], + ids=["column names do not match"], +) +def test_should_raise_error_if_row_column_names_invalid(table: Table, rows: list[Row], expected_error_msg: str) -> None: + with pytest.raises(UnknownColumnNameError, match=expected_error_msg): + table.add_rows(rows) diff --git a/tests/safeds/data/tabular/containers/_table/test_from_rows.py b/tests/safeds/data/tabular/containers/_table/test_from_rows.py index 43954d072..af8b459fe 100644 --- a/tests/safeds/data/tabular/containers/_table/test_from_rows.py +++ b/tests/safeds/data/tabular/containers/_table/test_from_rows.py @@ -1,6 +1,6 @@ import pytest from safeds.data.tabular.containers import Row, Table -from safeds.exceptions import SchemaMismatchError +from safeds.exceptions import UnknownColumnNameError @pytest.mark.parametrize( @@ -24,15 +24,33 @@ }, ), ), + ( + [ + Row({"A": 1, "B": 4, "C": "d"}), + Row({"A": 2, "B": 5, "C": "e"}), + Row({"A": 3, "B": "6", "C": "f"}), + ], + Table( + { + "A": [1, 2, 3], + "B": [4, 5, "6"], + "C": ["d", "e", "f"], + }, + ), + ), ], - ids=["empty", "non-empty"], + ids=["empty", "non-empty", "different schemas"], ) def test_should_create_table_from_rows(rows: list[Row], expected: Table) -> None: - assert Table.from_rows(rows).schema == expected.schema - assert Table.from_rows(rows) == expected + table = Table.from_rows(rows) + assert table.schema == expected.schema + assert table == expected -def test_should_raise_error_if_mismatching_schema() -> None: - rows = [Row({"A": 1, "B": 2}), Row({"A": 2, "B": "a"})] - with pytest.raises(SchemaMismatchError, match=r"Failed because at least two schemas didn't match."): +@pytest.mark.parametrize( + ("rows", "expected_error_msg"), + [([Row({"A": 1, "B": 2}), Row({"A": 2, "C": 4})], r"Could not find column\(s\) 'B'")], +) +def test_should_raise_error_if_unknown_column_names(rows: list[Row], expected_error_msg: str) -> None: + with pytest.raises(UnknownColumnNameError, match=expected_error_msg): Table.from_rows(rows) diff --git a/tests/safeds/data/tabular/containers/_table/test_split_rows.py b/tests/safeds/data/tabular/containers/_table/test_split_rows.py index 917970452..e366a65ba 100644 --- a/tests/safeds/data/tabular/containers/_table/test_split_rows.py +++ b/tests/safeds/data/tabular/containers/_table/test_split_rows.py @@ -1,7 +1,7 @@ import pandas as pd import pytest from safeds.data.tabular.containers import Table -from safeds.data.tabular.typing import Integer, Schema +from safeds.data.tabular.typing import Integer, Nothing, Schema @pytest.mark.parametrize( @@ -15,7 +15,7 @@ ), ( Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), - Table._from_pandas_dataframe(pd.DataFrame(), Schema({"col1": Integer(), "col2": Integer()})), + Table._from_pandas_dataframe(pd.DataFrame(), Schema({"col1": Nothing(), "col2": Nothing()})), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), 0, ), diff --git a/tests/safeds/data/tabular/containers/test_row.py b/tests/safeds/data/tabular/containers/test_row.py index 43fe0df93..011553f27 100644 --- a/tests/safeds/data/tabular/containers/test_row.py +++ b/tests/safeds/data/tabular/containers/test_row.py @@ -1,4 +1,5 @@ import re +from collections.abc import Callable from typing import Any import pandas as pd @@ -516,10 +517,50 @@ def test_should_contain_td_element_for_each_value(self, row: Row) -> None: class TestCopy: @pytest.mark.parametrize( "row", - [Row(), Row({"a": [3, 0.1]})], + [Row(), Row({"a": 3, "b": 4})], ids=["empty", "normal"], ) def test_should_copy_table(self, row: Row) -> None: copied = row._copy() assert copied == row assert copied is not row + + +class TestSortColumns: + @pytest.mark.parametrize( + ("row", "comparator", "expected"), + [ + ( + Row({"b": 1, "a": 2}), + lambda col1, col2: (col1[0] > col2[0]) - (col1[0] < col2[0]), + Row({"a": 2, "b": 1}), + ), + ( + Row({"a": 2, "b": 1}), + lambda col1, col2: (col2[0] > col1[0]) - (col2[0] < col1[0]), + Row({"b": 1, "a": 2}), + ), + (Row(), lambda col1, col2: (col1[0] > col2[0]) - (col1[0] < col2[0]), Row()), + ], + ids=[ + "sort descending by first element", + "sort ascending by first element", + "empty rows", + ], + ) + def test_should_sort_columns(self, row: Row, comparator: Callable[[tuple, tuple], int], expected: Row) -> None: + row = row.sort_columns(comparator) + assert row == expected + + @pytest.mark.parametrize( + "row", + [ + (Row({"b": 1, "a": 2})), + ], + ids=[ + "sort descending by first element", + ], + ) + def test_should_sort_table_out_of_place(self, row: Row) -> None: + sorted_row = row.sort_columns() + assert sorted_row != row diff --git a/tests/safeds/data/tabular/typing/test_column_type.py b/tests/safeds/data/tabular/typing/test_column_type.py index 5ad58915b..dfe9c13e6 100644 --- a/tests/safeds/data/tabular/typing/test_column_type.py +++ b/tests/safeds/data/tabular/typing/test_column_type.py @@ -1,53 +1,63 @@ +from collections.abc import Iterable +from typing import Any + import numpy as np +import pandas as pd import pytest from safeds.data.tabular.typing import ( Anything, Boolean, ColumnType, Integer, + Nothing, RealNumber, String, ) -class TestFromNumpyDataType: - # Test cases taken from https://numpy.org/doc/stable/reference/arrays.scalars.html#scalars +class TestDataType: @pytest.mark.parametrize( - ("data_type", "expected"), + ("data", "expected"), [ - # Boolean - (np.dtype(np.bool_), Boolean()), - # Number - (np.dtype(np.half), RealNumber()), - (np.dtype(np.single), RealNumber()), - (np.dtype(np.float_), RealNumber()), - (np.dtype(np.longfloat), RealNumber()), - # Int - (np.dtype(np.byte), Integer()), - (np.dtype(np.short), Integer()), - (np.dtype(np.intc), Integer()), - (np.dtype(np.int_), Integer()), - (np.dtype(np.longlong), Integer()), - (np.dtype(np.ubyte), Integer()), - (np.dtype(np.ushort), Integer()), - (np.dtype(np.uintc), Integer()), - (np.dtype(np.uint), Integer()), - (np.dtype(np.ulonglong), Integer()), - # String - (np.dtype(np.str_), String()), - (np.dtype(np.unicode_), String()), - (np.dtype(np.object_), String()), - (np.dtype(np.datetime64), String()), - (np.dtype(np.timedelta64), String()), + ([1, 2, 3], Integer(is_nullable=False)), + ([1.0, 2.0, 3.0], Integer(is_nullable=False)), + ([1.0, 2.5, 3.0], RealNumber(is_nullable=False)), + ([True, False, True], Boolean(is_nullable=False)), + (["a", "b", "c"], String(is_nullable=False)), + (["a", 1, 2.0], Anything(is_nullable=False)), + ([None, None, None], Nothing()), + ([None, 1, 2], Integer(is_nullable=True)), + ([1.0, 2.0, None], Integer(is_nullable=True)), + ([1.0, 2.5, None], RealNumber(is_nullable=True)), + ([True, False, None], Boolean(is_nullable=True)), + (["a", None, "b"], String(is_nullable=True)), + ], + ids=[ + "Integer", + "Real number .0", + "Real number", + "Boolean", + "String", + "Mixed", + "None", + "Nullable integer", + "Nullable RealNumber .0", + "Nullable RealNumber", + "Nullable Boolean", + "Nullable String", ], - ids=repr, ) - def test_should_create_column_type_from_numpy_data_type(self, data_type: np.dtype, expected: ColumnType) -> None: - assert ColumnType._from_numpy_data_type(data_type) == expected + def test_should_return_the_data_type(self, data: Iterable, expected: ColumnType) -> None: + assert ColumnType._data_type(pd.Series(data)) == expected - def test_should_raise_if_data_type_is_not_supported(self) -> None: - with pytest.raises(NotImplementedError): - ColumnType._from_numpy_data_type(np.dtype(np.void)) + @pytest.mark.parametrize( + ("data", "error_message"), + [(np.array([1, 2, 3], dtype=np.int16), "Unsupported numpy data type ''.")], + ids=["int16 not supported"], + ) + def test_should_throw_not_implemented_error_when_type_is_not_supported(self, data: Any, error_message: str) -> None: + with pytest.raises(NotImplementedError, match=error_message): + ColumnType._data_type(data) class TestRepr: diff --git a/tests/safeds/data/tabular/typing/test_schema.py b/tests/safeds/data/tabular/typing/test_schema.py index f4e827353..f6cd256d5 100644 --- a/tests/safeds/data/tabular/typing/test_schema.py +++ b/tests/safeds/data/tabular/typing/test_schema.py @@ -4,16 +4,17 @@ import pandas as pd import pytest -from safeds.data.tabular.typing import Boolean, ColumnType, Integer, RealNumber, Schema, String +from safeds.data.tabular.typing import Anything, Boolean, ColumnType, Integer, RealNumber, Schema, String from safeds.exceptions import UnknownColumnNameError if TYPE_CHECKING: + from collections.abc import Iterable from typing import Any class TestFromPandasDataFrame: @pytest.mark.parametrize( - ("dataframe", "expected"), + ("columns", "expected"), [ ( pd.DataFrame({"A": [True, False, True]}), @@ -25,7 +26,7 @@ class TestFromPandasDataFrame: ), ( pd.DataFrame({"A": [1.0, 2.0, 3.0]}), - Schema({"A": RealNumber()}), + Schema({"A": Integer()}), ), ( pd.DataFrame({"A": ["a", "b", "c"]}), @@ -33,24 +34,59 @@ class TestFromPandasDataFrame: ), ( pd.DataFrame({"A": [1, 2.0, "a", True]}), - Schema({"A": String()}), + Schema({"A": Anything()}), + ), + ( + pd.DataFrame({"A": [1.0, 2.5, 3.0]}), + Schema({"A": RealNumber()}), ), ( pd.DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"]}), Schema({"A": Integer(), "B": String()}), ), + ( + pd.DataFrame({"A": [True, False, None]}), + Schema({"A": Boolean(is_nullable=True)}), + ), + ( + pd.DataFrame({"A": [1, None, 3]}), + Schema({"A": Integer(is_nullable=True)}), + ), + ( + pd.DataFrame({"A": [1.0, None, 3.0]}), + Schema({"A": Integer(is_nullable=True)}), + ), + ( + pd.DataFrame({"A": [1.5, None, 3.0]}), + Schema({"A": RealNumber(is_nullable=True)}), + ), + ( + pd.DataFrame({"A": ["a", None, "c"]}), + Schema({"A": String(is_nullable=True)}), + ), + ( + pd.DataFrame({"A": [1, 2.0, None, True]}), + Schema({"A": Anything(is_nullable=True)}), + ), ], ids=[ + "boolean", "integer", - "real number", + "real number .0", "string", - "boolean", "mixed", + "real number", "multiple columns", + "boolean?", + "integer?", + "real number? .0", + "real number?", + "string?", + "Anything?", ], ) - def test_should_create_schema_from_pandas_dataframe(self, dataframe: pd.DataFrame, expected: Schema) -> None: - assert Schema._from_pandas_dataframe(dataframe) == expected + def test_should_create_schema_from_pandas_dataframe(self, columns: Iterable, expected: Schema) -> None: + assert Schema._from_pandas_dataframe(columns) == expected class TestRepr: @@ -235,6 +271,211 @@ def test_should_return_dict_for_schema(self, schema: Schema, expected: str) -> N assert schema.to_dict() == expected +class TestMergeMultipleSchemas: + @pytest.mark.parametrize( + ("schemas", "error_msg_regex"), + [([Schema({"Column1": Anything()}), Schema({"Column2": Anything()})], r"Could not find column\(s\) 'Column2'")], + ids=["different_column_names"], + ) + def test_should_raise_if_column_names_are_different(self, schemas: list[Schema], error_msg_regex: str) -> None: + with pytest.raises(UnknownColumnNameError, match=error_msg_regex): + Schema.merge_multiple_schemas(schemas) + + @pytest.mark.parametrize( + ("schemas", "expected"), + [ + ([Schema({"Column1": Integer()}), Schema({"Column1": Integer()})], Schema({"Column1": Integer()})), + ([Schema({"Column1": RealNumber()}), Schema({"Column1": RealNumber()})], Schema({"Column1": RealNumber()})), + ([Schema({"Column1": Boolean()}), Schema({"Column1": Boolean()})], Schema({"Column1": Boolean()})), + ([Schema({"Column1": String()}), Schema({"Column1": String()})], Schema({"Column1": String()})), + ([Schema({"Column1": Anything()}), Schema({"Column1": Anything()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": Integer()}), Schema({"Column1": RealNumber()})], Schema({"Column1": RealNumber()})), + ([Schema({"Column1": Integer()}), Schema({"Column1": Boolean()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": Integer()}), Schema({"Column1": String()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": Integer()}), Schema({"Column1": Anything()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": RealNumber()}), Schema({"Column1": Boolean()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": RealNumber()}), Schema({"Column1": String()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": RealNumber()}), Schema({"Column1": Anything()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": Boolean()}), Schema({"Column1": String()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": Boolean()}), Schema({"Column1": Anything()})], Schema({"Column1": Anything()})), + ([Schema({"Column1": String()}), Schema({"Column1": Anything()})], Schema({"Column1": Anything()})), + ( + [Schema({"Column1": Integer(is_nullable=True)}), Schema({"Column1": Integer()})], + Schema({"Column1": Integer(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber(is_nullable=True)}), Schema({"Column1": RealNumber()})], + Schema({"Column1": RealNumber(is_nullable=True)}), + ), + ( + [Schema({"Column1": Boolean(is_nullable=True)}), Schema({"Column1": Boolean()})], + Schema({"Column1": Boolean(is_nullable=True)}), + ), + ( + [Schema({"Column1": String(is_nullable=True)}), Schema({"Column1": String()})], + Schema({"Column1": String(is_nullable=True)}), + ), + ( + [Schema({"Column1": Anything(is_nullable=True)}), Schema({"Column1": Anything()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer(is_nullable=True)}), Schema({"Column1": RealNumber()})], + Schema({"Column1": RealNumber(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer(is_nullable=True)}), Schema({"Column1": Boolean()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer(is_nullable=True)}), Schema({"Column1": String()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer(is_nullable=True)}), Schema({"Column1": Anything()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber(is_nullable=True)}), Schema({"Column1": Boolean()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber(is_nullable=True)}), Schema({"Column1": String()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber(is_nullable=True)}), Schema({"Column1": Anything()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Boolean(is_nullable=True)}), Schema({"Column1": String()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Boolean(is_nullable=True)}), Schema({"Column1": Anything()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": String(is_nullable=True)}), Schema({"Column1": Anything()})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer()}), Schema({"Column1": Integer(is_nullable=True)})], + Schema({"Column1": Integer(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber()}), Schema({"Column1": RealNumber(is_nullable=True)})], + Schema({"Column1": RealNumber(is_nullable=True)}), + ), + ( + [Schema({"Column1": Boolean()}), Schema({"Column1": Boolean(is_nullable=True)})], + Schema({"Column1": Boolean(is_nullable=True)}), + ), + ( + [Schema({"Column1": String()}), Schema({"Column1": String(is_nullable=True)})], + Schema({"Column1": String(is_nullable=True)}), + ), + ( + [Schema({"Column1": Anything()}), Schema({"Column1": Anything(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer()}), Schema({"Column1": RealNumber(is_nullable=True)})], + Schema({"Column1": RealNumber(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer()}), Schema({"Column1": Boolean(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer()}), Schema({"Column1": String(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Integer()}), Schema({"Column1": Anything(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber()}), Schema({"Column1": Boolean(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber()}), Schema({"Column1": String(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": RealNumber()}), Schema({"Column1": Anything(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Boolean()}), Schema({"Column1": String(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": Boolean()}), Schema({"Column1": Anything(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ( + [Schema({"Column1": String()}), Schema({"Column1": Anything(is_nullable=True)})], + Schema({"Column1": Anything(is_nullable=True)}), + ), + ], + ids=[ + "Integer Integer", + "RealNumber RealNumber", + "Boolean Boolean", + "String String", + "Anything Anything", + "Integer RealNumber", + "Integer Boolean", + "Integer String", + "Integer Anything", + "RealNumber Boolean", + "RealNumber String", + "RealNumber Anything", + "Boolean String", + "Boolean Anything", + "String Anything", + "Integer(null) Integer", + "RealNumber(null) RealNumber", + "Boolean(null) Boolean", + "String(null) String", + "Anything(null) Anything", + "Integer(null) RealNumber", + "Integer(null) Boolean", + "Integer(null) String", + "Integer(null) Anything", + "RealNumber(null) Boolean", + "RealNumber(null) String", + "RealNumber(null) Anything", + "Boolean(null) String", + "Boolean(null) Anything", + "String(null) Anything", + "Integer Integer(null)", + "RealNumber RealNumber(null)", + "Boolean Boolean(null)", + "String String(null)", + "Anything Anything(null)", + "Integer RealNumber(null)", + "Integer Boolean(null)", + "Integer String(null)", + "Integer Anything(null)", + "RealNumber Boolean(null)", + "RealNumber String(null)", + "RealNumber Anything(null)", + "Boolean String(null)", + "Boolean Anything(null)", + "String Anything(null)", + ], + ) + def test_should_return_merged_schema(self, schemas: list[Schema], expected: Schema) -> None: + assert Schema.merge_multiple_schemas(schemas) == expected + schemas.reverse() + assert ( + Schema.merge_multiple_schemas(schemas) == expected + ) # test the reversed list because the first parameter is handled differently + + class TestReprMarkdown: @pytest.mark.parametrize( ("schema", "expected"),