[XPU] Fix the bug of LoRA logits on the XPU platform (#24081)

Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
Chaojun Zhang 2025-09-03 08:21:18 +08:00 committed by GitHub
parent 2fd1a40a54
commit 862f2ef893
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 5 deletions

View File

@ -1151,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded
if current_platform.is_tpu():
if current_platform.is_tpu() or current_platform.is_xpu():
indices_padded = indices_padded[:logits.size(0)]
lora_logits = (lora_logits.reshape(

View File

@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase):
add_inputs=True,
**kwargs)
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase):
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
sampler_indices,
add_inputs=True)
return y.view_as(y_org)

View File

@ -91,7 +91,7 @@ class XPUPlatform(Platform):
cache_config.block_size = 64
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
from vllm.config import CompilationLevel, CUDAGraphMode
compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
@ -100,6 +100,9 @@ class XPUPlatform(Platform):
"cudagraphs. Fallback to cudagraph_mode=NONE")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
# check and update parallel config
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"