[CPU] Support for Whisper (#30062)

Signed-off-by: Aditya Tewari <aditya.tewari@arm.com>
This commit is contained in:
Aditya Tewari 2025-12-10 12:58:42 +00:00 committed by GitHub
parent 53d2420b44
commit cebda2a4af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 24 deletions

View File

@ -36,6 +36,11 @@ function cpu_tests() {
set -e
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
# Run model tests
docker exec cpu-test bash -c "
set -e
pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model"
# Run kernel tests
docker exec cpu-test bash -c "
set -e

View File

@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata(
input.casual = casual;
input.isa = isa;
input.enable_kv_split = enable_kv_split;
TORCH_CHECK(casual, "Only supports casual mask for now.");
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {

View File

@ -92,13 +92,14 @@ def run_test(
*,
tensor_parallel_size: int,
distributed_executor_backend: str | None = None,
dtype: str = "half",
) -> None:
prompt_list = PROMPTS * 10
expected_list = EXPECTED[model] * 10
with vllm_runner(
model,
dtype="half",
dtype=dtype,
max_model_len=448,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
@ -120,12 +121,28 @@ def run_test(
@pytest.mark.core_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("dtype", ["half"])
@create_new_process_for_each_test()
def test_models(vllm_runner, model) -> None:
def test_models(vllm_runner, model, dtype) -> None:
run_test(
vllm_runner,
model,
tensor_parallel_size=1,
dtype=dtype,
)
@pytest.mark.cpu_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("dtype", ["half"])
def test_models_cpu(vllm_runner, model, dtype) -> None:
# @create_new_process_for_each_test() does not work for some runners
# TODO: to fix cpu privilege issues in run-cpu-test-arm.sh
run_test(
vllm_runner,
model,
tensor_parallel_size=1,
dtype=dtype,
)

View File

@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
logger = init_logger(__name__)
@ -50,11 +50,13 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
"""CPU attention supports decoder,
encoder-only and encoder-decoder attention."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
@ -136,6 +138,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size)
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
def build(
self,
@ -151,7 +154,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
causal = False if self.is_cross_attention else common_attn_metadata.causal
sdpa_start_loc = query_start_loc
num_decode_tokens = 0
@ -171,22 +174,19 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
query_start_loc = query_start_loc[: num_decodes + 1]
block_table_tensor = block_table_tensor[:num_decodes]
sheduler_metadata = None
if causal:
# for decode batch, use the custom kernel
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
num_reqs=num_reqs,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
seq_lens=seq_lens,
dtype=self.dtype,
query_start_loc=query_start_loc,
causal=causal,
sliding_window_size=self.window_size,
isa=self.isa,
enable_kv_split=True,
)
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
num_reqs=num_reqs,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
seq_lens=seq_lens,
dtype=self.dtype,
query_start_loc=query_start_loc,
causal=causal,
sliding_window_size=self.window_size,
isa=self.isa,
enable_kv_split=True,
)
attn_metadata = CPUAttentionMetadata(
isa=self.isa,

View File

@ -313,8 +313,12 @@ def bind_kv_cache(
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_cuda_alike() or current_platform.is_xpu():
# We know that the GPU runner is not impacted by this
if (
current_platform.is_cuda_alike()
or current_platform.is_xpu()
or current_platform.is_cpu()
):
# We know that the GPU / CPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass