diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6d9c6f51b34d..f7d230c5d7d6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -209,7 +209,7 @@ class Attention(nn.Module): if self.use_output: output_shape = (output_shape if output_shape is not None else query.shape) - output = torch.empty(output_shape, + output = torch.zeros(output_shape, dtype=query.dtype, device=query.device) hidden_size = output_shape[-1]