mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
Fix implementation divergence for BLOOM models between vLLM and HuggingFace when using prompt embeds (#24686)
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
parent
e090b7b45b
commit
ddcec289c7
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -99,9 +98,10 @@ AITER_MODEL_LIST = [
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
max_tokens: int, num_logprobs: int, use_rocm_aiter: bool,
|
||||
monkeypatch) -> None:
|
||||
use_prompt_embeds: bool, monkeypatch) -> None:
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@ -119,8 +119,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
# in parts of the operators
|
||||
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
|
||||
|
||||
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
@ -257,7 +257,7 @@ class BloomModel(nn.Module):
|
||||
config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
|
||||
return self.word_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -271,6 +271,7 @@ class BloomModel(nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user