mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 03:45:31 +08:00
[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
9ca8cb38fd
commit
7eb6cb6c18
@ -39,7 +39,7 @@ docker run \
|
|||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||||
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
|
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
|
||||||
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
||||||
cd tests
|
cd tests
|
||||||
pytest -v -s v1/core
|
pytest -v -s v1/core
|
||||||
pytest -v -s v1/engine
|
pytest -v -s v1/engine
|
||||||
|
|||||||
@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs(
|
|||||||
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models(
|
def test_models(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
hf_runner,
|
hf_runner,
|
||||||
model: str,
|
model: str,
|
||||||
backend: str,
|
backend: str,
|
||||||
@ -77,48 +76,46 @@ def test_models(
|
|||||||
model_executor: str,
|
model_executor: str,
|
||||||
enable_prompt_embeds: bool,
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
with monkeypatch.context() as m:
|
# 5042 tokens for gemma2
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
# gemma2 has alternating sliding window size of 4096
|
||||||
|
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||||
|
prompt = (
|
||||||
|
"The following numbers of the sequence "
|
||||||
|
+ ", ".join(str(i) for i in range(1024))
|
||||||
|
+ " are:"
|
||||||
|
)
|
||||||
|
example_prompts = [prompt]
|
||||||
|
|
||||||
# 5042 tokens for gemma2
|
with hf_runner(model) as hf_model:
|
||||||
# gemma2 has alternating sliding window size of 4096
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
if enable_prompt_embeds:
|
||||||
prompt = (
|
with torch.no_grad():
|
||||||
"The following numbers of the sequence "
|
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||||
+ ", ".join(str(i) for i in range(1024))
|
|
||||||
+ " are:"
|
|
||||||
)
|
|
||||||
example_prompts = [prompt]
|
|
||||||
|
|
||||||
with hf_runner(model) as hf_model:
|
with VllmRunner(
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
model,
|
||||||
if enable_prompt_embeds:
|
max_model_len=8192,
|
||||||
with torch.no_grad():
|
enforce_eager=enforce_eager,
|
||||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
async_scheduling=async_scheduling,
|
||||||
|
distributed_executor_backend=model_executor,
|
||||||
|
attention_config={"backend": backend},
|
||||||
|
) as vllm_model:
|
||||||
|
if enable_prompt_embeds:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||||
|
vllm_outputs = _fix_prompt_embed_outputs(
|
||||||
|
vllm_outputs, hf_model, example_prompts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
with VllmRunner(
|
check_outputs_equal(
|
||||||
model,
|
outputs_0_lst=hf_outputs,
|
||||||
max_model_len=8192,
|
outputs_1_lst=vllm_outputs,
|
||||||
enforce_eager=enforce_eager,
|
name_0="hf",
|
||||||
enable_prompt_embeds=enable_prompt_embeds,
|
name_1="vllm",
|
||||||
gpu_memory_utilization=0.7,
|
)
|
||||||
async_scheduling=async_scheduling,
|
|
||||||
distributed_executor_backend=model_executor,
|
|
||||||
) as vllm_model:
|
|
||||||
if enable_prompt_embeds:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
|
||||||
vllm_outputs = _fix_prompt_embed_outputs(
|
|
||||||
vllm_outputs, hf_model, example_prompts
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
||||||
|
|
||||||
check_outputs_equal(
|
|
||||||
outputs_0_lst=hf_outputs,
|
|
||||||
outputs_1_lst=vllm_outputs,
|
|
||||||
name_0="hf",
|
|
||||||
name_1="vllm",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@ -161,12 +158,6 @@ def test_models_distributed(
|
|||||||
): # noqa
|
): # noqa
|
||||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||||
|
|
||||||
if attention_backend:
|
|
||||||
monkeypatch_context.setenv(
|
|
||||||
"VLLM_ATTENTION_BACKEND",
|
|
||||||
attention_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
for k, v in extra_env.items():
|
for k, v in extra_env.items():
|
||||||
monkeypatch_context.setenv(k, v)
|
monkeypatch_context.setenv(k, v)
|
||||||
|
|
||||||
@ -178,6 +169,7 @@ def test_models_distributed(
|
|||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method
|
# will hurt multiprocessing backend with fork method
|
||||||
# (the default method).
|
# (the default method).
|
||||||
|
attention_config = {"backend": attention_backend} if attention_backend else None
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -185,6 +177,7 @@ def test_models_distributed(
|
|||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
enable_prompt_embeds=enable_prompt_embeds,
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
|
attention_config=attention_config,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
|||||||
@ -208,7 +208,8 @@ def test_attn_quant(
|
|||||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||||
# Force spawn as it is more general.
|
# Force spawn as it is more general.
|
||||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
|
||||||
|
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||||
|
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
# Testing properties
|
# Testing properties
|
||||||
@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
|||||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||||
# Force spawn as it is more general.
|
# Force spawn as it is more general.
|
||||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
|
||||||
|
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||||
|
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
# Testing properties
|
# Testing properties
|
||||||
@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp(
|
|||||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||||
# Force spawn as it is more general.
|
# Force spawn as it is more general.
|
||||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
|
||||||
|
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||||
|
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
# Testing properties
|
# Testing properties
|
||||||
|
|||||||
@ -89,7 +89,6 @@ class TestSetting:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_compile_correctness(
|
def test_compile_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
test_setting: TestSetting,
|
test_setting: TestSetting,
|
||||||
):
|
):
|
||||||
# this test is run under multiple suits, with different GPUs.
|
# this test is run under multiple suits, with different GPUs.
|
||||||
@ -107,49 +106,48 @@ def test_compile_correctness(
|
|||||||
f"{cuda_device_count_stateless()}"
|
f"{cuda_device_count_stateless()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
final_args = [
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
*model_args,
|
||||||
final_args = [
|
"-pp",
|
||||||
*model_args,
|
str(pp_size),
|
||||||
"-pp",
|
"-tp",
|
||||||
str(pp_size),
|
str(tp_size),
|
||||||
"-tp",
|
"-cc.cudagraph_mode=none",
|
||||||
str(tp_size),
|
f"--attention-backend={attn_backend}",
|
||||||
"-cc.cudagraph_mode=none",
|
]
|
||||||
]
|
|
||||||
|
|
||||||
all_args: list[list[str]] = []
|
all_args: list[list[str]] = []
|
||||||
all_envs: list[dict[str, str] | None] = []
|
all_envs: list[dict[str, str] | None] = []
|
||||||
|
|
||||||
for comp_mode in [
|
for comp_mode in [
|
||||||
CompilationMode.STOCK_TORCH_COMPILE,
|
CompilationMode.STOCK_TORCH_COMPILE,
|
||||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||||
CompilationMode.VLLM_COMPILE,
|
CompilationMode.VLLM_COMPILE,
|
||||||
]:
|
]:
|
||||||
for mode in [CompilationMode.NONE, comp_mode]:
|
for mode in [CompilationMode.NONE, comp_mode]:
|
||||||
all_args.append(
|
all_args.append(
|
||||||
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
|
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
|
||||||
)
|
|
||||||
|
|
||||||
# inductor will change the output, so we only compare if the output
|
|
||||||
# is close, not exactly the same.
|
|
||||||
compare_all_settings(
|
|
||||||
model,
|
|
||||||
all_args,
|
|
||||||
all_envs,
|
|
||||||
method=method if method != "generate" else "generate_close",
|
|
||||||
)
|
)
|
||||||
all_envs.clear()
|
|
||||||
all_args.clear()
|
|
||||||
|
|
||||||
for mode in [
|
# inductor will change the output, so we only compare if the output
|
||||||
CompilationMode.NONE,
|
# is close, not exactly the same.
|
||||||
CompilationMode.STOCK_TORCH_COMPILE,
|
compare_all_settings(
|
||||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
model,
|
||||||
CompilationMode.VLLM_COMPILE,
|
all_args,
|
||||||
]:
|
all_envs,
|
||||||
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
|
method=method if method != "generate" else "generate_close",
|
||||||
all_envs.append({})
|
)
|
||||||
all_envs.append({})
|
all_envs.clear()
|
||||||
|
all_args.clear()
|
||||||
|
|
||||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
for mode in [
|
||||||
|
CompilationMode.NONE,
|
||||||
|
CompilationMode.STOCK_TORCH_COMPILE,
|
||||||
|
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||||
|
CompilationMode.VLLM_COMPILE,
|
||||||
|
]:
|
||||||
|
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
|
||||||
|
all_envs.append({})
|
||||||
|
all_envs.append({})
|
||||||
|
|
||||||
|
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||||
|
|||||||
@ -74,7 +74,6 @@ def llm_pair(request):
|
|||||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||||
# when per-request generators are not used in V1.
|
# when per-request generators are not used in V1.
|
||||||
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||||
**backend_config.env_vars,
|
|
||||||
}
|
}
|
||||||
with temporary_environ(env_vars):
|
with temporary_environ(env_vars):
|
||||||
full = LLM(
|
full = LLM(
|
||||||
@ -170,16 +169,10 @@ class TestFullCUDAGraph:
|
|||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
def test_full_cudagraph_with_invalid_backend():
|
def test_full_cudagraph_with_invalid_backend():
|
||||||
with (
|
# Flex_Attention is not supported with full cuda graph
|
||||||
temporary_environ(
|
with pytest.raises(RuntimeError):
|
||||||
{
|
|
||||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
|
|
||||||
# Flex_Attention is not supported with full cuda graph
|
|
||||||
}
|
|
||||||
),
|
|
||||||
pytest.raises(RuntimeError),
|
|
||||||
):
|
|
||||||
LLM(
|
LLM(
|
||||||
model="Qwen/Qwen2-1.5B-Instruct",
|
model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||||
|
attention_config={"backend": "FLEX_ATTENTION"},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -197,20 +197,19 @@ def test_custom_compile_config(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_fp8_kv_scale_compile(
|
def test_fp8_kv_scale_compile(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
compilation_mode: int,
|
compilation_mode: int,
|
||||||
model: str,
|
model: str,
|
||||||
backend: AttentionBackendEnum | None,
|
backend: AttentionBackendEnum | None,
|
||||||
):
|
):
|
||||||
if backend:
|
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
|
||||||
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"quantization": "fp8",
|
"quantization": "fp8",
|
||||||
"kv_cache_dtype": "fp8_e4m3",
|
"kv_cache_dtype": "fp8_e4m3",
|
||||||
"calculate_kv_scales": True,
|
"calculate_kv_scales": True,
|
||||||
"max_model_len": 512,
|
"max_model_len": 512,
|
||||||
}
|
}
|
||||||
|
if backend:
|
||||||
|
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||||
|
|
||||||
run_model(compilation_mode, model, **model_kwargs)
|
run_model(compilation_mode, model, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -219,14 +219,12 @@ def _test_cp_gsm8k(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
server_env = {}
|
|
||||||
if attn_backend:
|
if attn_backend:
|
||||||
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
|
server_args.append(f"--attention-backend={attn_backend}")
|
||||||
|
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
model_id,
|
model_id,
|
||||||
server_args,
|
server_args,
|
||||||
env_dict=server_env,
|
|
||||||
max_wait_seconds=720,
|
max_wait_seconds=720,
|
||||||
) as remote_server:
|
) as remote_server:
|
||||||
host = f"http://{remote_server.host}"
|
host = f"http://{remote_server.host}"
|
||||||
|
|||||||
@ -20,23 +20,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
|
|||||||
)
|
)
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_pp_cudagraph(
|
def test_pp_cudagraph(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
PP_SIZE: int,
|
PP_SIZE: int,
|
||||||
MODEL_NAME: str,
|
MODEL_NAME: str,
|
||||||
ATTN_BACKEND: LiteralString,
|
ATTN_BACKEND: LiteralString,
|
||||||
):
|
):
|
||||||
with monkeypatch.context() as m:
|
cudagraph_args = [
|
||||||
cudagraph_args = [
|
# use half precision for speed and memory savings in CI environment
|
||||||
# use half precision for speed and memory savings in CI environment
|
"--dtype",
|
||||||
"--dtype",
|
"float16",
|
||||||
"float16",
|
"--pipeline-parallel-size",
|
||||||
"--pipeline-parallel-size",
|
str(PP_SIZE),
|
||||||
str(PP_SIZE),
|
"--distributed-executor-backend",
|
||||||
"--distributed-executor-backend",
|
"mp",
|
||||||
"mp",
|
f"--attention-backend={ATTN_BACKEND}",
|
||||||
]
|
]
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND)
|
|
||||||
|
|
||||||
eager_args = cudagraph_args + ["--enforce-eager"]
|
eager_args = cudagraph_args + ["--enforce-eager"]
|
||||||
|
|
||||||
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
|
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from typing import Annotated, Literal
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import CompilationConfig, config
|
from vllm.config import AttentionConfig, CompilationConfig, config
|
||||||
from vllm.engine.arg_utils import (
|
from vllm.engine.arg_utils import (
|
||||||
EngineArgs,
|
EngineArgs,
|
||||||
contains_type,
|
contains_type,
|
||||||
@ -298,6 +298,139 @@ def test_compilation_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_attention_config():
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
|
||||||
|
# default value
|
||||||
|
args = parser.parse_args([])
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
assert engine_args.attention_config == AttentionConfig()
|
||||||
|
|
||||||
|
# set backend via dot notation
|
||||||
|
args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"])
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
assert engine_args.attention_config.backend is not None
|
||||||
|
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
|
||||||
|
|
||||||
|
# set backend via --attention-backend shorthand
|
||||||
|
args = parser.parse_args(["--attention-backend", "FLASHINFER"])
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
assert engine_args.attention_backend is not None
|
||||||
|
assert engine_args.attention_backend == "FLASHINFER"
|
||||||
|
|
||||||
|
# set all fields via dot notation
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"--attention-config.backend",
|
||||||
|
"FLASH_ATTN",
|
||||||
|
"--attention-config.flash_attn_version",
|
||||||
|
"3",
|
||||||
|
"--attention-config.use_prefill_decode_attention",
|
||||||
|
"true",
|
||||||
|
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
|
||||||
|
"16",
|
||||||
|
"--attention-config.use_cudnn_prefill",
|
||||||
|
"true",
|
||||||
|
"--attention-config.use_trtllm_ragged_deepseek_prefill",
|
||||||
|
"true",
|
||||||
|
"--attention-config.use_trtllm_attention",
|
||||||
|
"true",
|
||||||
|
"--attention-config.disable_flashinfer_prefill",
|
||||||
|
"true",
|
||||||
|
"--attention-config.disable_flashinfer_q_quantization",
|
||||||
|
"true",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
assert engine_args.attention_config.backend is not None
|
||||||
|
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
|
||||||
|
assert engine_args.attention_config.flash_attn_version == 3
|
||||||
|
assert engine_args.attention_config.use_prefill_decode_attention is True
|
||||||
|
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
|
||||||
|
assert engine_args.attention_config.use_cudnn_prefill is True
|
||||||
|
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
|
||||||
|
assert engine_args.attention_config.use_trtllm_attention is True
|
||||||
|
assert engine_args.attention_config.disable_flashinfer_prefill is True
|
||||||
|
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
|
||||||
|
|
||||||
|
# set to string form of a dict with all fields
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"--attention-config="
|
||||||
|
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
|
||||||
|
'"use_prefill_decode_attention": false, '
|
||||||
|
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
|
||||||
|
'"use_cudnn_prefill": false, '
|
||||||
|
'"use_trtllm_ragged_deepseek_prefill": false, '
|
||||||
|
'"use_trtllm_attention": false, '
|
||||||
|
'"disable_flashinfer_prefill": false, '
|
||||||
|
'"disable_flashinfer_q_quantization": false}',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
assert engine_args.attention_config.backend is not None
|
||||||
|
assert engine_args.attention_config.backend.name == "FLASHINFER"
|
||||||
|
assert engine_args.attention_config.flash_attn_version == 2
|
||||||
|
assert engine_args.attention_config.use_prefill_decode_attention is False
|
||||||
|
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
|
||||||
|
assert engine_args.attention_config.use_cudnn_prefill is False
|
||||||
|
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
|
||||||
|
assert engine_args.attention_config.use_trtllm_attention is False
|
||||||
|
assert engine_args.attention_config.disable_flashinfer_prefill is False
|
||||||
|
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
|
||||||
|
|
||||||
|
# test --attention-backend flows into VllmConfig.attention_config
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"--model",
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"--attention-backend",
|
||||||
|
"FLASH_ATTN",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
vllm_config = engine_args.create_engine_config()
|
||||||
|
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
|
|
||||||
|
# test --attention-config.backend flows into VllmConfig.attention_config
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"--model",
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"--attention-config.backend",
|
||||||
|
"FLASHINFER",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
vllm_config = engine_args.create_engine_config()
|
||||||
|
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER
|
||||||
|
|
||||||
|
# test --attention-backend and --attention-config.backend are mutually exclusive
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"--model",
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"--attention-backend",
|
||||||
|
"FLASH_ATTN",
|
||||||
|
"--attention-config.backend",
|
||||||
|
"FLASHINFER",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert args is not None
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
with pytest.raises(ValueError, match="mutually exclusive"):
|
||||||
|
engine_args.create_engine_config()
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_default():
|
def test_prefix_cache_default():
|
||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|||||||
@ -76,15 +76,10 @@ def default_server_args(with_tool_parser: bool):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gptoss_server(
|
def gptoss_server(default_server_args: list[str]):
|
||||||
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
|
server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
|
||||||
):
|
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
|
||||||
with monkeypatch_module.context() as m:
|
yield remote_server
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
|
||||||
with RemoteOpenAIServer(
|
|
||||||
GPT_OSS_MODEL_NAME, default_server_args
|
|
||||||
) as remote_server:
|
|
||||||
yield remote_server
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
|
|||||||
@ -6,7 +6,9 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||||
|
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.cpu import CpuPlatform
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
from vllm.platforms.cuda import CudaPlatform
|
from vllm.platforms.cuda import CudaPlatform
|
||||||
@ -73,18 +75,18 @@ def generate_params():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
|
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
|
||||||
def test_env(
|
def test_backend_selection(
|
||||||
device: str,
|
device: str,
|
||||||
name: str,
|
name: str,
|
||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
):
|
):
|
||||||
"""Test attention backend selection with valid device-backend pairs."""
|
"""Test attention backend selection with valid device-backend pairs."""
|
||||||
with monkeypatch.context() as m:
|
# Create AttentionConfig with the specified backend
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", name)
|
attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
vllm_config = VllmConfig(attention_config=attention_config)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||||
@ -217,27 +219,32 @@ def test_env(
|
|||||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||||
def test_fp32_fallback(device: str):
|
def test_fp32_fallback(device: str):
|
||||||
"""Test attention backend selection with fp32."""
|
"""Test attention backend selection with fp32."""
|
||||||
if device == "cpu":
|
# Use default config (no backend specified)
|
||||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
vllm_config = VllmConfig()
|
||||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
|
||||||
assert backend.get_name() == "CPU_ATTN"
|
|
||||||
|
|
||||||
elif device == "cuda":
|
with set_current_vllm_config(vllm_config):
|
||||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
if device == "cpu":
|
||||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||||
assert backend.get_name() == "FLEX_ATTENTION"
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||||
|
assert backend.get_name() == "CPU_ATTN"
|
||||||
|
|
||||||
|
elif device == "cuda":
|
||||||
|
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||||
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||||
|
assert backend.get_name() == "FLEX_ATTENTION"
|
||||||
|
|
||||||
|
|
||||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Test FlashAttn validation."""
|
"""Test FlashAttn validation."""
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Skipping as current backend selector does not "
|
"Skipping as current backend selector does not "
|
||||||
"handle fallbacks when a backend is set via env var."
|
"handle fallbacks when a backend is explicitly set."
|
||||||
)
|
)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
|
vllm_config = VllmConfig(attention_config=attention_config)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
# Unsupported CUDA arch
|
# Unsupported CUDA arch
|
||||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||||
@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert backend.get_name() != "FLASH_ATTN"
|
assert backend.get_name() != "FLASH_ATTN"
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
def test_invalid_backend():
|
||||||
"""Test that invalid attention backend names raise ValueError."""
|
"""Test that invalid attention backend names raise ValueError."""
|
||||||
with (
|
with (
|
||||||
monkeypatch.context() as m,
|
pytest.raises(ValueError),
|
||||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
|
||||||
):
|
):
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
|
# Invalid backend name should raise ValueError when creating enum
|
||||||
|
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
|
||||||
# Should raise ValueError for invalid backend
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
get_attn_backend(32, torch.float16, None, 16)
|
|
||||||
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
|
||||||
|
|||||||
@ -4,7 +4,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||||
|
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
|
|
||||||
@ -16,40 +18,56 @@ def clear_cache():
|
|||||||
|
|
||||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
# Set the current platform to ROCm using monkeypatch
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
|
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||||
|
|
||||||
# Set the current platform to ROCm using monkeypatch
|
# Test standard ROCm attention
|
||||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
|
||||||
|
vllm_config = VllmConfig(attention_config=attention_config)
|
||||||
|
|
||||||
# Test standard ROCm attention
|
with set_current_vllm_config(vllm_config):
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||||
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
|
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
|
||||||
|
|
||||||
# MLA test for deepseek related
|
# MLA test for deepseek related
|
||||||
|
# Change the attention backend to triton MLA
|
||||||
|
attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA)
|
||||||
|
vllm_config = VllmConfig(attention_config=attention_config)
|
||||||
|
|
||||||
# change the attention backend to triton MLA
|
with set_current_vllm_config(vllm_config):
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
|
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||||
assert backend.get_name() == "TRITON_MLA"
|
assert backend.get_name() == "TRITON_MLA"
|
||||||
|
|
||||||
# If attention backend is None
|
# If attention backend is None
|
||||||
# If use_mla is true
|
# If use_mla is true
|
||||||
# The selected backend is triton MLA
|
# The selected backend is triton MLA
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
attention_config = AttentionConfig(backend=None)
|
||||||
|
vllm_config = VllmConfig(attention_config=attention_config)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||||
assert backend.get_name() == "TRITON_MLA"
|
assert backend.get_name() == "TRITON_MLA"
|
||||||
|
|
||||||
# change the attention backend to AITER MLA
|
# Change the attention backend to AITER MLA
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
|
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA)
|
||||||
|
vllm_config = VllmConfig(attention_config=attention_config)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||||
|
|
||||||
# If attention backend is None
|
# If attention backend is None
|
||||||
# If use_mla is true
|
# If use_mla is true
|
||||||
# If VLLM_ROCM_USE_AITER is enabled
|
# If VLLM_ROCM_USE_AITER is enabled
|
||||||
# The selected backend is ROCM_AITER_MLA
|
# The selected backend is ROCM_AITER_MLA
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
|
||||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
attention_config = AttentionConfig(backend=None)
|
||||||
|
vllm_config = VllmConfig(attention_config=attention_config)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
backend = get_attn_backend(
|
||||||
|
576, torch.bfloat16, "auto", 1, False, use_mla=True
|
||||||
|
)
|
||||||
|
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def set_seed(seed):
|
|||||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||||
reason="CUDA not available or PyTorch version < 2.7",
|
reason="CUDA not available or PyTorch version < 2.7",
|
||||||
)
|
)
|
||||||
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
def test_flex_attention_vs_default_backend(vllm_runner):
|
||||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||||
|
|
||||||
This test compares the outputs from the FlexAttention backend with
|
This test compares the outputs from the FlexAttention backend with
|
||||||
@ -54,35 +54,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Run with flex attention
|
# Run with flex attention
|
||||||
with monkeypatch.context() as m:
|
set_seed(seed)
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
with vllm_runner(
|
||||||
|
model_name,
|
||||||
set_seed(seed)
|
runner="generate",
|
||||||
with vllm_runner(
|
tensor_parallel_size=1,
|
||||||
model_name,
|
num_gpu_blocks_override=128,
|
||||||
runner="generate",
|
enforce_eager=True,
|
||||||
tensor_parallel_size=1,
|
attention_config={"backend": "FLEX_ATTENTION"},
|
||||||
num_gpu_blocks_override=128,
|
) as llm_flex:
|
||||||
enforce_eager=True,
|
output_flex = llm_flex.generate_greedy_logprobs(
|
||||||
) as llm_flex:
|
prompts, max_tokens, num_logprobs
|
||||||
output_flex = llm_flex.generate_greedy_logprobs(
|
)
|
||||||
prompts, max_tokens, num_logprobs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run with default backend
|
# Run with default backend
|
||||||
with monkeypatch.context() as m:
|
set_seed(seed)
|
||||||
set_seed(seed)
|
with vllm_runner(
|
||||||
with vllm_runner(
|
model_name,
|
||||||
model_name,
|
runner="generate",
|
||||||
runner="generate",
|
tensor_parallel_size=1,
|
||||||
tensor_parallel_size=1,
|
num_gpu_blocks_override=128,
|
||||||
num_gpu_blocks_override=128,
|
enforce_eager=True,
|
||||||
enforce_eager=True,
|
gpu_memory_utilization=0.85,
|
||||||
gpu_memory_utilization=0.85,
|
) as llm_default:
|
||||||
) as llm_default:
|
output_default = llm_default.generate_greedy_logprobs(
|
||||||
output_default = llm_default.generate_greedy_logprobs(
|
prompts, max_tokens, num_logprobs
|
||||||
prompts, max_tokens, num_logprobs
|
)
|
||||||
)
|
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=output_flex,
|
outputs_0_lst=output_flex,
|
||||||
@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
|||||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||||
reason="CUDA not available or PyTorch version < 2.7",
|
reason="CUDA not available or PyTorch version < 2.7",
|
||||||
)
|
)
|
||||||
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
def test_encoder_flex_attention_vs_default_backend(vllm_runner):
|
||||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||||
|
|
||||||
This test compares the outputs from the FlexAttention backend with
|
This test compares the outputs from the FlexAttention backend with
|
||||||
@ -110,30 +107,26 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Run with flex attention
|
# Run with flex attention
|
||||||
with monkeypatch.context() as m:
|
with vllm_runner(
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
model_name,
|
||||||
with vllm_runner(
|
runner="pooling",
|
||||||
model_name,
|
dtype=torch.bfloat16,
|
||||||
runner="pooling",
|
tensor_parallel_size=1,
|
||||||
dtype=torch.bfloat16,
|
max_model_len=100,
|
||||||
tensor_parallel_size=1,
|
enforce_eager=True,
|
||||||
max_model_len=100,
|
attention_config={"backend": "FLEX_ATTENTION"},
|
||||||
enforce_eager=True,
|
) as llm_flex:
|
||||||
) as llm_flex:
|
flex_outputs = llm_flex.embed(prompts)
|
||||||
flex_outputs = llm_flex.embed(prompts)
|
|
||||||
|
|
||||||
# Run with default backend
|
# Run with default backend
|
||||||
with (
|
with vllm_runner(
|
||||||
monkeypatch.context() as m,
|
model_name,
|
||||||
vllm_runner(
|
runner="pooling",
|
||||||
model_name,
|
dtype=torch.bfloat16,
|
||||||
runner="pooling",
|
tensor_parallel_size=1,
|
||||||
dtype=torch.bfloat16,
|
max_model_len=100,
|
||||||
tensor_parallel_size=1,
|
enforce_eager=True,
|
||||||
max_model_len=100,
|
) as llm_default:
|
||||||
enforce_eager=True,
|
|
||||||
) as llm_default,
|
|
||||||
):
|
|
||||||
default_outputs = llm_default.embed(prompts)
|
default_outputs = llm_default.embed(prompts)
|
||||||
|
|
||||||
check_embeddings_close(
|
check_embeddings_close(
|
||||||
|
|||||||
@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME
|
|||||||
models = [MODEL_NAME]
|
models = [MODEL_NAME]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture
|
||||||
def set_attention_backend_for_rocm(monkeypatch):
|
def granite_speech_attention_config():
|
||||||
|
"""Return attention config for Granite Speech tests on ROCm."""
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
return {"backend": "TRITON_ATTN"}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def run_test(
|
def run_test(
|
||||||
@ -53,6 +55,7 @@ def run_test(
|
|||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
distributed_executor_backend: str | None = None,
|
distributed_executor_backend: str | None = None,
|
||||||
|
attention_config: dict | None = None,
|
||||||
):
|
):
|
||||||
"""Inference result should be the same between hf and vllm.
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
@ -80,6 +83,7 @@ def run_test(
|
|||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_lora_rank=64,
|
max_lora_rank=64,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
|
attention_config=attention_config,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
lora_request = LoRARequest("audio", 1, audio_lora_path)
|
lora_request = LoRARequest("audio", 1, audio_lora_path)
|
||||||
vllm_outputs_per_case = [
|
vllm_outputs_per_case = [
|
||||||
@ -131,6 +135,7 @@ def test_models(
|
|||||||
vllm_runner,
|
vllm_runner,
|
||||||
model: str,
|
model: str,
|
||||||
audio_assets: AudioTestAssets,
|
audio_assets: AudioTestAssets,
|
||||||
|
granite_speech_attention_config,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
@ -157,4 +162,5 @@ def test_models(
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
|
attention_config=granite_speech_attention_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,23 +2,17 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Pytest configuration for vLLM pooling tests."""
|
"""Pytest configuration for vLLM pooling tests."""
|
||||||
|
|
||||||
import os
|
import pytest
|
||||||
import warnings
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(config, items):
|
@pytest.fixture
|
||||||
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm."""
|
def siglip_attention_config():
|
||||||
if not current_platform.is_rocm():
|
"""Return attention config for SigLIP tests on ROCm.
|
||||||
return
|
|
||||||
|
|
||||||
siglip_tests = [item for item in items if "test_siglip" in item.nodeid]
|
On ROCm, SigLIP tests require FLEX_ATTENTION backend.
|
||||||
|
"""
|
||||||
if siglip_tests:
|
if current_platform.is_rocm():
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
return {"backend": "FLEX_ATTENTION"}
|
||||||
warnings.warn(
|
return None
|
||||||
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
|
|
||||||
UserWarning,
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -38,6 +38,7 @@ def _run_test(
|
|||||||
*,
|
*,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
tokenization_kwargs: dict[str, Any] | None = None,
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
|
attention_config: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if tokenization_kwargs is None:
|
if tokenization_kwargs is None:
|
||||||
tokenization_kwargs = {}
|
tokenization_kwargs = {}
|
||||||
@ -49,6 +50,7 @@ def _run_test(
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
max_model_len=64,
|
max_model_len=64,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
|
attention_config=attention_config,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_outputs = vllm_model.embed(
|
vllm_outputs = vllm_model.embed(
|
||||||
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
|
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
|
||||||
@ -90,6 +92,7 @@ def test_models_text(
|
|||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
image_assets,
|
image_assets,
|
||||||
|
siglip_attention_config,
|
||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -108,6 +111,7 @@ def test_models_text(
|
|||||||
"padding": "max_length",
|
"padding": "max_length",
|
||||||
"max_length": 64,
|
"max_length": 64,
|
||||||
}, # siglip2 was trained with this padding setting.
|
}, # siglip2 was trained with this padding setting.
|
||||||
|
attention_config=siglip_attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -117,6 +121,7 @@ def test_models_image(
|
|||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
image_assets,
|
image_assets,
|
||||||
|
siglip_attention_config,
|
||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -133,6 +138,7 @@ def test_models_image(
|
|||||||
input_images,
|
input_images,
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
attention_config=siglip_attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -141,6 +147,7 @@ def test_models_image(
|
|||||||
def test_models_text_image_no_crash(
|
def test_models_text_image_no_crash(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
image_assets,
|
image_assets,
|
||||||
|
siglip_attention_config,
|
||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -154,6 +161,7 @@ def test_models_text_image_no_crash(
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
max_model_len=64,
|
max_model_len=64,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
|
attention_config=siglip_attention_config,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
with pytest.raises(ValueError, match="not both"):
|
with pytest.raises(ValueError, match="not both"):
|
||||||
vllm_model.embed(texts, images=images)
|
vllm_model.embed(texts, images=images)
|
||||||
|
|||||||
@ -75,7 +75,6 @@ def test_models(
|
|||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
|
||||||
|
|
||||||
MAX_MODEL_LEN = 1024
|
MAX_MODEL_LEN = 1024
|
||||||
NUM_LOG_PROBS = 8
|
NUM_LOG_PROBS = 8
|
||||||
@ -86,6 +85,7 @@ def test_models(
|
|||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
kv_cache_dtype="auto",
|
kv_cache_dtype="auto",
|
||||||
|
attention_config={"backend": backend},
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, NUM_LOG_PROBS
|
example_prompts, max_tokens, NUM_LOG_PROBS
|
||||||
@ -97,6 +97,7 @@ def test_models(
|
|||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
attention_config={"backend": backend},
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, NUM_LOG_PROBS
|
example_prompts, max_tokens, NUM_LOG_PROBS
|
||||||
|
|||||||
@ -108,11 +108,12 @@ def can_initialize(
|
|||||||
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
|
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
|
||||||
monkeypatch.context() as m,
|
monkeypatch.context() as m,
|
||||||
):
|
):
|
||||||
if model_arch == "GptOssForCausalLM":
|
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
||||||
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
||||||
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
# L4 supports FA3.
|
||||||
# L4 supports FA3.
|
attention_config = (
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
{"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None
|
||||||
|
)
|
||||||
if model_arch == "WhisperForConditionalGeneration":
|
if model_arch == "WhisperForConditionalGeneration":
|
||||||
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
|
|
||||||
@ -143,6 +144,7 @@ def can_initialize(
|
|||||||
else "vllm",
|
else "vllm",
|
||||||
hf_overrides=hf_overrides_fn,
|
hf_overrides=hf_overrides_fn,
|
||||||
max_num_seqs=model_info.max_num_seqs,
|
max_num_seqs=model_info.max_num_seqs,
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -94,26 +94,20 @@ def mock_on_gfx9():
|
|||||||
None,
|
None,
|
||||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||||
),
|
),
|
||||||
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
|
# Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
||||||
(
|
|
||||||
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
|
||||||
None,
|
|
||||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
|
||||||
),
|
|
||||||
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
|
||||||
(
|
(
|
||||||
{"VLLM_ROCM_USE_AITER": "1"},
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
"TRITON_ATTN",
|
"TRITON_ATTN",
|
||||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||||
),
|
),
|
||||||
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
# Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
||||||
# (explicitly disabled)
|
# (explicitly disabled)
|
||||||
(
|
(
|
||||||
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
|
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
|
||||||
None,
|
None,
|
||||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||||
),
|
),
|
||||||
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
# Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
||||||
(
|
(
|
||||||
{"VLLM_ROCM_USE_AITER": "1"},
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
"ROCM_ATTN",
|
"ROCM_ATTN",
|
||||||
|
|||||||
@ -249,8 +249,8 @@ def create_dummy_kv_cache(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BackendConfig:
|
class BackendConfig:
|
||||||
name: str
|
name: str
|
||||||
env_vars: dict
|
attention_config: dict
|
||||||
comp_config: dict # compilation config
|
comp_config: dict
|
||||||
specific_gpu_arch: tuple | None = None
|
specific_gpu_arch: tuple | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -259,10 +259,10 @@ full_cg_backend_configs = {
|
|||||||
# FA3 on Hopper
|
# FA3 on Hopper
|
||||||
"FA3": BackendConfig(
|
"FA3": BackendConfig(
|
||||||
name="FA3",
|
name="FA3",
|
||||||
env_vars={
|
attention_config={
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
"backend": "FLASH_ATTN",
|
||||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
"flash_attn_version": 3,
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
"flash_attn_max_num_splits_for_cuda_graph": 16,
|
||||||
},
|
},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL",
|
"cudagraph_mode": "FULL",
|
||||||
@ -272,9 +272,7 @@ full_cg_backend_configs = {
|
|||||||
# FlashMLA on Hopper
|
# FlashMLA on Hopper
|
||||||
"FlashMLA": BackendConfig(
|
"FlashMLA": BackendConfig(
|
||||||
name="FlashMLA",
|
name="FlashMLA",
|
||||||
env_vars={
|
attention_config={"backend": "FLASHMLA"},
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
|
||||||
},
|
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
@ -283,9 +281,7 @@ full_cg_backend_configs = {
|
|||||||
# Cutlass MLA on Blackwell
|
# Cutlass MLA on Blackwell
|
||||||
"CutlassMLA": BackendConfig(
|
"CutlassMLA": BackendConfig(
|
||||||
name="CutlassMLA",
|
name="CutlassMLA",
|
||||||
env_vars={
|
attention_config={"backend": "CUTLASS_MLA"},
|
||||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
|
||||||
},
|
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
@ -294,9 +290,7 @@ full_cg_backend_configs = {
|
|||||||
# FlashInfer MLA on Blackwell
|
# FlashInfer MLA on Blackwell
|
||||||
"FlashInferMLA": BackendConfig(
|
"FlashInferMLA": BackendConfig(
|
||||||
name="FlashInferMLA",
|
name="FlashInferMLA",
|
||||||
env_vars={
|
attention_config={"backend": "FLASHINFER_MLA"},
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
|
|
||||||
},
|
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
@ -305,9 +299,9 @@ full_cg_backend_configs = {
|
|||||||
# FlashAttention MLA on Hopper
|
# FlashAttention MLA on Hopper
|
||||||
"FlashAttentionMLA": BackendConfig(
|
"FlashAttentionMLA": BackendConfig(
|
||||||
name="FlashAttentionMLA",
|
name="FlashAttentionMLA",
|
||||||
env_vars={
|
attention_config={
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
"backend": "FLASH_ATTN_MLA",
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
"flash_attn_max_num_splits_for_cuda_graph": 16,
|
||||||
},
|
},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||||
@ -317,10 +311,10 @@ full_cg_backend_configs = {
|
|||||||
# FA2
|
# FA2
|
||||||
"FA2": BackendConfig(
|
"FA2": BackendConfig(
|
||||||
name="FA2",
|
name="FA2",
|
||||||
env_vars={
|
attention_config={
|
||||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
"backend": "FLASH_ATTN",
|
||||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
"flash_attn_version": 2,
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
"flash_attn_max_num_splits_for_cuda_graph": 16,
|
||||||
},
|
},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
@ -329,7 +323,7 @@ full_cg_backend_configs = {
|
|||||||
# Triton Attention
|
# Triton Attention
|
||||||
"TritonAttn": BackendConfig(
|
"TritonAttn": BackendConfig(
|
||||||
name="TritonAttn",
|
name="TritonAttn",
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
attention_config={"backend": "TRITON_ATTN"},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
@ -337,14 +331,17 @@ full_cg_backend_configs = {
|
|||||||
# FlashInfer
|
# FlashInfer
|
||||||
"FlashInfer": BackendConfig(
|
"FlashInfer": BackendConfig(
|
||||||
name="FlashInfer",
|
name="FlashInfer",
|
||||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
attention_config={"backend": "FLASHINFER"},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
"RocmAttn": BackendConfig(
|
"RocmAttn": BackendConfig(
|
||||||
name="RocmAttn",
|
name="RocmAttn",
|
||||||
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
attention_config={
|
||||||
|
"backend": "ROCM_ATTN",
|
||||||
|
"use_prefill_decode_attention": True,
|
||||||
|
},
|
||||||
comp_config={
|
comp_config={
|
||||||
"cudagraph_mode": "FULL",
|
"cudagraph_mode": "FULL",
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import contextlib
|
|
||||||
import os
|
|
||||||
import weakref
|
import weakref
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
|
|
||||||
@ -13,26 +11,6 @@ from vllm import LLM
|
|||||||
from vllm.config import CompilationConfig, CompilationMode
|
from vllm.config import CompilationConfig, CompilationMode
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def temporary_environ(env_vars):
|
|
||||||
"""
|
|
||||||
Temporarily set environment variables and restore them afterward.
|
|
||||||
We have to do this vs monkeypatch because monkeypatch doesn't work
|
|
||||||
with "module" scoped fixtures.
|
|
||||||
"""
|
|
||||||
original_env = {k: os.environ.get(k) for k in env_vars}
|
|
||||||
try:
|
|
||||||
os.environ.update(env_vars)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
for k, v in original_env.items():
|
|
||||||
if v is None:
|
|
||||||
os.environ.pop(k, None)
|
|
||||||
else:
|
|
||||||
os.environ[k] = v
|
|
||||||
|
|
||||||
|
|
||||||
# test attention backend and cudagraph_mode combo
|
# test attention backend and cudagraph_mode combo
|
||||||
# (backend_name, cudagraph_mode, supported)
|
# (backend_name, cudagraph_mode, supported)
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
|||||||
):
|
):
|
||||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||||
|
|
||||||
env_vars = backend_configs[backend_name].env_vars
|
attention_config = backend_config.attention_config
|
||||||
|
|
||||||
with temporary_environ(env_vars), ExitStack() as stack:
|
with ExitStack() as stack:
|
||||||
if not supported:
|
if not supported:
|
||||||
stack.enter_context(pytest.raises(Exception))
|
stack.enter_context(pytest.raises(Exception))
|
||||||
|
|
||||||
@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
gpu_memory_utilization=0.45,
|
gpu_memory_utilization=0.45,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
|
attention_config=attention_config,
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
|
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
|
||||||
),
|
),
|
||||||
@ -122,9 +101,10 @@ combo_cases_2 = [
|
|||||||
def test_cudagraph_compilation_combo(
|
def test_cudagraph_compilation_combo(
|
||||||
backend_name, cudagraph_mode, compilation_mode, supported
|
backend_name, cudagraph_mode, compilation_mode, supported
|
||||||
):
|
):
|
||||||
env_vars = backend_configs[backend_name].env_vars
|
backend_config = backend_configs[backend_name]
|
||||||
|
attention_config = backend_config.attention_config
|
||||||
|
|
||||||
with temporary_environ(env_vars), ExitStack() as stack:
|
with ExitStack() as stack:
|
||||||
if not supported:
|
if not supported:
|
||||||
stack.enter_context(pytest.raises(Exception))
|
stack.enter_context(pytest.raises(Exception))
|
||||||
|
|
||||||
@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
gpu_memory_utilization=0.45,
|
gpu_memory_utilization=0.45,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
|
attention_config=attention_config,
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=compilation_mode, cudagraph_mode=cudagraph_mode
|
mode=compilation_mode, cudagraph_mode=cudagraph_mode
|
||||||
),
|
),
|
||||||
|
|||||||
@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
|
|||||||
BACKENDS,
|
BACKENDS,
|
||||||
)
|
)
|
||||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||||
backend, monkeypatch: pytest.MonkeyPatch
|
backend,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||||
@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
|||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
attention_config = {"backend": backend}
|
||||||
# Allow overrides from environment (useful for CI tuning)
|
# Allow overrides from environment (useful for CI tuning)
|
||||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||||
model = resolve_model_name(backend)
|
model = resolve_model_name(backend)
|
||||||
@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
|||||||
max_num_seqs=max_batch_size,
|
max_num_seqs=max_batch_size,
|
||||||
gpu_memory_utilization=gpu_mem_util,
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Baseline generation for the needle prompt alone.
|
# Baseline generation for the needle prompt alone.
|
||||||
@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
|||||||
max_num_seqs=max_batch_size,
|
max_num_seqs=max_batch_size,
|
||||||
gpu_memory_utilization=gpu_mem_util,
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
mismatches = 0
|
mismatches = 0
|
||||||
@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
|||||||
BACKENDS,
|
BACKENDS,
|
||||||
)
|
)
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||||
backend, monkeypatch: pytest.MonkeyPatch
|
backend,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
|
||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
model_name = resolve_model_name(backend)
|
model_name = resolve_model_name(backend)
|
||||||
@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
|||||||
dtype="bfloat16", # not everything is supported
|
dtype="bfloat16", # not everything is supported
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.9,
|
||||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
|
attention_config={"backend": backend},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use more realistic prompts for better token generation
|
# Use more realistic prompts for better token generation
|
||||||
@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
|||||||
"backend",
|
"backend",
|
||||||
BACKENDS,
|
BACKENDS,
|
||||||
)
|
)
|
||||||
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
def test_simple_generation(backend):
|
||||||
"""
|
"""
|
||||||
Simple test that runs the model with a basic prompt and prints the output.
|
Simple test that runs the model with a basic prompt and prints the output.
|
||||||
Useful for quick smoke testing and debugging.
|
Useful for quick smoke testing and debugging.
|
||||||
"""
|
"""
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
|
||||||
model = resolve_model_name(backend)
|
model = resolve_model_name(backend)
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
|||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False,
|
||||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
|
attention_config={"backend": backend},
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = "the capital of france is"
|
prompt = "the capital of france is"
|
||||||
@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
|
|||||||
The test will PASS if we detect differences (proving batch invariance matters).
|
The test will PASS if we detect differences (proving batch invariance matters).
|
||||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||||
"""
|
"""
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
|
||||||
|
|
||||||
# CRITICAL: Disable batch invariance for this test
|
# CRITICAL: Disable batch invariance for this test
|
||||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
||||||
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
|
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
|
||||||
@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
|
attention_config={"backend": backend},
|
||||||
)
|
)
|
||||||
|
|
||||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||||
@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
|||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||||
def test_decode_logprobs_match_prefill_logprobs(
|
def test_decode_logprobs_match_prefill_logprobs(
|
||||||
backend, monkeypatch: pytest.MonkeyPatch
|
backend,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Test that verifies decode logprobs match prefill logprobs.
|
Test that verifies decode logprobs match prefill logprobs.
|
||||||
@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
|
|||||||
This ensures that the logprobs from decode are consistent with what
|
This ensures that the logprobs from decode are consistent with what
|
||||||
we would get if we ran prefill on each prefix.
|
we would get if we ran prefill on each prefix.
|
||||||
"""
|
"""
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
|
||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
model_name = resolve_model_name(backend)
|
model_name = resolve_model_name(backend)
|
||||||
@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
|
attention_config={"backend": backend},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use a few test prompts
|
# Use a few test prompts
|
||||||
@ -920,6 +919,7 @@ def LLM_with_max_seqs(
|
|||||||
max_num_seqs: int,
|
max_num_seqs: int,
|
||||||
gpu_memory_utilization: float,
|
gpu_memory_utilization: float,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
|
attention_config: dict | None = None,
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
"""
|
"""
|
||||||
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||||
@ -934,6 +934,7 @@ def LLM_with_max_seqs(
|
|||||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False,
|
||||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
|
attention_config=attention_config,
|
||||||
# Enable for MOE models
|
# Enable for MOE models
|
||||||
# enable_expert_parallel=True,
|
# enable_expert_parallel=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
|
|||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", BACKENDS)
|
@pytest.mark.parametrize("backend", BACKENDS)
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||||
backend: str, monkeypatch: pytest.MonkeyPatch
|
backend: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||||
# Override backend for this test (and the RemoteOpenAIServer child process).
|
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
|
||||||
model_name = resolve_model_name(backend)
|
model_name = resolve_model_name(backend)
|
||||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||||
|
|
||||||
@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
|||||||
server_args: list[str] = [
|
server_args: list[str] = [
|
||||||
"--max-model-len=8192",
|
"--max-model-len=8192",
|
||||||
"--max-num-seqs=32",
|
"--max-num-seqs=32",
|
||||||
|
f"--attention-backend={backend}",
|
||||||
]
|
]
|
||||||
if tp_size:
|
if tp_size:
|
||||||
server_args += ["-tp", tp_size]
|
server_args += ["-tp", tp_size]
|
||||||
|
|||||||
@ -142,16 +142,17 @@ def run_tests(
|
|||||||
"""Test consistency of combos of async scheduling, preemption,
|
"""Test consistency of combos of async scheduling, preemption,
|
||||||
uni/multiproc executor with spec decoding."""
|
uni/multiproc executor with spec decoding."""
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
# Determine attention config based on platform
|
||||||
# avoid precision errors
|
if current_platform.is_rocm():
|
||||||
if current_platform.is_rocm():
|
if is_testing_with_spec_decoding:
|
||||||
if is_testing_with_spec_decoding:
|
# Use TRITON_ATTN for spec decoding test for consistency
|
||||||
# Use TRITON_ATTN for spec decoding test for consistency
|
attention_config = {"backend": "TRITON_ATTN"}
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
|
||||||
else:
|
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
|
|
||||||
else:
|
else:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
attention_config = {"backend": "ROCM_AITER_FA"}
|
||||||
|
else:
|
||||||
|
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
# lock matmul precision to full FP32 (IEEE)
|
# lock matmul precision to full FP32 (IEEE)
|
||||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
|
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
|
||||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||||
@ -174,6 +175,7 @@ def run_tests(
|
|||||||
spec_config,
|
spec_config,
|
||||||
test_prefill_chunking=test_prefill_chunking,
|
test_prefill_chunking=test_prefill_chunking,
|
||||||
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
outputs.append(test_results)
|
outputs.append(test_results)
|
||||||
|
|
||||||
@ -262,6 +264,7 @@ def run_test(
|
|||||||
spec_config: dict[str, Any] | None,
|
spec_config: dict[str, Any] | None,
|
||||||
test_prefill_chunking: bool,
|
test_prefill_chunking: bool,
|
||||||
is_testing_with_spec_decoding: bool = False,
|
is_testing_with_spec_decoding: bool = False,
|
||||||
|
attention_config: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
spec_decoding = spec_config is not None
|
spec_decoding = spec_config is not None
|
||||||
cache_arg: dict[str, Any] = (
|
cache_arg: dict[str, Any] = (
|
||||||
@ -301,6 +304,7 @@ def run_test(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
speculative_config=spec_config,
|
speculative_config=spec_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
|
attention_config=attention_config,
|
||||||
**cache_arg,
|
**cache_arg,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
results = []
|
results = []
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
|
|||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
|
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||||
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
def test_cascade_attention(example_system_message, attn_backend):
|
||||||
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
||||||
|
|
||||||
if attn_backend == "FLASHINFER":
|
if attn_backend == "FLASHINFER":
|
||||||
@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
|||||||
"needs investigation. See issue #25679."
|
"needs investigation. See issue #25679."
|
||||||
)
|
)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
llm = LLM(
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
|
||||||
|
)
|
||||||
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||||
|
|
||||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
|
# No cascade attention.
|
||||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
single_prompt = [example_system_message + prompt]
|
||||||
|
responses = llm.generate(single_prompt, sampling_params)
|
||||||
|
ref_output = responses[0].outputs[0].text
|
||||||
|
|
||||||
# No cascade attention.
|
# (Probably) Use cascade attention.
|
||||||
single_prompt = [example_system_message + prompt]
|
prompts = [example_system_message + prompt] * 64
|
||||||
responses = llm.generate(single_prompt, sampling_params)
|
responses = llm.generate(prompts, sampling_params)
|
||||||
ref_output = responses[0].outputs[0].text
|
for response in responses:
|
||||||
|
assert response.outputs[0].text == ref_output
|
||||||
# (Probably) Use cascade attention.
|
|
||||||
prompts = [example_system_message + prompt] * 64
|
|
||||||
responses = llm.generate(prompts, sampling_params)
|
|
||||||
for response in responses:
|
|
||||||
assert response.outputs[0].text == ref_output
|
|
||||||
|
|||||||
@ -438,25 +438,26 @@ def test_eagle_correctness(
|
|||||||
should be the same when using eagle speculative decoding.
|
should be the same when using eagle speculative decoding.
|
||||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||||
"""
|
"""
|
||||||
|
# Determine attention config
|
||||||
|
# Scout requires default backend selection because vision encoder has
|
||||||
|
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
|
||||||
|
# to Flex Attn
|
||||||
|
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||||
|
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||||
|
attention_config = None # Let it fall back to default
|
||||||
|
else:
|
||||||
|
attention_config = {"backend": attn_backend}
|
||||||
|
|
||||||
|
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
"TRITON_ATTN does not support "
|
||||||
|
"multi-token eagle spec decode on current platform"
|
||||||
|
)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||||
# Scout requires default backend selection
|
|
||||||
# because vision encoder has head_dim 88 being incompatible
|
|
||||||
# with FLASH_ATTN and needs to fall back to Flex Attn
|
|
||||||
|
|
||||||
# pass if not ROCm
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
|
||||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
|
||||||
else:
|
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
|
||||||
|
|
||||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
|
||||||
pytest.skip(
|
|
||||||
"TRITON_ATTN does not support "
|
|
||||||
"multi-token eagle spec decode on current platform"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||||
if "deepseek" in model_setup[1].lower():
|
if "deepseek" in model_setup[1].lower():
|
||||||
@ -471,7 +472,10 @@ def test_eagle_correctness(
|
|||||||
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
|
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
|
||||||
|
|
||||||
ref_llm = LLM(
|
ref_llm = LLM(
|
||||||
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
|
model=model_name,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
@ -492,6 +496,7 @@ def test_eagle_correctness(
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
model_impl=model_impl,
|
model_impl=model_impl,
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
matches = 0
|
matches = 0
|
||||||
|
|||||||
@ -3,21 +3,29 @@ set -xe
|
|||||||
|
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||||
|
ATTENTION_BACKEND="" # Default to empty (use vllm default)
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
case $1 in
|
case $1 in
|
||||||
--kv_buffer_device)
|
--kv_buffer_device)
|
||||||
KV_BUFFER_DEVICE="$2"
|
KV_BUFFER_DEVICE="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
|
--attention-backend)
|
||||||
|
ATTENTION_BACKEND="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown option $1"
|
echo "Unknown option $1"
|
||||||
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
|
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>] [--attention-backend <backend>]"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||||
|
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||||
|
echo "Using attention backend: $ATTENTION_BACKEND"
|
||||||
|
fi
|
||||||
|
|
||||||
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
||||||
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
||||||
@ -148,6 +156,11 @@ run_tests_for_model() {
|
|||||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||||
--kv-transfer-config '$KV_CONFIG'"
|
--kv-transfer-config '$KV_CONFIG'"
|
||||||
|
|
||||||
|
# Add attention backend config if specified
|
||||||
|
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||||
|
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||||
|
fi
|
||||||
|
|
||||||
if [ -n "$model_args" ]; then
|
if [ -n "$model_args" ]; then
|
||||||
FULL_CMD="$BASE_CMD $model_args"
|
FULL_CMD="$BASE_CMD $model_args"
|
||||||
else
|
else
|
||||||
@ -188,7 +201,12 @@ run_tests_for_model() {
|
|||||||
--block-size ${DECODE_BLOCK_SIZE} \
|
--block-size ${DECODE_BLOCK_SIZE} \
|
||||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||||
--kv-transfer-config '$KV_CONFIG'"
|
--kv-transfer-config '$KV_CONFIG'"
|
||||||
|
|
||||||
|
# Add attention backend config if specified
|
||||||
|
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||||
|
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||||
|
fi
|
||||||
|
|
||||||
# DP-EP attention mode
|
# DP-EP attention mode
|
||||||
if [[ -z "$DP_EP" ]]; then
|
if [[ -z "$DP_EP" ]]; then
|
||||||
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
|
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
|
||||||
|
|||||||
@ -15,14 +15,14 @@ configs=(
|
|||||||
|
|
||||||
run_tests() {
|
run_tests() {
|
||||||
local label=$1
|
local label=$1
|
||||||
local extra_env=$2
|
local extra_args=$2
|
||||||
|
|
||||||
echo "=== Running tests (${label}) ==="
|
echo "=== Running tests (${label}) ==="
|
||||||
for cfg in "${configs[@]}"; do
|
for cfg in "${configs[@]}"; do
|
||||||
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
|
echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}"
|
||||||
# Use 'env' to safely set variables without eval
|
# Use 'env' to safely set variables without eval
|
||||||
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
|
if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then
|
||||||
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
|
echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
@ -34,8 +34,8 @@ run_tests "default backend" ""
|
|||||||
|
|
||||||
# Check if FLASHINFER is set (non-empty)
|
# Check if FLASHINFER is set (non-empty)
|
||||||
if [[ -n "${FLASHINFER:-}" ]]; then
|
if [[ -n "${FLASHINFER:-}" ]]; then
|
||||||
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
|
echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER"
|
||||||
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
|
run_tests "FLASHINFER backend" "--attention-backend FLASHINFER"
|
||||||
else
|
else
|
||||||
echo "FLASHINFER not set, skipping FLASHINFER runs."
|
echo "FLASHINFER not set, skipping FLASHINFER runs."
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
|||||||
"TRITON_ATTN",
|
"TRITON_ATTN",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
def test_register_kv_caches(dist_init, attn_backend):
|
||||||
"""
|
"""
|
||||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||||
correct data.
|
correct data.
|
||||||
@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
|||||||
block layout info
|
block layout info
|
||||||
"""
|
"""
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
vllm_config = create_vllm_config(attention_backend=attn_backend)
|
||||||
|
|
||||||
vllm_config = create_vllm_config()
|
|
||||||
|
|
||||||
# Import the appropriate backend based on the parameter
|
# Import the appropriate backend based on the parameter
|
||||||
if attn_backend == "FLASH_ATTN":
|
if attn_backend == "FLASH_ATTN":
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
|
AttentionConfig,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
KVTransferConfig,
|
KVTransferConfig,
|
||||||
@ -94,6 +95,7 @@ def create_vllm_config(
|
|||||||
dtype: str = "float16",
|
dtype: str = "float16",
|
||||||
cache_dtype: str = "auto",
|
cache_dtype: str = "auto",
|
||||||
hf_overrides: dict[str, Any] | None = None,
|
hf_overrides: dict[str, Any] | None = None,
|
||||||
|
attention_backend: str | None = None,
|
||||||
) -> VllmConfig:
|
) -> VllmConfig:
|
||||||
"""Initialize VllmConfig For Testing."""
|
"""Initialize VllmConfig For Testing."""
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
@ -124,12 +126,14 @@ def create_vllm_config(
|
|||||||
enable_permute_local_kv=enable_permute_local_kv,
|
enable_permute_local_kv=enable_permute_local_kv,
|
||||||
kv_connector_extra_config=kv_connector_extra_config or {},
|
kv_connector_extra_config=kv_connector_extra_config or {},
|
||||||
)
|
)
|
||||||
|
attention_config = AttentionConfig(backend=attention_backend)
|
||||||
return VllmConfig(
|
return VllmConfig(
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
kv_transfer_config=kv_transfer_config,
|
kv_transfer_config=kv_transfer_config,
|
||||||
device_config=DeviceConfig("cpu"),
|
device_config=DeviceConfig("cpu"),
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
|
|||||||
from vllm.config import KVEventsConfig, KVTransferConfig
|
from vllm.config import KVEventsConfig, KVTransferConfig
|
||||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.system_utils import set_env_var
|
|
||||||
|
|
||||||
CPU_BLOCK_SIZES = [48]
|
CPU_BLOCK_SIZES = [48]
|
||||||
ATTN_BACKENDS = ["FLASH_ATTN"]
|
ATTN_BACKENDS = ["FLASH_ATTN"]
|
||||||
@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
|
|||||||
topic="test",
|
topic="test",
|
||||||
)
|
)
|
||||||
|
|
||||||
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
|
llm = LLM(
|
||||||
llm = LLM(
|
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
gpu_memory_utilization=0.5,
|
||||||
gpu_memory_utilization=0.5,
|
kv_events_config=kv_events_config,
|
||||||
kv_events_config=kv_events_config,
|
kv_transfer_config=kv_transfer_config,
|
||||||
kv_transfer_config=kv_transfer_config,
|
attention_config={"backend": attn_backend},
|
||||||
)
|
)
|
||||||
|
|
||||||
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
|
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
|
||||||
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
|
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
|
AttentionConfig,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
|||||||
def _create_proposer(
|
def _create_proposer(
|
||||||
method: str,
|
method: str,
|
||||||
num_speculative_tokens: int,
|
num_speculative_tokens: int,
|
||||||
|
attention_backend: str | None = None,
|
||||||
speculative_token_tree: list[tuple[int, ...]] | None = None,
|
speculative_token_tree: list[tuple[int, ...]] | None = None,
|
||||||
) -> EagleProposer:
|
) -> EagleProposer:
|
||||||
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
||||||
@ -70,6 +72,7 @@ def _create_proposer(
|
|||||||
max_model_len=model_config.max_model_len,
|
max_model_len=model_config.max_model_len,
|
||||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
),
|
),
|
||||||
|
attention_config=AttentionConfig(backend=attention_backend),
|
||||||
)
|
)
|
||||||
|
|
||||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||||
@ -331,8 +334,6 @@ def test_load_model(
|
|||||||
use_distinct_lm_head,
|
use_distinct_lm_head,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
|
||||||
|
|
||||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"TRITON_ATTN does not support "
|
"TRITON_ATTN does not support "
|
||||||
@ -394,7 +395,9 @@ def test_load_model(
|
|||||||
assert not isinstance(target_model, SupportsMultiModal)
|
assert not isinstance(target_model, SupportsMultiModal)
|
||||||
|
|
||||||
# Create proposer using the helper function
|
# Create proposer using the helper function
|
||||||
proposer = _create_proposer(method, num_speculative_tokens=8)
|
proposer = _create_proposer(
|
||||||
|
method, num_speculative_tokens=8, attention_backend=attn_backend
|
||||||
|
)
|
||||||
|
|
||||||
# Call the method under test
|
# Call the method under test
|
||||||
proposer.load_model(target_model)
|
proposer.load_model(target_model)
|
||||||
@ -420,8 +423,6 @@ def test_load_model(
|
|||||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||||
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
|
||||||
|
|
||||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"TRITON_ATTN does not support "
|
"TRITON_ATTN does not support "
|
||||||
@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
seq_lens = [seq_len_1, seq_len_2]
|
seq_lens = [seq_len_1, seq_len_2]
|
||||||
|
|
||||||
# Create proposer first so we can use its actual hidden_size
|
# Create proposer first so we can use its actual hidden_size
|
||||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
proposer = _create_proposer(
|
||||||
|
"eagle", num_speculative_tokens, attention_backend=attn_backend
|
||||||
|
)
|
||||||
# Get the hidden_size from the proposer to ensure consistency
|
# Get the hidden_size from the proposer to ensure consistency
|
||||||
hidden_size = proposer.hidden_size
|
hidden_size = proposer.hidden_size
|
||||||
|
|
||||||
@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
|
|||||||
|
|
||||||
# Create proposer first so we can use its actual hidden_size.
|
# Create proposer first so we can use its actual hidden_size.
|
||||||
proposer = _create_proposer(
|
proposer = _create_proposer(
|
||||||
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
|
"eagle",
|
||||||
|
num_speculative_tokens,
|
||||||
|
speculative_token_tree=spec_token_tree,
|
||||||
)
|
)
|
||||||
# Get the hidden_size from the proposer to ensure consistency.
|
# Get the hidden_size from the proposer to ensure consistency.
|
||||||
hidden_size = proposer.hidden_size
|
hidden_size = proposer.hidden_size
|
||||||
|
|||||||
@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int):
|
|||||||
def test_eagle_max_len(
|
def test_eagle_max_len(
|
||||||
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
|
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
|
||||||
):
|
):
|
||||||
with monkeypatch.context() as m:
|
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
pytest.skip(
|
||||||
|
"TRITON_ATTN does not support "
|
||||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
"multi-token eagle spec decode on current platform"
|
||||||
pytest.skip(
|
|
||||||
"TRITON_ATTN does not support "
|
|
||||||
"multi-token eagle spec decode on current platform"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
|
||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
||||||
enforce_eager=True, # For faster initialization.
|
|
||||||
speculative_config={
|
|
||||||
"method": "eagle",
|
|
||||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
|
||||||
"num_speculative_tokens": num_speculative_tokens,
|
|
||||||
"max_model_len": 80,
|
|
||||||
},
|
|
||||||
max_model_len=200,
|
|
||||||
)
|
)
|
||||||
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
|
|
||||||
outputs = llm.generate(_PROMPTS, sampling_params)
|
|
||||||
for o in outputs:
|
|
||||||
assert o.outputs[0].finish_reason == "length", (
|
|
||||||
"This test is only meaningful if the output "
|
|
||||||
"is truncated due to max length"
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||||
max_tokens=200,
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
structured_outputs=StructuredOutputsParams(
|
|
||||||
regex="^" + "a b c d e " * 15 + "$"
|
llm = LLM(
|
||||||
),
|
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
enforce_eager=True, # For faster initialization.
|
||||||
|
speculative_config={
|
||||||
|
"method": "eagle",
|
||||||
|
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||||
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
|
"max_model_len": 80,
|
||||||
|
},
|
||||||
|
max_model_len=200,
|
||||||
|
attention_config={"backend": attn_backend},
|
||||||
|
)
|
||||||
|
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
|
||||||
|
outputs = llm.generate(_PROMPTS, sampling_params)
|
||||||
|
for o in outputs:
|
||||||
|
assert o.outputs[0].finish_reason == "length", (
|
||||||
|
"This test is only meaningful if the output is truncated due to max length"
|
||||||
)
|
)
|
||||||
output = llm.generate(_PROMPTS, sampling_params)
|
|
||||||
for o in output:
|
sampling_params = SamplingParams(
|
||||||
assert o.prompt_token_ids is not None
|
max_tokens=200,
|
||||||
assert (
|
structured_outputs=StructuredOutputsParams(regex="^" + "a b c d e " * 15 + "$"),
|
||||||
len(o.prompt_token_ids)
|
)
|
||||||
< 80
|
output = llm.generate(_PROMPTS, sampling_params)
|
||||||
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
|
for o in output:
|
||||||
<= 200
|
assert o.prompt_token_ids is not None
|
||||||
), (
|
assert (
|
||||||
"This test is only meaningful if the output "
|
len(o.prompt_token_ids)
|
||||||
"is longer than the eagle max length"
|
< 80
|
||||||
)
|
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
|
||||||
assert o.outputs[0].text == "a b c d e " * 15
|
<= 200
|
||||||
|
), (
|
||||||
|
"This test is only meaningful if the output "
|
||||||
|
"is longer than the eagle max length"
|
||||||
|
)
|
||||||
|
assert o.outputs[0].text == "a b c d e " * 15
|
||||||
|
|||||||
@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by {attn_type}. "
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||||
"Set --attention-config.backend=FLEX_ATTENTION to use "
|
"Set --attention-backend=FLEX_ATTENTION to use "
|
||||||
"FlexAttention backend which supports all head sizes."
|
"FlexAttention backend which supports all head sizes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user