mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:05:16 +08:00
[Hotfix][Pixtral] Fix multiple images bugs (#8415)
This commit is contained in:
parent
b61bd98f90
commit
d31174a4e1
@ -658,8 +658,8 @@ class VllmRunner:
|
|||||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _final_steps_generate_w_logprobs(
|
def _final_steps_generate_w_logprobs(
|
||||||
self,
|
|
||||||
req_outputs: List[RequestOutput],
|
req_outputs: List[RequestOutput],
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||||
|
|||||||
BIN
tests/models/fixtures/pixtral_chat.pickle
Normal file
BIN
tests/models/fixtures/pixtral_chat.pickle
Normal file
Binary file not shown.
BIN
tests/models/fixtures/pixtral_chat_engine.pickle
Normal file
BIN
tests/models/fixtures/pixtral_chat_engine.pickle
Normal file
Binary file not shown.
@ -2,13 +2,128 @@
|
|||||||
|
|
||||||
Run `pytest tests/models/test_mistral.py`.
|
Run `pytest tests/models/test_mistral.py`.
|
||||||
"""
|
"""
|
||||||
import pytest
|
import pickle
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
import pytest
|
||||||
|
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
|
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||||
|
from vllm.multimodal import MultiModalDataBuiltins
|
||||||
|
|
||||||
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
pytestmark = pytest.mark.vlm
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
MODELS = ["mistralai/Pixtral-12B-2409"]
|
MODELS = ["mistralai/Pixtral-12B-2409"]
|
||||||
|
IMG_URLS = [
|
||||||
|
"https://picsum.photos/id/237/400/300",
|
||||||
|
"https://picsum.photos/id/231/200/300",
|
||||||
|
"https://picsum.photos/id/27/500/500",
|
||||||
|
"https://picsum.photos/id/17/150/600",
|
||||||
|
]
|
||||||
|
PROMPT = "Describe each image in one short sentence."
|
||||||
|
|
||||||
|
|
||||||
|
def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
return [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": PROMPT,
|
||||||
|
}] + [{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": url
|
||||||
|
}
|
||||||
|
} for url in urls],
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
||||||
|
msg = _create_msg_format(urls)
|
||||||
|
|
||||||
|
tokenizer = MistralTokenizer.from_model("pixtral")
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
|
||||||
|
tokenized = tokenizer.encode_chat_completion(request)
|
||||||
|
|
||||||
|
engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for chunk in request.messages[0].content:
|
||||||
|
if isinstance(chunk, ImageURLChunk):
|
||||||
|
images.append(image_from_chunk(chunk))
|
||||||
|
|
||||||
|
mm_data = MultiModalDataBuiltins(image=images)
|
||||||
|
engine_inputs["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
|
return engine_inputs
|
||||||
|
|
||||||
|
|
||||||
|
MSGS = [
|
||||||
|
_create_msg_format(IMG_URLS[:1]),
|
||||||
|
_create_msg_format(IMG_URLS[:2]),
|
||||||
|
_create_msg_format(IMG_URLS),
|
||||||
|
]
|
||||||
|
ENGINE_INPUTS = [
|
||||||
|
_create_engine_inputs(IMG_URLS[:1]),
|
||||||
|
_create_engine_inputs(IMG_URLS[:2]),
|
||||||
|
_create_engine_inputs(IMG_URLS),
|
||||||
|
]
|
||||||
|
|
||||||
|
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
|
||||||
|
LIMIT_MM_PER_PROMPT = dict(image=4)
|
||||||
|
|
||||||
|
MAX_MODEL_LEN = [8192, 65536]
|
||||||
|
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
|
||||||
|
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
|
||||||
|
|
||||||
|
|
||||||
|
def load_logprobs(filename: str) -> Any:
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason=
|
||||||
|
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
def test_chat(
|
||||||
|
vllm_runner,
|
||||||
|
max_model_len: int,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
tokenizer_mode="mistral",
|
||||||
|
enable_chunked_prefill=False,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||||
|
) as vllm_model:
|
||||||
|
outputs = []
|
||||||
|
for msg in MSGS:
|
||||||
|
output = vllm_model.model.chat(msg,
|
||||||
|
sampling_params=SAMPLING_PARAMS)
|
||||||
|
|
||||||
|
outputs.extend(output)
|
||||||
|
|
||||||
|
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||||
|
check_logprobs_close(outputs_0_lst=logprobs,
|
||||||
|
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
|
||||||
|
name_0="output",
|
||||||
|
name_1="h100_ref")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
@ -17,48 +132,37 @@ MODELS = ["mistralai/Pixtral-12B-2409"]
|
|||||||
)
|
)
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
||||||
def test_models(
|
args = EngineArgs(
|
||||||
vllm_runner,
|
model=model,
|
||||||
example_prompts,
|
tokenizer_mode="mistral",
|
||||||
model: str,
|
enable_chunked_prefill=False,
|
||||||
dtype: str,
|
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||||
max_tokens: int,
|
dtype=dtype,
|
||||||
num_logprobs: int,
|
)
|
||||||
) -> None:
|
engine = LLMEngine.from_engine_args(args)
|
||||||
image_urls = [
|
|
||||||
"https://picsum.photos/id/237/200/300",
|
|
||||||
"https://picsum.photos/seed/picsum/200/300"
|
|
||||||
]
|
|
||||||
expected = [
|
|
||||||
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
|
|
||||||
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
|
|
||||||
]
|
|
||||||
prompt = "Describe the image in one short sentence."
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
|
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
|
||||||
|
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype,
|
outputs = []
|
||||||
tokenizer_mode="mistral") as vllm_model:
|
count = 0
|
||||||
|
while True:
|
||||||
|
out = engine.step()
|
||||||
|
count += 1
|
||||||
|
for request_output in out:
|
||||||
|
if request_output.finished:
|
||||||
|
outputs.append(request_output)
|
||||||
|
|
||||||
for i, image_url in enumerate(image_urls):
|
if count == 2:
|
||||||
messages = [
|
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
|
||||||
{
|
SAMPLING_PARAMS)
|
||||||
"role":
|
if not engine.has_unfinished_requests():
|
||||||
"user",
|
break
|
||||||
"content": [{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}, {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = vllm_model.model.chat(messages,
|
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||||
sampling_params=sampling_params)
|
check_logprobs_close(outputs_0_lst=logprobs,
|
||||||
assert outputs[0].outputs[0].text == expected[i]
|
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
|
||||||
|
name_0="output",
|
||||||
|
name_1="h100_ref")
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import math
|
|
||||||
from array import array
|
from array import array
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from itertools import tee
|
from itertools import tee
|
||||||
@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
|||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
ctx.model_config.tokenizer,
|
ctx.model_config.tokenizer,
|
||||||
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||||
mm_encoder = tokenizer.instruct.mm_encoder
|
|
||||||
|
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||||
|
patch_size = mm_encoder.mm_config.image_patch_size
|
||||||
|
image_token_id = mm_encoder.special_ids.img
|
||||||
|
|
||||||
mm_config = ctx.model_config.multimodal_config
|
mm_config = ctx.model_config.multimodal_config
|
||||||
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
|
num_images = mm_config.limit_per_prompt.get("image", 1)
|
||||||
|
|
||||||
# approximate image size
|
|
||||||
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
|
|
||||||
|
|
||||||
|
# dummy size
|
||||||
|
size = 256
|
||||||
image = Image.new("RGB", (size, size), color=0)
|
image = Image.new("RGB", (size, size), color=0)
|
||||||
img_chunk = ImageChunk(image=image)
|
|
||||||
|
|
||||||
tokens = mm_encoder(img_chunk).tokens
|
image_feature_size = (size**2) // (patch_size**2)
|
||||||
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
|
||||||
tokens)
|
num_image_tokens = image_feature_size * num_images
|
||||||
|
|
||||||
|
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
|
[image_token_id]) * num_image_tokens
|
||||||
|
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
|
[0]) * (seq_len - num_image_tokens)
|
||||||
|
|
||||||
seq_data = SequenceData(token_ids)
|
seq_data = SequenceData(token_ids)
|
||||||
mm_data = {"image": max_num_images_per_request * [image]}
|
mm_data = {"image": num_images * [image]}
|
||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
|
||||||
@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
|
|||||||
return MultiModalInputs({"images": images})
|
return MultiModalInputs({"images": images})
|
||||||
|
|
||||||
|
|
||||||
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
|
||||||
inputs_embeds: torch.Tensor,
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
image_features: Optional[List[torch.Tensor]],
|
if multi_modal_data is not None and "image" in multi_modal_data:
|
||||||
image_id: int) -> torch.Tensor:
|
tokenizer = cached_get_tokenizer(
|
||||||
text_locations = input_ids != image_id
|
ctx.model_config.tokenizer,
|
||||||
image_locations = input_ids == image_id
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||||
|
|
||||||
seq_len = input_ids.shape[0]
|
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||||
|
image_token_id = mm_encoder.special_ids.img
|
||||||
|
|
||||||
N_txt = text_locations.sum().item()
|
if image_token_id not in llm_inputs['prompt_token_ids']:
|
||||||
_, D_txt = inputs_embeds.shape
|
raise ValueError(
|
||||||
N_img, D_img = image_features.shape
|
(f"You've passed {llm_inputs=} without {image_token_id=}"
|
||||||
|
" Make sure to process your input via mistral_common's"
|
||||||
|
" tokenizer or pass a chat completion request. For more"
|
||||||
|
" For more info, see: "
|
||||||
|
"https://github.com/vllm-project/vllm/issues/8411."))
|
||||||
|
|
||||||
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
|
return llm_inputs
|
||||||
"to image features dim {D_img}")
|
|
||||||
assert (seq_len == N_txt +
|
|
||||||
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
|
|
||||||
f"{(N_txt, N_img, image_locations.sum().item())}")
|
|
||||||
|
|
||||||
inputs_embeds[image_locations, :] = image_features
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
|
||||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
# always take last images
|
# if passed as batch take all images
|
||||||
images = [images[-1][i] for i in range(images.size(1))]
|
N, B, C, W, H = images.shape
|
||||||
|
images = images.reshape(N * B, C, W, H)
|
||||||
|
images = [images[i] for i in range(images.size(0))]
|
||||||
elif isinstance(images, list):
|
elif isinstance(images, list):
|
||||||
# always take last images
|
# if passed as list flatten lists of tensors
|
||||||
images = [images[-1][i] for i in range(len(images[0]))]
|
flatten_images = []
|
||||||
|
for imgs_per_req in images:
|
||||||
|
imgs_per_req = [
|
||||||
|
imgs_per_req[i] for i in range(imgs_per_req.size(0))
|
||||||
|
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
|
||||||
|
|
||||||
|
flatten_images.extend(imgs_per_req)
|
||||||
|
|
||||||
|
images = flatten_images
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user