[Model] Pipeline parallel support for Qwen2 (#6924)

This commit is contained in:
xuyi 2024-08-01 09:49:51 +08:00 committed by GitHub
parent 7ecee34321
commit 1d2e7fb73f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 27 deletions

View File

@ -40,6 +40,8 @@ _PP_SUPPORTED_MODELS = [
"GPT2LMHeadModel", "GPT2LMHeadModel",
"MixtralForCausalLM", "MixtralForCausalLM",
"NemotronForCausalLM", "NemotronForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
] ]

View File

@ -30,7 +30,7 @@ from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -49,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
@ -227,6 +228,7 @@ class Qwen2Model(nn.Module):
config: Qwen2Config, config: Qwen2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -237,10 +239,14 @@ class Qwen2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
Qwen2DecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: Qwen2DecoderLayer(config=config,
]) cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
@ -255,20 +261,30 @@ class Qwen2Model(nn.Module):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
else: else:
hidden_states = self.embed_tokens(input_ids) assert intermediate_tensors is not None
residual = None hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.layers)): residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
@ -351,6 +367,20 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
@ -381,6 +411,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@ -393,7 +425,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
name = maybe_remap_kv_scale_name(name, params_dict) name = maybe_remap_kv_scale_name(name, params_dict)
if name is None: if name is None:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -31,7 +31,8 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -52,6 +53,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .utils import is_pp_missing_parameter, make_layers
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
@ -315,6 +318,7 @@ class Qwen2MoeModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@ -324,13 +328,15 @@ class Qwen2MoeModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
Qwen2MoeDecoderLayer(config, config.num_hidden_layers,
layer_idx, lambda prefix: Qwen2MoeDecoderLayer(config=config,
cache_config, layer_idx=int(
quant_config=quant_config) prefix.split(".")[-1]),
for layer_idx in range(config.num_hidden_layers) cache_config=cache_config,
]) quant_config=quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
@ -339,14 +345,25 @@ class Qwen2MoeModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) if get_pp_group().is_first_rank:
residual = None hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)): residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata, kv_caches[i - self.start_layer],
residual) attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
@ -380,7 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
@ -389,6 +406,20 @@ class Qwen2MoeForCausalLM(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: Optional[torch.Tensor], logits: Optional[torch.Tensor],
@ -435,6 +466,9 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict: if name not in params_dict:
continue continue
@ -448,6 +482,9 @@ class Qwen2MoeForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
@ -460,6 +497,9 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"): if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace( remapped_kv_scale_name = name.replace(
@ -474,7 +514,6 @@ class Qwen2MoeForCausalLM(nn.Module):
continue continue
else: else:
name = remapped_kv_scale_name name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)