diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 3fb04c3b70dd1..4d7a37292cb02 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -8,7 +8,6 @@ import torch.nn as nn from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -172,10 +171,6 @@ class DeepseekV2Model(nn.Module): ) break else: - # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and "embed_tokens." in name: - continue - # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 660c8f1bb5226..0146b30579287 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -23,7 +23,6 @@ import torch.nn as nn from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -127,17 +126,11 @@ class LlamaModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and "embed_tokens." in name: - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) for name in params_dict: - # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and "embed_tokens." in name: - continue assert name in loaded_params, f"{name} is not loaded!" return loaded_params diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 90ab5c50361b6..05cb456e7776e 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -9,7 +9,6 @@ from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -155,10 +154,6 @@ class LlamaModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and "embed_tokens." in name: - continue - param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 406bb696bd4cf..ba37bc81607fe 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1028,8 +1028,11 @@ class EagleProposer: elif ( isinstance(target_embed_tokens.weight, torch.Tensor) and isinstance(self.model.model.embed_tokens.weight, torch.Tensor) - and torch.equal( - target_embed_tokens.weight, self.model.model.embed_tokens.weight + and torch.allclose( + target_embed_tokens.weight.cpu(), + self.model.model.embed_tokens.weight.cpu(), + rtol=1e-5, + atol=1e-7, ) ): share_embeddings = True