[Bugfix] Fix ModernBert cuda graph capturing in v1 (#21901)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-08-09 13:17:22 +08:00 committed by GitHub
parent 35afe1b30b
commit 429e4e2d42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 39 additions and 42 deletions

View File

@ -162,7 +162,8 @@ def mteb_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None,
hf_model_callback=None):
hf_model_callback=None,
atol=MTEB_RERANK_TOL):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
@ -198,7 +199,7 @@ def mteb_test_embed_models(hf_runner,
print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score)
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
assert st_main_score == pytest.approx(vllm_main_score, abs=atol)
def run_mteb_rerank(cross_encoder, tasks, languages):

View File

@ -466,7 +466,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: torch.Tensor,
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,

View File

@ -8,13 +8,15 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, torch_vllm_outplace_fused_experts)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
@ -284,15 +286,22 @@ class NomicMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=False,
inplace=False,
activation=self.hidden_act,
is_act_and_mul=False)
# FIXME(Isotr0py): This implementation is too tricky,
# we should use FusedMoE instead in the future
# after supporting ungated activation for it.
topk_weights, topk_ids, _ = fused_topk(hidden_states,
router_logits,
self.top_k,
renormalize=False)
final_hidden_states = torch_vllm_outplace_fused_experts(
hidden_states=hidden_states,
w1=self.w1,
w2=self.w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=self.hidden_act,
is_act_and_mul=False,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
@ -391,6 +400,7 @@ class BertWithRopeEncoder(nn.Module):
return hidden_states
@support_torch_compile
class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
@ -407,7 +417,7 @@ class BertWithRope(nn.Module, SupportsQuant):
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
@ -554,20 +564,6 @@ class JinaRobertaModel(BertWithRope):
"norm2": "mlp_ln",
})
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().forward(input_ids=input_ids,
positions=position_ids,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids)
@torch.inference_mode()
def jina_merge_lora_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]):

View File

@ -8,6 +8,7 @@ from torch import nn
from transformers import ModernBertConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -46,7 +47,7 @@ class ModernBertEmbeddings(nn.Module):
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds:
if inputs_embeds is not None:
return self.norm(inputs_embeds)
else:
inputs_embeds = self.tok_embeddings(input_ids)
@ -117,7 +118,7 @@ class ModernBertAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
position_ids: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
@ -169,9 +170,9 @@ class ModernBertLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
):
attn_outputs = self.attn(self.attn_norm(hidden_states),
position_ids: torch.Tensor,
) -> torch.Tensor:
attn_outputs = self.attn(hidden_states=self.attn_norm(hidden_states),
position_ids=position_ids)
hidden_states = hidden_states + attn_outputs
mlp_output = self.mlp(self.mlp_norm(hidden_states))
@ -192,13 +193,14 @@ class ModernBertEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
position_ids: torch.Tensor,
) -> torch.Tensor:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, position_ids)
return hidden_states
@support_torch_compile
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
@ -234,13 +236,11 @@ class ModernBertModel(nn.Module):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
positions: Optional[torch.Tensor] = None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
position_ids = positions if positions is not None else position_ids
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
@ -249,7 +249,7 @@ class ModernBertModel(nn.Module):
outputs = self.encoder_layer(
hidden_states=hidden_states,
position_ids=position_ids,
position_ids=positions,
)
norm_outputs = self.final_norm(outputs)
return norm_outputs

View File

@ -105,7 +105,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: torch.Tensor,
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
@ -119,8 +119,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
position_ids=positions,
padding_idx=self.padding_idx)
return self.model(input_ids=input_ids,
position_ids=positions,
return self.model(input_ids,
positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)