mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 04:15:01 +08:00
[Model] Qwen2.5 VL SiLU-and-Mul (#22066)
Signed-off-by: kf <kuanfu.liu@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: kf <kuanfu.liu@embeddedllm.com>
This commit is contained in:
parent
23322431c8
commit
ee2eb6ecd8
@ -43,9 +43,10 @@ from vllm.distributed import parallel_state
|
|||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -171,16 +172,12 @@ class Qwen2_5_VisionMLP(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_proj = ColumnParallelLinear(in_features,
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_features,
|
input_size=in_features,
|
||||||
|
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate_proj")
|
prefix=f"{prefix}.gate_up_proj")
|
||||||
self.up_proj = ColumnParallelLinear(in_features,
|
|
||||||
hidden_features,
|
|
||||||
bias=bias,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.up_proj")
|
|
||||||
self.down_proj = RowParallelLinear(hidden_features,
|
self.down_proj = RowParallelLinear(hidden_features,
|
||||||
in_features,
|
in_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -189,10 +186,9 @@ class Qwen2_5_VisionMLP(nn.Module):
|
|||||||
self.act_fn = act_fn
|
self.act_fn = act_fn
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
x_gate, _ = self.gate_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x_gate = self.act_fn(x_gate)
|
x = self.act_fn(gate_up)
|
||||||
x_up, _ = self.up_proj(x)
|
x_down, _ = self.down_proj(x)
|
||||||
x_down, _ = self.down_proj(x_gate * x_up)
|
|
||||||
return x_down
|
return x_down
|
||||||
|
|
||||||
|
|
||||||
@ -540,11 +536,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
Qwen2_5_VisionBlock(
|
Qwen2_5_VisionBlock(dim=self.hidden_size,
|
||||||
dim=self.hidden_size,
|
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
mlp_hidden_dim=vision_config.intermediate_size,
|
mlp_hidden_dim=vision_config.intermediate_size,
|
||||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
act_fn=get_act_and_mul_fn(
|
||||||
|
vision_config.hidden_act),
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||||
@ -752,6 +748,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
("attn.qkv.", "attn.q.", "q"),
|
("attn.qkv.", "attn.q.", "q"),
|
||||||
("attn.qkv.", "attn.k.", "k"),
|
("attn.qkv.", "attn.k.", "k"),
|
||||||
("attn.qkv.", "attn.v.", "v"),
|
("attn.qkv.", "attn.v.", "v"),
|
||||||
|
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
|
||||||
|
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user