-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Distributed] Make xm.all_gather a single graph in Dynamo #4922
Conversation
torch_xla/core/xla_model.py
Outdated
@@ -78,7 +78,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None): | |||
if kind_devices: | |||
return kind_devices[:max_devices] if max_devices else kind_devices | |||
|
|||
|
|||
g_xrt_world_size = None | |||
def xrt_world_size(defval=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab This is the python function that I want to use in 'allow_in_graph'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, if you are going to manually cache the value of this anyway, then i think just using allow_in_graph without the caching is the same thing.
the issue with allow_in_graph is if you expect the value to be updated on later iterations, allow_in_graph will prevent that from working. But if you expect the value to be a constant for the whole execution, then allow_in_graph will capture the value during compile and reuse it later (e.g. cache it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to use allow_in_graph. However, it looks like that the function I pass into allow_in_graph will need to return a tensor type? If the function return a bool or int, is there a workaround?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is how I use allow_in_graph:
ptxla@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla$ git diff
diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py
index 6ff4a5a5..a07ff472 100755
--- a/torch_xla/core/xla_model.py
+++ b/torch_xla/core/xla_model.py
@@ -6,6 +6,7 @@ import time
from typing import List, Optional
import torch
import torch.distributed._functional_collectives
+from torch._dynamo import allow_in_graph
import torch.nn.functional as F
import torch_xla
from torch_xla.experimental import pjrt
@@ -1088,3 +1089,6 @@ def optimization_barrier_(tensors):
tensors (List[torch.Tensor]): List of `torch.Tensor` to add barrier to.
"""
torch_xla._XLAC._xla_optimization_barrier_(tensors)
+
+
+allow_in_graph(xrt_world_size)
And here is the error:
root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_mp_all_gather.py
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/local/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/workspaces/work/pytorch/xla/test/test_mp_all_gather.py", line 32, in _mp_fn
result = compiled_all_gather(ordinal_tensor, dim=0)
File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 405, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert
return _compile(
File "/workspaces/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
r = func(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile
out_code = transform_code_object(code, transform)
File "/workspaces/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 386, in transform
tracer.run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1972, in run
super().run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
and self.step()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
getattr(self, inst.opname)(inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
return inner_fn(self, inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function
return super().call_function(tx, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function
return tx.inline_user_function_return(
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_
tracer.run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
and self.step()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
getattr(self, inst.opname)(inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
return inner_fn(self, inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function
return super().call_function(tx, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function
return tx.inline_user_function_return(
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_
tracer.run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
and self.step()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
getattr(self, inst.opname)(inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
return inner_fn(self, inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1086, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/workspaces/work/pytorch/torch/_dynamo/variables/torch.py", line 603, in call_function
tensor_variable = wrap_fx_proxy(
File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 1098, in wrap_fx_proxy_cls
unimplemented(
File "/workspaces/work/pytorch/torch/_dynamo/exc.py", line 107, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0>
from user code:
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather
return _all_gather_using_all_reduce(
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce
left, right = ordinal, xrt_world_size() - 1 - ordinal
Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "test/test_mp_all_gather.py", line 66, in <module>
xmp.spawn(_mp_fn, args=())
File "/workspaces/work/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 367, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0>
from user code:
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather
return _all_gather_using_all_reduce(
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce
left, right = ordinal, xrt_world_size() - 1 - ordinal
Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla#
torch_xla/core/xla_model.py
Outdated
|
||
g_ordinal = None | ||
def get_ordinal(defval=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab This is the python function that I want to use in 'allow_in_graph'.
7673dd6
to
89f7c6c
Compare
8681385
to
674e53c
Compare
torch_xla/core/xla_model.py
Outdated
@@ -109,10 +114,15 @@ def get_ordinal(defval=0): | |||
Returns: | |||
The replication ordinal of the current thread. | |||
""" | |||
if pjrt.using_pjrt(): | |||
return pjrt.global_ordinal() | |||
global g_ordinal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will break PJRT + v3 cases, the implementation we had checks the devices in
m.def("_xla_get_default_device_ordinal", []() {
std::string device_str = GetCurrentThreadDevice();
torch::lazy::BackendDevice device =
bridge::AtenDeviceToXlaDevice(device_str);
return device.ordinal();
});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this is confusing. That call is in the C++ layer. Then allow_in_graph won't work here.
But we can work around by caching a map...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure actually, effectively this function won't return constant in the v3 cases because there are two devices per process. This is a bit tricky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can bypass the v3 cases for now, what's going to happen if you add a condition here to skip this cahce value of we are on v3 + PJRT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will introduce graph breaks in Dynamo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this use thread local storage instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool. Was not aware python has this feature. Let me work on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamo doesn't seem to compile in the same thread as the user code. threading.local doesn't work here.
@@ -533,8 +553,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): | |||
A tensor which has, in the ``dim`` dimension, all the values from the | |||
participating replicas. | |||
""" | |||
if pin_layout and xla_device_hw( | |||
value.device) in ('TPU', 'GPU', 'XPU') and output == None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we had it because CPU was not supported at some point. Do you need to remove it because it will break dynamo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
3fb46da
to
26cfb00
Compare
Thanks Jack for approving. |
Summary:
This pull request makes xm.all_gather, the _all_gather_using_all_reduce path, a single graph in Dynamo. To do that, it:
Test Plan:
PJRT_DEVICE=TPU python test/test_mp_all_gather.py