Fix implementation divergence for BLOOM models between vLLM and HuggingFace when using prompt embeds (#24686)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
Andrew Sansom 2025-09-11 23:35:48 -05:00 committed by GitHub
parent e090b7b45b
commit ddcec289c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 5 deletions

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
import pytest
@ -99,9 +98,10 @@ AITER_MODEL_LIST = [
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
max_tokens: int, num_logprobs: int, use_rocm_aiter: bool,
monkeypatch) -> None:
use_prompt_embeds: bool, monkeypatch) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
@ -119,8 +119,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

View File

@ -257,7 +257,7 @@ class BloomModel(nn.Module):
config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
return self.word_embeddings(input_ids)
def forward(
self,
@ -271,6 +271,7 @@ class BloomModel(nn.Module):
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]