mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:45:01 +08:00
[Quantization] Enable BNB support for more MoE models (#21370)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
7311f74468
commit
5c3f2628d5
@ -54,8 +54,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -327,6 +327,7 @@ class Dots1DecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class Dots1Model(nn.Module):
|
class Dots1Model(nn.Module):
|
||||||
|
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
@ -404,68 +405,12 @@ class Dots1Model(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
@support_torch_compile
|
return FusedMoE.make_expert_params_mapping(
|
||||||
class Dots1ForCausalLM(nn.Module, SupportsPP):
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
ckpt_up_proj_name="up_proj",
|
||||||
super().__init__()
|
num_experts=self.config.n_routed_experts)
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
self.config = config
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.model = Dots1Model(vllm_config=vllm_config,
|
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
if get_pp_group().is_last_rank:
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config)
|
|
||||||
else:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
self.model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.model.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids,
|
|
||||||
positions,
|
|
||||||
intermediate_tensors,
|
|
||||||
inputs_embeds,
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def make_empty_intermediate_tensors(
|
|
||||||
self, batch_size: int, dtype: torch.dtype,
|
|
||||||
device: torch.device) -> IntermediateTensors:
|
|
||||||
return IntermediateTensors({
|
|
||||||
"hidden_states":
|
|
||||||
torch.zeros((batch_size, self.config.hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device),
|
|
||||||
"residual":
|
|
||||||
torch.zeros((batch_size, self.config.hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device),
|
|
||||||
})
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
@ -477,14 +422,9 @@ class Dots1ForCausalLM(nn.Module, SupportsPP):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.n_routed_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
@ -534,3 +474,71 @@ class Dots1ForCausalLM(nn.Module, SupportsPP):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Dots1Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|||||||
@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
@ -461,6 +461,15 @@ class Glm4MoeModel(nn.Module):
|
|||||||
device=device),
|
device=device),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.n_routed_experts)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -472,16 +481,9 @@ class Glm4MoeModel(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.n_routed_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||||
if spec_layer is not None:
|
if spec_layer is not None:
|
||||||
@ -570,7 +572,7 @@ class Glm4MoeModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class Glm4MoeForCausalLM(nn.Module, SupportsPP):
|
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -677,6 +679,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
|
return self.model.get_expert_mapping()
|
||||||
|
|
||||||
|
|
||||||
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||||
weight_name: str) -> Optional[int]:
|
weight_name: str) -> Optional[int]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user