diff --git a/tests/models/language/pooling/test_splade_sparse_pooler.py b/tests/models/language/pooling/test_splade_sparse_pooler.py new file mode 100644 index 000000000000..636a6f2f9d74 --- /dev/null +++ b/tests/models/language/pooling/test_splade_sparse_pooler.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from vllm.model_executor.models.bert import ( + BertMLMHead, + SPLADESparsePooler, +) + +# --------------------------------------------------------------------- +# 1) Functional test: SPLADE formula correctness (no HF download needed) +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)]) +def test_splade_pooler_matches_reference_formula(B, T, H, V): + """Ensure SPLADESparsePooler forward() matches the mathematical formula: + log1p(relu(logits)) -> max over sequence length (after masking).""" + torch.manual_seed(0) + + # Prepare [B] sequences of shape [T, H] + hs_list = [torch.randn(T, H) for _ in range(B)] + + # Simulate PoolingMetadata (only required fields) + prompt_lens = [T, T - 1] + token_ids = torch.tensor( + [ + [101, 5, 102], # Batch 0: [CLS], token, [SEP] + [101, 6, 6], # Batch 1: [CLS], token, token (last token ignored) + ], + dtype=torch.long, + ) + meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids) + + # MLM head (prefer BertMLMHead, fallback to Linear if unavailable) + try: + mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12) + except Exception: + mlm_head = nn.Linear(H, V, bias=True) + + # Forward pass through SPLADE pooler + pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True) + pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V] + + # Basic output checks + assert isinstance(pooled, list) and len(pooled) == B + for vec in pooled: + assert vec.shape == (V,) + assert torch.isfinite(vec).all() + assert (vec >= 0).all(), "SPLADE outputs must be non-negative." + + # Reference implementation for comparison + def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor: + keep = torch.ones(L, dtype=torch.bool) + if L > 0 and tid_row[0].item() == 101: # remove CLS + keep[0] = False + if L > 0 and tid_row[L - 1].item() == 102: # remove SEP + keep[L - 1] = False + + valid = hs[:L][keep[:L]] + if valid.numel() == 0: + return torch.zeros(V, dtype=torch.float32) + + logits = mlm_head(valid) # [L', V] + scores = torch.log1p(torch.relu(logits)) # [L', V] + return scores.max(dim=0).values.to(torch.float32) + + torch.testing.assert_close( + pooled[0], + ref_one(hs_list[0], prompt_lens[0], token_ids[0]), + rtol=1e-4, + atol=1e-4, + ) + torch.testing.assert_close( + pooled[1], + ref_one(hs_list[1], prompt_lens[1], token_ids[1]), + rtol=1e-4, + atol=1e-4, + ) + + +# --------------------------------------------------------------------- +# 2) Integration smoke test: end-to-end embedding path wiring +# --------------------------------------------------------------------- + + +@pytest.mark.cpu_model +def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch): + """Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings.""" + from transformers import AutoTokenizer + + MODEL_ID = "hf-internal-testing/tiny-random-bert" + hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]} + + # Enforce CPU-only execution (optional) + monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "") + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + tok = AutoTokenizer.from_pretrained(MODEL_ID) + vocab_size = tok.vocab_size + + # The embed path should route through SPLADESparsePooler + with vllm_runner( + MODEL_ID, + runner="pooling", + max_model_len=64, + hf_overrides=hf_overrides, + ) as vm: + outs = vm.embed(["hello world", "splade sparse test"]) + + # Basic sanity checks + assert len(outs) == 2 + assert outs[0].shape[0] == vocab_size + assert outs[1].shape[0] == vocab_size + assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all() + assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all() diff --git a/tests/models/registry.py b/tests/models/registry.py index ad90229adf8a..fbc11c2ddfd4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -486,6 +486,9 @@ _EMBEDDING_EXAMPLE_MODELS = { "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), + "BertSpladeSparseEmbeddingModel": _HfExamplesInfo( + "naver/splade-v3", is_available_online=False + ), # [Multimodal] "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index e07da3d4d29a..df302aee0bf6 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -572,6 +572,220 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: return token_type_ids +class BertMLMHead(nn.Module): + def __init__( + self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12 + ): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.GELU() + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.decoder = nn.Linear(hidden_size, vocab_size, bias=True) + + def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor): + self.decoder.weight = embeddings_weight + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self.dense(hidden_states) + x = self.activation(x) + x = self.layer_norm(x) + logits = self.decoder(x) + return logits + + +class SPLADESparsePooler(Pooler): + """ + SPLADE sparse pooling: + logits = mlm_head(hidden_states) + -> log1p(relu(logits)) + -> (max|sum over L) + -> [V] + + Padding is masked with an attention mask, + [CLS]/[SEP] is removed (selected), + and then pooled. + """ + + def __init__( + self, + mlm_head: nn.Module, + cls_token_id: Optional[int] = 101, + sep_token_id: Optional[int] = 102, + pooling: str = "max", + remove_cls_sep: bool = True, + ): + super().__init__() + assert pooling in ("max", "sum") + self.mlm_head = mlm_head + self.cls_token_id = cls_token_id + self.sep_token_id = sep_token_id + self.pooling = pooling + self.remove_cls_sep = remove_cls_sep + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"embed"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor: + assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2 + + lens_tensor: torch.Tensor = pooling_metadata.prompt_lens + lens: list[int] = lens_tensor.tolist() + B: int = len(lens) + + token_ids = pooling_metadata.prompt_token_ids + offset = 0 + pooled_list: list[torch.Tensor] = [] + + for i in range(B): + L = int(lens[i]) + hs = hidden_states[offset : offset + L] + + start_idx = 0 + end_idx = L + if self.remove_cls_sep and token_ids is not None: + if ( + self.cls_token_id is not None + and token_ids[i, 0].item() == self.cls_token_id + ): + start_idx = 1 + if ( + self.sep_token_id is not None + and token_ids[i, L - 1].item() == self.sep_token_id + ): + end_idx = max(start_idx, L - 1) + + if end_idx <= start_idx: + V = int(self.mlm_head.decoder.out_features) + pooled_list.append(hs.new_zeros((V,))) + offset += L + continue + + logits_i = self.mlm_head(hs[start_idx:end_idx]) + scores_i = torch.log1p(torch.relu(logits_i)) + + if self.pooling == "sum": + pooled_i = scores_i.sum(dim=0) + else: # "max" + pooled_i = scores_i.max(dim=0).values + + pooled_list.append(pooled_i.contiguous()) + offset += L + + return torch.stack(pooled_list, dim=0).contiguous() + + +@default_pooling_type("CLS") +class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): + """ + BertEmbeddingModel + SPLADE sparse embedding. + - Make logits by self.mlm_head + - pooler: SPLADESparsePooler(mlm_head...) + """ + + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max" + ): + super().__init__(vllm_config=vllm_config, prefix=prefix) + cfg = vllm_config.model_config.hf_config + + # MLM head + self.mlm_head = BertMLMHead( + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), + ) + + self._splade_pooling = splade_pooling + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler = self._build_pooler(pooler_config) + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + cfg = self.model.config + + if not hasattr(self, "mlm_head"): + self.mlm_head = BertMLMHead( + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), + ) + + pooling_mode = getattr(self, "_splade_pooling", "max") + + cls_id = getattr(cfg, "cls_token_id", None) + sep_id = getattr(cfg, "sep_token_id", None) + + return DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": SPLADESparsePooler( + mlm_head=self.mlm_head, + cls_token_id=cls_id, + sep_token_id=sep_id, + pooling=pooling_mode, # "max" or "sum" + remove_cls_sep=True, + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + if not hasattr(self, "mlm_head"): + cfg = self.model.config + self.mlm_head = BertMLMHead( + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), + ) + + def _strip(name: str) -> str: + for p in ("model.", "bert."): + if name.startswith(p): + name = name[len(p) :] + return name + + weights_list = list(weights) + model_side: list[tuple[str, torch.Tensor]] = [] + mlm_side: list[tuple[str, torch.Tensor]] = [] + + for k, w in weights_list: + name = _strip(k) + if name.startswith("cls.predictions."): + mlm_side.append((name, w)) + else: + model_side.append((name, w)) + + loaded: set[str] = set() + loaded_model = self.model.load_weights(model_side) + loaded.update({"model." + n for n in loaded_model}) + + if mlm_side: + name_map = { + "cls.predictions.transform.dense.weight": "mlm_head.dense.weight", + "cls.predictions.transform.dense.bias": "mlm_head.dense.bias", + ("cls.predictions.transform.LayerNorm.weight"): ( + "mlm_head.layer_norm.weight" + ), + ("cls.predictions.transform.LayerNorm.bias"): ( + "mlm_head.layer_norm.bias" + ), + "cls.predictions.decoder.weight": "mlm_head.decoder.weight", + "cls.predictions.decoder.bias": "mlm_head.decoder.bias", + } + remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map] + if remapped: + loaded_mlm = AutoWeightsLoader(self).load_weights(remapped) + loaded.update(loaded_mlm) + + return loaded + + @default_pooling_type("CLS") class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 194d2593a7fe..92ad19a20e02 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -172,6 +172,7 @@ _TEXT_GENERATION_MODELS = { _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), + "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma3TextModel": ("gemma3", "Gemma3Model"),