mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:46:07 +08:00
[Quantizaton] [AMD] Add support for running DeepSeek int8 w8a8 MoE on ROCm (#17558)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
d47b605eca
commit
e3d0a1d190
@ -559,7 +559,6 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|||||||
scale_a.shape * [1, 128] == a.shape
|
scale_a.shape * [1, 128] == a.shape
|
||||||
scale_b.shape * [128, 128] == b.shape
|
scale_b.shape * [128, 128] == b.shape
|
||||||
"""
|
"""
|
||||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
|
||||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||||
assert bias is None or bias.shape[0] == b.shape[
|
assert bias is None or bias.shape[0] == b.shape[
|
||||||
1] and bias.dtype == out_dtype
|
1] and bias.dtype == out_dtype
|
||||||
@ -567,7 +566,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|||||||
m = a.shape[0]
|
m = a.shape[0]
|
||||||
n = b.shape[1]
|
n = b.shape[1]
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||||
|
if current_platform.is_rocm() or not cutlass_compatible_b:
|
||||||
triton_scaled_mm_module = importlib.import_module(
|
triton_scaled_mm_module = importlib.import_module(
|
||||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||||
"triton_scaled_mm")
|
"triton_scaled_mm")
|
||||||
|
|||||||
@ -85,6 +85,32 @@ def block_dequant(
|
|||||||
return x_dq_block
|
return x_dq_block
|
||||||
|
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from triton.language import core
|
||||||
|
|
||||||
|
# NOTE: This can be removed when hip.libdevice.round() is available.
|
||||||
|
@core.extern
|
||||||
|
def round_f32(arg0, _builder=None):
|
||||||
|
return core.extern_elementwise("",
|
||||||
|
"", [arg0], {
|
||||||
|
(core.dtype("fp32"), ):
|
||||||
|
("llvm.round", core.dtype("fp32")),
|
||||||
|
(core.dtype("fp64"), ):
|
||||||
|
("llvm.round", core.dtype("fp64")),
|
||||||
|
},
|
||||||
|
is_pure=True,
|
||||||
|
_builder=_builder)
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def round_int8(x):
|
||||||
|
return round_f32(x).to(tl.int8)
|
||||||
|
else:
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def round_int8(x):
|
||||||
|
return tl.extra.cuda.libdevice.round(x).to(tl.int8)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _per_token_quant_int8(
|
def _per_token_quant_int8(
|
||||||
x_ptr,
|
x_ptr,
|
||||||
@ -106,7 +132,7 @@ def _per_token_quant_int8(
|
|||||||
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
||||||
scale_x = absmax / 127
|
scale_x = absmax / 127
|
||||||
x_q = x * (127 / absmax)
|
x_q = x * (127 / absmax)
|
||||||
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
x_q = round_int8(x_q)
|
||||||
|
|
||||||
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
||||||
tl.store(scale_ptr + row_id, scale_x)
|
tl.store(scale_ptr + row_id, scale_x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user