mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 07:07:12 +08:00
[Quantization] add BNB for MixtralForCausalLM (#20893)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
c488b928a7
commit
a99b9f7dee
@ -227,7 +227,12 @@ def get_model_architecture(
|
|||||||
# Special handling for quantized Mixtral.
|
# Special handling for quantized Mixtral.
|
||||||
# FIXME(woosuk): This is a temporary hack.
|
# FIXME(woosuk): This is a temporary hack.
|
||||||
mixtral_supported = [
|
mixtral_supported = [
|
||||||
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
|
"fp8",
|
||||||
|
"compressed-tensors",
|
||||||
|
"gptq_marlin",
|
||||||
|
"awq_marlin",
|
||||||
|
"quark",
|
||||||
|
"bitsandbytes",
|
||||||
]
|
]
|
||||||
|
|
||||||
vllm_supported_archs = ModelRegistry.get_supported_archs()
|
vllm_supported_archs = ModelRegistry.get_supported_archs()
|
||||||
|
|||||||
@ -45,12 +45,14 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from . import mixtral
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers,
|
||||||
|
maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeMoE(nn.Module):
|
class GraniteMoeMoE(nn.Module):
|
||||||
@ -307,6 +309,103 @@ class GraniteMoeModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def _load_weights(self,
|
||||||
|
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
"""
|
||||||
|
This function is copied from `MixtralModel.load_weights`, mainly to
|
||||||
|
decouple from mixtral, avoiding impact on support like BNB
|
||||||
|
quantization.
|
||||||
|
"""
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="w1",
|
||||||
|
ckpt_down_proj_name="w2",
|
||||||
|
ckpt_up_proj_name="w3",
|
||||||
|
num_experts=self.config.num_local_experts)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if name.endswith("scale"):
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
new_weights = {}
|
new_weights = {}
|
||||||
@ -339,7 +438,7 @@ class GraniteMoeModel(nn.Module):
|
|||||||
new_weights[gate_name] = p
|
new_weights[gate_name] = p
|
||||||
else:
|
else:
|
||||||
new_weights[n] = p
|
new_weights[n] = p
|
||||||
return mixtral.MixtralModel.load_weights(self, new_weights.items())
|
return self._load_weights(new_weights.items())
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
|||||||
@ -27,8 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from . import mixtral
|
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
|
||||||
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
|
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
|
||||||
|
|
||||||
@ -242,7 +241,7 @@ class GraniteMoeSharedModel(nn.Module):
|
|||||||
new_weights[gate_name] = p
|
new_weights[gate_name] = p
|
||||||
else:
|
else:
|
||||||
new_weights[n] = p
|
new_weights[n] = p
|
||||||
return mixtral.MixtralModel.load_weights(self, new_weights.items())
|
return GraniteMoeModel._load_weights(self, new_weights.items())
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
|||||||
@ -317,6 +317,15 @@ class MixtralModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="w1",
|
||||||
|
ckpt_down_proj_name="w2",
|
||||||
|
ckpt_up_proj_name="w3",
|
||||||
|
num_experts=self.config.num_local_experts)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -326,16 +335,9 @@ class MixtralModel(nn.Module):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="w1",
|
|
||||||
ckpt_down_proj_name="w2",
|
|
||||||
ckpt_up_proj_name="w3",
|
|
||||||
num_experts=self.config.num_local_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if (self.quant_config is not None and
|
if (self.quant_config is not None and
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
@ -486,3 +488,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -352,6 +352,7 @@ class OlmoeModel(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
@ -380,7 +381,7 @@ class OlmoeModel(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for mapping in self.get_expert_mapping():
|
for mapping in expert_params_mapping:
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -413,6 +413,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
@ -442,7 +443,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for mapping in self.get_expert_mapping():
|
for mapping in expert_params_mapping:
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -400,11 +400,9 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
".v_scale", "_v_scale", ".weight_scale",
|
".v_scale", "_v_scale", ".weight_scale",
|
||||||
"_weight_scale", ".input_scale", "_input_scale")
|
"_weight_scale", ".input_scale", "_input_scale")
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = self.get_expert_mapping()
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user