mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Quantizaton] [AMD] Add support for running DeepSeek int8 w8a8 MoE on ROCm (#17558)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
d47b605eca
commit
e3d0a1d190
@ -559,7 +559,6 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
||||
scale_a.shape * [1, 128] == a.shape
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
@ -567,7 +566,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
if current_platform.is_rocm() or not cutlass_compatible_b:
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
|
||||
@ -85,6 +85,32 @@ def block_dequant(
|
||||
return x_dq_block
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from triton.language import core
|
||||
|
||||
# NOTE: This can be removed when hip.libdevice.round() is available.
|
||||
@core.extern
|
||||
def round_f32(arg0, _builder=None):
|
||||
return core.extern_elementwise("",
|
||||
"", [arg0], {
|
||||
(core.dtype("fp32"), ):
|
||||
("llvm.round", core.dtype("fp32")),
|
||||
(core.dtype("fp64"), ):
|
||||
("llvm.round", core.dtype("fp64")),
|
||||
},
|
||||
is_pure=True,
|
||||
_builder=_builder)
|
||||
|
||||
@triton.jit
|
||||
def round_int8(x):
|
||||
return round_f32(x).to(tl.int8)
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def round_int8(x):
|
||||
return tl.extra.cuda.libdevice.round(x).to(tl.int8)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_quant_int8(
|
||||
x_ptr,
|
||||
@ -106,7 +132,7 @@ def _per_token_quant_int8(
|
||||
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
||||
scale_x = absmax / 127
|
||||
x_q = x * (127 / absmax)
|
||||
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
||||
x_q = round_int8(x_q)
|
||||
|
||||
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
||||
tl.store(scale_ptr + row_id, scale_x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user