[OOT] Support sync_model_loading for OOT (#25126)

Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
Chendi.Xue 2025-09-19 00:41:53 -05:00 committed by GitHub
parent 6c8a3c099b
commit a6149aa587
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 17 deletions

View File

@ -12,7 +12,6 @@ from torch.nn import Parameter
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter):
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
from vllm.platforms import current_platform
if current_platform.is_tpu():
weight_loader = _make_synced_weight_loader(weight_loader)
if current_platform.use_sync_weight_loader():
weight_loader = current_platform.make_synced_weight_loader(
weight_loader)
self._weight_loader = weight_loader
self.tp_rank = get_tensor_model_parallel_rank()

View File

@ -44,23 +44,12 @@ def set_weight_attrs(
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
if current_platform.use_sync_weight_loader(
) and key == "weight_loader":
value = current_platform.make_synced_weight_loader(value)
setattr(weight, key, value)
def _make_synced_weight_loader(original_weight_loader):
def _synced_weight_loader(param, *args, **kwargs):
out = original_weight_loader(param, *args, **kwargs)
# torch._sync doesn't support, is not needed for CPU tensors.
if param.device != torch.device("cpu"):
torch._sync(param)
return out
return _synced_weight_loader
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}

View File

@ -594,6 +594,29 @@ class Platform:
"""
return False
@classmethod
def use_sync_weight_loader(cls) -> bool:
"""
Returns if the current platform needs to sync weight loader.
"""
return False
@classmethod
def make_synced_weight_loader(cls, original_weight_loader):
"""
Wrap the original weight loader to make it synced.
"""
if not cls.use_sync_weight_loader():
return original_weight_loader
def _synced_weight_loader(param, *args, **kwargs):
out = original_weight_loader(param, *args, **kwargs)
if param.device != torch.device("cpu"):
torch._sync(param)
return out
return _synced_weight_loader
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -226,6 +226,10 @@ class TpuPlatform(Platform):
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
@classmethod
def use_sync_weight_loader(cls) -> bool:
return True
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform