[Bugfix] Fix torch.compile x LoRA for PyTorch 2.8 (#20823)

Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
Richard Zou 2025-07-12 02:06:04 -04:00 committed by GitHub
parent fb25e95688
commit a3a5a47e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -240,17 +240,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
1, 0)
embeddings_indices = torch.narrow(
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))
indices = embeddings_indices[1]
# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
full_lora_a_embeddings = F.embedding(
x + indices,
x + indices_1,
self.lora_a_stacked_2d,
)
indices = embeddings_indices[0]
full_output = self.base_layer.forward(x +
(indices * added_tokens_mask))
(indices_0 * added_tokens_mask))
full_output_org = full_output
if full_output.ndim == 3: