[Bugfix] Correct behavior of GraniteMoeHybrid for TensorParallel execution (#20137)

Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
This commit is contained in:
Stan Wozniak 2025-06-28 17:16:41 +02:00 committed by GitHub
parent daceac57c7
commit daec9dea6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 78 deletions

View File

@ -1,42 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ...utils import check_logprobs_close
# Path of the checkpoints
MODELS = [
"ibm-granite/granite-4.0-tiny-preview",
]
@pytest.mark.skip(
reason="Granite 4.0 is not yet available in huggingface transformers")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_model_equivalence_to_hf_greedy(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -28,8 +28,9 @@ SSM_MODELS = [
HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
# NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as
# it is not yet available in huggingface transformers
# NOTE: Currently the test failes due to HF transformers issue fixed in:
# https://github.com/huggingface/transformers/pull/39033
# We will enable vLLM test for Granite after next HF transformers release.
# "ibm-granite/granite-4.0-tiny-preview",
# NOTE: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's

View File

@ -15,7 +15,8 @@ from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
@ -36,8 +37,9 @@ from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GraniteMoeHybridMambaDecoderLayer(nn.Module):
@ -220,35 +222,37 @@ class GraniteMoeHybridAttention(nn.Module):
self.hidden_size = config.hidden_size
self.attention_bias = config.attention_bias
self.attention_multiplier = config.attention_multiplier
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.total_num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.total_num_heads
self.total_num_kv_heads = config.num_key_value_heads
self.q_proj = ReplicatedLinear(self.hidden_size,
self.num_heads * self.head_dim,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
# TensorParallel logic
tp_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)
self.k_proj = ReplicatedLinear(self.hidden_size,
self.num_key_value_heads *
self.head_dim,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.k_proj")
self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.v_proj = ReplicatedLinear(self.hidden_size,
self.num_key_value_heads *
self.head_dim,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.v_proj")
self.o_proj = ReplicatedLinear(self.hidden_size,
self.hidden_size,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.o_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if config.position_embedding_type == "rope":
self.rotary_emb = get_rope(
@ -278,9 +282,12 @@ class GraniteMoeHybridAttention(nn.Module):
hidden_states: torch.Tensor,
) -> torch.Tensor:
query = self.q_proj(hidden_states)[0]
key = self.k_proj(hidden_states)[0]
value = self.v_proj(hidden_states)[0]
qkv, _ = self.qkv_proj(hidden_states)
query, key, value = qkv.split([
self.num_heads * self.head_dim, self.num_key_value_heads *
self.head_dim, self.num_key_value_heads * self.head_dim
],
dim=-1)
if self.rotary_emb is not None:
query, key = self.rotary_emb(positions, query, key)
@ -401,6 +408,12 @@ class GraniteMoeHybridModel(nn.Module):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@ -411,6 +424,15 @@ class GraniteMoeHybridModel(nn.Module):
weight_loader(param, p)
loaded_params.add(n)
def _load_shard(n, p, shard_id):
# Skip layers on other devices.
if not is_pp_missing_parameter(n, self):
param = params_dict[n]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, p, shard_id)
loaded_params.add(n)
def _load_expert(n, p, name, shard_id, expert_id):
param = params_dict[n]
weight_loader = getattr(param, "weight_loader",
@ -465,7 +487,15 @@ class GraniteMoeHybridModel(nn.Module):
".block_sparse_moe.gate.weight")
_load(gate_name, p)
else:
_load(n, p)
loaded = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in n:
_load_shard(n.replace(weight_name, param_name),
p,
shard_id=shard_id)
loaded = True
if not loaded:
_load(n, p)
return loaded_params
@ -473,7 +503,13 @@ class GraniteMoeHybridModel(nn.Module):
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
SupportsPP, IsHybrid, SupportsV0Only,
SupportsQuant):
packed_modules_mapping = {}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",