[CI] Enable all hf transformers baselines in test_hybrid (#23936)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-09-02 22:15:06 +02:00 committed by GitHub
parent 98aee612aa
commit d328f7894f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 57 deletions

View File

@ -34,17 +34,6 @@ HYBRID_MODELS = [
"LiquidAI/LFM2-1.2B", "LiquidAI/LFM2-1.2B",
] ]
HF_UNSUPPORTED_MODELS = [
# The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
"yujiepan/mamba2-codestral-v0.1-tiny-random",
# transformers 4.55 is still producing garbage for this model
# TODO(tdoublep): follow-up on transformers side
"ibm-granite/granite-4.0-tiny-preview"
]
V1_SUPPORTED_MODELS = [ V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev", "ai21labs/Jamba-tiny-dev",
@ -90,20 +79,13 @@ def test_models(
try: try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
hf_version_check = model_info.check_transformers_version( model_info.check_transformers_version(on_fail="skip")
on_fail="return")
except ValueError: except ValueError:
hf_version_check = None pass
if hf_version_check is not None:
print(f"Skipping transformers comparison because: {hf_version_check}")
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS and hf_version_check is None: hf_outputs = hf_model.generate_greedy_logprobs_limit(
hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs)
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "0")
@ -121,7 +103,7 @@ def test_models(
else: else:
vllm_v1_outputs = None vllm_v1_outputs = None
if hf_outputs is not None and vllm_v0_outputs is not None: if vllm_v0_outputs is not None:
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs, outputs_1_lst=vllm_v0_outputs,
@ -130,12 +112,10 @@ def test_models(
) )
if model in V1_SUPPORTED_MODELS: if model in V1_SUPPORTED_MODELS:
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close( check_logprobs_close(
outputs_0_lst=ref_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0", name_0="hf",
name_1="vllm-v1", name_1="vllm-v1",
) )
@ -402,11 +382,8 @@ def test_full_cuda_graph(
pass pass
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS: hf_outputs = hf_model.generate_greedy_logprobs_limit(
hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs)
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "0")
@ -421,7 +398,7 @@ def test_full_cuda_graph(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None and vllm_v0_outputs is not None: if vllm_v0_outputs is not None:
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs, outputs_1_lst=vllm_v0_outputs,
@ -429,12 +406,10 @@ def test_full_cuda_graph(
name_1="vllm-v0", name_1="vllm-v0",
) )
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
assert ref_outputs is not None
check_logprobs_close( check_logprobs_close(
outputs_0_lst=ref_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0", name_0="hf",
name_1="vllm-v1", name_1="vllm-v1",
) )
@ -460,11 +435,8 @@ def test_fp32_state(
pass pass
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS: hf_outputs = hf_model.generate_greedy_logprobs_limit(
hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs)
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "0")
@ -480,18 +452,16 @@ def test_fp32_state(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close( check_logprobs_close(
outputs_0_lst=ref_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v1_outputs, outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0", name_0="hf",
name_1="vllm-v1", name_1="vllm-v1",
) )

View File

@ -154,7 +154,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
trust_remote_code=True), trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1", "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.56.0", min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}), {"1b": "bigscience/bloomz-1b1"}),
@ -208,7 +208,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"), "GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501
min_transformers_version="4.55.3"),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True), trust_remote_code=True),
@ -228,7 +229,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
min_transformers_version="4.56.0", min_transformers_version="4.55.3",
extras={ extras={
"tiny": "ai21labs/Jamba-tiny-dev", "tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501 "random": "ai21labs/Jamba-tiny-random", # noqa: E501
@ -244,7 +245,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 "Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
is_available_online=False), is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"), "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
extras={
"random": "yujiepan/mamba2-codestral-v0.1-tiny-random", # noqa: E501
}),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True), trust_remote_code=True),