mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 14:56:22 +08:00
[ROCm] gemm_a16w16 upstreaming (#26969)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
parent
1fb4217a05
commit
2d977a7a9e
@ -103,12 +103,41 @@ def default_unquantized_gemm(
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def use_aiter_triton_gemm(n, m, k, dtype):
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER == 0
|
||||
# MI300's - fp8nuz=True
|
||||
or current_platform.is_fp8_fnuz()
|
||||
or dtype not in [torch.float16, torch.bfloat16]
|
||||
):
|
||||
return False
|
||||
|
||||
# use hipblaslt for the larger GEMMs
|
||||
if n > 2048 and m > 512:
|
||||
return False
|
||||
return (
|
||||
(m == 5120 and k == 2880)
|
||||
or (m == 2880 and k == 4096)
|
||||
or (m == 128 and k == 2880)
|
||||
or (m == 640 and k == 2880)
|
||||
or (m == 2880 and k == 512)
|
||||
)
|
||||
|
||||
|
||||
def rocm_unquantized_gemm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
n = x.numel() / x.size(-1)
|
||||
m = weight.shape[0]
|
||||
k = weight.shape[1]
|
||||
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
use_skinny = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx9()
|
||||
@ -120,11 +149,8 @@ def rocm_unquantized_gemm_impl(
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
x_view = x.reshape(-1, x.size(-1))
|
||||
n = x_view.shape[0]
|
||||
m = weight.shape[0]
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
if m > 8 and 0 < n <= 4:
|
||||
cu_count = current_platform.get_cu_count()
|
||||
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
||||
@ -133,7 +159,7 @@ def rocm_unquantized_gemm_impl(
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def rocm_unquantized_gemm_impl_fake(
|
||||
def rocm_unquantized_gemm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
||||
@ -145,13 +171,13 @@ def rocm_unquantized_gemm(
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
|
||||
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_unquantized_gemm_impl",
|
||||
op_name="rocm_unquantized_gemm",
|
||||
op_func=rocm_unquantized_gemm_impl,
|
||||
fake_impl=rocm_unquantized_gemm_impl_fake,
|
||||
fake_impl=rocm_unquantized_gemm_fake,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -25,12 +25,14 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
@ -153,6 +155,7 @@ class MLPBlock(torch.nn.Module):
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
|
||||
@ -177,7 +180,12 @@ class MLPBlock(torch.nn.Module):
|
||||
if self.is_sequence_parallel:
|
||||
x = sequence_parallel_chunk(x)
|
||||
|
||||
g = self.router(x)
|
||||
if current_platform.is_rocm():
|
||||
g = rocm_unquantized_gemm(
|
||||
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
|
||||
)
|
||||
else:
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user