mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 11:05:01 +08:00
[Model] GritLM supports other attention backends (#18109)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
259127f8b8
commit
d62a076e84
@ -11,7 +11,6 @@ from scipy.spatial.distance import cosine
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
|
||||||
|
|
||||||
from ....utils import RemoteOpenAIServer
|
from ....utils import RemoteOpenAIServer
|
||||||
|
|
||||||
@ -117,44 +116,37 @@ def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
|
|||||||
assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001)
|
assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001)
|
||||||
|
|
||||||
|
|
||||||
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch,
|
def test_gritlm_offline_embedding(vllm_runner):
|
||||||
vllm_runner):
|
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||||
# GritLM embedding implementation is only supported by XFormers backend.
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
|
||||||
|
|
||||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
with vllm_runner(
|
||||||
|
MODEL_NAME,
|
||||||
|
task="embed",
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
) as vllm_model:
|
||||||
|
llm = vllm_model.model
|
||||||
|
|
||||||
with vllm_runner(
|
d_rep = run_llm_encode(
|
||||||
MODEL_NAME,
|
llm,
|
||||||
task="embed",
|
documents,
|
||||||
max_model_len=MAX_MODEL_LEN,
|
d_instruction,
|
||||||
) as vllm_model:
|
)
|
||||||
llm = vllm_model.model
|
q_rep = run_llm_encode(
|
||||||
|
llm,
|
||||||
|
queries,
|
||||||
|
q_instruction,
|
||||||
|
)
|
||||||
|
|
||||||
d_rep = run_llm_encode(
|
validate_embed_output(q_rep, d_rep)
|
||||||
llm,
|
|
||||||
documents,
|
|
||||||
d_instruction,
|
|
||||||
)
|
|
||||||
q_rep = run_llm_encode(
|
|
||||||
llm,
|
|
||||||
queries,
|
|
||||||
q_instruction,
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_embed_output(q_rep, d_rep)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gritlm_api_server_embedding():
|
async def test_gritlm_api_server_embedding():
|
||||||
queries, q_instruction, documents, d_instruction = get_test_data()
|
queries, q_instruction, documents, d_instruction = get_test_data()
|
||||||
|
|
||||||
# GritLM embedding implementation is only supported by XFormers backend.
|
|
||||||
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
|
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||||
env_dict = {STR_BACKEND_ENV_VAR: "XFORMERS"}
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
||||||
client_embedding = server.get_async_client()
|
client_embedding = server.get_async_client()
|
||||||
|
|
||||||
d_rep = await run_client_embeddings(
|
d_rep = await run_client_embeddings(
|
||||||
@ -172,35 +164,28 @@ async def test_gritlm_api_server_embedding():
|
|||||||
|
|
||||||
|
|
||||||
def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
|
def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
|
||||||
# GritLM embedding implementation is only supported by XFormers backend.
|
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv("VLLM_USE_V1", "0")
|
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
|
||||||
|
|
||||||
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
|
with vllm_runner(
|
||||||
|
MODEL_NAME,
|
||||||
|
task="generate",
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
) as vllm_model:
|
||||||
|
llm = vllm_model.model
|
||||||
|
|
||||||
with vllm_runner(
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
|
||||||
MODEL_NAME,
|
outputs = llm.generate(input, sampling_params=sampling_params)
|
||||||
task="generate",
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
) as vllm_model:
|
|
||||||
llm = vllm_model.model
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
|
assert outputs[0].outputs[0].text == "The capital of France is Paris."
|
||||||
outputs = llm.generate(input, sampling_params=sampling_params)
|
|
||||||
|
|
||||||
assert outputs[0].outputs[0].text == "The capital of France is Paris."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gritlm_api_server_generate():
|
async def test_gritlm_api_server_generate():
|
||||||
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
|
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
|
||||||
|
|
||||||
# GritLM embedding implementation is only supported by XFormers backend.
|
|
||||||
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
|
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||||
env_dict = {"VLLM_USE_V1": "0", STR_BACKEND_ENV_VAR: "XFORMERS"}
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
||||||
client_generate = server.get_async_client()
|
client_generate = server.get_async_client()
|
||||||
|
|
||||||
outputs = await client_generate.completions.create(
|
outputs = await client_generate.completions.create(
|
||||||
|
|||||||
@ -1,22 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from array import array
|
from array import array
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
from vllm.attention.backends.xformers import XFormersImpl
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.pooler import PoolerHead
|
from vllm.model_executor.layers.pooler import PoolerHead
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||||
PoolingTensors)
|
PoolingTensors)
|
||||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||||
PoolingSequenceGroupOutput)
|
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
|
||||||
from .interfaces import SupportsV0Only
|
from .interfaces import SupportsV0Only
|
||||||
@ -204,39 +200,21 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Use full attention for pooling
|
||||||
|
if vllm_config.model_config.runner_type == "pooling":
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
hf_config.is_causal = False
|
||||||
|
|
||||||
|
vllm_config.cache_config.sliding_window = None
|
||||||
|
|
||||||
|
for attr in ("sliding_window", "interleaved_sliding_window"):
|
||||||
|
if hasattr(hf_config, attr):
|
||||||
|
delattr(hf_config, attr)
|
||||||
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
|
||||||
self.runner_type = vllm_config.model_config.runner_type
|
|
||||||
|
|
||||||
self._pooler = GritLMPooler(vllm_config.model_config)
|
self._pooler = GritLMPooler(vllm_config.model_config)
|
||||||
|
|
||||||
for layer in self.model.layers:
|
|
||||||
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
|
|
||||||
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
|
|
||||||
"GritLM embedding is only supported by XFormers backend, "
|
|
||||||
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
|
|
||||||
# Change attention to non-causal for pooling tasks.
|
|
||||||
if self.runner_type == "pooling":
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
assert attn_metadata.prefill_metadata.attn_bias is None
|
|
||||||
attn_metadata.prefill_metadata.attn_bias = [
|
|
||||||
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
|
||||||
]
|
|
||||||
|
|
||||||
return super().forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def pooler(
|
def pooler(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -96,19 +96,22 @@ class LlamaMLP(nn.Module):
|
|||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
class LlamaAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
config: LlamaConfig,
|
self,
|
||||||
hidden_size: int,
|
config: LlamaConfig,
|
||||||
num_heads: int,
|
hidden_size: int,
|
||||||
num_kv_heads: int,
|
num_heads: int,
|
||||||
rope_theta: float = 10000,
|
num_kv_heads: int,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_theta: float = 10000,
|
||||||
max_position_embeddings: int = 8192,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
max_position_embeddings: int = 8192,
|
||||||
bias: bool = False,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
bias_o_proj: bool = False,
|
bias: bool = False,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
bias_o_proj: bool = False,
|
||||||
prefix: str = "") -> None:
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -194,6 +197,7 @@ class LlamaAttention(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
per_layer_sliding_window=sliding_window,
|
per_layer_sliding_window=sliding_window,
|
||||||
|
attn_type=attn_type,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
if hasattr(config, 'qkv_bias'):
|
if hasattr(config, 'qkv_bias'):
|
||||||
attention_bias = config.qkv_bias
|
attention_bias = config.qkv_bias
|
||||||
|
|
||||||
|
# By default, Llama uses causal attention as it is a decoder-only model.
|
||||||
|
# You can override the HF config with `is_causal=False` to enable
|
||||||
|
# bidirectional attention, which is used in some embedding models
|
||||||
|
# (e.g. parasail-ai/GritLM-7B-vllm)
|
||||||
|
if getattr(config, "is_causal", True):
|
||||||
|
attn_type = AttentionType.DECODER
|
||||||
|
else:
|
||||||
|
attn_type = AttentionType.ENCODER_ONLY
|
||||||
|
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
bias_o_proj=bias_o_proj,
|
bias_o_proj=bias_o_proj,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
attn_type=attn_type,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
|
|||||||
@ -100,19 +100,19 @@ class Qwen2MLP(nn.Module):
|
|||||||
class Qwen2Attention(nn.Module):
|
class Qwen2Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
max_position: int = 4096 * 32,
|
max_position: int = 4096 * 32,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
rope_scaling: Optional[Tuple] = None,
|
rope_scaling: Optional[Tuple] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
dual_chunk_attention_config: Optional[dict[str,
|
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||||
Any]] = None) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user