mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[AMD][ROCm] Enable DeepSeek model on ROCm (#12662)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
parent
4896d0c2dd
commit
c36ac98d01
31
tests/kernels/test_rocm_attention_selector.py
Normal file
31
tests/kernels/test_rocm_attention_selector.py
Normal file
@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
def test_selector(monkeypatch):
|
||||
"""Test that the attention selector for ROCm.
|
||||
"""
|
||||
override_backend_env_variable(monkeypatch, "ROCM_FLASH")
|
||||
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert backend.get_name() == "ROCM_FLASH"
|
||||
# mla test for deepseek related
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
@ -24,6 +24,15 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
return model_runner
|
||||
|
||||
|
||||
def test_deepseek_mla_attn_backend_module():
|
||||
model_runner = _create_model_runner(
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_prompt(batch_size):
|
||||
model_runner = _create_model_runner(
|
||||
|
||||
@ -27,7 +27,11 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize, scaled_quantize)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -47,6 +47,16 @@ def apply_w8a8_block_fp8_linear(
|
||||
|
||||
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
||||
and weight.shape[1] % 128 == 0)
|
||||
if current_platform.is_rocm():
|
||||
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
|
||||
input_2d.shape[:-1])[::-1]
|
||||
scale_b_shape = (weight_scale.view(-1, 1)
|
||||
if weight_scale.dim() <= 1 else weight_scale.T).shape
|
||||
ar, ac = scale_a_shape
|
||||
br, bc = scale_b_shape
|
||||
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
|
||||
or br not in (1, weight.shape[0])):
|
||||
shape_supported_by_cutlass = False
|
||||
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
|
||||
@ -79,6 +79,9 @@ class RocmPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
if use_mla:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user