[Quantization] add BNB for MixtralForCausalLM (#20893)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-07-14 15:34:34 +08:00 committed by GitHub
parent c488b928a7
commit a99b9f7dee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 128 additions and 20 deletions

View File

@ -227,7 +227,12 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
"fp8",
"compressed-tensors",
"gptq_marlin",
"awq_marlin",
"quark",
"bitsandbytes",
]
vllm_supported_archs = ModelRegistry.get_supported_archs()

View File

@ -45,12 +45,14 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers,
maybe_prefix)
class GraniteMoeMoE(nn.Module):
@ -307,6 +309,103 @@ class GraniteMoeModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def _load_weights(self,
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""
This function is copied from `MixtralModel.load_weights`, mainly to
decouple from mixtral, avoiding impact on support like BNB
quantization.
"""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
new_weights = {}
@ -339,7 +438,7 @@ class GraniteMoeModel(nn.Module):
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
return self._load_weights(new_weights.items())
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

View File

@ -27,8 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from . import mixtral
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
@ -242,7 +241,7 @@ class GraniteMoeSharedModel(nn.Module):
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
return GraniteMoeModel._load_weights(self, new_weights.items())
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

View File

@ -317,6 +317,15 @@ class MixtralModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -326,16 +335,9 @@ class MixtralModel(nn.Module):
("qkv_proj", "v_proj", "v"),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
@ -486,3 +488,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

View File

@ -352,6 +352,7 @@ class OlmoeModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@ -380,7 +381,7 @@ class OlmoeModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in self.get_expert_mapping():
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue

View File

@ -413,6 +413,7 @@ class Qwen2MoeModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@ -442,7 +443,7 @@ class Qwen2MoeModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in self.get_expert_mapping():
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue

View File

@ -400,11 +400,9 @@ class Qwen3MoeModel(nn.Module):
".v_scale", "_v_scale", ".weight_scale",
"_weight_scale", ".input_scale", "_input_scale")
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).