diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index cf296a3b534b..35345b1be01c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1282,10 +1282,11 @@ def scaled_fp8_quant( output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + torch.ops._C.dynamic_scaled_fp8_quant(output, input.contiguous(), + scale) else: assert scale.numel() == 1, f"{scale.shape}" - torch.ops._C.static_scaled_fp8_quant(output, input, scale) + torch.ops._C.static_scaled_fp8_quant(output, input.contiguous(), scale) return output, scale