Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train with Multi Gpu #11

Closed
kevinchow1993 opened this issue May 7, 2021 · 5 comments
Closed

Train with Multi Gpu #11

kevinchow1993 opened this issue May 7, 2021 · 5 comments
Labels
bug Something isn't working enhancement New feature or request in-depth Deep and valuable discussion

Comments

@kevinchow1993
Copy link

当前代码在使用多卡训练时会出现 stop iteration的错,原因是某个卡上分配的数据比其他卡少,根本原因是由于在active_datasets.py中的create_X_L_file() 和 create_X_U_file(),多个卡会同时写同一个txt文件,导致先写完这个文件的卡创建dataloader时读取到了不全的txt.
解决方案:

  1. 在这两个函数中写文件时先随机sleep一小段时间,错开写文件的时间
    time.sleep(random.uniform(0,3))
    if not osp.exists(save_path):
        mmcv.mkdir_or_exist(save_folder)
        np.savetxt(save_path, ann[X_L_single], fmt='%s')
  1. 在tools/train.py中每次create_xx_file后,同步各个卡的线程。加上这句
          if dist.is_initialized():
              torch.distributed.barrier()
@yuantn yuantn added bug Something isn't working in-depth Deep and valuable discussion enhancement New feature or request labels May 7, 2021
@yuantn
Copy link
Owner

yuantn commented May 7, 2021

非常感谢!这对于在多个 GPU 上训练将会非常有帮助。


Many thanks! This would be very useful for training on multiple GPUs.

@chufengt
Copy link

@kevinchow1993
您好,我根据您的描述修改了代码,但是还是遇到StopIteration的错误,请问您有什么建议呢?

  1. create_X_U_filecreate_X_L_file修改以下部分
time.sleep(random.uniform(0,3))  
save_path = save_folder + '/trainval_X_U_' + year + '.txt'  
if not osp.exists(save_path):  
    mmcv.mkdir_or_exist(save_folder)  
    np.savetxt(save_path, ann[X_U_single], fmt='%s')  
X_U_path.append(save_path)  
  1. train.py中每一处create_xx_file 后增加同步线程的代码

@yuantn
Copy link
Owner

yuantn commented Aug 23, 2021

如果这种方式不起作用的话,我认为您也可以试试这样修改:
tools/train.pycreate_X_L_filecreate_X_U_file 之前添加一行条件:

if torch.cuda.current_devices() == 0:

create_X_L_filecreate_X_U_file 之后再添加同步线程:

if dist.is_initialized():
    torch.distributed.barrier()

If it does not work, I think you can also try like this:
Add a condition before create_X_L_file and create_X_U_file in tools/train.py:

if torch.cuda.current_devices() == 0:

Add threads synchronization after create_X_L_file and create_X_U_file:

if dist.is_initialized():
     torch.distributed.barrier()

@chufengt
Copy link

chufengt commented Aug 23, 2021

@yuantn
我进行了如下修改:

if torch.cuda.current_device() == 0:
    cfg = create_X_L_file(cfg, X_L, all_anns, cycle)
if dist.is_initialized():
    torch.distributed.barrier()

会在第一次save checkpoint时卡住

@yuantn
Copy link
Owner

yuantn commented Aug 23, 2021

是否还需要把返回的 cfg 分配给每张 GPU 上?具体如下:

if torch.cuda.current_device() == 0:
    cfg_save = create_X_L_file(cfg, X_L, all_anns, cycle)
    joblib.dump(cfg_save, 'cfg_save.tmp')
if dist.is_initialized():
    torch.distributed.barrier()
cfg = joblib.load("cfg_save.tmp")

Is it necessary to distribute the return cfg to each GPU? The code is as follows:

if torch.cuda.current_device() == 0:
    cfg_save = create_X_L_file(cfg, X_L, all_anns, cycle)
    joblib.dump(cfg_save, 'cfg_save.tmp')
if dist.is_initialized():
    torch.distributed.barrier()
cfg = joblib.load("cfg_save.tmp")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request in-depth Deep and valuable discussion
Projects
None yet
Development

No branches or pull requests

3 participants