mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[ROCm][V1] Add intial ROCm support to V1 (#12790)
This commit is contained in:
parent
cbc40128eb
commit
ba59b78a9c
16
requirements-rocm-build.txt
Normal file
16
requirements-rocm-build.txt
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Common dependencies
|
||||||
|
-r requirements-common.txt
|
||||||
|
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/rocm6.2
|
||||||
|
torch==2.5.1
|
||||||
|
torchvision==0.20.1
|
||||||
|
torchaudio==2.5.1
|
||||||
|
|
||||||
|
cmake>=3.26
|
||||||
|
ninja
|
||||||
|
packaging
|
||||||
|
setuptools>=61
|
||||||
|
setuptools-scm>=8
|
||||||
|
wheel
|
||||||
|
jinja2
|
||||||
|
amdsmi==6.2.4
|
||||||
@ -718,7 +718,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
v_scale: torch.Tensor,
|
v_scale: torch.Tensor,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None):
|
sliding_window=None,
|
||||||
|
sm_scale=None):
|
||||||
|
|
||||||
q_dtype_is_f32 = q.dtype is torch.float32
|
q_dtype_is_f32 = q.dtype is torch.float32
|
||||||
# need to reduce num. blocks when using fp32
|
# need to reduce num. blocks when using fp32
|
||||||
@ -759,7 +760,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
# round up Lk to a power of 2 - this is required for Triton block size
|
# round up Lk to a power of 2 - this is required for Triton block size
|
||||||
Lk_padded = triton.next_power_of_2(Lk)
|
Lk_padded = triton.next_power_of_2(Lk)
|
||||||
|
|
||||||
sm_scale = 1.0 / (Lq**0.5)
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / (Lq**0.5)
|
||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
@ -29,12 +28,6 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Failed to import from vllm._rocm_C with %r", e)
|
logger.warning("Failed to import from vllm._rocm_C with %r", e)
|
||||||
|
|
||||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
|
||||||
logger.warning("`fork` method is not supported by ROCm. "
|
|
||||||
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
|
|
||||||
" `spawn` instead.")
|
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
||||||
|
|
||||||
# Models not supported by ROCm.
|
# Models not supported by ROCm.
|
||||||
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
||||||
|
|
||||||
@ -84,6 +77,9 @@ class RocmPlatform(Platform):
|
|||||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||||
== _Backend.FLASH_ATTN else selected_backend)
|
== _Backend.FLASH_ATTN else selected_backend)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
logger.info("Using ROCm Attention backend on V1 engine.")
|
||||||
|
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
|
||||||
if selected_backend == _Backend.ROCM_FLASH:
|
if selected_backend == _Backend.ROCM_FLASH:
|
||||||
if not cls.has_device_capability(90):
|
if not cls.has_device_capability(90):
|
||||||
# not Instinct series GPUs.
|
# not Instinct series GPUs.
|
||||||
@ -102,7 +98,11 @@ class RocmPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
return torch.cuda.get_device_name(device_id)
|
# NOTE: When using V1 this function is called when overriding the
|
||||||
|
# engine args. Calling torch.cuda.get_device_name(device_id) here
|
||||||
|
# will result in the ROCm context being initialized before other
|
||||||
|
# processes can be created.
|
||||||
|
return "AMD"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
@ -129,15 +129,30 @@ class RocmPlatform(Platform):
|
|||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
if scheduler_config.is_multi_step:
|
if scheduler_config.is_multi_step:
|
||||||
parallel_config.worker_cls = \
|
if envs.VLLM_USE_V1:
|
||||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
raise NotImplementedError(
|
||||||
|
"Multi-step scheduling is not supported (and not "
|
||||||
|
"needed) on VLLM V1. Please launch without "
|
||||||
|
"--num-scheduler-steps.")
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||||
elif vllm_config.speculative_config:
|
elif vllm_config.speculative_config:
|
||||||
parallel_config.worker_cls = \
|
if envs.VLLM_USE_V1:
|
||||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
raise NotImplementedError(
|
||||||
parallel_config.sd_worker_cls = \
|
"Speculative decoding is not yet supported on VLLM V1."
|
||||||
"vllm.worker.worker.Worker"
|
)
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||||
|
parallel_config.sd_worker_cls = \
|
||||||
|
"vllm.worker.worker.Worker"
|
||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
if envs.VLLM_USE_V1:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.v1.worker.gpu_worker.Worker"
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def verify_model_arch(cls, model_arch: str) -> None:
|
def verify_model_arch(cls, model_arch: str) -> None:
|
||||||
|
|||||||
@ -12,8 +12,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import get_flash_attn_version
|
from vllm.attention.backends.utils import get_flash_attn_version
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
182
vllm/v1/attention/backends/rocm_attn.py
Normal file
182
vllm/v1/attention/backends/rocm_attn.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Attention layer with PagedAttention on rocm"""
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCmAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_head_sizes() -> List[int]:
|
||||||
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "ROCM_ATTN_VLLM_V1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["ROCmAttentionImpl"]:
|
||||||
|
return ROCmAttentionImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
|
return FlashAttentionMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
if block_size % 16 != 0:
|
||||||
|
raise ValueError("Block size must be a multiple of 16.")
|
||||||
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ROCmAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
) -> None:
|
||||||
|
if blocksparse_params is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"ROCmAttention does not support block-sparse attention.")
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
|
self.alibi_slopes = alibi_slopes
|
||||||
|
if sliding_window is None:
|
||||||
|
self.sliding_window = (-1, -1)
|
||||||
|
else:
|
||||||
|
self.sliding_window = (sliding_window - 1, 0)
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
|
||||||
|
if head_size not in support_head_sizes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by ROCmAttention. "
|
||||||
|
f"Supported head sizes are: {support_head_sizes}.")
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"ROCmAttentionImpl")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: FlashAttentionMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
|
if attn_metadata is None:
|
||||||
|
# Profiling run.
|
||||||
|
return output
|
||||||
|
|
||||||
|
assert attn_metadata.use_cascade is False
|
||||||
|
|
||||||
|
# IMPORTANT!
|
||||||
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||||
|
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||||
|
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||||
|
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||||
|
# Minimize the PyTorch ops in this method as much as possible.
|
||||||
|
# Whenever making a change in this method, please benchmark the
|
||||||
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
PagedAttention.write_to_paged_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(sage): Refactor the context_attention_fwd kernel so that this
|
||||||
|
# overhead can be removed
|
||||||
|
context_lens = torch.empty_like(attn_metadata.seq_lens)
|
||||||
|
batch_size = len(attn_metadata.query_start_loc) - 1
|
||||||
|
assert len(context_lens) == batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
query_start = attn_metadata.query_start_loc[i]
|
||||||
|
query_end = attn_metadata.query_start_loc[i + 1]
|
||||||
|
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
|
||||||
|
query_start)
|
||||||
|
|
||||||
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
|
context_attention_fwd(q=query[:num_actual_tokens],
|
||||||
|
k=key[:num_actual_tokens],
|
||||||
|
v=value[:num_actual_tokens],
|
||||||
|
o=output[:num_actual_tokens],
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
b_loc=attn_metadata.block_table,
|
||||||
|
b_start_loc=attn_metadata.query_start_loc,
|
||||||
|
b_seq_len=attn_metadata.seq_lens,
|
||||||
|
b_ctx_len=context_lens,
|
||||||
|
max_input_len=attn_metadata.max_query_len,
|
||||||
|
k_scale=layer._k_scale,
|
||||||
|
v_scale=layer._v_scale,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
sliding_window=self.sliding_window[0],
|
||||||
|
sm_scale=self.scale)
|
||||||
|
return output
|
||||||
Loading…
x
Reference in New Issue
Block a user