[Bugfix] Fix spec decode memory regression after #28549 (#28819)

Signed-off-by: zhewenli <zhewenli@meta.com>
This commit is contained in:
Zhewen Li 2025-11-20 03:05:50 -08:00 committed by GitHub
parent 371b1d4c61
commit 93c8672ceb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 5 additions and 19 deletions

View File

@ -8,7 +8,6 @@ import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -172,10 +171,6 @@ class DeepseekV2Model(nn.Module):
)
break
else:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and "embed_tokens." in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

View File

@ -23,7 +23,6 @@ import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -127,17 +126,11 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and "embed_tokens." in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for name in params_dict:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and "embed_tokens." in name:
continue
assert name in loaded_params, f"{name} is not loaded!"
return loaded_params

View File

@ -9,7 +9,6 @@ from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -155,10 +154,6 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and "embed_tokens." in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -1028,8 +1028,11 @@ class EagleProposer:
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
and torch.equal(
target_embed_tokens.weight, self.model.model.embed_tokens.weight
and torch.allclose(
target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(),
rtol=1e-5,
atol=1e-7,
)
):
share_embeddings = True