[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu 2025-04-26 00:05:10 -05:00 committed by GitHub
parent 9e96f56efb
commit 54271bb766
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 15 deletions

View File

@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor, input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor: output_shape: List) -> torch.Tensor:
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[ from vllm.platforms.rocm import on_mi250_mi300
0] == 1 and qinput.shape[1] % 16 == 0: if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count()) current_platform.get_cu_count())
else: else:
@ -371,7 +372,7 @@ class Fp8LinearOp:
return w8a8_scaled_mm_func(qinput=qinput, return w8a8_scaled_mm_func(qinput=qinput,
weight=weight, weight=weight,
out_dtype=input.dtype, out_dtype=out_dtype,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
bias=bias, bias=bias,

View File

@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
def rocm_unquantized_gemm(x: torch.Tensor, def rocm_unquantized_gemm(x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor] = None): bias: Optional[torch.Tensor] = None):
from vllm.platforms.rocm import on_mi250_mi300
k = weight.shape[1] k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \ use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
x.dtype in [torch.float16, torch.bfloat16] \ x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None) and k % 8 == 0 and bias is None)
@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
m = weight.shape[0] m = weight.shape[0]
cu_count = current_platform.get_cu_count() cu_count = current_platform.get_cu_count()
if m > 8 and n < 4: if m > 8 and 0 < n < 4:
out = ops.wvSplitK(weight, x_view, cu_count) out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0]) return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192: elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, out, 4) out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0]) return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias) return torch.nn.functional.linear(x, weight, bias)

View File

@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias) return dispatch_unquantized_gemm()(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor: input_: torch.Tensor) -> torch.Tensor:

View File

@ -98,22 +98,22 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id return device_id
def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
@cache @cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int, block_size: int, gqa_ratio: int,
max_seq_len: int, max_seq_len: int,
sliding_window: int) -> bool: sliding_window: int) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName # rocm custom page attention not support on gfx1*
ON_NAVI = "gfx1" in GPU_ARCH
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
# rocm custom page attention not support on navi (gfx1*)
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
return (ON_MI250_MI300 and not ON_NAVI return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
and (not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128) and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32) and (block_size == 16 or block_size == 32)