[V1] Add V1 support of Qwen2-VL (#12128)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: imkero <kerorek@outlook.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang 2025-01-19 03:52:13 -08:00 committed by GitHub
parent edaae198e7
commit 81763c58a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 291 additions and 84 deletions

View File

@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
- ✅︎
- ✅︎
-
- ✅︎
* - `UltravoxModel`
- Ultravox
- T + A<sup>E+</sup>

View File

@ -105,7 +105,7 @@ def batch_make_image_embeddings(
pixel_values = preprocess_result["pixel_values"]
image_grid_thw = preprocess_result["image_grid_thw"]
# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
@ -124,11 +124,10 @@ def batch_make_image_embeddings(
for image_batch in image_batches_:
cur_batch_image_count = len(image_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in image_grid_thw[image_counter:image_counter +
cur_batch_image_count]
])
cur_batch_image_count])
result.append({
"image_embeds":
@ -187,7 +186,7 @@ def batch_make_video_embeddings(
pixel_values = preprocess_result["pixel_values_videos"]
video_grid_thw = preprocess_result["video_grid_thw"]
# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
@ -206,11 +205,10 @@ def batch_make_video_embeddings(
for video_batch in video_batches_:
cur_batch_video_count = len(video_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in video_grid_thw[video_counter:video_counter +
cur_batch_video_count]
])
cur_batch_video_count])
result.append({
"video_embeds":

View File

@ -76,8 +76,8 @@ def support_torch_compile(
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
- if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
@ -177,10 +177,20 @@ def _support_torch_compile(
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [
arg.ndim + dim if dim < 0 else dim for dim in dims
]
torch._dynamo.mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [
tensor.ndim + dim if dim < 0 else dim
for dim in dims
]
torch._dynamo.mark_dynamic(tensor, dims)
else:
raise ValueError(

View File

@ -841,6 +841,37 @@ class MRotaryEmbedding(RotaryEmbedding):
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
input_tokens,
image_grid_thw,
video_grid_thw,
image_token_id,
video_token_id,
vision_start_token_id,
vision_end_token_id,
spatial_merge_size,
context_len,
seq_len,
)
return llm_positions.tolist(), mrope_position_delta
@staticmethod
def get_input_positions_tensor(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
@ -916,7 +947,7 @@ class MRotaryEmbedding(RotaryEmbedding):
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions.tolist(), mrope_position_delta
return llm_positions, mrope_position_delta
@staticmethod
def get_next_input_positions(
@ -930,6 +961,17 @@ class MRotaryEmbedding(RotaryEmbedding):
seq_len + mrope_position_delta)) for _ in range(3)
]
@staticmethod
def get_next_input_positions_tensor(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> torch.Tensor:
return torch.arange(
mrope_position_delta + context_len,
mrope_position_delta + seq_len,
).expand(3, -1)
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}

View File

@ -554,10 +554,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key == "pixel_values" and "images" not in modalities:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
if input_key in ("pixel_values_videos",
"video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)

View File

@ -256,7 +256,15 @@ class Qwen2DecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -67,11 +67,15 @@ from vllm.transformers_utils.config import uses_mrope
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
# === Vision Inputs === #
@ -135,7 +139,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict):
- List[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
(concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
@ -611,6 +615,7 @@ class Qwen2VisionTransformer(nn.Module):
# adapter
x = self.merger(x)
return x
def load_weights(self, weights: Iterable[Tuple[str,
@ -874,8 +879,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
num_frames = max(max_total_frames // max(max_videos, 1), 1)
num_frames = min(max(max_total_frames // max(max_videos, 1), 1),
_MAX_FRAMES_PER_VIDEO)
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
if num_frames > 1 and num_frames % 2 == 1:
@ -955,13 +960,14 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
"image": hf_processor.image_token,
"video": hf_processor.video_token,
}
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)
num_tokens = grid_thw.prod() // merge_length
num_tokens = grid_thw.prod().item() // merge_length
return placeholder[modality] * num_tokens
return [
@ -1047,11 +1053,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2VLConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching"
self.config = config
self.multimodal_config = multimodal_config
@ -1173,59 +1176,82 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
video_embeds=video_embeds,
video_grid_thw=video_grid_thw)
def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
def _process_image_input(
self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
if image_input["type"] == "image_embeds":
return image_input["image_embeds"].type(self.visual.dtype)
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist())
def _process_video_input(
self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
def _process_video_input(self,
video_input: Qwen2VLVideoInputs) -> torch.Tensor:
if video_input["type"] == "video_embeds":
return video_input["video_embeds"].type(self.visual.dtype)
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos,
grid_thw=video_input["video_grid_thw"])
return video_embeds
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
def _merge_multimodal_embeddings(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: torch.Tensor,
placeholder_token_id: int,
) -> torch.Tensor:
mask = (input_ids == placeholder_token_id)
inputs_embeds[mask, :] = multimodal_embeddings
return inputs_embeds
return video_embeds.split(sizes.tolist())
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("pixel_values_videos",
"video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
# We make a tuple of each embedding with its modality string. This is a
# temporary workaround for models to handle mixed modalities when
# get_multimodal_embeddings and get_input_embeddings are called
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
if image_input is not None:
image_embeds = self._process_image_input(image_input)
multimodal_embeddings.append((image_embeds, "image"))
if video_input is not None:
video_embeds = self._process_video_input(video_input)
multimodal_embeddings.append((video_embeds, "video"))
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
@ -1237,21 +1263,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings:
if modality == "image":
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
embeddings,
placeholder_token_id=self.config.image_token_id,
)
if modality == "video":
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
embeddings,
placeholder_token_id=self.config.video_token_id,
)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def forward(

View File

@ -30,6 +30,9 @@ class CachedRequestState:
num_computed_tokens: int
output_token_ids: List[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)

View File

@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
@ -139,6 +140,32 @@ class GPUModelRunner:
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
# NOTE: `mrope_positions` is implemented as a permuted tensor to
# satisfy the following properties to allow `torch.compile` to work
# properly:
# - shape: (3, <variable>)
# - stride: (1, 3)
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
dtype=torch.int64,
device=self.device)
self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.mrope_positions = self.mrope_positions.permute((1, 0))
self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
@ -246,6 +273,35 @@ class GPUModelRunner:
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
@ -313,6 +369,11 @@ class GPUModelRunner:
arange,
out=positions_np)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@ -359,8 +420,16 @@ class GPUModelRunner:
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
if self.model_config.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
@ -472,6 +541,61 @@ class GPUModelRunner:
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
num_reqs = self.input_batch.num_reqs
for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req = self.requests[req_id]
assert req.mrope_positions is not None
num_computed_tokens = \
self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
num_prompt_tokens - num_computed_tokens)
completion_part_len = max(
0, num_scheduled_tokens - prompt_part_len)
else:
prompt_part_len = num_scheduled_tokens
completion_part_len = 0
assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0:
# prompt's mrope_positions are pre-computed
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
req.mrope_positions[:,src_start:src_end]
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
# compute completion's mrope_positions on-the-fly
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
MRotaryEmbedding.get_next_input_positions_tensor(
req.mrope_position_delta,
context_len=num_computed_tokens +
prompt_part_len,
seq_len=num_computed_tokens +
prompt_part_len +
completion_part_len,
)
mrope_pos_ptr += completion_part_len
def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
@ -618,9 +742,12 @@ class GPUModelRunner:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
positions = self.mrope_positions[:, :num_input_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_input_tokens]
hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,
@ -707,9 +834,12 @@ class GPUModelRunner:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
with set_forward_context(None, self.vllm_config):
positions = self.mrope_positions[:, :num_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_tokens]
hidden_states = model(
input_ids=input_ids,
positions=self.positions[:num_tokens],
positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,