mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 12:35:01 +08:00
[Bugfix] Allow skipping MoE in NVFP4 (fix for MTP) (#25987)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
f23b4c04fd
commit
2161efe978
@ -1194,6 +1194,8 @@ class FusedMoE(CustomOp):
|
|||||||
if quant_config is None
|
if quant_config is None
|
||||||
else quant_config.get_quant_method(self, prefix)
|
else quant_config.get_quant_method(self, prefix)
|
||||||
)
|
)
|
||||||
|
if quant_method is None:
|
||||||
|
quant_method = UnquantizedFusedMoEMethod(moe)
|
||||||
|
|
||||||
assert quant_method is not None
|
assert quant_method is not None
|
||||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||||
|
|||||||
@ -884,8 +884,9 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
|
skip_layer = self.is_layer_excluded(prefix)
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if self.is_layer_excluded(prefix):
|
if skip_layer:
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
# Check if this is a vision model layer that should not be quantized
|
# Check if this is a vision model layer that should not be quantized
|
||||||
if "vision_tower" in prefix or "vision_model" in prefix:
|
if "vision_tower" in prefix or "vision_model" in prefix:
|
||||||
@ -894,6 +895,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
|
if skip_layer:
|
||||||
|
return None
|
||||||
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
|
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -55,6 +55,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
DeepseekV2DecoderLayer(
|
DeepseekV2DecoderLayer(
|
||||||
vllm_config,
|
vllm_config,
|
||||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||||
|
config=self.config,
|
||||||
)
|
)
|
||||||
for i in range(self.config.num_hidden_layers)
|
for i in range(self.config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -48,7 +48,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||||
|
self.config = config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -66,11 +67,15 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_indices_buffer = None
|
topk_indices_buffer = None
|
||||||
|
|
||||||
self.shared_head = SharedHead(
|
self.shared_head = SharedHead(
|
||||||
config=config, prefix=prefix, quant_config=quant_config
|
config=config, prefix=prefix, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.mtp_block = DeepseekV2DecoderLayer(
|
self.mtp_block = DeepseekV2DecoderLayer(
|
||||||
vllm_config, prefix, topk_indices_buffer
|
vllm_config,
|
||||||
|
prefix,
|
||||||
|
config=self.config,
|
||||||
|
topk_indices_buffer=topk_indices_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -1055,10 +1055,12 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
|
config: Optional[DeepseekV2Config] = None,
|
||||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -1200,7 +1202,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: DeepseekV2DecoderLayer(
|
lambda prefix: DeepseekV2DecoderLayer(
|
||||||
vllm_config, prefix, topk_indices_buffer
|
vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
|
||||||
),
|
),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user