diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e26c90bf70cbe..9dbd0663eeff5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1270,7 +1270,7 @@ def scaled_fp8_quant( device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) @@ -1379,8 +1379,8 @@ def scaled_int8_quant( dtype=torch.float32) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) + torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp