Remove unneeded ROCm platform import when using CUDA (#22765)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-13 00:26:38 -04:00 committed by GitHub
parent c6b928798e
commit 4082338a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -22,7 +22,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
@ -886,6 +885,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window,

View File

@ -11,7 +11,6 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd
@ -296,6 +295,7 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16)
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention(
query.dtype,
head_size,