[V1] Initial support of multimodal models for V1 re-arch (#10699)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2024-12-08 04:50:51 -08:00 committed by GitHub
parent fd57d2b534
commit a11f326528
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 283 additions and 68 deletions

View File

@ -1050,9 +1050,12 @@ class EngineArgs:
# long context (> 32K) models. This is to avoid OOM errors in the # long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase. # initial memory profiling phase.
# Chunked prefill is currently disabled for multimodal models by # For multimodal models, chunked prefill is disabled by default in
# default. # V0, but enabled by design in V1
if use_long_context and not model_config.is_multimodal_model: if model_config.is_multimodal_model:
self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)
elif use_long_context:
is_gpu = device_config.device_type == "cuda" is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window() use_sliding_window = (model_config.get_sliding_window()
is not None) is not None)
@ -1241,12 +1244,9 @@ class EngineArgs:
Override the EngineConfig's configs based on the usage context for V1. Override the EngineConfig's configs based on the usage context for V1.
""" """
assert envs.VLLM_USE_V1, "V1 is not enabled" assert envs.VLLM_USE_V1, "V1 is not enabled"
# TODO (ywang96): Enable APC by default when VLM supports it.
if engine_config.model_config.is_multimodal_model: if engine_config.model_config.is_multimodal_model:
logger.warning( # TODO (ywang96): Enable APC by default when VLM supports it.
"Prefix caching is currently not supported for multimodal " assert not engine_config.cache_config.enable_prefix_caching
"models and has been disabled.")
engine_config.cache_config.enable_prefix_caching = False
@dataclass @dataclass

View File

@ -36,6 +36,11 @@ class SupportsMultiModal(Protocol):
""" """
Returns multimodal embeddings generated from multimodal kwargs Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings. to be merged with text embeddings.
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input image.
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
""" """
... ...

View File

@ -26,7 +26,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict):
Shape: Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)` `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
""" """
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
"""
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor data: NestedTensors
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)` """
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
@ -349,10 +355,32 @@ class InternVLInputPipeline:
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches) num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt) new_prompt_token_ids = tokenizer.encode(new_prompt)
img_context_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False)
assert len(img_context_token_id) == 1, \
(f"Invalid image token '{self.img_context_token}': A valid image "
f"token encodes to a single token ID, got {img_context_token_id}.")
img_context_token_id = img_context_token_id[0]
return token_inputs(prompt=prompt, # Get precise tracking of placeholder positions
prompt_token_ids=new_prompt_token_ids, token_idx = image_idx = 0
multi_modal_data=multi_modal_data) placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == img_context_token_id:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
def input_mapper( def input_mapper(
self, self,
@ -614,26 +642,46 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
patches_per_image = []
for request_pixel_values in pixel_values:
for image_pixel_values in request_pixel_values:
patches_per_image.append(image_pixel_values.shape[0])
# We need to flatten (B, N, P) to (B*N*P), # We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice. # so we call flatten_bn twice.
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values( data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)), flatten_bn(flatten_bn(pixel_values), concat=True)),
) patches_per_image=patches_per_image)
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_image_input( def _process_image_input(
self, self,
image_input: InternVLImageInputs, image_input: InternVLImageInputs,
) -> torch.Tensor: ) -> Tuple[torch.Tensor]:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_model is not None assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"]) image_embeds = self.extract_feature(image_input["data"])
patches_per_image = image_input["patches_per_image"]
if len(patches_per_image) == 1:
image_embeds = image_embeds.unsqueeze(0)
return image_embeds
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in patches_per_image
]
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds return image_embeds
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
@ -696,13 +744,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"inputs_embeds": inputs_embeds, "inputs_embeds": inputs_embeds,
} }
# Only required if the model is mono-architecture
if self.visual_token_mask is not None: if self.visual_token_mask is not None:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs.update( forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask}) {"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None self.visual_token_mask = None
self.img_context_token_id = None
hidden_states = self.language_model.model(**forward_kwargs) hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states return hidden_states

View File

@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
@ -46,12 +46,16 @@ from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix, merge_multimodal_embeddings)
# TODO: hard-coded for now. Consider making it configurable. # TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
NUM_PREFIX_TOKENS = 1 NUM_PREFIX_TOKENS = 1
ADDITIONAL_VOCAB_SIZE = 128 ADDITIONAL_VOCAB_SIZE = 128
DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066
DEFAULT_IM_START_TOKEN_ID = 152067
DEFAULT_IM_END_TOKEN_ID = 152064
DEFAULT_IM_COL_TOKEN_ID = 152065
class MolmoImageInputs(TypedDict): class MolmoImageInputs(TypedDict):
@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict):
`(batch_size, num_crops, num_patch)` `(batch_size, num_crops, num_patch)`
""" """
image_start_end: Tuple[int, int]
"""Starting and ending index of placeholder
tokens
"""
@dataclass @dataclass
class VisionBackboneConfig: class VisionBackboneConfig:
@ -918,6 +927,8 @@ def image_input_mapper_for_molmo(
ctx: InputContext, ctx: InputContext,
data: object, data: object,
): ):
if isinstance(data, list):
data = data[0]
return MultiModalKwargs(data) return MultiModalKwargs(data)
@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
if "image_masks" in out: if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"] dummy_imgdata["image_masks"] = out["image_masks"]
dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
return DummyData(dummy_seqdata, {"image": dummy_imgdata}) size = 0
offset = -1
for i in range(len(token_ids)):
if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID):
if offset < 0:
offset = i
size += 1
dummy_imgdata["image_start_end"] = (offset, offset + size)
return DummyData(seq_data=dummy_seqdata,
multi_modal_data={"image": dummy_imgdata},
multi_modal_placeholders={
"image":
[PlaceholderRange(offset=offset, length=size)]
})
def pad_images( def pad_images(
@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
if image_masks is not None: if image_masks is not None:
image_data["image_masks"] = image_masks image_data["image_masks"] = image_masks
image_data["seq_len"] = torch.tensor(len(out["input_ids"]), new_prompt_token_ids = out["input_ids"].tolist()
image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids),
dtype=torch.long) dtype=torch.long)
multi_modal_data = dict(image=image_data) multi_modal_data = dict(image=image_data)
size = 0
offset = -1
for i in range(len(new_prompt_token_ids)):
if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
DEFAULT_IM_START_TOKEN_ID,
DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID):
if offset < 0:
offset = i
size += 1
image_data["image_start_end"] = (offset, offset + size)
prompt = inputs.get("prompt") prompt = inputs.get("prompt")
if prompt is None: if prompt is None:
prompt = tokenizer.decode(out["input_ids"]) prompt = tokenizer.decode(new_prompt_token_ids)
return token_inputs( return token_inputs(
prompt_token_ids=out["input_ids"], prompt_token_ids=new_prompt_token_ids,
prompt=prompt, prompt=prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
multi_modal_placeholders={
"image": [PlaceholderRange(offset=offset, length=size)]
},
) )
@ -1113,6 +1154,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[MolmoImageInputs]: ) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None) images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None) image_masks = kwargs.pop("image_masks", None)
image_start_end = kwargs.pop("image_start_end", None)
if images is None: if images is None:
return None return None
@ -1130,6 +1172,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input_idx=image_input_idx, image_input_idx=image_input_idx,
seq_len=seq_len, seq_len=seq_len,
image_masks=image_masks, image_masks=image_masks,
image_start_end=image_start_end,
) )
def _process_image_input( def _process_image_input(
@ -1178,9 +1221,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# Note: In this original implementation from AI2, the final # Note: In this original implementation from AI2, the final
# vision_embeddings will be always be the same length # vision_embeddings will be always be the same length
# of input embedddings, which is not very efficient. # of input embeddings.
# TODO(ywang96): see if this can be optimized.
vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) vision_embeddings = torch.einsum('nd,nm->md', image_features, mat)
# Split by the sizes of the input sequences. For each full embedding,
# extract the actual vision embeddings to be merged.
vision_embeddings = list(vision_embeddings.split(seq_len.tolist()))
for i in range(len(vision_embeddings)):
start, end = image_input['image_start_end'][i]
vision_embeddings[i] = vision_embeddings[i][start:end]
return vision_embeddings return vision_embeddings
def get_input_embeddings( def get_input_embeddings(
@ -1190,7 +1240,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = inputs_embeds + multimodal_embeddings inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID,
DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID
])
return inputs_embeds return inputs_embeds
def forward( def forward(

View File

@ -48,6 +48,9 @@ try:
except ImportError: except ImportError:
USE_XFORMERS_OPS = False USE_XFORMERS_OPS = False
PIXTRAL_IMAGE_BREAK_ID = 12
PIXTRAL_IMAGE_END_ID = 13
def get_max_pixtral_image_tokens(ctx: InputContext): def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer_mode=ctx.model_config.tokenizer_mode) tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img image_token_id = mm_encoder.special_ids.img
mm_config = ctx.model_config.multimodal_config mm_config = ctx.model_config.multimodal_config
@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
size = 256 size = 256
image = Image.new("RGB", (size, size), color=0) image = Image.new("RGB", (size, size), color=0)
image_feature_size = (size**2) // (patch_size**2) encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
image_feature_size = len(encoding.tokens)
num_image_tokens = image_feature_size * num_images num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_prompt_token_counts( seq_data = SequenceData.from_prompt_token_counts(
(image_token_id, num_image_tokens), (image_token_id, num_image_tokens),
@ -101,14 +103,13 @@ def input_mapper_for_pixtral(ctx: InputContext,
Args: Args:
ctx: Context of the loaded model. ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped data: data potentially containing PIL images to be processed
to pixel_values in .forward() for a visual QWenLMHeadModel model. and mapped to `images`.
Returns: Returns:
MultiModalKwargs containing the stacked normalized images tensor or MultiModalKwargs containing the stacked normalized images tensor or
image embeddings. image embeddings.
""" """
# Early exit if we have provided an image to a language only Qwen model
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
@ -116,35 +117,67 @@ def input_mapper_for_pixtral(ctx: InputContext,
data_list = data if isinstance(data, list) else [data] data_list = data if isinstance(data, list) else [data]
images = [] images = []
image_tokens_list = []
for image_data in data_list: for image_data in data_list:
image = ImageChunk(image=image_data) image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image) encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda", image = torch.from_numpy(encoding.image).to(device="cuda",
dtype=torch.float16) dtype=torch.float16)
images.append(image) images.append(image)
image_tokens_list.append(encoding.tokens)
return MultiModalKwargs({"images": images}) image_tokens = torch.tensor([
token_id for image_tokens in image_tokens_list
for token_id in image_tokens
])
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is not None and "image" in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
tokenizer = cached_get_tokenizer( return inputs
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder prompt_token_ids = inputs.get("prompt_token_ids")
image_token_id = mm_encoder.special_ids.img prompt = inputs.get("prompt")
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
if image_token_id not in inputs['prompt_token_ids']: mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
raise ValueError( image_token_id = mm_encoder.special_ids.img
f"You've passed {inputs=} without {image_token_id=}" image_break_id = mm_encoder.special_ids.img_break
" Make sure to process your input via mistral_common's" image_end_id = mm_encoder.special_ids.img_end
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
return inputs if image_token_id not in inputs['prompt_token_ids']:
raise ValueError(
f"You've passed {inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
# Get precise tracking of placeholder positions
placeholder_ranges = []
curr_offset = -1
curr_length = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in (image_token_id, image_break_id):
if curr_offset < 0:
curr_offset = i
curr_length += 1
elif prompt_token_ids[i] == image_end_id:
curr_length += 1
placeholder_ranges.append(
PlaceholderRange(offset=curr_offset, length=curr_length))
curr_offset = -1
curr_length = 0
else:
pass
return token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@ -192,11 +225,29 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler() return get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
# NOTE: We patch the outputs of the vision encoder with embeddings
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
split_indices = torch.where(
image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds
def get_input_embeddings( def get_input_embeddings(
self, self,
@ -206,8 +257,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id) self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID,
PIXTRAL_IMAGE_BREAK_ID
])
return inputs_embeds return inputs_embeds
def forward( def forward(
@ -245,10 +298,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None,
) -> Optional[List[torch.Tensor]]: ) -> Optional[List[torch.Tensor]]:
if images is None: if images is None:
return None return None, None
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
# if passed as batch take all images # if passed as batch take all images
@ -267,7 +321,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
images = flatten_images images = flatten_images
return images if isinstance(image_tokens, torch.Tensor):
# image_tokens are batched
image_tokens = image_tokens.flatten()
elif isinstance(image_tokens, list):
# image_tokens are of different lengths thus passed as a list
image_tokens = torch.cat(image_tokens)
assert image_tokens.dim() == 1
return images, image_tokens
def _process_image_input(self, def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor: image_input: List[torch.Tensor]) -> torch.Tensor:

View File

@ -409,16 +409,42 @@ def merge_multimodal_embeddings(
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors, multimodal_embeddings: NestedTensors,
placeholder_token_id: int, placeholder_token_id: Union[int, List[int]],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``. ``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note: Note:
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(placeholder_token_id,
device=input_ids.device)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds, inputs_embeds,
(input_ids == placeholder_token_id), (input_ids == placeholder_token_id),

View File

@ -96,7 +96,8 @@ class PlaceholderRange(TypedDict):
"""The length of the placeholder.""" """The length of the placeholder."""
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor,
Tuple[torch.Tensor, ...]]
""" """
Uses a list instead of a tensor if the dimensions of each element do not match. Uses a list instead of a tensor if the dimensions of each element do not match.
""" """

View File

@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens(
return new_prompt, new_token_ids, placeholder_ranges return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(num_items: int, def consecutive_placeholder_ranges(
item_size: int) -> List[PlaceholderRange]: num_items: int,
item_size: int,
initial_offset: int = 0) -> List[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size""" """Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [ return [
PlaceholderRange(offset=i * item_size, length=item_size) PlaceholderRange(offset=initial_offset + i * item_size,
for i in range(num_items) length=item_size) for i in range(num_items)
] ]

View File

@ -73,12 +73,12 @@ class Scheduler:
# has the Transformer architecture (e.g., ViT). # has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the # FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations. # actual values from the configurations.
self.max_num_encoder_input_tokens = 2048 self.max_num_encoder_input_tokens = 16384
# NOTE(woosuk): For the models without encoder (e.g., text-only models), # NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of # the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache # the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run. # is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=2048) self.encoder_cache_manager = EncoderCacheManager(cache_size=16384)
def schedule(self) -> "SchedulerOutput": def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:

View File

@ -1,5 +1,7 @@
from typing import Dict, List, Mapping, Optional, Type, Union from typing import Dict, List, Mapping, Optional, Type, Union
from typing_extensions import TypeVar
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
@ -12,7 +14,8 @@ from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.detokenizer import Detokenizer
@ -21,6 +24,8 @@ from vllm.v1.executor.gpu_executor import GPUExecutor
logger = init_logger(__name__) logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
class LLMEngine: class LLMEngine:
"""Legacy LLMEngine for backwards compatibility.""" """Legacy LLMEngine for backwards compatibility."""
@ -169,5 +174,18 @@ class LLMEngine:
def stop_profile(self): def stop_profile(self):
self.engine_core.profile(False) self.engine_core.profile(False)
def get_tokenizer_group(self, group_type): def get_tokenizer_group(
pass self,
group_type: Type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group

View File

@ -33,7 +33,7 @@ class MMInputMapper:
num_images = len(image_inputs) num_images = len(image_inputs)
for i in range(num_images): for i in range(num_images):
mm_input = self.multi_modal_input_mapper( mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]}, {"image": image_inputs[i]},
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
mm_inputs.append(mm_input) mm_inputs.append(mm_input)