mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 21:07:10 +08:00
[Model] Support quantization of PixtralHFTransformer for PixtralHF (#9921)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
731aec5be7
commit
a53046b16f
@ -299,3 +299,33 @@ def get_act_fn(
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_and_mul_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
|
||||
@ -19,8 +19,11 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -798,20 +801,24 @@ class PixtralHFMLP(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
assert config.intermediate_size is not None
|
||||
# TODO: Use quant_config and prefix after optimizing this
|
||||
self.gate_proj = nn.Linear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=False)
|
||||
self.up_proj = nn.Linear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_sizes=[config.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_and_mul(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class PixtralHFAttention(nn.Module):
|
||||
@ -830,21 +837,21 @@ class PixtralHFAttention(nn.Module):
|
||||
self.n_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
# TODO: Use quant_config and prefix after optimizing this
|
||||
self.q_proj = nn.Linear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.v_proj = nn.Linear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.o_proj = nn.Linear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=config.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.n_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -854,13 +861,13 @@ class PixtralHFAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
batch, patches, _ = hidden_states.size()
|
||||
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
# Transpose q and k to apply HF's Rotary Position Embedding
|
||||
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch, patches, self.n_heads, self.head_dim)
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
|
||||
|
||||
@ -868,22 +875,21 @@ class PixtralHFAttention(nn.Module):
|
||||
# Transpose q and k back for attention
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
|
||||
out = xops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=attention_mask)
|
||||
else:
|
||||
v = v.reshape(batch, patches, self.n_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attention_mask)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||
out = out.view(batch, patches, self.n_heads * self.head_dim)
|
||||
attn_output, _ = self.o_proj(out)
|
||||
|
||||
return self.o_proj(out)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class PixtralHFTransformerBlock(nn.Module):
|
||||
@ -912,9 +918,9 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
attention_mask: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
r = self.attention.forward(self.attention_norm(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings)
|
||||
r, _ = self.attention.forward(self.attention_norm(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings)
|
||||
h = hidden_states + r
|
||||
r = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
@ -1053,10 +1059,24 @@ class PixtralHFVisionModel(nn.Module):
|
||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = []
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
layer_count = len(self.transformer.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if name.startswith("transformer.layers"):
|
||||
layer_idx = int(name.split(".")[2])
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user