mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 04:35:47 +08:00
[Bugfix] Correct behavior of GraniteMoeHybrid for TensorParallel execution (#20137)
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
This commit is contained in:
parent
daceac57c7
commit
daec9dea6e
@ -1,42 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
|
||||||
|
|
||||||
# Path of the checkpoints
|
|
||||||
MODELS = [
|
|
||||||
"ibm-granite/granite-4.0-tiny-preview",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="Granite 4.0 is not yet available in huggingface transformers")
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
|
||||||
def test_model_equivalence_to_hf_greedy(
|
|
||||||
hf_runner,
|
|
||||||
vllm_runner,
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
num_logprobs: int,
|
|
||||||
):
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
|
||||||
example_prompts, max_tokens, num_logprobs)
|
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
|
||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
|
||||||
example_prompts, max_tokens, num_logprobs)
|
|
||||||
|
|
||||||
check_logprobs_close(
|
|
||||||
outputs_0_lst=hf_outputs,
|
|
||||||
outputs_1_lst=vllm_outputs,
|
|
||||||
name_0="hf",
|
|
||||||
name_1="vllm",
|
|
||||||
)
|
|
||||||
@ -28,8 +28,9 @@ SSM_MODELS = [
|
|||||||
|
|
||||||
HYBRID_MODELS = [
|
HYBRID_MODELS = [
|
||||||
"ai21labs/Jamba-tiny-dev",
|
"ai21labs/Jamba-tiny-dev",
|
||||||
# NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as
|
# NOTE: Currently the test failes due to HF transformers issue fixed in:
|
||||||
# it is not yet available in huggingface transformers
|
# https://github.com/huggingface/transformers/pull/39033
|
||||||
|
# We will enable vLLM test for Granite after next HF transformers release.
|
||||||
# "ibm-granite/granite-4.0-tiny-preview",
|
# "ibm-granite/granite-4.0-tiny-preview",
|
||||||
# NOTE: Running Plamo2 in transformers implementation requires to install
|
# NOTE: Running Plamo2 in transformers implementation requires to install
|
||||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
|||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
@ -36,8 +37,9 @@ from .granitemoe import GraniteMoeMoE
|
|||||||
from .granitemoeshared import GraniteMoeSharedMLP
|
from .granitemoeshared import GraniteMoeSharedMLP
|
||||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||||
SupportsQuant, SupportsV0Only)
|
SupportsQuant, SupportsV0Only)
|
||||||
from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_layers, maybe_prefix)
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
|
maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||||
@ -220,35 +222,37 @@ class GraniteMoeHybridAttention(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.attention_bias = config.attention_bias
|
self.attention_bias = config.attention_bias
|
||||||
self.attention_multiplier = config.attention_multiplier
|
self.attention_multiplier = config.attention_multiplier
|
||||||
self.num_heads = config.num_attention_heads
|
self.total_num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.total_num_heads
|
||||||
self.num_key_value_heads = config.num_key_value_heads
|
self.total_num_kv_heads = config.num_key_value_heads
|
||||||
|
|
||||||
self.q_proj = ReplicatedLinear(self.hidden_size,
|
# TensorParallel logic
|
||||||
self.num_heads * self.head_dim,
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
bias=self.attention_bias,
|
assert self.total_num_heads % tp_size == 0
|
||||||
quant_config=quant_config,
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
prefix=f"{prefix}.q_proj")
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
|
||||||
self.k_proj = ReplicatedLinear(self.hidden_size,
|
self.qkv_proj = QKVParallelLinear(self.hidden_size,
|
||||||
self.num_key_value_heads *
|
self.head_dim,
|
||||||
self.head_dim,
|
self.total_num_heads,
|
||||||
bias=self.attention_bias,
|
self.total_num_kv_heads,
|
||||||
quant_config=quant_config,
|
bias=self.attention_bias,
|
||||||
prefix=f"{prefix}.k_proj")
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj")
|
||||||
|
|
||||||
self.v_proj = ReplicatedLinear(self.hidden_size,
|
self.o_proj = RowParallelLinear(self.hidden_size,
|
||||||
self.num_key_value_heads *
|
self.hidden_size,
|
||||||
self.head_dim,
|
bias=self.attention_bias,
|
||||||
bias=self.attention_bias,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
prefix=f"{prefix}.o_proj")
|
||||||
prefix=f"{prefix}.v_proj")
|
|
||||||
|
|
||||||
self.o_proj = ReplicatedLinear(self.hidden_size,
|
|
||||||
self.hidden_size,
|
|
||||||
bias=self.attention_bias,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.o_proj")
|
|
||||||
|
|
||||||
if config.position_embedding_type == "rope":
|
if config.position_embedding_type == "rope":
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
@ -278,9 +282,12 @@ class GraniteMoeHybridAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
query = self.q_proj(hidden_states)[0]
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
key = self.k_proj(hidden_states)[0]
|
query, key, value = qkv.split([
|
||||||
value = self.v_proj(hidden_states)[0]
|
self.num_heads * self.head_dim, self.num_key_value_heads *
|
||||||
|
self.head_dim, self.num_key_value_heads * self.head_dim
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
if self.rotary_emb is not None:
|
if self.rotary_emb is not None:
|
||||||
query, key = self.rotary_emb(positions, query, key)
|
query, key = self.rotary_emb(positions, query, key)
|
||||||
@ -401,6 +408,12 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
@ -411,6 +424,15 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
weight_loader(param, p)
|
weight_loader(param, p)
|
||||||
loaded_params.add(n)
|
loaded_params.add(n)
|
||||||
|
|
||||||
|
def _load_shard(n, p, shard_id):
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if not is_pp_missing_parameter(n, self):
|
||||||
|
param = params_dict[n]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, p, shard_id)
|
||||||
|
loaded_params.add(n)
|
||||||
|
|
||||||
def _load_expert(n, p, name, shard_id, expert_id):
|
def _load_expert(n, p, name, shard_id, expert_id):
|
||||||
param = params_dict[n]
|
param = params_dict[n]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
@ -465,7 +487,15 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
".block_sparse_moe.gate.weight")
|
".block_sparse_moe.gate.weight")
|
||||||
_load(gate_name, p)
|
_load(gate_name, p)
|
||||||
else:
|
else:
|
||||||
_load(n, p)
|
loaded = False
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name in n:
|
||||||
|
_load_shard(n.replace(weight_name, param_name),
|
||||||
|
p,
|
||||||
|
shard_id=shard_id)
|
||||||
|
loaded = True
|
||||||
|
if not loaded:
|
||||||
|
_load(n, p)
|
||||||
|
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
@ -473,7 +503,13 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||||
SupportsPP, IsHybrid, SupportsV0Only,
|
SupportsPP, IsHybrid, SupportsV0Only,
|
||||||
SupportsQuant):
|
SupportsQuant):
|
||||||
packed_modules_mapping = {}
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
embedding_modules = {
|
embedding_modules = {
|
||||||
"embed_tokens": "input_embeddings",
|
"embed_tokens": "input_embeddings",
|
||||||
"lm_head": "output_embeddings",
|
"lm_head": "output_embeddings",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user