mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:57:45 +08:00
[XPU] Fix the bug of LoRA logits on the XPU platform (#24081)
Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
parent
2fd1a40a54
commit
862f2ef893
@ -1151,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_logits = lora_logits.mT
|
lora_logits = lora_logits.mT
|
||||||
indices_padded = self.punica_wrapper.sampler_indices_padded
|
indices_padded = self.punica_wrapper.sampler_indices_padded
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu() or current_platform.is_xpu():
|
||||||
indices_padded = indices_padded[:logits.size(0)]
|
indices_padded = indices_padded[:logits.size(0)]
|
||||||
|
|
||||||
lora_logits = (lora_logits.reshape(
|
lora_logits = (lora_logits.reshape(
|
||||||
|
|||||||
@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
|||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampler_indices_padded(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This property provides access to padded sampler indices.
|
||||||
|
"""
|
||||||
|
return self._sampler_indices_padded[:]
|
||||||
|
|
||||||
def add_lora_logits(self,
|
def add_lora_logits(self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
|||||||
buffer = torch.zeros((x.size(0), r),
|
buffer = torch.zeros((x.size(0), r),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
|
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
|
||||||
bgmv_expand(buffer,
|
bgmv_expand(buffer,
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
y,
|
y,
|
||||||
self.sampler_indices,
|
sampler_indices,
|
||||||
add_inputs=True)
|
add_inputs=True)
|
||||||
return y.view_as(y_org)
|
return y.view_as(y_org)
|
||||||
|
|||||||
@ -91,7 +91,7 @@ class XPUPlatform(Platform):
|
|||||||
cache_config.block_size = 64
|
cache_config.block_size = 64
|
||||||
|
|
||||||
# lazy import to avoid circular import
|
# lazy import to avoid circular import
|
||||||
from vllm.config import CUDAGraphMode
|
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if compilation_config.cudagraph_mode is None or \
|
if compilation_config.cudagraph_mode is None or \
|
||||||
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
||||||
@ -100,6 +100,9 @@ class XPUPlatform(Platform):
|
|||||||
"cudagraphs. Fallback to cudagraph_mode=NONE")
|
"cudagraphs. Fallback to cudagraph_mode=NONE")
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
|
if vllm_config.lora_config is not None:
|
||||||
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
# check and update parallel config
|
# check and update parallel config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user