diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py deleted file mode 100644 index 952449f284159..0000000000000 --- a/tests/models/language/generation/test_granitemoehybrid.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from ...utils import check_logprobs_close - -# Path of the checkpoints -MODELS = [ - "ibm-granite/granite-4.0-tiny-preview", -] - - -@pytest.mark.skip( - reason="Granite 4.0 is not yet available in huggingface transformers") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_model_equivalence_to_hf_greedy( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -): - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index b2348e6449339..e6dd6c35e64d6 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -28,8 +28,9 @@ SSM_MODELS = [ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as - # it is not yet available in huggingface transformers + # NOTE: Currently the test failes due to HF transformers issue fixed in: + # https://github.com/huggingface/transformers/pull/39033 + # We will enable vLLM test for Granite after next HF transformers release. # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 26b5b3ac15345..33e8626209d50 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -15,7 +15,8 @@ from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) @@ -36,8 +37,9 @@ from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant, SupportsV0Only) -from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory, - make_layers, maybe_prefix) +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class GraniteMoeHybridMambaDecoderLayer(nn.Module): @@ -220,35 +222,37 @@ class GraniteMoeHybridAttention(nn.Module): self.hidden_size = config.hidden_size self.attention_bias = config.attention_bias self.attention_multiplier = config.attention_multiplier - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads + self.total_num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads - self.q_proj = ReplicatedLinear(self.hidden_size, - self.num_heads * self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + # TensorParallel logic + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size) - self.k_proj = ReplicatedLinear(self.hidden_size, - self.num_key_value_heads * - self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.k_proj") + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") - self.v_proj = ReplicatedLinear(self.hidden_size, - self.num_key_value_heads * - self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.v_proj") - - self.o_proj = ReplicatedLinear(self.hidden_size, - self.hidden_size, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") if config.position_embedding_type == "rope": self.rotary_emb = get_rope( @@ -278,9 +282,12 @@ class GraniteMoeHybridAttention(nn.Module): hidden_states: torch.Tensor, ) -> torch.Tensor: - query = self.q_proj(hidden_states)[0] - key = self.k_proj(hidden_states)[0] - value = self.v_proj(hidden_states)[0] + qkv, _ = self.qkv_proj(hidden_states) + query, key, value = qkv.split([ + self.num_heads * self.head_dim, self.num_key_value_heads * + self.head_dim, self.num_key_value_heads * self.head_dim + ], + dim=-1) if self.rotary_emb is not None: query, key = self.rotary_emb(positions, query, key) @@ -401,6 +408,12 @@ class GraniteMoeHybridModel(nn.Module): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -411,6 +424,15 @@ class GraniteMoeHybridModel(nn.Module): weight_loader(param, p) loaded_params.add(n) + def _load_shard(n, p, shard_id): + # Skip layers on other devices. + if not is_pp_missing_parameter(n, self): + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, p, shard_id) + loaded_params.add(n) + def _load_expert(n, p, name, shard_id, expert_id): param = params_dict[n] weight_loader = getattr(param, "weight_loader", @@ -465,7 +487,15 @@ class GraniteMoeHybridModel(nn.Module): ".block_sparse_moe.gate.weight") _load(gate_name, p) else: - _load(n, p) + loaded = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in n: + _load_shard(n.replace(weight_name, param_name), + p, + shard_id=shard_id) + loaded = True + if not loaded: + _load(n, p) return loaded_params @@ -473,7 +503,13 @@ class GraniteMoeHybridModel(nn.Module): class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsV0Only, SupportsQuant): - packed_modules_mapping = {} + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings",