mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 17:04:27 +08:00
[V1] Add V1 support of Qwen2-VL (#12128)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: imkero <kerorek@outlook.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
edaae198e7
commit
81763c58a0
@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
-
|
||||
- ✅︎
|
||||
* - `UltravoxModel`
|
||||
- Ultravox
|
||||
- T + A<sup>E+</sup>
|
||||
|
||||
@ -105,7 +105,7 @@ def batch_make_image_embeddings(
|
||||
pixel_values = preprocess_result["pixel_values"]
|
||||
image_grid_thw = preprocess_result["image_grid_thw"]
|
||||
|
||||
# pixel values to embeddinds & grid_thws
|
||||
# pixel values to embeddings & grid_thws
|
||||
with torch.no_grad():
|
||||
visual = llm.llm_engine.model_executor.driver_worker. \
|
||||
model_runner.model.visual
|
||||
@ -124,11 +124,10 @@ def batch_make_image_embeddings(
|
||||
for image_batch in image_batches_:
|
||||
cur_batch_image_count = len(image_batch)
|
||||
merge_size = image_processor.merge_size
|
||||
cur_batch_embed_len = sum([
|
||||
grid_thw.prod() // merge_size // merge_size
|
||||
cur_batch_embed_len = sum(
|
||||
grid_thw.prod(-1) // merge_size // merge_size
|
||||
for grid_thw in image_grid_thw[image_counter:image_counter +
|
||||
cur_batch_image_count]
|
||||
])
|
||||
cur_batch_image_count])
|
||||
|
||||
result.append({
|
||||
"image_embeds":
|
||||
@ -187,7 +186,7 @@ def batch_make_video_embeddings(
|
||||
pixel_values = preprocess_result["pixel_values_videos"]
|
||||
video_grid_thw = preprocess_result["video_grid_thw"]
|
||||
|
||||
# pixel values to embeddinds & grid_thws
|
||||
# pixel values to embeddings & grid_thws
|
||||
with torch.no_grad():
|
||||
visual = llm.llm_engine.model_executor.driver_worker.\
|
||||
model_runner.model.visual
|
||||
@ -206,11 +205,10 @@ def batch_make_video_embeddings(
|
||||
for video_batch in video_batches_:
|
||||
cur_batch_video_count = len(video_batch)
|
||||
merge_size = image_processor.merge_size
|
||||
cur_batch_embed_len = sum([
|
||||
grid_thw.prod() // merge_size // merge_size
|
||||
cur_batch_embed_len = sum(
|
||||
grid_thw.prod(-1) // merge_size // merge_size
|
||||
for grid_thw in video_grid_thw[video_counter:video_counter +
|
||||
cur_batch_video_count]
|
||||
])
|
||||
cur_batch_video_count])
|
||||
|
||||
result.append({
|
||||
"video_embeds":
|
||||
|
||||
@ -76,8 +76,8 @@ def support_torch_compile(
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer, the corresponding dimension of the argument
|
||||
will be marked as dynamic.
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
@ -177,10 +177,20 @@ def _support_torch_compile(
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
arg.ndim + dim if dim < 0 else dim for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
tensor.ndim + dim if dim < 0 else dim
|
||||
for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@ -841,6 +841,37 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
) -> Tuple[List[List[int]], int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
llm_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
input_tokens,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
image_token_id,
|
||||
video_token_id,
|
||||
vision_start_token_id,
|
||||
vision_end_token_id,
|
||||
spatial_merge_size,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
|
||||
return llm_positions.tolist(), mrope_position_delta
|
||||
|
||||
@staticmethod
|
||||
def get_input_positions_tensor(
|
||||
input_tokens: List[int],
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
vision_start_token_id: int,
|
||||
vision_end_token_id: int,
|
||||
spatial_merge_size: int,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
if isinstance(video_grid_thw, torch.Tensor):
|
||||
@ -916,7 +947,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions.tolist(), mrope_position_delta
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions(
|
||||
@ -930,6 +961,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
seq_len + mrope_position_delta)) for _ in range(3)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions_tensor(
|
||||
mrope_position_delta: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.arange(
|
||||
mrope_position_delta + context_len,
|
||||
mrope_position_delta + seq_len,
|
||||
).expand(3, -1)
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
@ -554,10 +554,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key == "pixel_values" and "images" not in modalities:
|
||||
if input_key in ("pixel_values",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
|
||||
if input_key in ("pixel_values_videos",
|
||||
"video_embeds") and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
|
||||
@ -256,7 +256,15 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class Qwen2Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@ -67,11 +67,15 @@ from vllm.transformers_utils.config import uses_mrope
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
@ -135,7 +139,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict):
|
||||
- List[`torch.Tensor`]: A list of tensors holding all videos' features.
|
||||
Each tensor holds an video's features.
|
||||
- `torch.Tensor`: A tensor holding all videos' features
|
||||
(concatenation of all videos' feature tensors).
|
||||
(concatenation of all videos' feature tensors).
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
@ -611,6 +615,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
|
||||
return x
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
@ -874,8 +879,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
|
||||
num_frames = max(max_total_frames // max(max_videos, 1), 1)
|
||||
num_frames = min(max(max_total_frames // max(max_videos, 1), 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
|
||||
if num_frames > 1 and num_frames % 2 == 1:
|
||||
@ -955,13 +960,14 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
"image": hf_processor.image_token,
|
||||
"video": hf_processor.video_token,
|
||||
}
|
||||
|
||||
merge_length = image_processor.merge_size**2
|
||||
|
||||
def get_replacement_qwen2vl(item_idx: int, modality: str):
|
||||
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
|
||||
assert isinstance(grid_thw, torch.Tensor)
|
||||
|
||||
num_tokens = grid_thw.prod() // merge_length
|
||||
num_tokens = grid_thw.prod().item() // merge_length
|
||||
return placeholder[modality] * num_tokens
|
||||
|
||||
return [
|
||||
@ -1047,11 +1053,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: Qwen2VLConfig = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen2-VL currently does not support prefix caching"
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
@ -1173,59 +1176,82 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_embeds=video_embeds,
|
||||
video_grid_thw=video_grid_thw)
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
def _process_image_input(
|
||||
self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["image_embeds"].type(self.visual.dtype)
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values,
|
||||
grid_thw=image_input["image_grid_thw"])
|
||||
return image_embeds
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
|
||||
return image_embeds.split(sizes.tolist())
|
||||
|
||||
def _process_video_input(
|
||||
self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
def _process_video_input(self,
|
||||
video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
if video_input["type"] == "video_embeds":
|
||||
return video_input["video_embeds"].type(self.visual.dtype)
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos,
|
||||
grid_thw=video_input["video_grid_thw"])
|
||||
return video_embeds
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
|
||||
def _merge_multimodal_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: torch.Tensor,
|
||||
placeholder_token_id: int,
|
||||
) -> torch.Tensor:
|
||||
mask = (input_ids == placeholder_token_id)
|
||||
inputs_embeds[mask, :] = multimodal_embeddings
|
||||
return inputs_embeds
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("pixel_values_videos",
|
||||
"video_embeds") and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
|
||||
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
if image_input is None and video_input is None:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return None
|
||||
|
||||
# We make a tuple of each embedding with its modality string. This is a
|
||||
# temporary workaround for models to handle mixed modalities when
|
||||
# get_multimodal_embeddings and get_input_embeddings are called
|
||||
# separately.
|
||||
# TODO(ywang96): Add support for mixed-modality inference for v1.
|
||||
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
multimodal_embeddings.append((image_embeds, "image"))
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
multimodal_embeddings.append((video_embeds, "video"))
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@ -1237,21 +1263,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
for embeddings, modality in multimodal_embeddings:
|
||||
if modality == "image":
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
embeddings,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
if modality == "video":
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
embeddings,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
@ -30,6 +30,9 @@ class CachedRequestState:
|
||||
num_computed_tokens: int
|
||||
output_token_ids: List[int]
|
||||
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingType
|
||||
@ -139,6 +140,32 @@ class GPUModelRunner:
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
# NOTE: `mrope_positions` is implemented as a permuted tensor to
|
||||
# satisfy the following properties to allow `torch.compile` to work
|
||||
# properly:
|
||||
# - shape: (3, <variable>)
|
||||
# - stride: (1, 3)
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
|
||||
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
self.mrope_positions = self.mrope_positions.permute((1, 0))
|
||||
self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
@ -246,6 +273,35 @@ class GPUModelRunner:
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
for mm_input in self.requests[req_id].mm_inputs:
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.extend(
|
||||
mm_input["image_grid_thw"].tolist())
|
||||
if mm_input.get("video_grid_thw") is not None:
|
||||
video_grid_thw.extend(
|
||||
mm_input["video_grid_thw"].tolist())
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
|
||||
self.requests[req_id].mrope_positions, \
|
||||
self.requests[req_id].mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
self.requests[req_id].prompt_token_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
@ -313,6 +369,11 @@ class GPUModelRunner:
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
@ -359,8 +420,16 @@ class GPUModelRunner:
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
if self.model_config.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
|
||||
@ -472,6 +541,61 @@ class GPUModelRunner:
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
mrope_pos_ptr = 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
assert req_id is not None
|
||||
|
||||
req = self.requests[req_id]
|
||||
assert req.mrope_positions is not None
|
||||
|
||||
num_computed_tokens = \
|
||||
self.input_batch.num_computed_tokens_cpu[index]
|
||||
num_scheduled_tokens = \
|
||||
scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_prompt_tokens = len(req.prompt_token_ids)
|
||||
|
||||
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
||||
prompt_part_len = max(0,
|
||||
num_prompt_tokens - num_computed_tokens)
|
||||
completion_part_len = max(
|
||||
0, num_scheduled_tokens - prompt_part_len)
|
||||
else:
|
||||
prompt_part_len = num_scheduled_tokens
|
||||
completion_part_len = 0
|
||||
|
||||
assert num_scheduled_tokens == prompt_part_len + completion_part_len
|
||||
|
||||
if prompt_part_len > 0:
|
||||
# prompt's mrope_positions are pre-computed
|
||||
dst_start = mrope_pos_ptr
|
||||
dst_end = mrope_pos_ptr + prompt_part_len
|
||||
src_start = num_computed_tokens
|
||||
src_end = num_computed_tokens + prompt_part_len
|
||||
|
||||
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
||||
req.mrope_positions[:,src_start:src_end]
|
||||
|
||||
mrope_pos_ptr += prompt_part_len
|
||||
|
||||
if completion_part_len > 0:
|
||||
# compute completion's mrope_positions on-the-fly
|
||||
dst_start = mrope_pos_ptr
|
||||
dst_end = mrope_pos_ptr + completion_part_len
|
||||
|
||||
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
req.mrope_position_delta,
|
||||
context_len=num_computed_tokens +
|
||||
prompt_part_len,
|
||||
seq_len=num_computed_tokens +
|
||||
prompt_part_len +
|
||||
completion_part_len,
|
||||
)
|
||||
|
||||
mrope_pos_ptr += completion_part_len
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -618,9 +742,12 @@ class GPUModelRunner:
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
positions = self.mrope_positions[:, :num_input_tokens] \
|
||||
if self.model_config.uses_mrope \
|
||||
else self.positions[:num_input_tokens]
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:num_input_tokens],
|
||||
positions=positions,
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -707,9 +834,12 @@ class GPUModelRunner:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
positions = self.mrope_positions[:, :num_tokens] \
|
||||
if self.model_config.uses_mrope \
|
||||
else self.positions[:num_tokens]
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:num_tokens],
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user