mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:25:29 +08:00
[Model] Support is_causal HF config field for Qwen2 model (#10621)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
05d1f8c9c6
commit
ed46f14321
@ -342,7 +342,7 @@ Text Embedding
|
|||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
|
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
|
||||||
- Qwen2-based
|
- Qwen2-based
|
||||||
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc.
|
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
|
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
|
||||||
@ -363,6 +363,13 @@ Text Embedding
|
|||||||
.. tip::
|
.. tip::
|
||||||
You can override the model's pooling method by passing :code:`--override-pooler-config`.
|
You can override the model's pooling method by passing :code:`--override-pooler-config`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
|
||||||
|
You can set `--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
|
||||||
|
|
||||||
|
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
|
||||||
|
despite being described otherwise on its model card.
|
||||||
|
|
||||||
Reward Modeling
|
Reward Modeling
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
@ -606,10 +613,10 @@ Text Generation
|
|||||||
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
|
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
|
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||||
|
|
||||||
Multimodal Embedding
|
Multimodal Embedding
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from ..utils import check_embeddings_close
|
|||||||
marks=[pytest.mark.core_model]),
|
marks=[pytest.mark.core_model]),
|
||||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||||
|
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@ -31,6 +32,10 @@ def test_models(
|
|||||||
model,
|
model,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
vllm_extra_kwargs = {}
|
||||||
|
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
|
||||||
|
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
|
||||||
|
|
||||||
# The example_prompts has ending "\n", for example:
|
# The example_prompts has ending "\n", for example:
|
||||||
# "Write a short story about a robot that dreams for the first time.\n"
|
# "Write a short story about a robot that dreams for the first time.\n"
|
||||||
# sentence_transformers will strip the input texts, see:
|
# sentence_transformers will strip the input texts, see:
|
||||||
@ -43,8 +48,11 @@ def test_models(
|
|||||||
is_sentence_transformer=True) as hf_model:
|
is_sentence_transformer=True) as hf_model:
|
||||||
hf_outputs = hf_model.encode(example_prompts)
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(model, task="embedding", dtype=dtype,
|
with vllm_runner(model,
|
||||||
max_model_len=None) as vllm_model:
|
task="embedding",
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=None,
|
||||||
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(example_prompts)
|
vllm_outputs = vllm_model.encode(example_prompts)
|
||||||
# This test is for verifying whether the model's extra_repr
|
# This test is for verifying whether the model's extra_repr
|
||||||
# can be printed correctly.
|
# can be printed correctly.
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def check_embeddings_close(
|
|||||||
dim=0)
|
dim=0)
|
||||||
|
|
||||||
fail_msg = (f"Test{prompt_idx}:"
|
fail_msg = (f"Test{prompt_idx}:"
|
||||||
f"\n{name_0}:\t{embeddings_0!r}"
|
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
||||||
f"\n{name_1}:\t{embeddings_1!r}")
|
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
||||||
|
|
||||||
assert sim >= 1 - tol, fail_msg
|
assert sim >= 1 - tol, fail_msg
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.transformers_utils.config import (
|
|||||||
get_hf_text_config, get_pooling_config,
|
get_hf_text_config, get_pooling_config,
|
||||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||||
identity, print_warning_once, resolve_obj_by_qualname)
|
print_warning_once, resolve_obj_by_qualname)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
@ -183,7 +183,7 @@ class ModelConfig:
|
|||||||
hf_overrides_fn = hf_overrides
|
hf_overrides_fn = hf_overrides
|
||||||
else:
|
else:
|
||||||
hf_overrides_kw = hf_overrides
|
hf_overrides_kw = hf_overrides
|
||||||
hf_overrides_fn = identity
|
hf_overrides_fn = None
|
||||||
|
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
|
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
|
||||||
@ -212,8 +212,15 @@ class ModelConfig:
|
|||||||
self.skip_tokenizer_init = skip_tokenizer_init
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
|
|
||||||
hf_config = get_config(self.model, trust_remote_code, revision,
|
hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
code_revision, config_format, **hf_overrides_kw)
|
code_revision, config_format)
|
||||||
hf_config = hf_overrides_fn(hf_config)
|
|
||||||
|
if hf_overrides_kw:
|
||||||
|
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
||||||
|
hf_config.update(hf_overrides_kw)
|
||||||
|
if hf_overrides_fn:
|
||||||
|
logger.info("Overriding HF config with %s", hf_overrides_fn)
|
||||||
|
hf_config = hf_overrides_fn(hf_config)
|
||||||
|
|
||||||
self.hf_config = hf_config
|
self.hf_config = hf_config
|
||||||
|
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
|
|||||||
@ -27,7 +27,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Qwen2Config
|
from transformers import Qwen2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -164,11 +164,17 @@ class Qwen2Attention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=attn_type)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -210,6 +216,15 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||||
|
# You can override the HF config with `is_causal=False` to enable
|
||||||
|
# bidirectional attention, which is used in some embedding models
|
||||||
|
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||||
|
if getattr(config, "is_causal", True):
|
||||||
|
self._attn_type = AttentionType.DECODER
|
||||||
|
else:
|
||||||
|
self._attn_type = AttentionType.ENCODER_ONLY
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -230,6 +245,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
|
attn_type=self._attn_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user