You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I just noticed that torch>=1.9.0 seems to have brought a change that leads to an exception where none was thrown before.
Setup
Small-text: 1.0.0a4
Torch: 1.9.1
Description
The following error occurs when executing examples/pytorch_multiclass_classification.py:
File "/my/path/.pyenv/versions/3.8.2/lib/python3.8/runpy.py", line 193, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/my/path/.pyenv/versions/3.8.2/lib/python3.8/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/my/path/small-text/examples/pytorch_multiclass_classification.py", line 102, in <module>
main()
File "/my/path/small-text/examples/pytorch_multiclass_classification.py", line 52, in main
active_learner.initialize_data(labeled_indices, y_initial)
File "/my/path/small-text/small_text/active_learner.py", line 141, in initialize_data
self._retrain(x_indices_validation=x_indices_validation)
File "/my/path/small-text/small_text/active_learner.py", line 384, in _retrain
self._clf.fit(x)
File "/my/path/small-text/small_text/integrations/pytorch/classifiers/kimcnn.py", line 176, in fit
return self._fit_main(sub_train, sub_valid)
File "/my/path/small-text/small_text/integrations/pytorch/classifiers/kimcnn.py", line 198, in _fit_main
res = self._train(sub_train, sub_valid, tmp_dir)
File "/my/path/small-text/small_text/integrations/pytorch/classifiers/kimcnn.py", line 218, in _train
train_loss, train_acc = self._train_func(sub_train)
File "/my/path/small-text/small_text/integrations/pytorch/classifiers/kimcnn.py", line 263, in _train_func
for i, (text, cls) in enumerate(train_iter):
File "/my/path/.local/share/virtualenvs/myvenv-123/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/my/path/.local/share/virtualenvs/myvenv-123/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 560, in _next_data
index = self._next_index() # may raise StopIteration
File "/my/path/.local/share/virtualenvs/myvenv-123/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 512, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "/my/path/.local/share/virtualenvs/myvenv-123/lib/python3.8/site-packages/torch/utils/data/sampler.py", line 226, in __iter__
for idx in self.sampler:
File "/my/path/.local/share/virtualenvs/myvenv-123/lib/python3.8/site-packages/torch/utils/data/sampler.py", line 124, in __iter__
yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
Solution
I have yet to investigate what would be the optimal solution for this.
A workaround is to downgrade pytorch (and torchtext): pip install torch==1.8.1 torchtext==0.9.1
I just noticed that
torch>=1.9.0
seems to have brought a change that leads to an exception where none was thrown before.Setup
Small-text: 1.0.0a4
Torch: 1.9.1
Description
The following error occurs when executing
examples/pytorch_multiclass_classification.py
:Solution
I have yet to investigate what would be the optimal solution for this.
A workaround is to downgrade pytorch (and torchtext):
pip install torch==1.8.1 torchtext==0.9.1
More information
pytorch/pytorch/issues/44714
The text was updated successfully, but these errors were encountered: