mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 21:44:39 +08:00
[OOT] Support sync_model_loading for OOT (#25126)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
parent
6c8a3c099b
commit
a6149aa587
@ -12,7 +12,6 @@ from torch.nn import Parameter
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import _make_synced_weight_loader
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
||||
@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter):
|
||||
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||
# we sync the param tensor after its weight loader is called.
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_tpu():
|
||||
weight_loader = _make_synced_weight_loader(weight_loader)
|
||||
if current_platform.use_sync_weight_loader():
|
||||
weight_loader = current_platform.make_synced_weight_loader(
|
||||
weight_loader)
|
||||
|
||||
self._weight_loader = weight_loader
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
@ -44,23 +44,12 @@ def set_weight_attrs(
|
||||
# TODO(woosuk): Remove this hack once we have a better solution.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu() and key == "weight_loader":
|
||||
value = _make_synced_weight_loader(value)
|
||||
if current_platform.use_sync_weight_loader(
|
||||
) and key == "weight_loader":
|
||||
value = current_platform.make_synced_weight_loader(value)
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def _make_synced_weight_loader(original_weight_loader):
|
||||
|
||||
def _synced_weight_loader(param, *args, **kwargs):
|
||||
out = original_weight_loader(param, *args, **kwargs)
|
||||
# torch._sync doesn't support, is not needed for CPU tensors.
|
||||
if param.device != torch.device("cpu"):
|
||||
torch._sync(param)
|
||||
return out
|
||||
|
||||
return _synced_weight_loader
|
||||
|
||||
|
||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||
|
||||
@ -594,6 +594,29 @@ class Platform:
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
"""
|
||||
Returns if the current platform needs to sync weight loader.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def make_synced_weight_loader(cls, original_weight_loader):
|
||||
"""
|
||||
Wrap the original weight loader to make it synced.
|
||||
"""
|
||||
if not cls.use_sync_weight_loader():
|
||||
return original_weight_loader
|
||||
|
||||
def _synced_weight_loader(param, *args, **kwargs):
|
||||
out = original_weight_loader(param, *args, **kwargs)
|
||||
if param.device != torch.device("cpu"):
|
||||
torch._sync(param)
|
||||
return out
|
||||
|
||||
return _synced_weight_loader
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@ -226,6 +226,10 @@ class TpuPlatform(Platform):
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
||||
|
||||
@classmethod
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
try:
|
||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user