mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 03:15:01 +08:00
[V1][TPU] TPU multimodal model support for ragged attention (#14158)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
3e1d223626
commit
fbfc3ee37e
@ -15,14 +15,18 @@ from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
NUM_QUERIES_PER_BLOCK,
|
||||
PallasAttentionBackend,
|
||||
PallasMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
@ -72,8 +76,10 @@ class TPUModelRunner:
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = _get_padded_number(
|
||||
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
|
||||
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
|
||||
NUM_QUERIES_PER_BLOCK)
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
@ -84,6 +90,28 @@ class TPUModelRunner:
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
# TODO: Support M-RoPE (e.g, Qwen2-VL)
|
||||
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=model_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# Lazy initialization
|
||||
# self.model: nn.Module # Set after load_model
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
@ -91,18 +119,9 @@ class TPUModelRunner:
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
|
||||
# KV caches for forward pass
|
||||
self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# Cached torch/numpy tensor
|
||||
# The pytorch tensor and numpy array share the same buffer.
|
||||
# Sometimes the numpy op is faster so we create both.
|
||||
@ -164,6 +183,7 @@ class TPUModelRunner:
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
@ -177,6 +197,14 @@ class TPUModelRunner:
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the unscheduled requests from the persistent batch.
|
||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||
# or running requests that are not scheduled in this step. We remove
|
||||
@ -426,6 +454,92 @@ class TPUModelRunner:
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_inputs: list[MultiModalKwargs] = []
|
||||
req_input_ids: list[tuple[str, int]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
for input_id in encoder_input_ids:
|
||||
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||
req_input_ids.append((req_id, input_id))
|
||||
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
# we process it separately to preserve item order.
|
||||
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
|
||||
# in the same batch while still being able to benefit from batching
|
||||
# multimodal inputs. The proper solution should be reordering the
|
||||
# encoder outputs.
|
||||
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
||||
|
||||
encoder_outputs = []
|
||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||
device=self.device)
|
||||
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
||||
# in case feature_size is fixed across all multimodal items.
|
||||
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||
# depending on the input multimodal items.
|
||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||
**batched_mm_inputs)
|
||||
|
||||
for output in curr_group_outputs:
|
||||
encoder_outputs.append(output)
|
||||
|
||||
# Cache the encoder outputs.
|
||||
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
self.encoder_cache[req_id][input_id] = output
|
||||
|
||||
def _gather_encoder_outputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> list[torch.Tensor]:
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info["offset"]
|
||||
num_encoder_tokens = pos_info["length"]
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens,
|
||||
# num_computed_tokens + num_scheduled_tokens) and
|
||||
# [start_pos, start_pos + num_encoder_tokens)
|
||||
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
||||
# The encoder output is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens)
|
||||
assert start_idx < end_idx
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||
return encoder_outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -434,16 +548,42 @@ class TPUModelRunner:
|
||||
# Update cached state
|
||||
self._update_states(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_encoder(scheduler_output)
|
||||
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
||||
else:
|
||||
encoder_outputs = []
|
||||
|
||||
# Prepare inputs
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
if encoder_outputs:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
self.input_ids, encoder_outputs)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
|
||||
input_ids = None
|
||||
else:
|
||||
# For text-only models, we use token ids as input.
|
||||
# While it is possible to use embeddings as input just like the
|
||||
# multimodal models, it is not desirable for performance since
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
token_ids=self.input_ids,
|
||||
position_ids=self.position_ids,
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = hidden_states[:total_num_scheduled_tokens]
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -538,14 +678,21 @@ class TPUModelRunner:
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def dummy_run(
|
||||
def _dummy_run(
|
||||
self,
|
||||
kv_caches,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
input_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
else:
|
||||
input_ids = torch.zeros((num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
inputs_embeds = None
|
||||
position_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
@ -571,7 +718,10 @@ class TPUModelRunner:
|
||||
num_seqs=num_tokens,
|
||||
)
|
||||
|
||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||
if self.is_multimodal_model:
|
||||
torch._dynamo.mark_dynamic(inputs_embeds, 0)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
@ -580,7 +730,12 @@ class TPUModelRunner:
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
self.model(input_ids, position_ids, kv_caches)
|
||||
self.model(
|
||||
input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
@ -590,11 +745,11 @@ class TPUModelRunner:
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
while True:
|
||||
self.dummy_run(self.kv_caches, num_tokens)
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
end = time.perf_counter()
|
||||
@ -647,17 +802,20 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
|
||||
Args:
|
||||
token_ids: The input token IDs of shape [num_tokens].
|
||||
position_ids: The input position IDs of shape [num_tokens].
|
||||
input_ids: The input token IDs of shape [num_tokens].
|
||||
positions: The input position IDs of shape [num_tokens].
|
||||
kv_caches: The key and value caches. They can be None during the
|
||||
memory profiling at initialization.
|
||||
inputs_embeds: The input embeddings of shape [num_tokens,
|
||||
hidden_size]. It is used for multimodal models.
|
||||
"""
|
||||
# Skip this in memory profiling at initialization.
|
||||
if kv_caches[0][0].numel() > 0:
|
||||
@ -684,9 +842,9 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
token_ids,
|
||||
position_ids,
|
||||
kv_caches,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@ -699,6 +857,12 @@ class ModelWrapperV1(nn.Module):
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
|
||||
def get_input_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_input_embeddings(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_padded_number(n: int, multiple: int) -> int:
|
||||
return ((n + multiple - 1) // multiple) * multiple
|
||||
|
||||
@ -134,7 +134,7 @@ class TPUWorker:
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
|
||||
self.model_runner.dummy_run(
|
||||
self.model_runner._dummy_run(
|
||||
runner_kv_caches,
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user