diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index 3ad6e7190942..7dd3c8a4e79e 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -11,7 +11,6 @@ from scipy.spatial.distance import cosine from vllm import LLM, SamplingParams from vllm.config import ModelConfig -from vllm.utils import STR_BACKEND_ENV_VAR from ....utils import RemoteOpenAIServer @@ -117,44 +116,37 @@ def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]): assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001) -def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch, - vllm_runner): - # GritLM embedding implementation is only supported by XFormers backend. - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") +def test_gritlm_offline_embedding(vllm_runner): + queries, q_instruction, documents, d_instruction = get_test_data() - queries, q_instruction, documents, d_instruction = get_test_data() + with vllm_runner( + MODEL_NAME, + task="embed", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + llm = vllm_model.model - with vllm_runner( - MODEL_NAME, - task="embed", - max_model_len=MAX_MODEL_LEN, - ) as vllm_model: - llm = vllm_model.model + d_rep = run_llm_encode( + llm, + documents, + d_instruction, + ) + q_rep = run_llm_encode( + llm, + queries, + q_instruction, + ) - d_rep = run_llm_encode( - llm, - documents, - d_instruction, - ) - q_rep = run_llm_encode( - llm, - queries, - q_instruction, - ) - - validate_embed_output(q_rep, d_rep) + validate_embed_output(q_rep, d_rep) @pytest.mark.asyncio async def test_gritlm_api_server_embedding(): queries, q_instruction, documents, d_instruction = get_test_data() - # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] - env_dict = {STR_BACKEND_ENV_VAR: "XFORMERS"} - with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server: + with RemoteOpenAIServer(MODEL_NAME, args) as server: client_embedding = server.get_async_client() d_rep = await run_client_embeddings( @@ -172,35 +164,28 @@ async def test_gritlm_api_server_embedding(): def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): - # GritLM embedding implementation is only supported by XFormers backend. - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" - input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" + with vllm_runner( + MODEL_NAME, + task="generate", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + llm = vllm_model.model - with vllm_runner( - MODEL_NAME, - task="generate", - max_model_len=MAX_MODEL_LEN, - ) as vllm_model: - llm = vllm_model.model + sampling_params = SamplingParams(temperature=0.0, max_tokens=256) + outputs = llm.generate(input, sampling_params=sampling_params) - sampling_params = SamplingParams(temperature=0.0, max_tokens=256) - outputs = llm.generate(input, sampling_params=sampling_params) - - assert outputs[0].outputs[0].text == "The capital of France is Paris." + assert outputs[0].outputs[0].text == "The capital of France is Paris." @pytest.mark.asyncio async def test_gritlm_api_server_generate(): input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" - # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] - env_dict = {"VLLM_USE_V1": "0", STR_BACKEND_ENV_VAR: "XFORMERS"} - with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server: + with RemoteOpenAIServer(MODEL_NAME, args) as server: client_generate = server.get_async_client() outputs = await client_generate.completions.create( diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index e4692c458088..6a444e8d1068 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,22 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 from array import array -from typing import Optional, Union +from typing import Optional import torch import torch.nn as nn -from xformers.ops.fmha.attn_bias import BlockDiagonalMask -from vllm.attention.backends.xformers import XFormersImpl from vllm.config import ModelConfig, VllmConfig -from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) -from vllm.sequence import (IntermediateTensors, PoolerOutput, - PoolingSequenceGroupOutput) +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsV0Only @@ -204,39 +200,21 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): prefix: str = "", **kwargs, ) -> None: + # Use full attention for pooling + if vllm_config.model_config.runner_type == "pooling": + hf_config = vllm_config.model_config.hf_config + hf_config.is_causal = False + + vllm_config.cache_config.sliding_window = None + + for attr in ("sliding_window", "interleaved_sliding_window"): + if hasattr(hf_config, attr): + delattr(hf_config, attr) + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - self.runner_type = vllm_config.model_config.runner_type - self._pooler = GritLMPooler(vllm_config.model_config) - for layer in self.model.layers: - if self.runner_type == "pooling" and hasattr(layer, "self_attn"): - assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( - "GritLM embedding is only supported by XFormers backend, " - "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - - # Change attention to non-causal for pooling tasks. - if self.runner_type == "pooling": - attn_metadata = get_forward_context().attn_metadata - assert attn_metadata.prefill_metadata.attn_bias is None - attn_metadata.prefill_metadata.attn_bias = [ - BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) - ] - - return super().forward( - input_ids=input_ids, - positions=positions, - **kwargs, - ) - def pooler( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7a3ea7a68768..c1593dcbe344 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -96,19 +96,22 @@ class LlamaMLP(nn.Module): class LlamaAttention(nn.Module): - def __init__(self, - config: LlamaConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: super().__init__() layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size @@ -194,6 +197,7 @@ class LlamaAttention(nn.Module): cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, + attn_type=attn_type, prefix=f"{prefix}.attn", ) @@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module): if hasattr(config, 'qkv_bias'): attention_bias = config.qkv_bias + # By default, Llama uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, @@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module): bias_o_proj=bias_o_proj, cache_config=cache_config, prefix=f"{prefix}.self_attn", + attn_type=attn_type, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b5850011e7fc..60f8a7cd7270 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -100,19 +100,19 @@ class Qwen2MLP(nn.Module): class Qwen2Attention(nn.Module): def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - dual_chunk_attention_config: Optional[dict[str, - Any]] = None) -> None: + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size()