mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 14:35:40 +08:00
[ROCm] gemm_a16w16 upstreaming (#26969)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
parent
1fb4217a05
commit
2d977a7a9e
@ -103,12 +103,41 @@ def default_unquantized_gemm(
|
|||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def use_aiter_triton_gemm(n, m, k, dtype):
|
||||||
|
if (
|
||||||
|
envs.VLLM_ROCM_USE_AITER == 0
|
||||||
|
# MI300's - fp8nuz=True
|
||||||
|
or current_platform.is_fp8_fnuz()
|
||||||
|
or dtype not in [torch.float16, torch.bfloat16]
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# use hipblaslt for the larger GEMMs
|
||||||
|
if n > 2048 and m > 512:
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
(m == 5120 and k == 2880)
|
||||||
|
or (m == 2880 and k == 4096)
|
||||||
|
or (m == 128 and k == 2880)
|
||||||
|
or (m == 640 and k == 2880)
|
||||||
|
or (m == 2880 and k == 512)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rocm_unquantized_gemm_impl(
|
def rocm_unquantized_gemm_impl(
|
||||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.platforms.rocm import on_gfx9
|
from vllm.platforms.rocm import on_gfx9
|
||||||
|
|
||||||
|
n = x.numel() / x.size(-1)
|
||||||
|
m = weight.shape[0]
|
||||||
k = weight.shape[1]
|
k = weight.shape[1]
|
||||||
|
|
||||||
|
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||||
|
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||||
|
|
||||||
|
return gemm_a16w16(x, weight, bias)
|
||||||
|
|
||||||
use_skinny = (
|
use_skinny = (
|
||||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||||
and on_gfx9()
|
and on_gfx9()
|
||||||
@ -120,11 +149,8 @@ def rocm_unquantized_gemm_impl(
|
|||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
x_view = x.reshape(-1, x.size(-1))
|
x_view = x.reshape(-1, x.size(-1))
|
||||||
n = x_view.shape[0]
|
|
||||||
m = weight.shape[0]
|
|
||||||
cu_count = current_platform.get_cu_count()
|
|
||||||
|
|
||||||
if m > 8 and 0 < n <= 4:
|
if m > 8 and 0 < n <= 4:
|
||||||
|
cu_count = current_platform.get_cu_count()
|
||||||
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
||||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||||
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
||||||
@ -133,7 +159,7 @@ def rocm_unquantized_gemm_impl(
|
|||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def rocm_unquantized_gemm_impl_fake(
|
def rocm_unquantized_gemm_fake(
|
||||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
||||||
@ -145,13 +171,13 @@ def rocm_unquantized_gemm(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
|
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_unquantized_gemm_impl",
|
op_name="rocm_unquantized_gemm",
|
||||||
op_func=rocm_unquantized_gemm_impl,
|
op_func=rocm_unquantized_gemm_impl,
|
||||||
fake_impl=rocm_unquantized_gemm_impl_fake,
|
fake_impl=rocm_unquantized_gemm_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -25,12 +25,14 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
@ -153,6 +155,7 @@ class MLPBlock(torch.nn.Module):
|
|||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
self.experts_per_token = config.num_experts_per_tok
|
self.experts_per_token = config.num_experts_per_tok
|
||||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
|
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
|
||||||
@ -177,7 +180,12 @@ class MLPBlock(torch.nn.Module):
|
|||||||
if self.is_sequence_parallel:
|
if self.is_sequence_parallel:
|
||||||
x = sequence_parallel_chunk(x)
|
x = sequence_parallel_chunk(x)
|
||||||
|
|
||||||
g = self.router(x)
|
if current_platform.is_rocm():
|
||||||
|
g = rocm_unquantized_gemm(
|
||||||
|
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
g = self.router(x)
|
||||||
x = self.experts(hidden_states=x, router_logits=g)
|
x = self.experts(hidden_states=x, router_logits=g)
|
||||||
|
|
||||||
if self.is_sequence_parallel:
|
if self.is_sequence_parallel:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user