-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Thread Worker Option to ThreadDataLoader #4252
Adding Thread Worker Option to ThreadDataLoader #4252
Conversation
Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk>
for more information, see https://pre-commit.ci
I'm still dealing with a error I've had that's shown up as the failures. For those transforms that are thread-safe this would be an enhancement, we have those that aren't safe as inheriting from |
Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk>
for more information, see https://pre-commit.ci
@ericspod @Nic-Ma I tried this PR with the fast training test, it works fine and it's slightly faster with my desktop (34s vs 36s)...this is how I run it: diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py
index 4dbb70b..8271a47 100644
--- a/tests/test_integration_fast_train.py
+++ b/tests/test_integration_fast_train.py
@@ -151,8 +151,8 @@ class IntegrationFastTrain(DistTestCase):
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5)
# disable multi-workers because `ThreadDataLoader` works with multi-threads
- train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True)
- val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)
+ train_loader = ThreadDataLoader(train_ds, num_workers=2, use_thread_workers=True, batch_size=4, shuffle=True)
+ val_loader = ThreadDataLoader(val_ds, num_workers=2, use_thread_workers=True, batch_size=1)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True)
model = UNet(
I think we should merge this one... |
We do know there is a thread-safety issue with many transforms which use the random state, things can sometimes train faster but there will be race issues which may preclude reproducibility. I need to have time to consider possible solutions to this, the purpose of this addition was to permit faster operation in some cases but also allows us to debug transform sequences in one process with a single worker thread. |
Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk>
Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It overall looks good to me.
Put some comments inline.
Thanks.
/build |
Signed-off-by: Eric Kerfoot eric.kerfoot@kcl.ac.uk
Description
Adds the ability to run workers in ThreadDataLoader as threads instead of processes. This is a fix for Windows when we have issues with its process spawning semantics.
Status
Work in progress
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.