mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:44:27 +08:00
[MODEL] Add support for Zamba2 models (#13185)
Signed-off-by: Yury Tokpanov <yury@zyphra.com> Signed-off-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
8b793f7ec6
commit
452e8fd968
@ -477,6 +477,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `Zamba2ForCausalLM`
|
||||||
|
* Zamba2
|
||||||
|
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
|
||||||
|
*
|
||||||
|
*
|
||||||
:::
|
:::
|
||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from ...utils import check_outputs_equal
|
from ...utils import check_outputs_equal
|
||||||
|
|
||||||
# This test is for the hybrid models
|
# This test is for the hybrid models
|
||||||
MODELS = ["ai21labs/Jamba-tiny-dev"]
|
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
|
||||||
# Bamba at Fp32 is too big for the CI (L4 GPU).
|
# Bamba at Fp32 is too big for the CI (L4 GPU).
|
||||||
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
|
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
|
||||||
|
|
||||||
@ -27,17 +27,19 @@ def test_models(
|
|||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
# numeric error produces different generation
|
# numeric error produces different generation
|
||||||
if 'Bamba' in model:
|
if "Bamba" in model:
|
||||||
example_prompts.pop(3)
|
example_prompts.pop(3)
|
||||||
|
|
||||||
with hf_runner(
|
model_kwargs = {
|
||||||
model,
|
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||||
dtype=dtype,
|
# don't use them
|
||||||
model_kwargs={
|
}
|
||||||
"use_mamba_kernels":
|
if "Zamba2" in model:
|
||||||
False, # mamba kernels are not installed so HF
|
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||||
# don't use them
|
# installed
|
||||||
}) as hf_model:
|
model_kwargs = {}
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
|
|||||||
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
||||||
model: str, dtype: str,
|
model: str, dtype: str,
|
||||||
max_tokens: int) -> None:
|
max_tokens: int) -> None:
|
||||||
# numeric error during prefill chucking produces different generation
|
# numeric error during prefill chunking produces different generation
|
||||||
# compared to w/o prefill chunking for those examples, removed them for now
|
# compared to w/o prefill chunking for those examples, removed them for now
|
||||||
if 'Jamba' in model:
|
if "Jamba" in model:
|
||||||
example_prompts.pop(7)
|
example_prompts.pop(7)
|
||||||
example_prompts.pop(2)
|
example_prompts.pop(2)
|
||||||
example_prompts.pop(1)
|
example_prompts.pop(1)
|
||||||
elif 'Bamba' in model:
|
elif "Bamba" in model:
|
||||||
example_prompts.pop(6)
|
example_prompts.pop(6)
|
||||||
example_prompts.pop(3)
|
example_prompts.pop(3)
|
||||||
example_prompts.pop(2)
|
example_prompts.pop(2)
|
||||||
dtype = "half" # use a different dtype for Bamba
|
dtype = "half" # use a different dtype for Bamba
|
||||||
|
elif "Zamba2" in model:
|
||||||
|
example_prompts.pop(7)
|
||||||
|
dtype = "half"
|
||||||
|
|
||||||
with hf_runner(
|
model_kwargs = {
|
||||||
model,
|
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||||
dtype=dtype,
|
# don't use them
|
||||||
model_kwargs={
|
}
|
||||||
"use_mamba_kernels":
|
if "Zamba2" in model:
|
||||||
False, # mamba kernels are not installed so HF
|
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||||
# don't use them
|
# installed
|
||||||
}) as hf_model:
|
model_kwargs = {}
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||||
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
|
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
|
|||||||
@ -195,6 +195,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
|
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
|
||||||
is_available_online=False,
|
is_available_online=False,
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
|
||||||
|
min_transformers_version="4.49"),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
||||||
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
||||||
|
|||||||
@ -821,6 +821,11 @@ class ModelConfig:
|
|||||||
if qk_rope_head_dim and qk_nope_head_dim:
|
if qk_rope_head_dim and qk_nope_head_dim:
|
||||||
return qk_rope_head_dim + qk_nope_head_dim
|
return qk_rope_head_dim + qk_nope_head_dim
|
||||||
|
|
||||||
|
if hasattr(self.hf_text_config,
|
||||||
|
"model_type") and (self.hf_text_config.model_type
|
||||||
|
== "zamba2"):
|
||||||
|
return self.hf_text_config.attention_head_dim
|
||||||
|
|
||||||
if self.is_attention_free:
|
if self.is_attention_free:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -944,6 +949,15 @@ class ModelConfig:
|
|||||||
"cannot determine the num of "
|
"cannot determine the num of "
|
||||||
f"{block_type.value} layers")
|
f"{block_type.value} layers")
|
||||||
|
|
||||||
|
if hasattr(self.hf_text_config,
|
||||||
|
"model_type") and (self.hf_text_config.model_type
|
||||||
|
== "zamba2"):
|
||||||
|
if attn_block_type:
|
||||||
|
return sum(t == "hybrid"
|
||||||
|
for t in layers_block_type_value[start:end])
|
||||||
|
else:
|
||||||
|
return self.get_num_layers(parallel_config)
|
||||||
|
|
||||||
return sum(t == block_type.value
|
return sum(t == block_type.value
|
||||||
for t in layers_block_type_value[start:end])
|
for t in layers_block_type_value[start:end])
|
||||||
|
|
||||||
|
|||||||
@ -245,7 +245,6 @@ class MambaMixer2(CustomOp):
|
|||||||
assert num_heads % self.tp_size == 0, \
|
assert num_heads % self.tp_size == 0, \
|
||||||
"Tensor parallel world size must divide num heads."
|
"Tensor parallel world size must divide num heads."
|
||||||
|
|
||||||
|
|
||||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
||||||
(
|
(
|
||||||
"If tensor parallel world size does not divide num_heads, "
|
"If tensor parallel world size does not divide num_heads, "
|
||||||
|
|||||||
@ -38,8 +38,6 @@ from .utils import (is_pp_missing_parameter,
|
|||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class BambaMLP(nn.Module):
|
class BambaMLP(nn.Module):
|
||||||
|
|
||||||
|
|||||||
@ -36,8 +36,6 @@ from .utils import (is_pp_missing_parameter,
|
|||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class JambaMoE(nn.Module):
|
class JambaMoE(nn.Module):
|
||||||
|
|
||||||
|
|||||||
@ -105,6 +105,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||||
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||||
|
|||||||
1031
vllm/model_executor/models/zamba2.py
Normal file
1031
vllm/model_executor/models/zamba2.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user