From 14dbd5a7674e5de2862c18adb711d9feecd35063 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 26 Jul 2024 20:47:50 -0700 Subject: [PATCH] [Model] H2O Danube3-4b (#6451) --- .buildkite/run-cpu-test.sh | 2 +- .../kernels/benchmark_paged_attention.py | 2 +- benchmarks/kernels/benchmark_rope.py | 2 +- csrc/attention/attention_kernels.cu | 6 +++ tests/kernels/test_attention.py | 4 +- tests/kernels/test_cache.py | 8 ++- tests/kernels/test_pos_encoding.py | 2 +- tests/models/test_danube3_4b.py | 52 +++++++++++++++++++ vllm/attention/ops/paged_attn.py | 2 +- vllm/utils.py | 6 +++ 10 files changed, 79 insertions(+), 7 deletions(-) create mode 100644 tests/models/test_danube3_4b.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 21deec2bba97..45bc8eb2f847 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,7 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " pip install pytest Pillow protobuf - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # online inference docker exec cpu-test bash -c " diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 78cac8a555d1..a04433142da4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -175,7 +175,7 @@ if __name__ == '__main__': parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 120, 128, 192, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 78736c7a7ba6..f542684a9a2a 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -94,7 +94,7 @@ if __name__ == '__main__': parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 120, 128, 192, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 350dbce1d7ba..875570a1e894 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -751,6 +751,9 @@ void paged_attention_v1_launcher( case 112: LAUNCH_PAGED_ATTENTION_V1(112); break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; @@ -912,6 +915,9 @@ void paged_attention_v2_launcher( case 112: LAUNCH_PAGED_ATTENTION_V2(112); break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 2e6412c28958..c7c6707461c3 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -28,7 +28,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256 +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 ] if not is_hip() else [64, 80, 96, 112, 128] BLOCK_SIZES = [16, 32] @@ -134,6 +134,8 @@ def test_paged_attention( seed: int, device: str, ) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index f9a609464abf..3fb9b59be170 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing -HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [8, 16, 32] # Arbitrary values for testing @@ -52,6 +52,8 @@ def test_copy_blocks( kv_cache_dtype: str, device: str, ) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -124,6 +126,8 @@ def test_reshape_and_cache( device: str, kv_cache_dtype: str, ) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -325,6 +329,8 @@ def test_swap_blocks( ) -> None: if kv_cache_dtype == "fp8" and "cpu" in direction: pytest.skip() + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 4c83659929d4..4a7ad6e0fa21 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] -HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [7, 17] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing diff --git a/tests/models/test_danube3_4b.py b/tests/models/test_danube3_4b.py new file mode 100644 index 000000000000..bfaa275f73c1 --- /dev/null +++ b/tests/models/test_danube3_4b.py @@ -0,0 +1,52 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +This tests danube3 separately because its head size isn't supported on CPU yet. + +Run `pytest tests/models/test_danube3_4b.py`. +""" +import pytest + +from .utils import check_outputs_equal + +MODELS = ["h2oai/h2o-danube3-4b-base"] + +target_dtype = "half" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [32]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [target_dtype]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index ce7b4d129779..0f6d2f2d1ab3 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -31,7 +31,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 192, 256] + return [64, 80, 96, 112, 120, 128, 192, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/utils.py b/vllm/utils.py index 90be09fc7b96..1448316e66ed 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -508,6 +508,12 @@ def create_kv_caches_with_random( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed)