mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 20:55:01 +08:00
[Bugfix] Fix torch.compile x LoRA for PyTorch 2.8 (#20823)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
parent
fb25e95688
commit
a3a5a47e48
@ -240,17 +240,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
|
||||
1, 0)
|
||||
embeddings_indices = torch.narrow(
|
||||
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))
|
||||
|
||||
indices = embeddings_indices[1]
|
||||
# NB: Don't use torch.narrow here. torch.narrow triggers some
|
||||
# Dynamic Shape specialization in torch.compile
|
||||
num_tokens = x.shape[0]
|
||||
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
|
||||
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
|
||||
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices,
|
||||
x + indices_1,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
indices = embeddings_indices[0]
|
||||
full_output = self.base_layer.forward(x +
|
||||
(indices * added_tokens_mask))
|
||||
(indices_0 * added_tokens_mask))
|
||||
|
||||
full_output_org = full_output
|
||||
if full_output.ndim == 3:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user