mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:24:29 +08:00
[Bugfix] Correct behavior of GraniteMoeHybrid for TensorParallel execution (#20137)
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
This commit is contained in:
parent
daceac57c7
commit
daec9dea6e
@ -1,42 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
# Path of the checkpoints
|
||||
MODELS = [
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Granite 4.0 is not yet available in huggingface transformers")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_model_equivalence_to_hf_greedy(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
):
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
@ -28,8 +28,9 @@ SSM_MODELS = [
|
||||
|
||||
HYBRID_MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
# NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as
|
||||
# it is not yet available in huggingface transformers
|
||||
# NOTE: Currently the test failes due to HF transformers issue fixed in:
|
||||
# https://github.com/huggingface/transformers/pull/39033
|
||||
# We will enable vLLM test for Granite after next HF transformers release.
|
||||
# "ibm-granite/granite-4.0-tiny-preview",
|
||||
# NOTE: Running Plamo2 in transformers implementation requires to install
|
||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||
|
||||
@ -15,7 +15,8 @@ from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
@ -36,8 +37,9 @@ from .granitemoe import GraniteMoeMoE
|
||||
from .granitemoeshared import GraniteMoeSharedMLP
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsQuant, SupportsV0Only)
|
||||
from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix)
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
@ -220,35 +222,37 @@ class GraniteMoeHybridAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.attention_bias = config.attention_bias
|
||||
self.attention_multiplier = config.attention_multiplier
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
|
||||
self.q_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=self.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
# TensorParallel logic
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
self.k_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.num_key_value_heads *
|
||||
self.head_dim,
|
||||
bias=self.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.k_proj")
|
||||
self.qkv_proj = QKVParallelLinear(self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=self.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.v_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.num_key_value_heads *
|
||||
self.head_dim,
|
||||
bias=self.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.v_proj")
|
||||
|
||||
self.o_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=self.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
self.o_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=self.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
if config.position_embedding_type == "rope":
|
||||
self.rotary_emb = get_rope(
|
||||
@ -278,9 +282,12 @@ class GraniteMoeHybridAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
query = self.q_proj(hidden_states)[0]
|
||||
key = self.k_proj(hidden_states)[0]
|
||||
value = self.v_proj(hidden_states)[0]
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
query, key, value = qkv.split([
|
||||
self.num_heads * self.head_dim, self.num_key_value_heads *
|
||||
self.head_dim, self.num_key_value_heads * self.head_dim
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
if self.rotary_emb is not None:
|
||||
query, key = self.rotary_emb(positions, query, key)
|
||||
@ -401,6 +408,12 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
@ -411,6 +424,15 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
weight_loader(param, p)
|
||||
loaded_params.add(n)
|
||||
|
||||
def _load_shard(n, p, shard_id):
|
||||
# Skip layers on other devices.
|
||||
if not is_pp_missing_parameter(n, self):
|
||||
param = params_dict[n]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, p, shard_id)
|
||||
loaded_params.add(n)
|
||||
|
||||
def _load_expert(n, p, name, shard_id, expert_id):
|
||||
param = params_dict[n]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
@ -465,7 +487,15 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
".block_sparse_moe.gate.weight")
|
||||
_load(gate_name, p)
|
||||
else:
|
||||
_load(n, p)
|
||||
loaded = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name in n:
|
||||
_load_shard(n.replace(weight_name, param_name),
|
||||
p,
|
||||
shard_id=shard_id)
|
||||
loaded = True
|
||||
if not loaded:
|
||||
_load(n, p)
|
||||
|
||||
return loaded_params
|
||||
|
||||
@ -473,7 +503,13 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
SupportsPP, IsHybrid, SupportsV0Only,
|
||||
SupportsQuant):
|
||||
packed_modules_mapping = {}
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
}
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user