[Model] GritLM supports other attention backends (#18109)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-14 18:33:19 +08:00 committed by GitHub
parent 259127f8b8
commit d62a076e84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 85 additions and 108 deletions

View File

@ -11,7 +11,6 @@ from scipy.spatial.distance import cosine
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.utils import STR_BACKEND_ENV_VAR
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
@ -117,44 +116,37 @@ def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001) assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001)
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch, def test_gritlm_offline_embedding(vllm_runner):
vllm_runner): queries, q_instruction, documents, d_instruction = get_test_data()
# GritLM embedding implementation is only supported by XFormers backend.
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
queries, q_instruction, documents, d_instruction = get_test_data() with vllm_runner(
MODEL_NAME,
task="embed",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
llm = vllm_model.model
with vllm_runner( d_rep = run_llm_encode(
MODEL_NAME, llm,
task="embed", documents,
max_model_len=MAX_MODEL_LEN, d_instruction,
) as vllm_model: )
llm = vllm_model.model q_rep = run_llm_encode(
llm,
queries,
q_instruction,
)
d_rep = run_llm_encode( validate_embed_output(q_rep, d_rep)
llm,
documents,
d_instruction,
)
q_rep = run_llm_encode(
llm,
queries,
q_instruction,
)
validate_embed_output(q_rep, d_rep)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gritlm_api_server_embedding(): async def test_gritlm_api_server_embedding():
queries, q_instruction, documents, d_instruction = get_test_data() queries, q_instruction, documents, d_instruction = get_test_data()
# GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
env_dict = {STR_BACKEND_ENV_VAR: "XFORMERS"}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server: with RemoteOpenAIServer(MODEL_NAME, args) as server:
client_embedding = server.get_async_client() client_embedding = server.get_async_client()
d_rep = await run_client_embeddings( d_rep = await run_client_embeddings(
@ -172,35 +164,28 @@ async def test_gritlm_api_server_embedding():
def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
# GritLM embedding implementation is only supported by XFormers backend. input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" with vllm_runner(
MODEL_NAME,
task="generate",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
llm = vllm_model.model
with vllm_runner( sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
MODEL_NAME, outputs = llm.generate(input, sampling_params=sampling_params)
task="generate",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
llm = vllm_model.model
sampling_params = SamplingParams(temperature=0.0, max_tokens=256) assert outputs[0].outputs[0].text == "The capital of France is Paris."
outputs = llm.generate(input, sampling_params=sampling_params)
assert outputs[0].outputs[0].text == "The capital of France is Paris."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gritlm_api_server_generate(): async def test_gritlm_api_server_generate():
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
# GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
env_dict = {"VLLM_USE_V1": "0", STR_BACKEND_ENV_VAR: "XFORMERS"}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server: with RemoteOpenAIServer(MODEL_NAME, args) as server:
client_generate = server.get_async_client() client_generate = server.get_async_client()
outputs = await client_generate.completions.create( outputs = await client_generate.completions.create(

View File

@ -1,22 +1,18 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from array import array from array import array
from typing import Optional, Union from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata, from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors) PoolingTensors)
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
PoolingSequenceGroupOutput)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsV0Only from .interfaces import SupportsV0Only
@ -204,39 +200,21 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
prefix: str = "", prefix: str = "",
**kwargs, **kwargs,
) -> None: ) -> None:
# Use full attention for pooling
if vllm_config.model_config.runner_type == "pooling":
hf_config = vllm_config.model_config.hf_config
hf_config.is_causal = False
vllm_config.cache_config.sliding_window = None
for attr in ("sliding_window", "interleaved_sliding_window"):
if hasattr(hf_config, attr):
delattr(hf_config, attr)
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.runner_type = vllm_config.model_config.runner_type
self._pooler = GritLMPooler(vllm_config.model_config) self._pooler = GritLMPooler(vllm_config.model_config)
for layer in self.model.layers:
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
"GritLM embedding is only supported by XFormers backend, "
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
# Change attention to non-causal for pooling tasks.
if self.runner_type == "pooling":
attn_metadata = get_forward_context().attn_metadata
assert attn_metadata.prefill_metadata.attn_bias is None
attn_metadata.prefill_metadata.attn_bias = [
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
]
return super().forward(
input_ids=input_ids,
positions=positions,
**kwargs,
)
def pooler( def pooler(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -28,7 +28,7 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import Attention from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -96,19 +96,22 @@ class LlamaMLP(nn.Module):
class LlamaAttention(nn.Module): class LlamaAttention(nn.Module):
def __init__(self, def __init__(
config: LlamaConfig, self,
hidden_size: int, config: LlamaConfig,
num_heads: int, hidden_size: int,
num_kv_heads: int, num_heads: int,
rope_theta: float = 10000, num_kv_heads: int,
rope_scaling: Optional[Dict[str, Any]] = None, rope_theta: float = 10000,
max_position_embeddings: int = 8192, rope_scaling: Optional[Dict[str, Any]] = None,
quant_config: Optional[QuantizationConfig] = None, max_position_embeddings: int = 8192,
bias: bool = False, quant_config: Optional[QuantizationConfig] = None,
bias_o_proj: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, bias_o_proj: bool = False,
prefix: str = "") -> None: cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__() super().__init__()
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -194,6 +197,7 @@ class LlamaAttention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
per_layer_sliding_window=sliding_window, per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module):
if hasattr(config, 'qkv_bias'): if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias attention_bias = config.qkv_bias
# By default, Llama uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. parasail-ai/GritLM-7B-vllm)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module):
bias_o_proj=bias_o_proj, bias_o_proj=bias_o_proj,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_type=attn_type,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,

View File

@ -100,19 +100,19 @@ class Qwen2MLP(nn.Module):
class Qwen2Attention(nn.Module): class Qwen2Attention(nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None, rope_scaling: Optional[Tuple] = None,
prefix: str = "", prefix: str = "",
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, dual_chunk_attention_config: Optional[dict[str, Any]] = None,
Any]] = None) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()