From a3a5a47e48d3c6610686a489af2bd987062e74df Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Sat, 12 Jul 2025 02:06:04 -0400 Subject: [PATCH] [Bugfix] Fix torch.compile x LoRA for PyTorch 2.8 (#20823) Signed-off-by: rzou --- vllm/lora/layers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3d0c583175021..39b45027bd54d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -240,17 +240,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) - embeddings_indices = torch.narrow( - self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) - indices = embeddings_indices[1] + # NB: Don't use torch.narrow here. torch.narrow triggers some + # Dynamic Shape specialization in torch.compile + num_tokens = x.shape[0] + indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] + indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] + full_lora_a_embeddings = F.embedding( - x + indices, + x + indices_1, self.lora_a_stacked_2d, ) - indices = embeddings_indices[0] full_output = self.base_layer.forward(x + - (indices * added_tokens_mask)) + (indices_0 * added_tokens_mask)) full_output_org = full_output if full_output.ndim == 3: