From fbf3973fab421a6ad017dbac615d33bc7b40dd32 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Tue, 26 Nov 2024 16:43:11 +0100 Subject: [PATCH] fix: move index tensors to default device A `torch.LongTensor` seems to always be created on the CPU. --- src/safeds/data/labeled/containers/_image_dataset.py | 2 ++ .../ml/nn/converters/_input_converter_image_to_column.py | 4 ++-- .../ml/nn/converters/_input_converter_image_to_table.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/safeds/data/labeled/containers/_image_dataset.py b/src/safeds/data/labeled/containers/_image_dataset.py index c852c11a8..a64bff83a 100644 --- a/src/safeds/data/labeled/containers/_image_dataset.py +++ b/src/safeds/data/labeled/containers/_image_dataset.py @@ -356,6 +356,8 @@ def split( upper_bound=_ClosedBound(1), ) + _init_default_device() + first_dataset: ImageDataset[Out_co] = copy.copy(self) second_dataset: ImageDataset[Out_co] = copy.copy(self) diff --git a/src/safeds/ml/nn/converters/_input_converter_image_to_column.py b/src/safeds/ml/nn/converters/_input_converter_image_to_column.py index da03c03a0..284316981 100644 --- a/src/safeds/ml/nn/converters/_input_converter_image_to_column.py +++ b/src/safeds/ml/nn/converters/_input_converter_image_to_column.py @@ -43,9 +43,9 @@ def _data_conversion_output( output = torch.zeros(len(input_data), len(one_hot_encoder._get_names_of_added_columns())) output[torch.arange(len(input_data)), output_data] = 1 - im_dataset: ImageDataset[Column] = ImageDataset[Column].__new__(ImageDataset) + im_dataset: ImageDataset[Column] = object.__new__(ImageDataset) im_dataset._output = _ColumnAsTensor._from_tensor(output, column_name, one_hot_encoder) - im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data)))) + im_dataset._shuffle_tensor_indices = torch.arange(len(input_data)) im_dataset._shuffle_after_epoch = False im_dataset._batch_size = 1 im_dataset._next_batch_index = 0 diff --git a/src/safeds/ml/nn/converters/_input_converter_image_to_table.py b/src/safeds/ml/nn/converters/_input_converter_image_to_table.py index fe5ce6ad4..427e5ad76 100644 --- a/src/safeds/ml/nn/converters/_input_converter_image_to_table.py +++ b/src/safeds/ml/nn/converters/_input_converter_image_to_table.py @@ -33,9 +33,9 @@ def _data_conversion_output(self, input_data: ImageList, output_data: Tensor) -> output = torch.zeros(len(input_data), len(column_names)) output[torch.arange(len(input_data)), output_data] = 1 - im_dataset: ImageDataset[Table] = ImageDataset[Table].__new__(ImageDataset) + im_dataset: ImageDataset[Table] = object.__new__(ImageDataset) im_dataset._output = _TableAsTensor._from_tensor(output, column_names) - im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data)))) + im_dataset._shuffle_tensor_indices = torch.arange(len(input_data)) im_dataset._shuffle_after_epoch = False im_dataset._batch_size = 1 im_dataset._next_batch_index = 0