Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Process Group for tensors shared across processes #21449

Closed
wants to merge 13 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ Allocator* get() {
return &allocator;
}

void recordStreamMasqueradingAsCUDA(void *ptr, HIPStreamMasqueradingAsCUDA stream) {
HIPCachingAllocator::recordStream(ptr, stream.hip_stream());
void recordStreamMasqueradingAsCUDA(
void *ptr, HIPStreamMasqueradingAsCUDA stream, bool suppressError) {
HIPCachingAllocator::recordStream(ptr, stream.hip_stream(), suppressError);
}

} // namespace HIPCachingAllocatorMasqueradingAsCUDA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace c10 { namespace hip {
namespace HIPCachingAllocatorMasqueradingAsCUDA {

Allocator* get();
C10_HIP_API void recordStreamMasqueradingAsCUDA(void *ptr, HIPStreamMasqueradingAsCUDA stream);
C10_HIP_API void recordStreamMasqueradingAsCUDA(
void *ptr, HIPStreamMasqueradingAsCUDA stream, bool suppressError=false);

} // namespace HIPCachingAllocatorMasqueradingAsCUDA
}} // namespace c10::hip
31 changes: 21 additions & 10 deletions c10/cuda/CUDACachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,22 +376,33 @@ struct THCCachingAllocator
cacheInfoAux(small_blocks, dev_id, total, largest);
}

void recordStream(void* ptr, cuda::CUDAStream stream)
void recordStream(void* ptr, cuda::CUDAStream stream, bool suppressError=false)
{
// Empty tensor's storage().data() might be a null ptr. As there is no
// blocks associated with those tensors, it is fine to do nothing here.
if (ptr) {
std::lock_guard<std::recursive_mutex> lock(mutex);
Block* block = find_allocated_block(ptr);
if (!block) {
AT_ERROR("invalid device pointer: ", ptr);
}
if (stream.stream() == block->stream) {
// ignore uses on the allocation stream, since those don't require any
// special synchronization
return;
// In some cases (e.g., tensor loaded from blob, or shared by another
// process), this CUDACachingAllocator does not know about the ptr,
// and the caller of this function might not have enough context to
// check where the tensor is originated. One option is to expose a new
// API from CUDACachingAllocator to check whether it knows about the
// ptr, but it would force other use cases to unnecessarily do two
// map look up (one check + one recordStream). Hence, we provide a
// suppressError argument to avoid error and two lookups.
if (!suppressError) {
AT_ERROR("invalid device pointer: ", ptr);
}
} else {
if (stream.stream() == block->stream) {
// ignore uses on the allocation stream, since those don't require any
// special synchronization
return;
}
block->stream_uses.insert(stream);
}
block->stream_uses.insert(stream);
}
}

Expand Down Expand Up @@ -651,9 +662,9 @@ void* getBaseAllocation(void *ptr, size_t *size)
return caching_allocator.getBaseAllocation(ptr, size);
}

void recordStream(void *ptr, cuda::CUDAStream stream)
void recordStream(void *ptr, cuda::CUDAStream stream, bool suppressError)
{
caching_allocator.recordStream(ptr, stream);
caching_allocator.recordStream(ptr, stream, suppressError);
}

std::mutex* getFreeMutex()
Expand Down
2 changes: 1 addition & 1 deletion c10/cuda/CUDACachingAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ C10_CUDA_API Allocator* get();
C10_CUDA_API void emptyCache();
C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size);
C10_CUDA_API void recordStream(void *ptr, CUDAStream stream);
C10_CUDA_API void recordStream(void *ptr, CUDAStream stream, bool suppressError=false);
C10_CUDA_API uint64_t currentMemoryAllocated(int device);
C10_CUDA_API uint64_t maxMemoryAllocated(int device);
C10_CUDA_API void resetMaxMemoryAllocated(int device);
Expand Down
189 changes: 189 additions & 0 deletions test/test_c10d_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import tempfile
import unittest

import torch
import torch.distributed as c10d
import torch.multiprocessing as mp

from common_cuda import TEST_MULTIGPU
from common_utils import TestCase, load_tests, run_tests
from common_utils import NO_MULTIPROCESSING_SPAWN

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests

if not c10d.is_available():
print('c10d not available, skipping tests')
sys.exit(0)


if NO_MULTIPROCESSING_SPAWN:
print('spawn not available, skipping tests')
sys.exit(0)


NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")


class ProcessGroupShareTensorTest(TestCase):

@property
def world_size(self):
return 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't world_size = 2 be a much more idiomatic way to write this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let me make the change.


@classmethod
def opts(cls, threads=2):
opts = c10d.ProcessGroupGloo.Options()
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
opts.timeout = 5.0
opts.threads = threads
return opts

@classmethod
def _init_pg_gloo(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
return c10d.ProcessGroupGloo(
store, rank, world_size, ProcessGroupShareTensorTest.opts())

@classmethod
def _init_pg_nccl(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
return c10d.ProcessGroupNCCL(store, rank, world_size)

@classmethod
def assert_equal(cls, expected, value):
assert (expected == value).all().item() == 1, (
"Expecting tensor value {} but got {}."
).format(expected, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit suboptimal, because the way this will look to the test runner is that the multiprocessing spawned subprocess died unceremoniously, and you have to look at the tea leaves to see, "Ah, it failed due to an assert." The way that test_multiprocessing goes about arranging this, is to have the parent process (aka the test runner) always responsible for doing the actual asserts (where you can do a normal self.assertEquals), and just have the child process pass back the tensors for the parent process to do checking on (or, if the child process must do the test, passing back a boolean saying if the result worked or not.)

What does the test suite output look like when this assert fails?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I guess because you are using multiprocessing.spawn, the exceptions will get propagated backwards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I don't like reinventing the wheel either. Will make the change.

What does the test suite output look like when this assert fails?

I have only tried error cases in local pytest runs, where the error messages are clear. Let me add a deliberate error to see what CI test shows.


# Why classmethod? multiprocessing cannot pickle TestCase subclass when in
# spawn mode. See https://bugs.python.org/issue33884.
Copy link
Contributor

@ezyang ezyang Jun 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, test_multiprocessing.py does this by just having the test runner methods as honest to goodness top-level functions. This is just an informational comment, since a class method is just as good.

@classmethod
def _test_broadcast_process(
cls, rank, filename, shared_tensors, world_size, init_pg):
pg = init_pg(rank, filename, world_size)
xs = [shared_tensors[rank]]
pg.broadcast(xs).wait()
cls.assert_equal(torch.zeros(2, 2), xs[0].to("cpu"))

@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
def test_shared_broadcast_gloo(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
shared_tensors = [torch.ones(2, 2).to(i) * i for i in range(2)]
mp.spawn(ProcessGroupShareTensorTest._test_broadcast_process,
args=(file.name,
shared_tensors,
self.world_size,
ProcessGroupShareTensorTest._init_pg_gloo),
nprocs=self.world_size,
join=True)

@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@unittest.skipIf(NO_NCCL, "NCCL needed")
def test_shared_broadcast_nccl(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
shared_tensors = [torch.ones(2, 2).to(i) * i for i in range(2)]
mp.spawn(ProcessGroupShareTensorTest._test_broadcast_process,
args=(file.name,
shared_tensors,
self.world_size,
ProcessGroupShareTensorTest._init_pg_nccl),
nprocs=self.world_size,
join=True)

@classmethod
def _test_allreduce_process(
cls, rank, filename, shared_tensors, world_size, init_pg):
pg = init_pg(rank, filename, world_size)
xs = [shared_tensors[rank]]
pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait()
cls.assert_equal(torch.ones(2, 2) * 2, xs[0].to("cpu"))

@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
def test_shared_allreduce_gloo(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
shared_tensors = [torch.ones(2, 2).to(i) for i in range(2)]
mp.spawn(ProcessGroupShareTensorTest._test_allreduce_process,
args=(file.name,
shared_tensors,
self.world_size,
ProcessGroupShareTensorTest._init_pg_gloo),
nprocs=self.world_size,
join=True)

@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@unittest.skipIf(NO_NCCL, "NCCL needed")
def test_shared_allreduce_nccl(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
shared_tensors = [torch.ones(2, 2).to(i) for i in range(2)]
mp.spawn(ProcessGroupShareTensorTest._test_allreduce_process,
args=(file.name,
shared_tensors,
self.world_size,
ProcessGroupShareTensorTest._init_pg_nccl),
nprocs=self.world_size,
join=True)

@classmethod
def _test_reduce_process(
cls, rank, filename, shared_tensors, world_size, init_pg):
pg = init_pg(rank, filename, world_size)
x = shared_tensors[rank]
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
if rank == 0:
cls.assert_equal(torch.ones(2, 2) * 2, x.to("cpu"))
else:
cls.assert_equal(torch.ones(2, 2), x.to("cpu"))

@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@unittest.skipIf(NO_NCCL, "NCCL needed")
def test_shared_reduce_nccl(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
shared_tensors = [torch.ones(2, 2).to(i) for i in range(2)]
mp.spawn(ProcessGroupShareTensorTest._test_reduce_process,
args=(file.name,
shared_tensors,
self.world_size,
ProcessGroupShareTensorTest._init_pg_nccl),
nprocs=self.world_size,
join=True)

@classmethod
def _test_allgather_process(
cls, rank, filename, shared_tensors, world_size, init_pg):
pg = init_pg(rank, filename, world_size)
xs = [shared_tensors[rank]]
ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]]
pg.allgather(ys, xs).wait()
for i in range(world_size):
cls.assert_equal(torch.ones(2, 2) * i, ys[0][i].to("cpu"))

@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
def test_shared_allgather_gloo(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
shared_tensors = [torch.ones(2, 2).to(i) * i for i in range(2)]
mp.spawn(ProcessGroupShareTensorTest._test_allgather_process,
args=(file.name,
shared_tensors,
self.world_size,
ProcessGroupShareTensorTest._init_pg_gloo),
nprocs=self.world_size,
join=True)

@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@unittest.skipIf(NO_NCCL, "NCCL needed")
def test_shared_allgather_nccl(self):
with tempfile.NamedTemporaryFile(delete=False) as file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking, it's better not to leak temporary files (which is what delete=False does, unless you explicitly delete it later). What is the reasoning for using delete=false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but if I don't add delete=False, it complains it cannot find the tmp file on delete. I was thinking maybe with context and tempfile both try to delete the file when exiting the context, as a result, one of them hits the error. But let me double check if that is the case, and will ad a comment if yes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be extremely surprised if that were the case. Here's a simple test:

macbook-pro-116:~ ezyang$ cat test.py
import tempfile

with tempfile.NamedTemporaryFile() as f:
    pass
macbook-pro-116:~ ezyang$ python test.py
macbook-pro-116:~ ezyang$ 

shared_tensors = [torch.ones(2, 2).to(i) * i for i in range(2)]
mp.spawn(ProcessGroupShareTensorTest._test_allgather_process,
args=(file.name,
shared_tensors,
self.world_size,
ProcessGroupShareTensorTest._init_pg_nccl),
nprocs=self.world_size,
join=True)


if __name__ == '__main__':
run_tests()
4 changes: 2 additions & 2 deletions torch/lib/c10d/ProcessGroupGloo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ void initializeStreamsEvents(
// `tensors` are created on a different stream. Hence, they must record
// new streams in this Work to prevent being freed before the Work finishes.
c10::cuda::CUDACachingAllocator::recordStream(
tensors[i].storage().data(), streams[i]);
tensors[i].storage().data(), streams[i], true);
}
}

Expand Down Expand Up @@ -205,7 +205,7 @@ void initializeStreamsEvents(
// new streams in this Work to prevent being freed before the Work
// finishes.
c10::cuda::CUDACachingAllocator::recordStream(
tensor.storage().data(), streams[i]);
tensor.storage().data(), streams[i], true);
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
//
// See [Sync Streams].
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].storage().data(), ncclStream);
inputs[i].storage().data(), ncclStream, true);

C10D_NCCL_CHECK(fn(
inputs[i],
Expand Down Expand Up @@ -529,7 +529,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
[&] (at::Tensor& input, at::Tensor& output,
ncclComm_t comm, at::cuda::CUDAStream& stream) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data(), stream
output.storage().data(), stream, true
);
return ncclAllGather(
input.data_ptr(),
Expand All @@ -548,7 +548,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
for (size_t j = 0; j < outputTensors[0].size(); ++j) {
// See [Sync Streams].
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors[i][j].storage().data(), ncclStreams[i]);
outputTensors[i][j].storage().data(), ncclStreams[i], true);

outputTensors[i][j].copy_(outputFlattened[i][j], true);
}
Expand All @@ -572,7 +572,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
[&] (at::Tensor& input, at::Tensor& output,
ncclComm_t comm, at::cuda::CUDAStream& stream) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data(), stream
output.storage().data(), stream, true
);
return ncclReduceScatter(
input.data_ptr(),
Expand All @@ -591,7 +591,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
for (size_t j = 0; j < inputTensors[0].size(); ++j) {
// See [Sync Streams].
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors[i][j].storage().data(), ncclStreams[i]);
inputTensors[i][j].storage().data(), ncclStreams[i], true);

inputFlattened[i][j].copy_(inputTensors[i][j], true);
}
Expand Down