[Quantizaton] [AMD] Add support for running DeepSeek int8 w8a8 MoE on ROCm (#17558)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith 2025-05-02 23:41:10 -05:00 committed by GitHub
parent d47b605eca
commit e3d0a1d190
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 3 deletions

View File

@ -559,7 +559,6 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
@ -567,7 +566,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
m = a.shape[0]
n = b.shape[1]
if current_platform.is_rocm():
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b:
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")

View File

@ -85,6 +85,32 @@ def block_dequant(
return x_dq_block
if current_platform.is_rocm():
from triton.language import core
# NOTE: This can be removed when hip.libdevice.round() is available.
@core.extern
def round_f32(arg0, _builder=None):
return core.extern_elementwise("",
"", [arg0], {
(core.dtype("fp32"), ):
("llvm.round", core.dtype("fp32")),
(core.dtype("fp64"), ):
("llvm.round", core.dtype("fp64")),
},
is_pure=True,
_builder=_builder)
@triton.jit
def round_int8(x):
return round_f32(x).to(tl.int8)
else:
@triton.jit
def round_int8(x):
return tl.extra.cuda.libdevice.round(x).to(tl.int8)
@triton.jit
def _per_token_quant_int8(
x_ptr,
@ -106,7 +132,7 @@ def _per_token_quant_int8(
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
x_q = round_int8(x_q)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)