mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[XPU] Fix the bug of LoRA logits on the XPU platform (#24081)
Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
parent
2fd1a40a54
commit
862f2ef893
@ -1151,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
lora_logits = lora_logits.mT
|
||||
indices_padded = self.punica_wrapper.sampler_indices_padded
|
||||
|
||||
if current_platform.is_tpu():
|
||||
if current_platform.is_tpu() or current_platform.is_xpu():
|
||||
indices_padded = indices_padded[:logits.size(0)]
|
||||
|
||||
lora_logits = (lora_logits.reshape(
|
||||
|
||||
@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
return self._sampler_indices_padded[:]
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
|
||||
bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
sampler_indices,
|
||||
add_inputs=True)
|
||||
return y.view_as(y_org)
|
||||
|
||||
@ -91,7 +91,7 @@ class XPUPlatform(Platform):
|
||||
cache_config.block_size = 64
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.cudagraph_mode is None or \
|
||||
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
||||
@ -100,6 +100,9 @@ class XPUPlatform(Platform):
|
||||
"cudagraphs. Fallback to cudagraph_mode=NONE")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user