From a6149aa587d6582545b7878a2dffed3a2419605d Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Fri, 19 Sep 2025 00:41:53 -0500 Subject: [PATCH] [OOT] Support sync_model_loading for OOT (#25126) Signed-off-by: Chendi Xue --- vllm/model_executor/parameter.py | 6 +++--- vllm/model_executor/utils.py | 17 +++-------------- vllm/platforms/interface.py | 23 +++++++++++++++++++++++ vllm/platforms/tpu.py | 4 ++++ 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 221712ba9a338..03e5e5809b678 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -12,7 +12,6 @@ from torch.nn import Parameter from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger -from vllm.model_executor.utils import _make_synced_weight_loader __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", @@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter): # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. from vllm.platforms import current_platform - if current_platform.is_tpu(): - weight_loader = _make_synced_weight_loader(weight_loader) + if current_platform.use_sync_weight_loader(): + weight_loader = current_platform.make_synced_weight_loader( + weight_loader) self._weight_loader = weight_loader self.tp_rank = get_tensor_model_parallel_rank() diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 65436786f82ac..543918418953b 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -44,23 +44,12 @@ def set_weight_attrs( # TODO(woosuk): Remove this hack once we have a better solution. from vllm.platforms import current_platform - if current_platform.is_tpu() and key == "weight_loader": - value = _make_synced_weight_loader(value) + if current_platform.use_sync_weight_loader( + ) and key == "weight_loader": + value = current_platform.make_synced_weight_loader(value) setattr(weight, key, value) -def _make_synced_weight_loader(original_weight_loader): - - def _synced_weight_loader(param, *args, **kwargs): - out = original_weight_loader(param, *args, **kwargs) - # torch._sync doesn't support, is not needed for CPU tensors. - if param.device != torch.device("cpu"): - torch._sync(param) - return out - - return _synced_weight_loader - - def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: parent_map = getattr(model, "packed_modules_mapping", None) parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 054d08c3a85be..53fc762dce540 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -594,6 +594,29 @@ class Platform: """ return False + @classmethod + def use_sync_weight_loader(cls) -> bool: + """ + Returns if the current platform needs to sync weight loader. + """ + return False + + @classmethod + def make_synced_weight_loader(cls, original_weight_loader): + """ + Wrap the original weight loader to make it synced. + """ + if not cls.use_sync_weight_loader(): + return original_weight_loader + + def _synced_weight_loader(param, *args, **kwargs): + out = original_weight_loader(param, *args, **kwargs) + if param.device != torch.device("cpu"): + torch._sync(param) + return out + + return _synced_weight_loader + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 6a061956d8141..4e4db116abca0 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -226,6 +226,10 @@ class TpuPlatform(Platform): torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() + @classmethod + def use_sync_weight_loader(cls) -> bool: + return True + try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform