From bb7c74146d9ef44fb73121f2fc00befe8f40450e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 16 Oct 2024 23:57:04 +0000 Subject: [PATCH 1/3] [TPU] Ensure torch._sync(param) is called after param.data.copy_() --- vllm/model_executor/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index d7eec818cbba..0e66e29f5d37 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -4,6 +4,7 @@ import torch from vllm.utils import seed_everything +from vllm.platforms import current_platform def set_random_seed(seed: int) -> None: @@ -28,4 +29,19 @@ def set_weight_attrs( for key, value in weight_attrs.items(): assert not hasattr( weight, key), (f"Overwriting existing tensor attribute: {key}") + + # NOTE(woosuk): For TPU, param.data.copy_(weight) happens lazily, + # which means that the param and weight tensors co-exist until the param + # tensor is used by other operations. This causes excessive memory usage + # during model loading. To avoid this, we sync the param tensor after + # its weight loader is called. + # TODO(woosuk): Remove this hack once we have a better solution. + if current_platform.is_tpu() and key == "weight_loader": + original_weight_loader = value + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + value = _synced_weight_loader setattr(weight, key, value) From cf842bdb033a341bc45ac5bdc6d34a00438a7c8a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 17 Oct 2024 00:02:57 +0000 Subject: [PATCH 2/3] yapf --- vllm/model_executor/utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 0e66e29f5d37..eaa89be11038 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -3,8 +3,8 @@ import torch -from vllm.utils import seed_everything from vllm.platforms import current_platform +from vllm.utils import seed_everything def set_random_seed(seed: int) -> None: @@ -37,11 +37,14 @@ def set_weight_attrs( # its weight loader is called. # TODO(woosuk): Remove this hack once we have a better solution. if current_platform.is_tpu() and key == "weight_loader": - original_weight_loader = value + value = _make_synced_weight_loader(value) + setattr(weight, key, value) - def _synced_weight_loader(param, *args, **kwargs): - original_weight_loader(param, *args, **kwargs) - torch._sync(param) - value = _synced_weight_loader - setattr(weight, key, value) +def _make_synced_weight_loader(original_weight_loader): + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + return _synced_weight_loader From f5d8d91b142ca7038249399262f5828f779239d4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 17 Oct 2024 03:52:16 +0000 Subject: [PATCH 3/3] Update comment --- vllm/model_executor/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index eaa89be11038..c27b1cf6ac7b 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -30,11 +30,14 @@ def set_weight_attrs( assert not hasattr( weight, key), (f"Overwriting existing tensor attribute: {key}") - # NOTE(woosuk): For TPU, param.data.copy_(weight) happens lazily, - # which means that the param and weight tensors co-exist until the param - # tensor is used by other operations. This causes excessive memory usage - # during model loading. To avoid this, we sync the param tensor after - # its weight loader is called. + # NOTE(woosuk): During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. # TODO(woosuk): Remove this hack once we have a better solution. if current_platform.is_tpu() and key == "weight_loader": value = _make_synced_weight_loader(value)