diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 5b7ce9a7677e3..87f74277cf900 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -1,6 +1,6 @@ #!/bin/bash -set -e +set -xue # Build the docker image. docker build -f docker/Dockerfile.tpu -t vllm-tpu . @@ -38,7 +38,9 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_7 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \ && echo TEST_8 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ + && echo TEST_9 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 2bc32ace0a59d..95657455bd7bb 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -13,18 +13,24 @@ import pytest from vllm.platforms import current_platform -MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +MODEL_NAMES = [ + "Qwen/Qwen2-1.5B-Instruct", + "google/gemma-3-1b-it", +] NUM_CONCURRENT = 500 TASK = "gsm8k" FILTER = "exact_match,strict-match" RTOL = 0.03 -EXPECTED_VALUE = 0.58 +EXPECTED_VALUES = { + "Qwen/Qwen2-1.5B-Instruct": 0.58, + "google/gemma-3-1b-it": 0.25, +} -def run_test(more_args=None): +def run_test(model_name, more_args=None): """Run the end to end accuracy test.""" - model_args = f"pretrained={MODEL_NAME},max_model_len=4096" + model_args = f"pretrained={model_name},max_model_len=4096" if more_args is not None: model_args = "{},{}".format(model_args, more_args) @@ -37,9 +43,12 @@ def run_test(more_args=None): ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert model_name in EXPECTED_VALUES, ( + f"Cannot find the expected value for the model {model_name=}") + expected_value = EXPECTED_VALUES[model_name] + assert (measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" # TODO: [AlexM] Fix it with new CI/CD tests @@ -49,7 +58,8 @@ TPU_TP_TEST_STR = "" #"tensor_parallel_size=4" @pytest.mark.skipif(not current_platform.is_cuda() and not current_platform.is_tpu(), reason="V1 is currently only supported on CUDA and TPU") -def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("model", MODEL_NAMES) +def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): """Run with the V1 Engine.""" with monkeypatch.context() as m: @@ -64,7 +74,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): if TPU_TP_TEST_STR: more_args += ",{}".format(TPU_TP_TEST_STR) - run_test(more_args) + run_test(model, more_args) def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): @@ -72,4 +82,4 @@ def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "0") - run_test() + run_test("Qwen/Qwen2-1.5B-Instruct") diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py new file mode 100644 index 0000000000000..54eab145efb47 --- /dev/null +++ b/tests/v1/tpu/test_pallas.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import ANY, patch + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, + NUM_QUERIES_PER_BLOCK, + PallasAttentionBackendImpl, + PallasMetadata) + + +def test_ragged_paged_attention(): + # We verify that the kernel inputs such as sliding_window, etc. are passed + # in from the model correctly. + # The correctness of the paged attention kernel is tested in the kernel + # library. + num_heads = 4 + head_size = 128 + scale = 1.0 + num_kv_heads = 4 + sliding_window = 128 + logits_soft_cap = 50.0 + attn_impl = PallasAttentionBackendImpl( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=sliding_window, + kv_cache_dtype="auto", + logits_soft_cap=logits_soft_cap, + attn_type=AttentionType.DECODER, + ) + mock_vmem_limit_bytes = 1024 + attn_impl.vmem_limit_bytes = mock_vmem_limit_bytes + + class FakeAttentionLayer: + _k_scale_float: float + _v_scale_float: float + + layer = FakeAttentionLayer() + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + + num_tokens = 16 + num_blocks = 1024 + block_size = 16 + query = torch.zeros(num_tokens, num_heads * head_size) + key = torch.zeros(num_tokens, num_kv_heads * head_size) + value = torch.zeros(num_tokens, num_kv_heads * head_size) + kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size) + slot_mapping = torch.zeros(num_tokens, dtype=torch.int64) + max_num_reqs = 8 + max_num_blocks_per_req = 8 + block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), + dtype=torch.int32) + context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32) + query_lens = [1] * max_num_reqs + query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, + dtype=torch.int32), + dim=0, + dtype=torch.int32) + num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32) + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + query_start_loc=query_start_loc, + num_seqs=num_seqs, + ) + + with patch("torch.ops.xla.ragged_paged_attention" + ) as mock_ragged_paged_attention: + attn_impl.forward( + layer=layer, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + mock_ragged_paged_attention.assert_called_once_with( + ANY, # query + ANY, # kv_cache + ANY, # context_lens + ANY, # block_tables + ANY, # query_start_loc + ANY, # num_seqs + num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, + num_queries_per_block=NUM_QUERIES_PER_BLOCK, + vmem_limit_bytes=mock_vmem_limit_bytes, + use_kernel=True, + sm_scale=scale, + sliding_window=sliding_window, + soft_cap=logits_soft_cap, + ) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 2f86920e2773a..2789863298027 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -92,6 +92,8 @@ class PallasAttentionBackendImpl(AttentionImpl): self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -99,15 +101,10 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("Head size must be a multiple of 128.") if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") - if sliding_window is not None: - raise NotImplementedError("Sliding window is not supported.") if kv_cache_dtype != "auto": raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") - if logits_soft_cap is not None: - raise NotImplementedError( - "Attention logits soft-capping is not supported.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -172,7 +169,10 @@ class PallasAttentionBackendImpl(AttentionImpl): num_queries_per_block=NUM_QUERIES_PER_BLOCK, vmem_limit_bytes=self.vmem_limit_bytes, use_kernel=True, - sm_scale=self.scale) + sm_scale=self.scale, + sliding_window=self.sliding_window, + soft_cap=self.logits_soft_cap, + ) return output.reshape(num_tokens, hidden_size)