[V1] Support VLMs with fine-grained scheduling (#9871)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Woosuk Kwon 2024-11-12 20:53:13 -08:00 committed by GitHub
parent 0d4ea3fb5c
commit bbd3e86926
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 542 additions and 96 deletions

View File

@ -216,9 +216,11 @@ class GPT2Model(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids) if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
else: else:
@ -263,6 +265,9 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -270,9 +275,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -538,6 +538,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
normalize=False, normalize=False,
softmax=False) softmax=False)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -545,9 +548,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -448,6 +449,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def process_mm_inputs(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -455,6 +475,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LLaVA-1.5. """Run forward pass for LLaVA-1.5.
@ -494,24 +515,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
""" """
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
else: elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs) vision_embeddings = self.process_mm_inputs(**kwargs)
if image_input is not None: # always pass the input via `inputs_embeds`
vision_embeddings = self._process_image_input(image_input) # to make sure the computation graph is consistent
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.get_input_embeddings(input_ids,
input_ids) vision_embeddings)
input_ids = None
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,

View File

@ -360,6 +360,9 @@ class OPTForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -367,9 +370,11 @@ class OPTForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -39,6 +39,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -500,15 +501,20 @@ def input_processor_for_phi3v(ctx: InputContext,
# TODO: Move this to utils or integrate with clip. # TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = [] new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_idx = 0 placeholder_idx = 0
while merged_token_ids: while merged_token_ids:
token_id = merged_token_ids.pop(0) token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID: if token_id == _IMAGE_TOKEN_ID:
new_token_ids.extend( replacement_ids = repeat_and_pad_token(
repeat_and_pad_token( _IMAGE_TOKEN_ID,
_IMAGE_TOKEN_ID, repeat_count=image_feature_size[placeholder_idx],
repeat_count=image_feature_size[placeholder_idx], )
)) placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_idx += 1 placeholder_idx += 1
else: else:
new_token_ids.append(token_id) new_token_ids.append(token_id)
@ -516,7 +522,8 @@ def input_processor_for_phi3v(ctx: InputContext,
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@ -669,32 +676,42 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds return image_embeds
def process_mm_inputs(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
return inputs_embeds
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object): **kwargs: object):
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
else: elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs) vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
if image_input is not None: # to make sure the computation graph is consistent
vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.get_input_embeddings(input_ids,
inputs_embeds = self.embed_tokens(input_ids) vision_embeddings)
inputs_embeds = merge_multimodal_embeddings( input_ids = None
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,

View File

@ -441,6 +441,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -448,9 +451,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -0,0 +1,48 @@
from typing import Dict, List, Set, Tuple
from vllm.v1.request import Request
class EncoderCacheManager:
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: Dict[str, Set[int]] = {}
# List of [req_id, input_id]
self.freed: List[Tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool:
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
def can_allocate(self, request: Request, input_id: int) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
def get_cached_input_ids(self, request: Request) -> Set[int]:
return self.cached.get(request.request_id, set())
def free(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
return
self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed

View File

@ -1,16 +1,21 @@
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Union from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
logger = init_logger(__name__) logger = init_logger(__name__)
@ -61,12 +66,20 @@ class Scheduler:
# Request id -> RunningRequestData # Request id -> RunningRequestData
self.running_reqs_data: Dict[str, RunningRequestData] = {} self.running_reqs_data: Dict[str, RunningRequestData] = {}
def schedule(self) -> "SchedulerOutput": # Encoder-related.
scheduled_new_reqs: List[Request] = [] # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
scheduled_resumed_reqs: List[Request] = [] # projector if needed). Currently, we assume that the encoder also
scheduled_running_reqs: List[Request] = [] # has the Transformer architecture (e.g., ViT).
preempted_reqs: List[Request] = [] # FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
self.max_num_encoder_input_tokens = 2048
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=2048)
def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler. # There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens, # Each request just has the num_computed_tokens and num_tokens,
@ -74,23 +87,45 @@ class Scheduler:
# At each step, the scheduler tries to assign tokens to the requests # At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its # so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills, # num_tokens. This is general enough to cover chunked prefills,
# prefix caching, and the "jump forward" optimization in the future. # prefix caching, and the "jump decoding" optimization in the future.
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
scheduled_running_reqs: List[Request] = []
preempted_reqs: List[Request] = []
req_to_new_block_ids: Dict[str, List[int]] = {} req_to_new_block_ids: Dict[str, List[int]] = {}
num_scheduled_tokens: Dict[str, int] = {} num_scheduled_tokens: Dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: Dict[str, List[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# First, schedule the RUNNING requests. # First, schedule the RUNNING requests.
# NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
# in the "partial" state, where the request has some tokens computed
# but not all. The constraint is due to the persistent batch in the
# V1 model runner.
# TODO(woosuk): Remove this constraint after refactoring model runner.
has_partial_request = False
req_index = 0 req_index = 0
while req_index < len(self.running): while req_index < len(self.running):
if token_budget == 0: # Only the last request in the RUNNING queue can be "partial".
break assert not has_partial_request
assert token_budget > 0
request = self.running[req_index] request = self.running[req_index]
num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs.
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
self._try_schedule_encoder_inputs(request,
request.num_computed_tokens,
num_new_tokens,
encoder_budget))
assert num_new_tokens > 0
while True: while True:
new_blocks = self.kv_cache_manager.append_slots( new_blocks = self.kv_cache_manager.append_slots(
request, num_new_tokens) request, num_new_tokens)
@ -106,22 +141,40 @@ class Scheduler:
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
if preempted_req == request: if preempted_req == request:
# No more request to preempt. # No more request to preempt.
can_schedule = False
break break
else: else:
# The request can be scheduled. # The request can be scheduled.
scheduled_running_reqs.append(request) can_schedule = True
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
break break
if not can_schedule:
break
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
has_partial_request = (request.num_computed_tokens + num_new_tokens
< request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
while self.waiting: while self.waiting:
if has_partial_request:
break
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
if token_budget == 0: if token_budget == 0:
@ -149,12 +202,21 @@ class Scheduler:
computed_blocks.pop() computed_blocks.pop()
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs.
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks) request, num_new_tokens, computed_blocks)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
break break
request.num_computed_tokens = num_computed_tokens
self.waiting.popleft() self.waiting.popleft()
self.running.append(request) self.running.append(request)
@ -172,6 +234,18 @@ class Scheduler:
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
has_partial_request = (num_computed_tokens + num_new_tokens <
request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@ -205,12 +279,14 @@ class Scheduler:
scheduled_running_reqs=running_reqs_data, scheduled_running_reqs=running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
preempted_req_ids=preempted_req_ids, preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler, # finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step. # instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between # It contains the request IDs that are finished in between
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
) )
self.finished_req_ids = set() self.finished_req_ids = set()
@ -234,6 +310,72 @@ class Scheduler:
self.running_reqs_data[request.request_id] = req_data self.running_reqs_data[request.request_id] = req_data
return req_data return req_data
def _try_schedule_encoder_inputs(
self,
request: Request,
num_computed_tokens: int,
num_new_tokens: int,
encoder_budget: int,
) -> Tuple[List[int], int, int]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
An encoder input will be scheduled if:
- Its output tokens overlap with the range of tokens being computed
in this step, i.e.,
[num_computed_tokens, num_computed_tokens + num_new_tokens).
- It is not already computed and stored in the encoder cache.
- There is sufficient encoder token budget to process it.
- The encoder cache has space to store it.
If an encoder input cannot be scheduled due to cache or budget
limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input.
"""
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: List[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
assert len(mm_positions) > 0
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_new_tokens:
# The encoder input is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue
if self.encoder_cache_manager.has_cache(request, i):
# The encoder input is already computed and cached.
continue
if not self.encoder_cache_manager.can_allocate(request, i):
# The encoder cache is full. We can only schedule the decoder
# tokens just before the encoder input.
num_new_tokens = start_pos - num_computed_tokens
break
if num_encoder_tokens > encoder_budget:
# The encoder budget is exhausted. We can only schedule the
# decoder tokens up until the encoder input.
# NOTE(woosuk): We assume that the encoder tokens should be
# processed altogether, as the encoder usually uses
# bidirectional attention.
num_new_tokens = start_pos - num_computed_tokens
break
encoder_budget -= num_encoder_tokens
encoder_inputs_to_schedule.append(i)
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
def update_from_output( def update_from_output(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
@ -251,6 +393,17 @@ class Scheduler:
# the request generates output tokens. Otherwise, we ignore the # the request generates output tokens. Otherwise, we ignore the
# sampler output for the request. # sampler output for the request.
assert request.num_computed_tokens <= request.num_tokens assert request.num_computed_tokens <= request.num_tokens
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"]
num_tokens = request.mm_positions[input_id]["length"]
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free(request, input_id)
if request.num_computed_tokens == request.num_tokens: if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request # NOTE(woosuk): Currently, we assume that each request
@ -355,7 +508,8 @@ class NewRequestData:
req_id: str req_id: str
prompt_token_ids: List[int] prompt_token_ids: List[int]
prompt: Optional[str] prompt: Optional[str]
multi_modal_data: Optional[MultiModalDataDict] mm_inputs: List["MultiModalKwargs"]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams sampling_params: SamplingParams
block_ids: List[int] block_ids: List[int]
num_computed_tokens: int num_computed_tokens: int
@ -369,9 +523,10 @@ class NewRequestData:
) -> "NewRequestData": ) -> "NewRequestData":
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
prompt_token_ids=request.inputs["prompt_token_ids"], prompt_token_ids=request.prompt_token_ids,
prompt=request.inputs.get("prompt"), prompt=request.prompt,
multi_modal_data=request.inputs.get("multi_modal_data"), mm_inputs=request.mm_inputs,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
block_ids=block_ids, block_ids=block_ids,
num_computed_tokens=num_computed_tokens, num_computed_tokens=num_computed_tokens,
@ -429,6 +584,8 @@ class SchedulerOutput:
num_scheduled_tokens: Dict[str, int] num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
preempted_req_ids: Set[str] preempted_req_ids: Set[str]
finished_req_ids: Set[str] finished_req_ids: Set[str]
free_encoder_input_ids: List[Tuple[str, int]]

View File

@ -17,6 +17,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType) EngineCoreRequest, EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder from vllm.v1.serial_utils import PickleEncoder
@ -65,6 +66,9 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
# Setup scheduler. # Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config, self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config, vllm_config.cache_config,
@ -93,6 +97,12 @@ class EngineCore:
"""Add request to the scheduler.""" """Add request to the scheduler."""
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
# FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may
# take 10-50 ms, which can cause a spike in the latency. We should
# consider moving this to a separate thread.
if req.mm_data:
req.mm_inputs = self.mm_input_mapper.process_inputs(
req.mm_data, req.mm_processor_kwargs)
self.scheduler.add_request(req) self.scheduler.add_request(req)
def abort_requests(self, request_ids: List[str]): def abort_requests(self, request_ids: List[str]):

View File

@ -0,0 +1,39 @@
from typing import Any, Dict, List, Optional
from vllm.config import ModelConfig
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
class MMInputMapper:
def __init__(
self,
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]],
) -> List[MultiModalKwargs]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
# Process each image input separately so that later we can schedule
# them in a fine-grained manner.
mm_inputs: List[MultiModalKwargs] = []
num_images = len(image_inputs)
for i in range(num_images):
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)
mm_inputs.append(mm_input)
return mm_inputs

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs.data import DecoderOnlyInputs from vllm.inputs.data import DecoderOnlyInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
@ -47,14 +48,30 @@ class Request:
self._all_token_ids: List[int] = self.prompt_token_ids.copy() self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Raw multimodal data before the mm input mapper (e.g., PIL images).
self.mm_data = inputs.get("multi_modal_data")
self.mm_processor_kwargs = inputs.get("mm_processor_kwargs")
mm_positions = inputs.get("multi_modal_placeholders")
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
inputs=DecoderOnlyInputs(type="token", inputs=DecoderOnlyInputs(
prompt_token_ids=request.prompt_token_ids, type="token",
prompt=request.prompt), prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
multi_modal_data=request.mm_data,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=request.mm_processor_kwargs,
),
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time, arrival_time=request.arrival_time,
@ -96,9 +113,21 @@ class Request:
def get_finished_reason(self) -> Union[str, None]: def get_finished_reason(self) -> Union[str, None]:
return RequestStatus.get_finished_reason(self.status) return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return self.mm_data is not None
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_positions)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
class RequestStatus(enum.IntEnum): class RequestStatus(enum.IntEnum):
"""Status of a sequence.""" """Status of a request."""
WAITING = 0 WAITING = 0
RUNNING = 1 RUNNING = 1
PREEMPTED = 2 PREEMPTED = 2
@ -119,7 +148,7 @@ class RequestStatus(enum.IntEnum):
# Mapping of finished statuses to their finish reasons. # Mapping of finished statuses to their finish reasons.
# NOTE: The ignored sequences are the sequences whose prompt lengths # NOTE: The ignored requests are the requests whose prompt lengths
# are longer than the model's length cap. Therefore, the stop # are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API. # reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = { _FINISHED_REASON_MAP = {

View File

@ -1,7 +1,7 @@
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
import torch import torch
@ -14,9 +14,10 @@ from vllm.compilation.config import CompilationConfig
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalKwargs
from vllm.plugins import set_compilation_config from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
@ -27,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.base import PlaceholderRange
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
@ -37,8 +39,8 @@ class GPUModelRunner:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
input_registry: InputRegistry = INPUT_REGISTRY,
): ):
# TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
@ -75,10 +77,16 @@ class GPUModelRunner:
parallel_config) parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = input_registry
# Lazy initialization # Lazy initialization
# self.model: nn.Module # Set after load_model # self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = [] self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Request states. # Request states.
self.requests: Dict[str, CachedRequestState] = {} self.requests: Dict[str, CachedRequestState] = {}
@ -96,18 +104,28 @@ class GPUModelRunner:
and not self.model_config.enforce_eager) and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens, self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states. # Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests. # Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None) self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch. # Remove the requests from the persistent batch.
stopped_req_ids = set().union( stopped_req_ids = set().union(
@ -156,7 +174,8 @@ class GPUModelRunner:
req_id=req_id, req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids, prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt, prompt=req_data.prompt,
multi_modal_data=req_data.multi_modal_data, mm_inputs=req_data.mm_inputs,
mm_positions=req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
generator=generator, generator=generator,
block_ids=req_data.block_ids, block_ids=req_data.block_ids,
@ -285,11 +304,9 @@ class GPUModelRunner:
seq_start_loc_np[0] = 0 seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:]) np.cumsum(seq_lens, out=seq_start_loc_np[1:])
self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, input_ids = input_ids.to(self.device, non_blocking=True)
non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions, self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True) non_blocking=True)
query_start_loc = query_start_loc.to(self.device, non_blocking=True) query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
@ -308,7 +325,7 @@ class GPUModelRunner:
# token from the partial request. # token from the partial request.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices return input_ids, attn_metadata, logits_indices
def _prepare_sampling( def _prepare_sampling(
self, self,
@ -325,13 +342,91 @@ class GPUModelRunner:
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
return sampling_metadata return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = []
req_input_ids: List[Tuple[int, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
# Run the encoder.
# `encoder_outputs` is either of the following:
# 1. A tensor of shape [num_images, feature_size, hidden_size]
# in case when feature_size is fixed across all images.
# 2. A list (length: num_images) of tensors, each of shape
# [feature_size, hidden_size] in case when the feature size is
# dynamic depending on input images.
encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> List[torch.Tensor]:
encoder_outputs: List[torch.Tensor] = []
num_reqs = self.input_batch.num_reqs
for req_id in self.input_batch.req_ids[:num_reqs]:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
self._update_states(scheduler_output) self._update_states(scheduler_output)
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
# Run the encoder.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
# Prepare the decoder inputs.
input_ids, attn_metadata, logits_indices = self._prepare_inputs(
scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -343,12 +438,26 @@ class GPUModelRunner:
# Eager mode. # Eager mode.
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# NOTE(woosuk): To unify token ids and soft tokens (vision embeddings),
# always use embeddings (rather than token ids) as input to the model.
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata):
hidden_states = self.model( hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], input_ids=None,
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
kv_caches=self.kv_caches, kv_caches=self.kv_caches,
attn_metadata=None, attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_input_tokens],
) )
hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices] hidden_states = hidden_states[logits_indices]
@ -440,13 +549,16 @@ class GPUModelRunner:
with set_forward_context(None): # noqa: SIM117 with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes): with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape. # Trigger compilation for general shape.
model(self.input_ids, model(input_ids=None,
self.positions, positions=self.positions,
dummy_kv_caches, kv_caches=dummy_kv_caches,
attn_metadata=None) attn_metadata=None,
inputs_embeds=self.inputs_embeds)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
self._dummy_run(self.model, self.max_num_tokens) self._dummy_run(self.model, self.max_num_tokens)
torch.cuda.synchronize() torch.cuda.synchronize()
@ -468,10 +580,11 @@ class GPUModelRunner:
# can reuse the memory pool allocated for the large shapes. # can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes): for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model( self.model(
self.input_ids[:num_tokens], input_ids=None,
self.positions[:num_tokens], positions=self.positions[:num_tokens],
kv_caches=self.kv_caches, kv_caches=self.kv_caches,
attn_metadata=None, attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens],
) )
end_time = time.perf_counter() end_time = time.perf_counter()
@ -506,7 +619,8 @@ class CachedRequestState:
req_id: str req_id: str
prompt_token_ids: List[int] prompt_token_ids: List[int]
prompt: Optional[str] prompt: Optional[str]
multi_modal_data: Optional["MultiModalDataDict"] mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams sampling_params: SamplingParams
generator: Optional[torch.Generator] generator: Optional[torch.Generator]