mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 09:25:29 +08:00
[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
9e96f56efb
commit
54271bb766
@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
|||||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||||
input_2d: torch.Tensor,
|
input_2d: torch.Tensor,
|
||||||
output_shape: List) -> torch.Tensor:
|
output_shape: List) -> torch.Tensor:
|
||||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
|
from vllm.platforms.rocm import on_mi250_mi300
|
||||||
0] == 1 and qinput.shape[1] % 16 == 0:
|
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
|
||||||
|
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||||
current_platform.get_cu_count())
|
current_platform.get_cu_count())
|
||||||
else:
|
else:
|
||||||
@ -371,7 +372,7 @@ class Fp8LinearOp:
|
|||||||
|
|
||||||
return w8a8_scaled_mm_func(qinput=qinput,
|
return w8a8_scaled_mm_func(qinput=qinput,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=out_dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
|||||||
@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|||||||
def rocm_unquantized_gemm(x: torch.Tensor,
|
def rocm_unquantized_gemm(x: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None):
|
bias: Optional[torch.Tensor] = None):
|
||||||
|
from vllm.platforms.rocm import on_mi250_mi300
|
||||||
k = weight.shape[1]
|
k = weight.shape[1]
|
||||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
|
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
|
||||||
x.dtype in [torch.float16, torch.bfloat16] \
|
x.dtype in [torch.float16, torch.bfloat16] \
|
||||||
and k % 8 == 0 and bias is None)
|
and k % 8 == 0 and bias is None)
|
||||||
|
|
||||||
@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
|
|||||||
m = weight.shape[0]
|
m = weight.shape[0]
|
||||||
cu_count = current_platform.get_cu_count()
|
cu_count = current_platform.get_cu_count()
|
||||||
|
|
||||||
if m > 8 and n < 4:
|
if m > 8 and 0 < n < 4:
|
||||||
out = ops.wvSplitK(weight, x_view, cu_count)
|
out = ops.wvSplitK(weight, x_view, cu_count)
|
||||||
return out.view(*x.shape[:-1], weight.shape[0])
|
return out.view(*x.shape[:-1], weight.shape[0])
|
||||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||||
out = ops.LLMM1(weight, x_view, out, 4)
|
out = ops.LLMM1(weight, x_view, 4)
|
||||||
return out.view(*x.shape[:-1], weight.shape[0])
|
return out.view(*x.shape[:-1], weight.shape[0])
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||||
|
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||||
from vllm.model_executor.parameter import BasevLLMParameter
|
from vllm.model_executor.parameter import BasevLLMParameter
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return F.linear(x, layer.weight, bias)
|
return dispatch_unquantized_gemm()(x, layer.weight, bias)
|
||||||
|
|
||||||
def embedding(self, layer: torch.nn.Module,
|
def embedding(self, layer: torch.nn.Module,
|
||||||
input_: torch.Tensor) -> torch.Tensor:
|
input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -98,22 +98,22 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
|||||||
return device_id
|
return device_id
|
||||||
|
|
||||||
|
|
||||||
|
def on_mi250_mi300() -> bool:
|
||||||
|
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||||
block_size: int, gqa_ratio: int,
|
block_size: int, gqa_ratio: int,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
sliding_window: int) -> bool:
|
sliding_window: int) -> bool:
|
||||||
|
|
||||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
# rocm custom page attention not support on gfx1*
|
||||||
ON_NAVI = "gfx1" in GPU_ARCH
|
|
||||||
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
|
||||||
|
|
||||||
# rocm custom page attention not support on navi (gfx1*)
|
|
||||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||||
# disabled due to observed numerical discrepancy.
|
# disabled due to observed numerical discrepancy.
|
||||||
return (ON_MI250_MI300 and not ON_NAVI
|
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||||
and (not envs.VLLM_USE_V1 or sliding_window == 0
|
or sliding_window == (-1, -1))
|
||||||
or sliding_window == (-1, -1))
|
|
||||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
and (head_size == 64 or head_size == 128)
|
and (head_size == 64 or head_size == 128)
|
||||||
and (block_size == 16 or block_size == 32)
|
and (block_size == 16 or block_size == 32)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user