diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 5a5fdfbb214c..f805a64103c0 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -4,10 +4,11 @@ from typing import Any import pytest -from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, - LASTPoolingEmbedModelInfo, check_transformers_version) +from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, + EmbedModelInfo, LASTPoolingEmbedModelInfo, + RerankModelInfo, check_transformers_version) from .embed_utils import correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models +from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel @@ -58,6 +59,14 @@ MODELS = [ enable_test=False), ] +RERANK_MODELS = [ + # classifier_pooling: mean + CLSPoolingRerankModelInfo( + "Alibaba-NLP/gte-reranker-modernbert-base", + architecture="ModernBertForSequenceClassification", + enable_test=True), +] + @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, @@ -88,3 +97,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner, correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts, vllm_extra_kwargs) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(hf_runner, vllm_runner, + model_info: RerankModelInfo) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 2c3bdd1c93ae..c6e84e2d4e04 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -26,8 +26,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import (SupportsCrossEncoding, SupportsV0Only, - default_pooling_type) +from .interfaces import SupportsCrossEncoding, default_pooling_type from .utils import WeightsMapper, maybe_prefix @@ -93,16 +92,14 @@ class ModernBertAttention(nn.Module): bias=config.attention_bias, ) + sliding_window = None if layer_id % config.global_attn_every_n_layers != 0: - self.local_attention = (config.local_attention // 2, - config.local_attention // 2) + sliding_window = config.local_attention // 2 + rope_theta = config.local_rope_theta if config.local_rope_theta \ + is not None else config.global_rope_theta else: - self.local_attention = (-1, -1) + rope_theta = config.global_rope_theta - rope_theta = config.global_rope_theta - if self.local_attention != ( - -1, -1) and config.local_rope_theta is not None: - rope_theta = config.local_rope_theta self.rotary_emb = ModernBertRotaryEmbedding(config=config, head_size=self.head_dim, dim=self.head_dim, @@ -111,7 +108,8 @@ class ModernBertAttention(nn.Module): self.head_dim, self.scaling, prefix=f"{layer_id}.attn", - attn_type=AttentionType.ENCODER_ONLY) + attn_type=AttentionType.ENCODER_ONLY, + per_layer_sliding_window=sliding_window) self.Wo = RowParallelLinear(config.hidden_size, config.hidden_size, bias=config.attention_bias) @@ -278,6 +276,7 @@ class ModernBertPooler(Pooler): return self.pooling.get_pooling_updates(task) def _head(self, pooled_output: torch.Tensor): + pooled_output = pooled_output.to(self.dense.weight.dtype) return self.norm(self.act(self.dense(pooled_output))) def forward( @@ -296,8 +295,7 @@ class ModernBertPooler(Pooler): @default_pooling_type("CLS") -class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding): +class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): is_pooling_model = True @@ -308,6 +306,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -317,14 +316,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, Pooler.for_encode(pooler_config), "classify": ClassifierPooler( - pooling=ModernBertPooler(config), + pooling=self.pooling, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( vllm_config.model_config), ), "score": ClassifierPooler( - pooling=ModernBertPooler(config), + pooling=self.pooling, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( vllm_config.model_config), @@ -353,7 +352,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, default_weight_loader) weight_loader(param, loaded_weight) if name.startswith("head"): - param = params_dict["_pooler.pooler." + name[len("head") + 1:]] + param = params_dict["pooling." + name[len("head") + 1:]] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) @@ -368,5 +367,5 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, return self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, - position_ids=positions, + positions=positions, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 95ba56b35937..a411477bc3e3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -384,6 +384,8 @@ class FlashAttentionImpl(AttentionImpl): self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) + elif attn_type == AttentionType.ENCODER_ONLY: + self.sliding_window = (sliding_window - 1, sliding_window - 1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 045a06d9278d..ed4d6bcb09d4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -826,7 +826,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Prepare encoder attention metadata separately # (encoder layers are not in KV cache groups) if self.is_encoder_only_model: - common_attn_metadata, encoder_attn_metadata = \ + + per_layer_metadata = \ self._build_encoder_only_attn_metadata( scheduler_output) @@ -835,6 +836,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config, Attention) for layer_name, attn_module in attention_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: + common_attn_metadata, encoder_attn_metadata =\ + per_layer_metadata[layer_name] attn_metadata[layer_name] = encoder_attn_metadata # Prepare the attention metadata for each KV cache group and make layers @@ -2683,30 +2686,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Check if model is encoder-only block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla - attn_specs = list[AttentionSpec]() - for attn_module in attn_layers.values(): + attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) + for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: - assert attn_module.sliding_window is None, "Sliding " - "window attention is not supported for encoder-only models" + if attn_module.sliding_window is None: + attn_spec: AttentionSpec = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + else: + attn_spec = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla) + attn_specs[attn_spec].append(layer_name) - attn_specs.append( - FullAttentionSpec(block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla)) else: raise ValueError("Expected only encoder-only layers") if len(attn_specs) > 0: - assert len(attn_specs) == len(attn_layers), \ + total_layers = 0 + for attn_spec, layer_names in attn_specs.items(): + + attn_backends = get_attn_backends_for_layers(layer_names) + total_layers += len(layer_names) + + self.attn_groups.append( + create_attn_groups(attn_backends, attn_spec)) + assert total_layers == len(attn_layers), \ "All or none of the layers are expected to be encoder-only" - - attn_backends = get_attn_backends_for_layers(attn_layers.keys()) - - self.attn_groups.append( - create_attn_groups(attn_backends, attn_specs[0])) self.is_encoder_only_model = True def calculate_reorder_batch_threshold(self) -> None: @@ -3071,7 +3085,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _build_encoder_only_attn_metadata( self, scheduler_output: "SchedulerOutput") -> \ - tuple[CommonAttentionMetadata, Any]: + dict[str, tuple[CommonAttentionMetadata, Any]]: """Prepare encoder attention metadata for encoder-only models. Args: @@ -3088,10 +3102,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] max_num_scheduled_tokens = max(tokens) - # Use the first attention metadata builder - # to create encoder attention metadata - builder = self.attn_groups[0][0].metadata_builder - dummy_block_table = torch.zeros((num_reqs, 1), dtype=torch.int32, device=self.device) @@ -3099,22 +3109,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32, device=self.device) - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - block_table_tensor=dummy_block_table, - slot_mapping=dummy_slot_mapping, - causal=False, - ) + group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]() - return common_metadata, builder.build( - common_prefix_len=0, # No cascade for encoder - common_attn_metadata=common_metadata, - ) + for attn_group_list in self.attn_groups: + + assert len(attn_group_list) == 1 + attn_group = attn_group_list[0] + + # Use the first attention metadata builder + # to create encoder attention metadata + builder = attn_group.metadata_builder + + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + block_table_tensor=dummy_block_table, + slot_mapping=dummy_slot_mapping, + causal=False, + ) + + metadata = builder.build( + common_prefix_len=0, # No cascade for encoder + common_attn_metadata=common_metadata, + ) + + for layer_name in attn_group.layer_names: + group_metadata[layer_name] = (common_metadata, metadata) + + return group_metadata