Skip to content

Commit

Permalink
Remove unneccessary copied in OneHotEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Severin Paul Höfer committed May 10, 2023
1 parent f2e347d commit 07e6adc
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ def transform(self, table: Table) -> Table:
if len(missing_columns) > 0:
raise UnknownColumnNameError(list(missing_columns))

# Make a copy of the table:
# TODO: change to copy method once implemented
new_table = table.remove_columns([])

encoded_values = {}
for new_column_name in self._value_to_column.values():
encoded_values[new_column_name] = [0.0 for _ in range(table.number_of_rows)]
Expand All @@ -162,21 +158,25 @@ def transform(self, table: Table) -> Table:
encoded_values[new_column_name][i] = 1.0

for new_column in self._column_names[old_column_name]:
new_table = new_table.add_column(Column(new_column, encoded_values[new_column]))

# Drop corresponding old columns:
new_table = new_table.remove_columns(list(self._column_names.keys()))
table = table.add_column(Column(new_column, encoded_values[new_column]))

# New columns may not be sorted:
column_names = []

for name in table.column_names:
if name not in self._column_names.keys():
column_names.append(name)
else:
column_names.extend(
[f_name for f_name in self._value_to_column.values() if f_name.startswith(name)],
)
return new_table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

# Drop old, non-encoded columns:
# (Don't do this earlier - we need the old column nams for sorting,
# plus we need to prevent the table from possibly having 0 columns temporarily.)
table = table.remove_columns(list(self._column_names.keys()))

# Apply sorting and return:
return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

# noinspection PyProtectedMember
def inverse_transform(self, transformed_table: Table) -> Table:
Expand Down Expand Up @@ -204,10 +204,6 @@ def inverse_transform(self, transformed_table: Table) -> Table:
if self._column_names is None or self._value_to_column is None:
raise TransformerNotFittedError

# Make a copy of the table:
# TODO: change to copy method once implemented
new_table = transformed_table.remove_columns([])

original_columns = {}
for original_column_name in self._column_names:
original_columns[original_column_name] = [None for _ in range(transformed_table.number_of_rows)]
Expand All @@ -218,11 +214,10 @@ def inverse_transform(self, transformed_table: Table) -> Table:
if transformed_table.get_column(constructed_column)[i] == 1.0:
original_columns[original_column_name][i] = value

for column_name, encoded_column in original_columns.items():
new_table = new_table.add_column(Column(column_name, encoded_column))
table = transformed_table

# Drop old column names:
new_table = new_table.remove_columns(list(self._value_to_column.values()))
for column_name, encoded_column in original_columns.items():
table = table.add_column(Column(column_name, encoded_column))

column_names = [
(
Expand All @@ -236,9 +231,13 @@ def inverse_transform(self, transformed_table: Table) -> Table:
][0]
]
)
for name in transformed_table.column_names
for name in table.column_names
]
return new_table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

# Drop old column names:
table = table.remove_columns(list(self._value_to_column.values()))

return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

def is_fitted(self) -> bool:
"""
Expand Down

0 comments on commit 07e6adc

Please sign in to comment.