Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERROR when using multi-gpu training #18

Closed
tsWen0309 opened this issue Nov 19, 2023 · 7 comments
Closed

ERROR when using multi-gpu training #18

tsWen0309 opened this issue Nov 19, 2023 · 7 comments

Comments

@tsWen0309
Copy link

Hi, thanks for sharing your work. I can't train your model on my GPUS- two 4090. Is there any solution?
image

@LYMDLUT
Copy link

LYMDLUT commented Nov 22, 2023

Hi, thanks for sharing your work. I can't train your model on my GPUS- two 4090. Is there any solution? image

I have met the same bug

@christopher-beckham
Copy link

christopher-beckham commented Nov 26, 2023

I had a similar issue, make sure you're using PyTorch 1.12 as per the environment.yml file.

@tsWen0309
Copy link
Author

Hi, thanks for sharing your work. I can't train your model on my GPUS- two 4090. Is there any solution? image

I have met the same bug

I solved the problem by using PyTorch 1.13.1 with cuda 11.6 cudnn8.3.2.0

@RachelTeamo
Copy link

I tried this code, when I set --nproc_per_node=1, the code works fine, but once --nproc_per_node>1 (e.g.,--nproc_per_node=2), this code doesn't work and reports the same error as in the picture, is there a solution for this? My torch version is 2.1 because I'm using H800GPU.

@marcoamonteiro
Copy link

@RachelTeamo to run on torch 2.1 replace line 89 in training_loop.py:
ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
with
ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[dist.get_rank()], broadcast_buffers=False)

I couldn't find any info in the PyTorch docs warning about the change in DDP API but this solved the issue for me.

@RachelTeamo
Copy link

Thanks for your suggestions, I replaced the code follow your suggestion. But the issue still exist.

@Shiien
Copy link

Shiien commented Mar 22, 2024

I solved this by ignoring line 79-84

# if dist.get_rank() == 0:
#     with torch.no_grad():
#         images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
#         sigma = torch.ones([batch_gpu], device=device)
#         labels = torch.zeros([batch_gpu, net.label_dim], device=device)
#         misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)

And set ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[dist.get_rank()], broadcast_buffers=False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants