diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index d8503b20459f..6e4b69c30325 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1151,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - if current_platform.is_tpu(): + if current_platform.is_tpu() or current_platform.is_xpu(): indices_padded = indices_padded[:logits.size(0)] lora_logits = (lora_logits.reshape( diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 572e39e0eced..163bb412235c 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase): add_inputs=True, **kwargs) + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, @@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase): buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale) bgmv_expand(buffer, lora_b_stacked, y, - self.sampler_indices, + sampler_indices, add_inputs=True) return y.view_as(y_org) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d61b921e19cf..645a9e63a4e5 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -91,7 +91,7 @@ class XPUPlatform(Platform): cache_config.block_size = 64 # lazy import to avoid circular import - from vllm.config import CUDAGraphMode + from vllm.config import CompilationLevel, CUDAGraphMode compilation_config = vllm_config.compilation_config if compilation_config.cudagraph_mode is None or \ compilation_config.cudagraph_mode.max_cudagraph_mode() \ @@ -100,6 +100,9 @@ class XPUPlatform(Platform): "cudagraphs. Fallback to cudagraph_mode=NONE") compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if vllm_config.lora_config is not None: + compilation_config.level = CompilationLevel.NO_COMPILATION + # check and update parallel config parallel_config = vllm_config.parallel_config parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"