[V1][TPU] TPU multimodal model support for ragged attention (#14158)

Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-03-04 19:58:48 -05:00 committed by GitHub
parent 3e1d223626
commit fbfc3ee37e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 194 additions and 30 deletions

View File

@ -15,14 +15,18 @@ from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
@ -72,8 +76,10 @@ class TPUModelRunner:
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
self.max_num_tokens = _get_padded_number(
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
NUM_QUERIES_PER_BLOCK)
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@ -84,6 +90,28 @@ class TPUModelRunner:
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
# TODO: Support M-RoPE (e.g, Qwen2-VL)
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
@ -91,18 +119,9 @@ class TPUModelRunner:
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
vocab_size=model_config.get_vocab_size(),
)
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# KV caches for forward pass
self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = []
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# Sometimes the numpy op is faster so we create both.
@ -164,6 +183,7 @@ class TPUModelRunner:
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@ -177,6 +197,14 @@ class TPUModelRunner:
if req_index is not None:
removed_req_indices.append(req_index)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
@ -426,6 +454,92 @@ class TPUModelRunner:
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = []
req_input_ids: list[tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order.
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
for output in curr_group_outputs:
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
@torch.no_grad()
def execute_model(
self,
@ -434,16 +548,42 @@ class TPUModelRunner:
# Update cached state
self._update_states(scheduler_output)
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
else:
encoder_outputs = []
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
self.input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
input_ids = None
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids
inputs_embeds = None
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
token_ids=self.input_ids,
position_ids=self.position_ids,
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:total_num_scheduled_tokens]
num_reqs = self.input_batch.num_reqs
@ -538,14 +678,21 @@ class TPUModelRunner:
fullgraph=True,
dynamic=False)
def dummy_run(
def _dummy_run(
self,
kv_caches,
num_tokens: int,
) -> None:
input_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
else:
input_ids = torch.zeros((num_tokens),
dtype=torch.int32,
device=self.device)
inputs_embeds = None
position_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
@ -571,7 +718,10 @@ class TPUModelRunner:
num_seqs=num_tokens,
)
torch._dynamo.mark_dynamic(input_ids, 0)
if self.is_multimodal_model:
torch._dynamo.mark_dynamic(inputs_embeds, 0)
else:
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
@ -580,7 +730,12 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(input_ids, position_ids, kv_caches)
self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
def capture_model(self) -> None:
"""Compile the model."""
@ -590,11 +745,11 @@ class TPUModelRunner:
start = time.perf_counter()
num_tokens = 16
while True:
self.dummy_run(self.kv_caches, num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
xm.mark_step()
xm.wait_device_ops()
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
end = time.perf_counter()
@ -647,17 +802,20 @@ class ModelWrapperV1(nn.Module):
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [num_tokens].
position_ids: The input position IDs of shape [num_tokens].
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
@ -684,9 +842,9 @@ class ModelWrapperV1(nn.Module):
assert self.model is not None
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return hidden_states
@ -699,6 +857,12 @@ class ModelWrapperV1(nn.Module):
logits = self.model.compute_logits(hidden_states, sampling_metadata)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)
def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple

View File

@ -134,7 +134,7 @@ class TPUWorker:
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
self.model_runner.dummy_run(
self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)