mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 18:04:27 +08:00
[OOT] Support sync_model_loading for OOT (#25126)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
parent
6c8a3c099b
commit
a6149aa587
@ -12,7 +12,6 @@ from torch.nn import Parameter
|
|||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.utils import _make_synced_weight_loader
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
||||||
@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter):
|
|||||||
# This sometimes causes OOM errors during model loading. To avoid this,
|
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||||
# we sync the param tensor after its weight loader is called.
|
# we sync the param tensor after its weight loader is called.
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
if current_platform.is_tpu():
|
if current_platform.use_sync_weight_loader():
|
||||||
weight_loader = _make_synced_weight_loader(weight_loader)
|
weight_loader = current_platform.make_synced_weight_loader(
|
||||||
|
weight_loader)
|
||||||
|
|
||||||
self._weight_loader = weight_loader
|
self._weight_loader = weight_loader
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|||||||
@ -44,23 +44,12 @@ def set_weight_attrs(
|
|||||||
# TODO(woosuk): Remove this hack once we have a better solution.
|
# TODO(woosuk): Remove this hack once we have a better solution.
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_tpu() and key == "weight_loader":
|
if current_platform.use_sync_weight_loader(
|
||||||
value = _make_synced_weight_loader(value)
|
) and key == "weight_loader":
|
||||||
|
value = current_platform.make_synced_weight_loader(value)
|
||||||
setattr(weight, key, value)
|
setattr(weight, key, value)
|
||||||
|
|
||||||
|
|
||||||
def _make_synced_weight_loader(original_weight_loader):
|
|
||||||
|
|
||||||
def _synced_weight_loader(param, *args, **kwargs):
|
|
||||||
out = original_weight_loader(param, *args, **kwargs)
|
|
||||||
# torch._sync doesn't support, is not needed for CPU tensors.
|
|
||||||
if param.device != torch.device("cpu"):
|
|
||||||
torch._sync(param)
|
|
||||||
return out
|
|
||||||
|
|
||||||
return _synced_weight_loader
|
|
||||||
|
|
||||||
|
|
||||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||||
parent_map = getattr(model, "packed_modules_mapping", None)
|
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||||
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||||
|
|||||||
@ -594,6 +594,29 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def use_sync_weight_loader(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Returns if the current platform needs to sync weight loader.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_synced_weight_loader(cls, original_weight_loader):
|
||||||
|
"""
|
||||||
|
Wrap the original weight loader to make it synced.
|
||||||
|
"""
|
||||||
|
if not cls.use_sync_weight_loader():
|
||||||
|
return original_weight_loader
|
||||||
|
|
||||||
|
def _synced_weight_loader(param, *args, **kwargs):
|
||||||
|
out = original_weight_loader(param, *args, **kwargs)
|
||||||
|
if param.device != torch.device("cpu"):
|
||||||
|
torch._sync(param)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return _synced_weight_loader
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -226,6 +226,10 @@ class TpuPlatform(Platform):
|
|||||||
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
||||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def use_sync_weight_loader(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user