diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index e33b305322043..9b0ac38db5e3c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -22,6 +22,7 @@ def flashinfer_w8a8_scaled_mm( As: torch.Tensor, Bs: torch.Tensor, bias: torch.Tensor, + output_shape: list, ) -> torch.Tensor: return flashinfer_scaled_fp8_mm( A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias