mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 16:35:17 +08:00
[Model] Pipeline parallel support for Qwen2 (#6924)
This commit is contained in:
parent
7ecee34321
commit
1d2e7fb73f
@ -40,6 +40,8 @@ _PP_SUPPORTED_MODELS = [
|
||||
"GPT2LMHeadModel",
|
||||
"MixtralForCausalLM",
|
||||
"NemotronForCausalLM",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2MoeForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@ -49,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
@ -227,6 +228,7 @@ class Qwen2Model(nn.Module):
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -237,10 +239,14 @@ class Qwen2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen2DecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Qwen2DecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@ -255,20 +261,30 @@ class Qwen2Model(nn.Module):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@ -351,6 +367,20 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@ -381,6 +411,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -393,7 +425,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -31,7 +31,8 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -52,6 +53,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
|
||||
@ -315,6 +318,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -324,13 +328,15 @@ class Qwen2MoeModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen2MoeDecoderLayer(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Qwen2MoeDecoderLayer(config=config,
|
||||
layer_idx=int(
|
||||
prefix.split(".")[-1]),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
@ -339,14 +345,25 @@ class Qwen2MoeModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], attn_metadata,
|
||||
residual)
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@ -380,7 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
@ -389,6 +406,20 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
@ -435,6 +466,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
@ -448,6 +482,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
@ -460,6 +497,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
@ -474,7 +514,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user