Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomDevice] add recompute support #53044

Merged
merged 2 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultXPUGenerator(p.GetDeviceId()).get());
#endif
} else if (p.GetType() == phi::AllocationType::CUSTOM) {
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get());
} else {
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
dev_ctx->SetGenerator(phi::DefaultCPUGenerator().get());
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/generator_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void BindGenerator(py::module* m_ptr) {
m.def("default_cpu_generator", &phi::DefaultCPUGenerator);
m.def("default_cuda_generator", &phi::DefaultCUDAGenerator);
m.def("default_xpu_generator", &phi::DefaultXPUGenerator);
m.def("default_custom_device_generator", &phi::DefaultCustomDeviceGenerator);
m.def("set_random_seed_generator", &phi::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &phi::GetRandomSeedGenerator);
}
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/core/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator;
}

const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place) {
static std::
unordered_map<phi::Place, std::shared_ptr<Generator>, phi::Place::Hash>
generators;
if (generators.find(place) == generators.end()) {
generators.insert({place, std::make_shared<Generator>(GetRandomSeed())});
}
return generators[place];
}

using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;

static RNGMap& GetRandomSeedGeneratorMap() {
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/core/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo>
#include <utility>

#include "paddle/phi/common/place.h"

namespace phi {

class Generator {
Expand Down Expand Up @@ -80,6 +82,9 @@ const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id = -1);

const std::shared_ptr<Generator>& DefaultXPUGenerator(int64_t device_id = -1);

const std::shared_ptr<Generator>& DefaultCustomDeviceGenerator(
const phi::CustomPlace& place);

std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def _dygraph_clip(self, params_grads):
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

# bf16 is not supported on XPU now
if not paddle.is_compiled_with_xpu():
if not (
paddle.is_compiled_with_xpu()
or isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
)
):
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
for p, g in params_grads:
if g is None:
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/distributed/fleet/recompute/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,19 @@ def _recompute_without_reentrant(

if preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
if 'gpu:' in cur_device:
fw_cuda_rng_state = paddle.get_cuda_rng_state()
elif (
cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
):
fw_cuda_rng_state = paddle.get_rng_state(cur_device)
else:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".format(
cur_device
)
)
fw_cuda_rng_state = paddle.get_cuda_rng_state()
fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
Expand Down
15 changes: 10 additions & 5 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,19 @@ def _broadcast_object_list_help(object_list, hcg):
def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device()
dev = cur_device.split(":")[0]
assert dev in [
"xpu",
"gpu",
"npu",
], f"Only support xpu, gpu and npu now, but this is {dev}"
assert (
dev
in [
"xpu",
"gpu",
]
or dev in paddle.device.get_all_custom_device_type()
), f"Only support xpu, gpu and custom_device now, but this is {dev}"
dev_idx = int(cur_device.split(':')[1])
if dev == "gpu":
place = paddle.CUDAPlace(dev_idx)
elif dev in paddle.device.get_all_custom_device_type():
place = paddle.CustomPlace(dev, dev_idx)
else:
place = eval(f"paddle.{dev.upper()}Place")(dev_idx)

Expand Down
44 changes: 42 additions & 2 deletions python/paddle/framework/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# TODO: define random api
import paddle
from paddle import fluid
from paddle.fluid import core

Expand Down Expand Up @@ -48,7 +49,18 @@ def seed(seed):
elif core.is_compiled_with_xpu():
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).manual_seed(seed)

place = fluid.framework._current_expected_place()
if isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).manual_seed(seed)
return core.default_cpu_generator().manual_seed(seed)


Expand All @@ -70,7 +82,7 @@ def get_rng_state(device=None):
if device is None:
place = fluid.framework._current_expected_place()
else:
place = device._convert_to_place(device)
place = paddle.device._convert_to_place(device)

if isinstance(place, core.CPUPlace):
state_list.append(core.default_cpu_generator().get_state())
Expand All @@ -80,6 +92,19 @@ def get_rng_state(device=None):
elif isinstance(place, core.XPUPlace):
for i in range(core.get_xpu_device_count()):
state_list.append(core.default_xpu_generator(i).get_state())
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
for i in range(dev_cnt):
state_list.append(
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).get_state()
)
else:
raise ValueError(
"get_rng_state is not implemented for current device: {}".format(
Expand Down Expand Up @@ -157,6 +182,21 @@ def set_rng_state(state_list, device=None):
)
for i in range(core.get_xpu_device_count()):
core.default_xpu_generator(i).set_state(state_list[i])
elif isinstance(place, core.CustomPlace):
dev_cnt = sum(
[
place.get_device_type() == s.split(':')[0]
for s in core.get_available_custom_device()
]
)
if not len(state_list) == dev_cnt:
raise ValueError(
f"Length of custom device state list shoule be equal to the {place.get_dtype_type()} device count"
)
for i in range(dev_cnt):
core.default_custom_device_generator(
core.CustomPlace(place.get_device_type(), i)
).set_state(state_list[i])
elif isinstance(place, core.CPUPlace):
if not len(state_list) == 1:
raise ValueError("Length of cpu state list shoule be equal to 1")
Expand Down