From 07e6adc594a953ab2bf57a08f038f7230cb5573e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Severin=20Paul=20H=C3=B6fer?= <> Date: Wed, 10 May 2023 12:14:57 +0200 Subject: [PATCH] Remove unneccessary copied in OneHotEncoder --- .../transformation/_one_hot_encoder.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/safeds/data/tabular/transformation/_one_hot_encoder.py b/src/safeds/data/tabular/transformation/_one_hot_encoder.py index ab6b2186d..b0f440009 100644 --- a/src/safeds/data/tabular/transformation/_one_hot_encoder.py +++ b/src/safeds/data/tabular/transformation/_one_hot_encoder.py @@ -142,10 +142,6 @@ def transform(self, table: Table) -> Table: if len(missing_columns) > 0: raise UnknownColumnNameError(list(missing_columns)) - # Make a copy of the table: - # TODO: change to copy method once implemented - new_table = table.remove_columns([]) - encoded_values = {} for new_column_name in self._value_to_column.values(): encoded_values[new_column_name] = [0.0 for _ in range(table.number_of_rows)] @@ -162,13 +158,10 @@ def transform(self, table: Table) -> Table: encoded_values[new_column_name][i] = 1.0 for new_column in self._column_names[old_column_name]: - new_table = new_table.add_column(Column(new_column, encoded_values[new_column])) - - # Drop corresponding old columns: - new_table = new_table.remove_columns(list(self._column_names.keys())) + table = table.add_column(Column(new_column, encoded_values[new_column])) + # New columns may not be sorted: column_names = [] - for name in table.column_names: if name not in self._column_names.keys(): column_names.append(name) @@ -176,7 +169,14 @@ def transform(self, table: Table) -> Table: column_names.extend( [f_name for f_name in self._value_to_column.values() if f_name.startswith(name)], ) - return new_table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name)) + + # Drop old, non-encoded columns: + # (Don't do this earlier - we need the old column nams for sorting, + # plus we need to prevent the table from possibly having 0 columns temporarily.) + table = table.remove_columns(list(self._column_names.keys())) + + # Apply sorting and return: + return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name)) # noinspection PyProtectedMember def inverse_transform(self, transformed_table: Table) -> Table: @@ -204,10 +204,6 @@ def inverse_transform(self, transformed_table: Table) -> Table: if self._column_names is None or self._value_to_column is None: raise TransformerNotFittedError - # Make a copy of the table: - # TODO: change to copy method once implemented - new_table = transformed_table.remove_columns([]) - original_columns = {} for original_column_name in self._column_names: original_columns[original_column_name] = [None for _ in range(transformed_table.number_of_rows)] @@ -218,11 +214,10 @@ def inverse_transform(self, transformed_table: Table) -> Table: if transformed_table.get_column(constructed_column)[i] == 1.0: original_columns[original_column_name][i] = value - for column_name, encoded_column in original_columns.items(): - new_table = new_table.add_column(Column(column_name, encoded_column)) + table = transformed_table - # Drop old column names: - new_table = new_table.remove_columns(list(self._value_to_column.values())) + for column_name, encoded_column in original_columns.items(): + table = table.add_column(Column(column_name, encoded_column)) column_names = [ ( @@ -236,9 +231,13 @@ def inverse_transform(self, transformed_table: Table) -> Table: ][0] ] ) - for name in transformed_table.column_names + for name in table.column_names ] - return new_table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name)) + + # Drop old column names: + table = table.remove_columns(list(self._value_to_column.values())) + + return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name)) def is_fitted(self) -> bool: """