mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:46:08 +08:00
Remove unneeded ROCm platform import when using CUDA (#22765)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
c6b928798e
commit
4082338a25
@ -22,7 +22,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
@ -886,6 +885,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len, self.sliding_window,
|
||||
|
||||
@ -11,7 +11,6 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
@ -296,6 +295,7 @@ def chunked_prefill_paged_decode(
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||
16)
|
||||
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
query.dtype,
|
||||
head_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user