[Model] Support is_causal HF config field for Qwen2 model (#10621)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-25 17:51:20 +08:00 committed by GitHub
parent 05d1f8c9c6
commit ed46f14321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 51 additions and 13 deletions

View File

@ -342,7 +342,7 @@ Text Embedding
- ✅︎ - ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based - Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc. - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM` * - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
@ -363,6 +363,13 @@ Text Embedding
.. tip:: .. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`. You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set `--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
despite being described otherwise on its model card.
Reward Modeling Reward Modeling
--------------- ---------------
@ -606,10 +613,10 @@ Text Generation
| :sup:`+` Multiple items can be inputted per text prompt for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality.
.. note:: .. note::
vLLM currently only supports adding LoRA to the language backbone of multimodal models. vLLM currently only supports adding LoRA to the language backbone of multimodal models.
.. note:: .. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Multimodal Embedding Multimodal Embedding

View File

@ -21,6 +21,7 @@ from ..utils import check_embeddings_close
marks=[pytest.mark.core_model]), marks=[pytest.mark.core_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
], ],
) )
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@ -31,6 +32,10 @@ def test_models(
model, model,
dtype: str, dtype: str,
) -> None: ) -> None:
vllm_extra_kwargs = {}
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
# The example_prompts has ending "\n", for example: # The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n" # "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see: # sentence_transformers will strip the input texts, see:
@ -43,8 +48,11 @@ def test_models(
is_sentence_transformer=True) as hf_model: is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts) hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model, task="embedding", dtype=dtype, with vllm_runner(model,
max_model_len=None) as vllm_model: task="embedding",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.encode(example_prompts)
# This test is for verifying whether the model's extra_repr # This test is for verifying whether the model's extra_repr
# can be printed correctly. # can be printed correctly.

View File

@ -24,7 +24,7 @@ def check_embeddings_close(
dim=0) dim=0)
fail_msg = (f"Test{prompt_idx}:" fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{embeddings_0!r}" f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1!r}") f"\n{name_1}:\t{embeddings_1[:16]!r}")
assert sim >= 1 - tol, fail_msg assert sim >= 1 - tol, fail_msg

View File

@ -27,7 +27,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config, get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once, resolve_obj_by_qualname) print_warning_once, resolve_obj_by_qualname)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
@ -183,7 +183,7 @@ class ModelConfig:
hf_overrides_fn = hf_overrides hf_overrides_fn = hf_overrides
else: else:
hf_overrides_kw = hf_overrides hf_overrides_kw = hf_overrides
hf_overrides_fn = identity hf_overrides_fn = None
if rope_scaling is not None: if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
@ -212,8 +212,15 @@ class ModelConfig:
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
hf_config = get_config(self.model, trust_remote_code, revision, hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format, **hf_overrides_kw) code_revision, config_format)
hf_config = hf_overrides_fn(hf_config)
if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.info("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config)
self.hf_config = hf_config self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)

View File

@ -27,7 +27,7 @@ import torch
from torch import nn from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -164,11 +164,17 @@ class Qwen2Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=attn_type)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -210,6 +216,15 @@ class Qwen2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
self._attn_type = AttentionType.DECODER
else:
self._attn_type = AttentionType.ENCODER_ONLY
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
@ -230,6 +245,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
attn_type=self._attn_type,
) )
# Fully Connected # Fully Connected