diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 8ccd062afcc90..c35cd8a6e900e 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -73,7 +73,12 @@ class PagedAttention(nn.Module): raise ValueError(f"head_size ({self.head_size}) is not supported. " f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") - def set_attn_bias(self, input_metadata: InputMetadata) -> None: + def set_attn_bias( + self, + input_metadata: InputMetadata, + dtype: torch.dtype, + ) -> None: + del dtype # Unused. if input_metadata.attn_bias: # Already set by a previous layer. return @@ -196,7 +201,7 @@ class PagedAttention(nn.Module): if num_prompt_tokens > 0: # Prompt run. assert input_metadata.num_generation_tokens == 0 - self.set_attn_bias(input_metadata) + self.set_attn_bias(input_metadata, dtype=query.dtype) self.multi_query_kv_attention( output[:num_prompt_tokens], query[:num_prompt_tokens], @@ -340,13 +345,14 @@ class PagedAttentionWithALiBi(PagedAttention): slopes = torch.tensor(slopes, dtype=torch.float32) self.register_buffer("alibi_slopes", slopes, persistent=False) - def set_attn_bias(self, input_metadata: InputMetadata) -> None: + def set_attn_bias(self, input_metadata: InputMetadata, + dtype: torch.dtype) -> None: if input_metadata.attn_bias: # Already set by a previous layer. return # Generates ALiBi mask for each prompt. for prompt_len in input_metadata.prompt_lens: - bias = torch.arange(prompt_len) + bias = torch.arange(prompt_len, dtype=dtype) # Note(zhuohan): HF uses # `bias = bias[None, :].repeat(prompt_len, 1)` # here. We find that both biases give the same results, but @@ -364,6 +370,7 @@ class PagedAttentionWithALiBi(PagedAttention): prompt_len, padded_len, device=self.alibi_slopes.device, + dtype=dtype, )[:, :, :, :prompt_len].copy_(bias) bias.mul_(self.alibi_slopes[:, None, None]) attn_bias = LowerTriangularMaskWithTensorBias(bias)