diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7bb01507ac2c..64f4310151cd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -559,7 +559,6 @@ def cutlass_scaled_mm(a: torch.Tensor, scale_a.shape * [1, 128] == a.shape scale_b.shape * [128, 128] == b.shape """ - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert bias is None or bias.shape[0] == b.shape[ 1] and bias.dtype == out_dtype @@ -567,7 +566,8 @@ def cutlass_scaled_mm(a: torch.Tensor, m = a.shape[0] n = b.shape[1] - if current_platform.is_rocm(): + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if current_platform.is_rocm() or not cutlass_compatible_b: triton_scaled_mm_module = importlib.import_module( "vllm.model_executor.layers.quantization.compressed_tensors." "triton_scaled_mm") diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 98b06b6c2ae9..aaaf7a9e0a4c 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -85,6 +85,32 @@ def block_dequant( return x_dq_block +if current_platform.is_rocm(): + from triton.language import core + + # NOTE: This can be removed when hip.libdevice.round() is available. + @core.extern + def round_f32(arg0, _builder=None): + return core.extern_elementwise("", + "", [arg0], { + (core.dtype("fp32"), ): + ("llvm.round", core.dtype("fp32")), + (core.dtype("fp64"), ): + ("llvm.round", core.dtype("fp64")), + }, + is_pure=True, + _builder=_builder) + + @triton.jit + def round_int8(x): + return round_f32(x).to(tl.int8) +else: + + @triton.jit + def round_int8(x): + return tl.extra.cuda.libdevice.round(x).to(tl.int8) + + @triton.jit def _per_token_quant_int8( x_ptr, @@ -106,7 +132,7 @@ def _per_token_quant_int8( absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) scale_x = absmax / 127 x_q = x * (127 / absmax) - x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) + x_q = round_int8(x_q) tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) tl.store(scale_ptr + row_id, scale_x)