Skip to content

Commit

Permalink
fix: move index tensors to default device
Browse files Browse the repository at this point in the history
A `torch.LongTensor` seems to always be created on the CPU.
  • Loading branch information
lars-reimann committed Nov 26, 2024
1 parent c24b6d5 commit fbf3973
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/safeds/data/labeled/containers/_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ def split(
upper_bound=_ClosedBound(1),
)

_init_default_device()

first_dataset: ImageDataset[Out_co] = copy.copy(self)
second_dataset: ImageDataset[Out_co] = copy.copy(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def _data_conversion_output(
output = torch.zeros(len(input_data), len(one_hot_encoder._get_names_of_added_columns()))
output[torch.arange(len(input_data)), output_data] = 1

im_dataset: ImageDataset[Column] = ImageDataset[Column].__new__(ImageDataset)
im_dataset: ImageDataset[Column] = object.__new__(ImageDataset)
im_dataset._output = _ColumnAsTensor._from_tensor(output, column_name, one_hot_encoder)
im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data))))
im_dataset._shuffle_tensor_indices = torch.arange(len(input_data))
im_dataset._shuffle_after_epoch = False
im_dataset._batch_size = 1
im_dataset._next_batch_index = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def _data_conversion_output(self, input_data: ImageList, output_data: Tensor) ->
output = torch.zeros(len(input_data), len(column_names))
output[torch.arange(len(input_data)), output_data] = 1

im_dataset: ImageDataset[Table] = ImageDataset[Table].__new__(ImageDataset)
im_dataset: ImageDataset[Table] = object.__new__(ImageDataset)
im_dataset._output = _TableAsTensor._from_tensor(output, column_names)
im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data))))
im_dataset._shuffle_tensor_indices = torch.arange(len(input_data))
im_dataset._shuffle_after_epoch = False
im_dataset._batch_size = 1
im_dataset._next_batch_index = 0
Expand Down

0 comments on commit fbf3973

Please sign in to comment.