diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0202ba5a6097f..6b0ceaf21951f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -401,7 +401,7 @@ Specified using `--task embed`. !!! note `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. - You should manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. + You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. !!! note For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded. diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 0d0d98c59dbc7..a664864ff898f 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -34,32 +34,27 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - extract_layer_index, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -logger = init_logger(__name__) - class Qwen2MLP(nn.Module): @@ -499,69 +494,3 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - # TODO: Replace this model class with as_embedding_model( - # Qwen2ForCausalLM) after changing the default pooling method - if pooler_config.pooling_type is None: - logger.warning( - "This embedding model will default to last-token pooling in " - "an upcoming version. To avoid breaking changes, you should " - "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" - " explicitly.") - - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.MEAN, - normalize=True, - softmax=False) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: - return self.model(input_ids, positions, intermediate_tensors) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 97ea12de65373..5a6a12fccdd0d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -142,7 +142,7 @@ _EMBEDDING_MODELS = { "ModernBertModel": ("modernbert", "ModernBertModel"), "NomicBertModel": ("bert_with_rope", "NomicBertModel"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), - "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), + "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),