mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 03:27:04 +08:00
[xpu]feat: support multi-lora on xpu (#20616)
Signed-off-by: yan <yan.ma@intel.com>
This commit is contained in:
parent
b942c094e3
commit
a4c23314c0
@ -13,6 +13,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@ -283,6 +284,7 @@ try:
|
|||||||
op_func=_lora_expand,
|
op_func=_lora_expand,
|
||||||
mutates_args=["output_tensor"],
|
mutates_args=["output_tensor"],
|
||||||
fake_impl=_lora_expand_fake,
|
fake_impl=_lora_expand_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
lora_expand = torch.ops.vllm.lora_expand
|
lora_expand = torch.ops.vllm.lora_expand
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@ -237,6 +238,7 @@ try:
|
|||||||
op_func=_lora_shrink,
|
op_func=_lora_shrink,
|
||||||
mutates_args=["output_tensor"],
|
mutates_args=["output_tensor"],
|
||||||
fake_impl=_lora_shrink_fake,
|
fake_impl=_lora_shrink_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
lora_shrink = torch.ops.vllm.lora_shrink
|
lora_shrink = torch.ops.vllm.lora_shrink
|
||||||
|
|
||||||
|
|||||||
@ -35,7 +35,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
|
|||||||
lora_strides_d1.append(lora_a_weight.stride(1))
|
lora_strides_d1.append(lora_a_weight.stride(1))
|
||||||
lora_strides_d2.append(lora_a_weight.stride(2))
|
lora_strides_d2.append(lora_a_weight.stride(2))
|
||||||
if len(lora_a_weights) > 1:
|
if len(lora_a_weights) > 1:
|
||||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
lora_ptr_tensor = torch.tensor(tensor_ptrs,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.uint64)
|
||||||
else:
|
else:
|
||||||
lora_ptr_tensor = lora_a_weights[0]
|
lora_ptr_tensor = lora_a_weights[0]
|
||||||
|
|
||||||
@ -89,8 +91,12 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
|
|||||||
|
|
||||||
if len(lora_weights) > 1:
|
if len(lora_weights) > 1:
|
||||||
# note these are device tensors
|
# note these are device tensors
|
||||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
lora_ptr_tensor = torch.tensor(tensor_ptrs,
|
||||||
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
|
device=device,
|
||||||
|
dtype=torch.uint64)
|
||||||
|
slice_start_tensor = torch.tensor(slice_offset_lst,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.uint64)
|
||||||
else:
|
else:
|
||||||
slice_start_tensor = slice_offset_lst[0]
|
slice_start_tensor = slice_offset_lst[0]
|
||||||
lora_ptr_tensor = lora_b_weight[0]
|
lora_ptr_tensor = lora_b_weight[0]
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -513,7 +514,9 @@ def deserialize_tensorizer_model(model: nn.Module,
|
|||||||
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
||||||
stream,
|
stream,
|
||||||
dtype=tensorizer_config.dtype,
|
dtype=tensorizer_config.dtype,
|
||||||
device=torch.device("cuda", torch.cuda.current_device()),
|
device=f'xpu:{torch.xpu.current_device()}'
|
||||||
|
if current_platform.is_xpu() else
|
||||||
|
f'cuda:{torch.cuda.current_device()}',
|
||||||
**tensorizer_args.deserialization_kwargs) as deserializer:
|
**tensorizer_args.deserialization_kwargs) as deserializer:
|
||||||
deserializer.load_into_module(model)
|
deserializer.load_into_module(model)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|||||||
@ -58,6 +58,10 @@ class XPUPlatform(Platform):
|
|||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
return torch.xpu.get_device_name(device_id)
|
return torch.xpu.get_device_name(device_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
device_props = torch.xpu.get_device_properties(device_id)
|
device_props = torch.xpu.get_device_properties(device_id)
|
||||||
@ -78,6 +82,13 @@ class XPUPlatform(Platform):
|
|||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 64
|
cache_config.block_size = 64
|
||||||
|
|
||||||
|
# FIXME: Temporarily forcing eager mode
|
||||||
|
# remove after t.compile support stabilizes.
|
||||||
|
if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
|
||||||
|
and not vllm_config.model_config.enforce_eager):
|
||||||
|
from vllm.config import CompilationLevel
|
||||||
|
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
|
||||||
|
|
||||||
# Instances created using VllmConfig() typically have model_config as
|
# Instances created using VllmConfig() typically have model_config as
|
||||||
# None by default. The modification involves adding a check to prevent
|
# None by default. The modification involves adding a check to prevent
|
||||||
# potential null exceptions check and update model config.
|
# potential null exceptions check and update model config.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user