[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu 2025-04-26 00:05:10 -05:00 committed by GitHub
parent 9e96f56efb
commit 54271bb766
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 15 deletions

View File

@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
0] == 1 and qinput.shape[1] % 16 == 0:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else:
@ -371,7 +372,7 @@ class Fp8LinearOp:
return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,

View File

@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
from vllm.platforms.rocm import on_mi250_mi300
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)
@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and n < 4:
if m > 8 and 0 < n < 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, out, 4)
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)

View File

@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
return dispatch_unquantized_gemm()(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:

View File

@ -98,22 +98,22 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id
def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int,
sliding_window: int) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_NAVI = "gfx1" in GPU_ARCH
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
# rocm custom page attention not support on navi (gfx1*)
# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return (ON_MI250_MI300 and not ON_NAVI
and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)