mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:54:33 +08:00
Fix some Transformers nightly tests (#29802)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
13ea39bc09
commit
f5b0846ba0
@ -29,7 +29,7 @@ logger = init_logger(__name__)
|
||||
class JinaVLScorer(nn.Module):
|
||||
def __init__(self, model_config: "ModelConfig"):
|
||||
super().__init__()
|
||||
config = model_config.hf_config
|
||||
config = model_config.hf_config.get_text_config()
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = ColumnParallelLinear(
|
||||
config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -62,19 +62,6 @@ class ModernBertEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class ModernBertRotaryEmbedding(RotaryEmbedding):
|
||||
def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float):
|
||||
super().__init__(
|
||||
head_size=head_size,
|
||||
rotary_dim=dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=True,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
|
||||
class ModernBertAttention(nn.Module):
|
||||
def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
|
||||
super().__init__()
|
||||
@ -95,19 +82,33 @@ class ModernBertAttention(nn.Module):
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
sliding_window = None
|
||||
if layer_id % config.global_attn_every_n_layers != 0:
|
||||
sliding_window = config.local_attention // 2
|
||||
rope_theta = (
|
||||
config.local_rope_theta
|
||||
if config.local_rope_theta is not None
|
||||
else config.global_rope_theta
|
||||
)
|
||||
if layer_types := getattr(config, "layer_types", None):
|
||||
# Transformers v5
|
||||
layer_type = layer_types[layer_id]
|
||||
rope_parameters = config.rope_parameters[layer_type]
|
||||
sliding_window: int | None = None
|
||||
if layer_type == "sliding_attention":
|
||||
sliding_window = config.local_attention // 2
|
||||
else:
|
||||
rope_theta = config.global_rope_theta
|
||||
# Transformers v4
|
||||
sliding_window = None
|
||||
if layer_id % config.global_attn_every_n_layers != 0:
|
||||
sliding_window = config.local_attention // 2
|
||||
rope_theta = (
|
||||
config.local_rope_theta
|
||||
if config.local_rope_theta is not None
|
||||
else config.global_rope_theta
|
||||
)
|
||||
else:
|
||||
rope_theta = config.global_rope_theta
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
|
||||
self.rotary_emb = ModernBertRotaryEmbedding(
|
||||
config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
self.attn = EncoderOnlyAttention(
|
||||
self.num_heads,
|
||||
|
||||
@ -503,7 +503,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_config.get_text_config()
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user