Skip to content

Commit

Permalink
[CustomDevice] add recompute support (#53044)
Browse files Browse the repository at this point in the history
* [CustomDevice] add recompute support

* update
  • Loading branch information
ronny1996 authored Apr 19, 2023
1 parent 7e19d16 commit 3206fa8
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 10 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultXPUGenerator(p.GetDeviceId()).get());
#endif
} else if (p.GetType() == phi::AllocationType::CUSTOM) {
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get());
} else {
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCPUGenerator().get());
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/generator_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void BindGenerator(py::module* m_ptr) {
m.def("default_cpu_generator", &phi::DefaultCPUGenerator);
m.def("default_cuda_generator", &phi::DefaultCUDAGenerator);
m.def("default_xpu_generator", &phi::DefaultXPUGenerator);
m.def("default_custom_device_generator", &phi::DefaultCustomDeviceGenerator);
m.def("set_random_seed_generator", &phi::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &phi::GetRandomSeedGenerator);
}
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/core/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator;
}

const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place) {
static std::
unordered_map<phi::Place, std::shared_ptr<Generator>, phi::Place::Hash>
generators;
if (generators.find(place) == generators.end()) {
generators.insert({place, std::make_shared<Generator>(GetRandomSeed())});
}
return generators[place];
}

using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;

static RNGMap& GetRandomSeedGeneratorMap() {
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/core/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo>
#include <utility>

#include "paddle/phi/common/place.h"

namespace phi {

class Generator {
Expand Down Expand Up @@ -80,6 +82,9 @@ const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id = -1);

const std::shared_ptr<Generator>& DefaultXPUGenerator(int64_t device_id = -1);

const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place);

std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def _dygraph_clip(self, params_grads):
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

# bf16 is not supported on XPU now
if not paddle.is_compiled_with_xpu():
if not (
paddle.is_compiled_with_xpu()
or isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
)
):
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
for p, g in params_grads:
if g is None:
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/distributed/fleet/recompute/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,19 @@ def _recompute_without_reentrant(

if preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
if 'gpu:' in cur_device:
fw_cuda_rng_state = paddle.get_cuda_rng_state()
elif (
cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
):
fw_cuda_rng_state = paddle.get_rng_state(cur_device)
else:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".format(
cur_device
)
)
fw_cuda_rng_state = paddle.get_cuda_rng_state()
fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
Expand Down
15 changes: 10 additions & 5 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,19 @@ def _broadcast_object_list_help(object_list, hcg):
def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device()
dev = cur_device.split(":")[0]
assert dev in [
"xpu",
"gpu",
"npu",
], f"Only support xpu, gpu and npu now, but this is {dev}"
assert (
dev
in [
"xpu",
"gpu",
]
or dev in paddle.device.get_all_custom_device_type()
), f"Only support xpu, gpu and custom_device now, but this is {dev}"
dev_idx = int(cur_device.split(':')[1])
if dev == "gpu":
place = paddle.CUDAPlace(dev_idx)
elif dev in paddle.device.get_all_custom_device_type():
place = paddle.CustomPlace(dev, dev_idx)
else:
place = eval(f"paddle.{dev.upper()}Place")(dev_idx)

Expand Down
44 changes: 42 additions & 2 deletions python/paddle/framework/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# TODO: define random api
import paddle
from paddle import fluid
from paddle.fluid import core

Expand Down Expand Up @@ -48,7 +49,18 @@ def seed(seed):
elif core.is_compiled_with_xpu():
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).manual_seed(seed)

place = fluid.framework._current_expected_place()
if isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).manual_seed(seed)
return core.default_cpu_generator().manual_seed(seed)


Expand All @@ -70,7 +82,7 @@ def get_rng_state(device=None):
if device is None:
place = fluid.framework._current_expected_place()
else:
place = device._convert_to_place(device)
place = paddle.device._convert_to_place(device)

if isinstance(place, core.CPUPlace):
state_list.append(core.default_cpu_generator().get_state())
Expand All @@ -80,6 +92,19 @@ def get_rng_state(device=None):
elif isinstance(place, core.XPUPlace):
for i in range(core.get_xpu_device_count()):
state_list.append(core.default_xpu_generator(i).get_state())
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
state_list.append(
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).get_state()
)
else:
raise ValueError(
"get_rng_state is not implemented for current device: {}".format(
Expand Down Expand Up @@ -157,6 +182,21 @@ def set_rng_state(state_list, device=None):
)
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).set_state(state_list[i])
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
if not len(state_list) == dev_cnt:
raise ValueError(
f"Length of custom device state list shoule be equal to the {place.get_dtype_type()} device count"
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).set_state(state_list[i])
elif isinstance(place, core.CPUPlace):
if not len(state_list) == 1:
raise ValueError("Length of cpu state list shoule be equal to 1")
Expand Down

0 comments on commit 3206fa8

Please sign in to comment.