mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 03:15:29 +08:00
[CPU] Support for Whisper (#30062)
Signed-off-by: Aditya Tewari <aditya.tewari@arm.com>
This commit is contained in:
parent
53d2420b44
commit
cebda2a4af
@ -36,6 +36,11 @@ function cpu_tests() {
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run model tests
|
||||
docker exec cpu-test bash -c "
|
||||
set -e
|
||||
pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model"
|
||||
|
||||
# Run kernel tests
|
||||
docker exec cpu-test bash -c "
|
||||
set -e
|
||||
|
||||
@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata(
|
||||
input.casual = casual;
|
||||
input.isa = isa;
|
||||
input.enable_kv_split = enable_kv_split;
|
||||
TORCH_CHECK(casual, "Only supports casual mask for now.");
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||
|
||||
@ -92,13 +92,14 @@ def run_test(
|
||||
*,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: str | None = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
prompt_list = PROMPTS * 10
|
||||
expected_list = EXPECTED[model] * 10
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
@ -120,12 +121,28 @@ def run_test(
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@create_new_process_for_each_test()
|
||||
def test_models(vllm_runner, model) -> None:
|
||||
def test_models(vllm_runner, model, dtype) -> None:
|
||||
run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models_cpu(vllm_runner, model, dtype) -> None:
|
||||
# @create_new_process_for_each_test() does not work for some runners
|
||||
# TODO: to fix cpu privilege issues in run-cpu-test-arm.sh
|
||||
run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -50,11 +50,13 @@ class CPUAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""CPU attention supports decoder and encoder-only attention."""
|
||||
"""CPU attention supports decoder,
|
||||
encoder-only and encoder-decoder attention."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -136,6 +138,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
|
||||
self.window_size = -1
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.isa = _get_attn_isa(self.dtype, self.block_size)
|
||||
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
@ -151,7 +154,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
causal = False if self.is_cross_attention else common_attn_metadata.causal
|
||||
|
||||
sdpa_start_loc = query_start_loc
|
||||
num_decode_tokens = 0
|
||||
@ -171,22 +174,19 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
|
||||
query_start_loc = query_start_loc[: num_decodes + 1]
|
||||
block_table_tensor = block_table_tensor[:num_decodes]
|
||||
|
||||
sheduler_metadata = None
|
||||
if causal:
|
||||
# for decode batch, use the custom kernel
|
||||
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_reqs,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
seq_lens=seq_lens,
|
||||
dtype=self.dtype,
|
||||
query_start_loc=query_start_loc,
|
||||
causal=causal,
|
||||
sliding_window_size=self.window_size,
|
||||
isa=self.isa,
|
||||
enable_kv_split=True,
|
||||
)
|
||||
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_reqs,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
seq_lens=seq_lens,
|
||||
dtype=self.dtype,
|
||||
query_start_loc=query_start_loc,
|
||||
causal=causal,
|
||||
sliding_window_size=self.window_size,
|
||||
isa=self.isa,
|
||||
enable_kv_split=True,
|
||||
)
|
||||
|
||||
attn_metadata = CPUAttentionMetadata(
|
||||
isa=self.isa,
|
||||
|
||||
@ -313,8 +313,12 @@ def bind_kv_cache(
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
# We know that the GPU runner is not impacted by this
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_xpu()
|
||||
or current_platform.is_cpu()
|
||||
):
|
||||
# We know that the GPU / CPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
pass
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user