From 63d8eabed05a632679f0e9e929d52afb08d2ec7f Mon Sep 17 00:00:00 2001 From: Alexey Kiryushin Date: Tue, 1 Apr 2025 05:57:59 +0000 Subject: [PATCH] [Bugfix]: Fix is_embedding_layer condition in VocabParallelEmbedding (#15824) Signed-off-by: alexwl --- vllm/model_executor/layers/vocab_parallel_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index f65dfc3cb3294..1eb0c8c2ef4e1 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -235,7 +235,7 @@ class VocabParallelEmbedding(torch.nn.Module): # If we are making an embedding layer, then our quantization linear # method must implement the embedding operation. If we are another # layer type like ParallelLMHead, this is not important. - is_embedding_layer = type(self.__class__) is VocabParallelEmbedding + is_embedding_layer = type(self) is VocabParallelEmbedding quant_method_implements_embedding = method_has_implemented_embedding( type(quant_method)) if is_embedding_layer and not quant_method_implements_embedding: