[Bugfix] Correct behavior of GraniteMoeHybrid for TensorParallel execution (#20137)

Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
This commit is contained in:
Stan Wozniak 2025-06-28 17:16:41 +02:00 committed by GitHub
parent daceac57c7
commit daec9dea6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 78 deletions

View File

@ -1,42 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ...utils import check_logprobs_close
# Path of the checkpoints
MODELS = [
"ibm-granite/granite-4.0-tiny-preview",
]
@pytest.mark.skip(
reason="Granite 4.0 is not yet available in huggingface transformers")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_model_equivalence_to_hf_greedy(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -28,8 +28,9 @@ SSM_MODELS = [
HYBRID_MODELS = [ HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev", "ai21labs/Jamba-tiny-dev",
# NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as # NOTE: Currently the test failes due to HF transformers issue fixed in:
# it is not yet available in huggingface transformers # https://github.com/huggingface/transformers/pull/39033
# We will enable vLLM test for Granite after next HF transformers release.
# "ibm-granite/granite-4.0-tiny-preview", # "ibm-granite/granite-4.0-tiny-preview",
# NOTE: Running Plamo2 in transformers implementation requires to install # NOTE: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's # causal-conv1d package, which is not listed as a test dependency as it's

View File

@ -15,7 +15,8 @@ from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import ( from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata) Mamba2Metadata, prepare_mamba2_metadata)
@ -36,8 +37,9 @@ from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only) SupportsQuant, SupportsV0Only)
from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_layers, maybe_prefix) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GraniteMoeHybridMambaDecoderLayer(nn.Module): class GraniteMoeHybridMambaDecoderLayer(nn.Module):
@ -220,35 +222,37 @@ class GraniteMoeHybridAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.attention_bias = config.attention_bias self.attention_bias = config.attention_bias
self.attention_multiplier = config.attention_multiplier self.attention_multiplier = config.attention_multiplier
self.num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.total_num_heads
self.num_key_value_heads = config.num_key_value_heads self.total_num_kv_heads = config.num_key_value_heads
self.q_proj = ReplicatedLinear(self.hidden_size, # TensorParallel logic
self.num_heads * self.head_dim, tp_size = get_tensor_model_parallel_world_size()
bias=self.attention_bias, assert self.total_num_heads % tp_size == 0
quant_config=quant_config, self.num_heads = self.total_num_heads // tp_size
prefix=f"{prefix}.q_proj") if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)
self.k_proj = ReplicatedLinear(self.hidden_size, self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.num_key_value_heads * self.head_dim,
self.head_dim, self.total_num_heads,
bias=self.attention_bias, self.total_num_kv_heads,
quant_config=quant_config, bias=self.attention_bias,
prefix=f"{prefix}.k_proj") quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.v_proj = ReplicatedLinear(self.hidden_size, self.o_proj = RowParallelLinear(self.hidden_size,
self.num_key_value_heads * self.hidden_size,
self.head_dim, bias=self.attention_bias,
bias=self.attention_bias, quant_config=quant_config,
quant_config=quant_config, prefix=f"{prefix}.o_proj")
prefix=f"{prefix}.v_proj")
self.o_proj = ReplicatedLinear(self.hidden_size,
self.hidden_size,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if config.position_embedding_type == "rope": if config.position_embedding_type == "rope":
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
@ -278,9 +282,12 @@ class GraniteMoeHybridAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
query = self.q_proj(hidden_states)[0] qkv, _ = self.qkv_proj(hidden_states)
key = self.k_proj(hidden_states)[0] query, key, value = qkv.split([
value = self.v_proj(hidden_states)[0] self.num_heads * self.head_dim, self.num_key_value_heads *
self.head_dim, self.num_key_value_heads * self.head_dim
],
dim=-1)
if self.rotary_emb is not None: if self.rotary_emb is not None:
query, key = self.rotary_emb(positions, query, key) query, key = self.rotary_emb(positions, query, key)
@ -401,6 +408,12 @@ class GraniteMoeHybridModel(nn.Module):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
@ -411,6 +424,15 @@ class GraniteMoeHybridModel(nn.Module):
weight_loader(param, p) weight_loader(param, p)
loaded_params.add(n) loaded_params.add(n)
def _load_shard(n, p, shard_id):
# Skip layers on other devices.
if not is_pp_missing_parameter(n, self):
param = params_dict[n]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, p, shard_id)
loaded_params.add(n)
def _load_expert(n, p, name, shard_id, expert_id): def _load_expert(n, p, name, shard_id, expert_id):
param = params_dict[n] param = params_dict[n]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
@ -465,7 +487,15 @@ class GraniteMoeHybridModel(nn.Module):
".block_sparse_moe.gate.weight") ".block_sparse_moe.gate.weight")
_load(gate_name, p) _load(gate_name, p)
else: else:
_load(n, p) loaded = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in n:
_load_shard(n.replace(weight_name, param_name),
p,
shard_id=shard_id)
loaded = True
if not loaded:
_load(n, p)
return loaded_params return loaded_params
@ -473,7 +503,13 @@ class GraniteMoeHybridModel(nn.Module):
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
SupportsPP, IsHybrid, SupportsV0Only, SupportsPP, IsHybrid, SupportsV0Only,
SupportsQuant): SupportsQuant):
packed_modules_mapping = {} packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
"lm_head": "output_embeddings", "lm_head": "output_embeddings",