mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
9e96f56efb
commit
54271bb766
@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
|
||||
0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
else:
|
||||
@ -371,7 +372,7 @@ class Fp8LinearOp:
|
||||
|
||||
return w8a8_scaled_mm_func(qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=input.dtype,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
|
||||
@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
def rocm_unquantized_gemm(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
k = weight.shape[1]
|
||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
|
||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
|
||||
x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and k % 8 == 0 and bias is None)
|
||||
|
||||
@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
|
||||
m = weight.shape[0]
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
if m > 8 and n < 4:
|
||||
if m > 8 and 0 < n < 4:
|
||||
out = ops.wvSplitK(weight, x_view, cu_count)
|
||||
return out.view(*x.shape[:-1], weight.shape[0])
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
out = ops.LLMM1(weight, x_view, out, 4)
|
||||
out = ops.LLMM1(weight, x_view, 4)
|
||||
return out.view(*x.shape[:-1], weight.shape[0])
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -98,22 +98,22 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
return device_id
|
||||
|
||||
|
||||
def on_mi250_mi300() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||
|
||||
|
||||
@cache
|
||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_NAVI = "gfx1" in GPU_ARCH
|
||||
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||
|
||||
# rocm custom page attention not support on navi (gfx1*)
|
||||
# rocm custom page attention not support on gfx1*
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
return (ON_MI250_MI300 and not ON_NAVI
|
||||
and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user