mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
[Bugfix] Fix ModernBert load & Enable sliding window attention for bidirectional attention. (#22637)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
2f4657952b
commit
6d729c43fb
@ -4,10 +4,11 @@ from typing import Any
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
|
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
|
||||||
LASTPoolingEmbedModelInfo, check_transformers_version)
|
EmbedModelInfo, LASTPoolingEmbedModelInfo,
|
||||||
|
RerankModelInfo, check_transformers_version)
|
||||||
from .embed_utils import correctness_test_embed_models
|
from .embed_utils import correctness_test_embed_models
|
||||||
from .mteb_utils import mteb_test_embed_models
|
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
########## BertModel
|
########## BertModel
|
||||||
@ -58,6 +59,14 @@ MODELS = [
|
|||||||
enable_test=False),
|
enable_test=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
RERANK_MODELS = [
|
||||||
|
# classifier_pooling: mean
|
||||||
|
CLSPoolingRerankModelInfo(
|
||||||
|
"Alibaba-NLP/gte-reranker-modernbert-base",
|
||||||
|
architecture="ModernBertForSequenceClassification",
|
||||||
|
enable_test=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", MODELS)
|
@pytest.mark.parametrize("model_info", MODELS)
|
||||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||||
@ -88,3 +97,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
|
|||||||
|
|
||||||
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
|
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
|
||||||
example_prompts, vllm_extra_kwargs)
|
example_prompts, vllm_extra_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||||
|
def test_rerank_models_mteb(hf_runner, vllm_runner,
|
||||||
|
model_info: RerankModelInfo) -> None:
|
||||||
|
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||||
|
|||||||
@ -26,8 +26,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
|
from .interfaces import SupportsCrossEncoding, default_pooling_type
|
||||||
default_pooling_type)
|
|
||||||
from .utils import WeightsMapper, maybe_prefix
|
from .utils import WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@ -93,16 +92,14 @@ class ModernBertAttention(nn.Module):
|
|||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sliding_window = None
|
||||||
if layer_id % config.global_attn_every_n_layers != 0:
|
if layer_id % config.global_attn_every_n_layers != 0:
|
||||||
self.local_attention = (config.local_attention // 2,
|
sliding_window = config.local_attention // 2
|
||||||
config.local_attention // 2)
|
rope_theta = config.local_rope_theta if config.local_rope_theta \
|
||||||
|
is not None else config.global_rope_theta
|
||||||
else:
|
else:
|
||||||
self.local_attention = (-1, -1)
|
rope_theta = config.global_rope_theta
|
||||||
|
|
||||||
rope_theta = config.global_rope_theta
|
|
||||||
if self.local_attention != (
|
|
||||||
-1, -1) and config.local_rope_theta is not None:
|
|
||||||
rope_theta = config.local_rope_theta
|
|
||||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
|
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
|
||||||
head_size=self.head_dim,
|
head_size=self.head_dim,
|
||||||
dim=self.head_dim,
|
dim=self.head_dim,
|
||||||
@ -111,7 +108,8 @@ class ModernBertAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
prefix=f"{layer_id}.attn",
|
prefix=f"{layer_id}.attn",
|
||||||
attn_type=AttentionType.ENCODER_ONLY)
|
attn_type=AttentionType.ENCODER_ONLY,
|
||||||
|
per_layer_sliding_window=sliding_window)
|
||||||
self.Wo = RowParallelLinear(config.hidden_size,
|
self.Wo = RowParallelLinear(config.hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=config.attention_bias)
|
bias=config.attention_bias)
|
||||||
@ -278,6 +276,7 @@ class ModernBertPooler(Pooler):
|
|||||||
return self.pooling.get_pooling_updates(task)
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
def _head(self, pooled_output: torch.Tensor):
|
def _head(self, pooled_output: torch.Tensor):
|
||||||
|
pooled_output = pooled_output.to(self.dense.weight.dtype)
|
||||||
return self.norm(self.act(self.dense(pooled_output)))
|
return self.norm(self.act(self.dense(pooled_output)))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -296,8 +295,7 @@ class ModernBertPooler(Pooler):
|
|||||||
|
|
||||||
|
|
||||||
@default_pooling_type("CLS")
|
@default_pooling_type("CLS")
|
||||||
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||||
SupportsCrossEncoding):
|
|
||||||
|
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
|
|
||||||
@ -308,6 +306,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
|||||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "modernbert"))
|
prefix=maybe_prefix(prefix, "modernbert"))
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
self.pooling = ModernBertPooler(config)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
@ -317,14 +316,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
|||||||
Pooler.for_encode(pooler_config),
|
Pooler.for_encode(pooler_config),
|
||||||
"classify":
|
"classify":
|
||||||
ClassifierPooler(
|
ClassifierPooler(
|
||||||
pooling=ModernBertPooler(config),
|
pooling=self.pooling,
|
||||||
classifier=self.classifier,
|
classifier=self.classifier,
|
||||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||||
vllm_config.model_config),
|
vllm_config.model_config),
|
||||||
),
|
),
|
||||||
"score":
|
"score":
|
||||||
ClassifierPooler(
|
ClassifierPooler(
|
||||||
pooling=ModernBertPooler(config),
|
pooling=self.pooling,
|
||||||
classifier=self.classifier,
|
classifier=self.classifier,
|
||||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||||
vllm_config.model_config),
|
vllm_config.model_config),
|
||||||
@ -353,7 +352,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
|||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
if name.startswith("head"):
|
if name.startswith("head"):
|
||||||
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
|
param = params_dict["pooling." + name[len("head") + 1:]]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
@ -368,5 +367,5 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
|||||||
return self.model(
|
return self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=positions,
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -384,6 +384,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
if sliding_window is None:
|
if sliding_window is None:
|
||||||
self.sliding_window = (-1, -1)
|
self.sliding_window = (-1, -1)
|
||||||
|
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||||
else:
|
else:
|
||||||
self.sliding_window = (sliding_window - 1, 0)
|
self.sliding_window = (sliding_window - 1, 0)
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|||||||
@ -826,7 +826,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Prepare encoder attention metadata separately
|
# Prepare encoder attention metadata separately
|
||||||
# (encoder layers are not in KV cache groups)
|
# (encoder layers are not in KV cache groups)
|
||||||
if self.is_encoder_only_model:
|
if self.is_encoder_only_model:
|
||||||
common_attn_metadata, encoder_attn_metadata = \
|
|
||||||
|
per_layer_metadata = \
|
||||||
self._build_encoder_only_attn_metadata(
|
self._build_encoder_only_attn_metadata(
|
||||||
scheduler_output)
|
scheduler_output)
|
||||||
|
|
||||||
@ -835,6 +836,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.vllm_config, Attention)
|
self.vllm_config, Attention)
|
||||||
for layer_name, attn_module in attention_layers.items():
|
for layer_name, attn_module in attention_layers.items():
|
||||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
common_attn_metadata, encoder_attn_metadata =\
|
||||||
|
per_layer_metadata[layer_name]
|
||||||
attn_metadata[layer_name] = encoder_attn_metadata
|
attn_metadata[layer_name] = encoder_attn_metadata
|
||||||
|
|
||||||
# Prepare the attention metadata for each KV cache group and make layers
|
# Prepare the attention metadata for each KV cache group and make layers
|
||||||
@ -2683,30 +2686,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Check if model is encoder-only
|
# Check if model is encoder-only
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
use_mla = self.vllm_config.model_config.use_mla
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
attn_specs = list[AttentionSpec]()
|
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
|
||||||
for attn_module in attn_layers.values():
|
for layer_name, attn_module in attn_layers.items():
|
||||||
|
|
||||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
assert attn_module.sliding_window is None, "Sliding "
|
if attn_module.sliding_window is None:
|
||||||
"window attention is not supported for encoder-only models"
|
attn_spec: AttentionSpec = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
use_mla=use_mla)
|
||||||
|
else:
|
||||||
|
attn_spec = SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
sliding_window=attn_module.sliding_window,
|
||||||
|
use_mla=use_mla)
|
||||||
|
attn_specs[attn_spec].append(layer_name)
|
||||||
|
|
||||||
attn_specs.append(
|
|
||||||
FullAttentionSpec(block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
use_mla=use_mla))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Expected only encoder-only layers")
|
raise ValueError("Expected only encoder-only layers")
|
||||||
|
|
||||||
if len(attn_specs) > 0:
|
if len(attn_specs) > 0:
|
||||||
assert len(attn_specs) == len(attn_layers), \
|
total_layers = 0
|
||||||
|
for attn_spec, layer_names in attn_specs.items():
|
||||||
|
|
||||||
|
attn_backends = get_attn_backends_for_layers(layer_names)
|
||||||
|
total_layers += len(layer_names)
|
||||||
|
|
||||||
|
self.attn_groups.append(
|
||||||
|
create_attn_groups(attn_backends, attn_spec))
|
||||||
|
assert total_layers == len(attn_layers), \
|
||||||
"All or none of the layers are expected to be encoder-only"
|
"All or none of the layers are expected to be encoder-only"
|
||||||
|
|
||||||
attn_backends = get_attn_backends_for_layers(attn_layers.keys())
|
|
||||||
|
|
||||||
self.attn_groups.append(
|
|
||||||
create_attn_groups(attn_backends, attn_specs[0]))
|
|
||||||
self.is_encoder_only_model = True
|
self.is_encoder_only_model = True
|
||||||
|
|
||||||
def calculate_reorder_batch_threshold(self) -> None:
|
def calculate_reorder_batch_threshold(self) -> None:
|
||||||
@ -3071,7 +3085,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def _build_encoder_only_attn_metadata(
|
def _build_encoder_only_attn_metadata(
|
||||||
self, scheduler_output: "SchedulerOutput") -> \
|
self, scheduler_output: "SchedulerOutput") -> \
|
||||||
tuple[CommonAttentionMetadata, Any]:
|
dict[str, tuple[CommonAttentionMetadata, Any]]:
|
||||||
"""Prepare encoder attention metadata for encoder-only models.
|
"""Prepare encoder attention metadata for encoder-only models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -3088,10 +3102,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||||
max_num_scheduled_tokens = max(tokens)
|
max_num_scheduled_tokens = max(tokens)
|
||||||
|
|
||||||
# Use the first attention metadata builder
|
|
||||||
# to create encoder attention metadata
|
|
||||||
builder = self.attn_groups[0][0].metadata_builder
|
|
||||||
|
|
||||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
@ -3099,22 +3109,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
common_metadata = CommonAttentionMetadata(
|
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
|
||||||
seq_lens=self.seq_lens[:num_reqs],
|
|
||||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
|
||||||
num_computed_tokens_cpu=self.input_batch.
|
|
||||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
|
||||||
num_reqs=num_reqs,
|
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
|
||||||
max_query_len=max_num_scheduled_tokens,
|
|
||||||
block_table_tensor=dummy_block_table,
|
|
||||||
slot_mapping=dummy_slot_mapping,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return common_metadata, builder.build(
|
for attn_group_list in self.attn_groups:
|
||||||
common_prefix_len=0, # No cascade for encoder
|
|
||||||
common_attn_metadata=common_metadata,
|
assert len(attn_group_list) == 1
|
||||||
)
|
attn_group = attn_group_list[0]
|
||||||
|
|
||||||
|
# Use the first attention metadata builder
|
||||||
|
# to create encoder attention metadata
|
||||||
|
builder = attn_group.metadata_builder
|
||||||
|
|
||||||
|
common_metadata = CommonAttentionMetadata(
|
||||||
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||||
|
seq_lens=self.seq_lens[:num_reqs],
|
||||||
|
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||||
|
num_computed_tokens_cpu=self.input_batch.
|
||||||
|
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
|
max_query_len=max_num_scheduled_tokens,
|
||||||
|
block_table_tensor=dummy_block_table,
|
||||||
|
slot_mapping=dummy_slot_mapping,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = builder.build(
|
||||||
|
common_prefix_len=0, # No cascade for encoder
|
||||||
|
common_attn_metadata=common_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_name in attn_group.layer_names:
|
||||||
|
group_metadata[layer_name] = (common_metadata, metadata)
|
||||||
|
|
||||||
|
return group_metadata
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user