mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 16:14:03 +08:00
[V1] Support VLMs with fine-grained scheduling (#9871)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
0d4ea3fb5c
commit
bbd3e86926
@ -216,9 +216,11 @@ class GPT2Model(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
else:
|
||||
@ -263,6 +265,9 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -270,9 +275,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
@ -538,6 +538,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -545,9 +548,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@ -448,6 +449,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def process_mm_inputs(self, **kwargs):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if vision_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -455,6 +475,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for LLaVA-1.5.
|
||||
@ -494,24 +515,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"""
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
else:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.process_mm_inputs(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
@ -360,6 +360,9 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -367,9 +370,11 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_list_of
|
||||
@ -500,15 +501,20 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
|
||||
# TODO: Move this to utils or integrate with clip.
|
||||
new_token_ids: List[int] = []
|
||||
placeholder_ranges: List[PlaceholderRange] = []
|
||||
placeholder_idx = 0
|
||||
while merged_token_ids:
|
||||
token_id = merged_token_ids.pop(0)
|
||||
if token_id == _IMAGE_TOKEN_ID:
|
||||
new_token_ids.extend(
|
||||
repeat_and_pad_token(
|
||||
_IMAGE_TOKEN_ID,
|
||||
repeat_count=image_feature_size[placeholder_idx],
|
||||
))
|
||||
replacement_ids = repeat_and_pad_token(
|
||||
_IMAGE_TOKEN_ID,
|
||||
repeat_count=image_feature_size[placeholder_idx],
|
||||
)
|
||||
placeholder_ranges.append({
|
||||
"offset": len(new_token_ids),
|
||||
"length": len(replacement_ids)
|
||||
})
|
||||
new_token_ids.extend(replacement_ids)
|
||||
placeholder_idx += 1
|
||||
else:
|
||||
new_token_ids.append(token_id)
|
||||
@ -516,7 +522,8 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"image": placeholder_ranges})
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@ -669,32 +676,42 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return image_embeds
|
||||
|
||||
def process_mm_inputs(self, **kwargs):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if vision_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.image_token_id)
|
||||
else:
|
||||
inputs_embeds = self.language_model.model.embed_tokens(
|
||||
input_ids)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.process_mm_inputs(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
@ -441,6 +441,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -448,9 +451,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
48
vllm/v1/core/encoder_cache_manager.py
Normal file
48
vllm/v1/core/encoder_cache_manager.py
Normal file
@ -0,0 +1,48 @@
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class EncoderCacheManager:
|
||||
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
self.num_free_slots = cache_size
|
||||
# req_id -> cached input ids
|
||||
self.cached: Dict[str, Set[int]] = {}
|
||||
# List of [req_id, input_id]
|
||||
self.freed: List[Tuple[str, int]] = []
|
||||
|
||||
def has_cache(self, request: Request, input_id: int) -> bool:
|
||||
req_id = request.request_id
|
||||
return req_id in self.cached and input_id in self.cached[req_id]
|
||||
|
||||
def can_allocate(self, request: Request, input_id: int) -> bool:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
return num_tokens <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
self.cached[req_id] = set()
|
||||
self.cached[req_id].add(input_id)
|
||||
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> Set[int]:
|
||||
return self.cached.get(request.request_id, set())
|
||||
|
||||
def free(self, request: Request, input_id: int) -> None:
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
return
|
||||
|
||||
self.cached[req_id].discard(input_id)
|
||||
if len(self.cached[req_id]) == 0:
|
||||
del self.cached[req_id]
|
||||
self.num_free_slots += request.get_num_encoder_tokens(input_id)
|
||||
self.freed.append((req_id, input_id))
|
||||
|
||||
def get_freed_ids(self) -> List[Tuple[str, int]]:
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
return freed
|
||||
@ -1,16 +1,21 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Union
|
||||
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
|
||||
Tuple, Union)
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.engine import EngineCoreOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.base import PlaceholderRange
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -61,12 +66,20 @@ class Scheduler:
|
||||
# Request id -> RunningRequestData
|
||||
self.running_reqs_data: Dict[str, RunningRequestData] = {}
|
||||
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
scheduled_new_reqs: List[Request] = []
|
||||
scheduled_resumed_reqs: List[Request] = []
|
||||
scheduled_running_reqs: List[Request] = []
|
||||
preempted_reqs: List[Request] = []
|
||||
# Encoder-related.
|
||||
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
|
||||
# projector if needed). Currently, we assume that the encoder also
|
||||
# has the Transformer architecture (e.g., ViT).
|
||||
# FIXME(woosuk): Below are placeholder values. We need to calculate the
|
||||
# actual values from the configurations.
|
||||
self.max_num_encoder_input_tokens = 2048
|
||||
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
|
||||
# the encoder cache will not be initialized and used, regardless of
|
||||
# the cache size. This is because the memory space for the encoder cache
|
||||
# is preallocated in the profiling run.
|
||||
self.encoder_cache_manager = EncoderCacheManager(cache_size=2048)
|
||||
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and num_tokens,
|
||||
@ -74,23 +87,45 @@ class Scheduler:
|
||||
# At each step, the scheduler tries to assign tokens to the requests
|
||||
# so that each request's num_computed_tokens can catch up its
|
||||
# num_tokens. This is general enough to cover chunked prefills,
|
||||
# prefix caching, and the "jump forward" optimization in the future.
|
||||
# prefix caching, and the "jump decoding" optimization in the future.
|
||||
|
||||
scheduled_new_reqs: List[Request] = []
|
||||
scheduled_resumed_reqs: List[Request] = []
|
||||
scheduled_running_reqs: List[Request] = []
|
||||
preempted_reqs: List[Request] = []
|
||||
|
||||
req_to_new_block_ids: Dict[str, List[int]] = {}
|
||||
num_scheduled_tokens: Dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: Dict[str, List[int]] = {}
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
# NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
|
||||
# in the "partial" state, where the request has some tokens computed
|
||||
# but not all. The constraint is due to the persistent batch in the
|
||||
# V1 model runner.
|
||||
# TODO(woosuk): Remove this constraint after refactoring model runner.
|
||||
has_partial_request = False
|
||||
req_index = 0
|
||||
while req_index < len(self.running):
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
# Only the last request in the RUNNING queue can be "partial".
|
||||
assert not has_partial_request
|
||||
assert token_budget > 0
|
||||
request = self.running[req_index]
|
||||
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
|
||||
self._try_schedule_encoder_inputs(request,
|
||||
request.num_computed_tokens,
|
||||
num_new_tokens,
|
||||
encoder_budget))
|
||||
assert num_new_tokens > 0
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.append_slots(
|
||||
request, num_new_tokens)
|
||||
@ -106,22 +141,40 @@ class Scheduler:
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
break
|
||||
else:
|
||||
# The request can be scheduled.
|
||||
scheduled_running_reqs.append(request)
|
||||
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in new_blocks
|
||||
]
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
can_schedule = True
|
||||
break
|
||||
if not can_schedule:
|
||||
break
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in new_blocks
|
||||
]
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
has_partial_request = (request.num_computed_tokens + num_new_tokens
|
||||
< request.num_tokens)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting:
|
||||
if has_partial_request:
|
||||
break
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
if token_budget == 0:
|
||||
@ -149,12 +202,21 @@ class Scheduler:
|
||||
computed_blocks.pop()
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Schedule encoder inputs.
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request, num_new_tokens, computed_blocks)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
@ -172,6 +234,18 @@ class Scheduler:
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
has_partial_request = (num_computed_tokens + num_new_tokens <
|
||||
request.num_tokens)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
@ -205,12 +279,14 @@ class Scheduler:
|
||||
scheduled_running_reqs=running_reqs_data,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
preempted_req_ids=preempted_req_ids,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
||||
)
|
||||
|
||||
self.finished_req_ids = set()
|
||||
@ -234,6 +310,72 @@ class Scheduler:
|
||||
self.running_reqs_data[request.request_id] = req_data
|
||||
return req_data
|
||||
|
||||
def _try_schedule_encoder_inputs(
|
||||
self,
|
||||
request: Request,
|
||||
num_computed_tokens: int,
|
||||
num_new_tokens: int,
|
||||
encoder_budget: int,
|
||||
) -> Tuple[List[int], int, int]:
|
||||
"""
|
||||
Determine which encoder inputs need to be scheduled in the current step,
|
||||
and update `num_new_tokens` and encoder token budget accordingly.
|
||||
|
||||
An encoder input will be scheduled if:
|
||||
- Its output tokens overlap with the range of tokens being computed
|
||||
in this step, i.e.,
|
||||
[num_computed_tokens, num_computed_tokens + num_new_tokens).
|
||||
- It is not already computed and stored in the encoder cache.
|
||||
- There is sufficient encoder token budget to process it.
|
||||
- The encoder cache has space to store it.
|
||||
|
||||
If an encoder input cannot be scheduled due to cache or budget
|
||||
limitations, the method adjusts `num_new_tokens` to schedule only the
|
||||
decoder tokens up to just before the unschedulable encoder input.
|
||||
"""
|
||||
if not request.has_encoder_inputs():
|
||||
return [], num_new_tokens, encoder_budget
|
||||
|
||||
encoder_inputs_to_schedule: List[int] = []
|
||||
mm_positions = request.mm_positions
|
||||
assert mm_positions is not None
|
||||
assert len(mm_positions) > 0
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info["offset"]
|
||||
num_encoder_tokens = pos_info["length"]
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
# [start_pos, start_pos + num_encoder_tokens)
|
||||
if start_pos >= num_computed_tokens + num_new_tokens:
|
||||
# The encoder input is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
# The encoder input is already computed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
if self.encoder_cache_manager.has_cache(request, i):
|
||||
# The encoder input is already computed and cached.
|
||||
continue
|
||||
if not self.encoder_cache_manager.can_allocate(request, i):
|
||||
# The encoder cache is full. We can only schedule the decoder
|
||||
# tokens just before the encoder input.
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
if num_encoder_tokens > encoder_budget:
|
||||
# The encoder budget is exhausted. We can only schedule the
|
||||
# decoder tokens up until the encoder input.
|
||||
# NOTE(woosuk): We assume that the encoder tokens should be
|
||||
# processed altogether, as the encoder usually uses
|
||||
# bidirectional attention.
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
|
||||
encoder_budget -= num_encoder_tokens
|
||||
encoder_inputs_to_schedule.append(i)
|
||||
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -251,6 +393,17 @@ class Scheduler:
|
||||
# the request generates output tokens. Otherwise, we ignore the
|
||||
# sampler output for the request.
|
||||
assert request.num_computed_tokens <= request.num_tokens
|
||||
|
||||
cached_encoder_input_ids = (
|
||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||
for input_id in list(cached_encoder_input_ids):
|
||||
start_pos = request.mm_positions[input_id]["offset"]
|
||||
num_tokens = request.mm_positions[input_id]["length"]
|
||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
self.encoder_cache_manager.free(request, input_id)
|
||||
|
||||
if request.num_computed_tokens == request.num_tokens:
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
# NOTE(woosuk): Currently, we assume that each request
|
||||
@ -355,7 +508,8 @@ class NewRequestData:
|
||||
req_id: str
|
||||
prompt_token_ids: List[int]
|
||||
prompt: Optional[str]
|
||||
multi_modal_data: Optional[MultiModalDataDict]
|
||||
mm_inputs: List["MultiModalKwargs"]
|
||||
mm_positions: List["PlaceholderRange"]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
@ -369,9 +523,10 @@ class NewRequestData:
|
||||
) -> "NewRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.inputs["prompt_token_ids"],
|
||||
prompt=request.inputs.get("prompt"),
|
||||
multi_modal_data=request.inputs.get("multi_modal_data"),
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt=request.prompt,
|
||||
mm_inputs=request.mm_inputs,
|
||||
mm_positions=request.mm_positions,
|
||||
sampling_params=request.sampling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
@ -429,6 +584,8 @@ class SchedulerOutput:
|
||||
|
||||
num_scheduled_tokens: Dict[str, int]
|
||||
total_num_scheduled_tokens: int
|
||||
scheduled_encoder_inputs: Dict[str, List[int]]
|
||||
|
||||
preempted_req_ids: Set[str]
|
||||
finished_req_ids: Set[str]
|
||||
free_encoder_input_ids: List[Tuple[str, int]]
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreRequest, EngineCoreRequestType)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
@ -65,6 +66,9 @@ class EngineCore:
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
|
||||
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
|
||||
|
||||
# Setup scheduler.
|
||||
self.scheduler = Scheduler(vllm_config.scheduler_config,
|
||||
vllm_config.cache_config,
|
||||
@ -93,6 +97,12 @@ class EngineCore:
|
||||
"""Add request to the scheduler."""
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
# FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may
|
||||
# take 10-50 ms, which can cause a spike in the latency. We should
|
||||
# consider moving this to a separate thread.
|
||||
if req.mm_data:
|
||||
req.mm_inputs = self.mm_input_mapper.process_inputs(
|
||||
req.mm_data, req.mm_processor_kwargs)
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def abort_requests(self, request_ids: List[str]):
|
||||
|
||||
39
vllm/v1/engine/mm_input_mapper.py
Normal file
39
vllm/v1/engine/mm_input_mapper.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
|
||||
|
||||
class MMInputMapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
|
||||
model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
||||
) -> List[MultiModalKwargs]:
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
image_inputs = [image_inputs]
|
||||
|
||||
# Process each image input separately so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
mm_inputs: List[MultiModalKwargs] = []
|
||||
num_images = len(image_inputs)
|
||||
for i in range(num_images):
|
||||
mm_input = self.multi_modal_input_mapper(
|
||||
{"image": [image_inputs[i]]},
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
mm_inputs.append(mm_input)
|
||||
return mm_inputs
|
||||
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from vllm.inputs.data import DecoderOnlyInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import RequestMetrics
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@ -47,14 +48,30 @@ class Request:
|
||||
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
# Raw multimodal data before the mm input mapper (e.g., PIL images).
|
||||
self.mm_data = inputs.get("multi_modal_data")
|
||||
self.mm_processor_kwargs = inputs.get("mm_processor_kwargs")
|
||||
mm_positions = inputs.get("multi_modal_placeholders")
|
||||
if mm_positions:
|
||||
# FIXME(woosuk): Support other modalities.
|
||||
self.mm_positions = mm_positions.get("image", [])
|
||||
else:
|
||||
self.mm_positions = []
|
||||
# Output of the mm input mapper (e.g., image tensors).
|
||||
self.mm_inputs: List[MultiModalKwargs] = []
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
inputs=DecoderOnlyInputs(type="token",
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt=request.prompt),
|
||||
inputs=DecoderOnlyInputs(
|
||||
type="token",
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt=request.prompt,
|
||||
multi_modal_data=request.mm_data,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
mm_processor_kwargs=request.mm_processor_kwargs,
|
||||
),
|
||||
sampling_params=request.sampling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
@ -96,9 +113,21 @@ class Request:
|
||||
def get_finished_reason(self) -> Union[str, None]:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def has_encoder_inputs(self) -> bool:
|
||||
return self.mm_data is not None
|
||||
|
||||
@property
|
||||
def num_encoder_inputs(self) -> int:
|
||||
return len(self.mm_positions)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_positions)
|
||||
num_tokens = self.mm_positions[input_id]["length"]
|
||||
return num_tokens
|
||||
|
||||
|
||||
class RequestStatus(enum.IntEnum):
|
||||
"""Status of a sequence."""
|
||||
"""Status of a request."""
|
||||
WAITING = 0
|
||||
RUNNING = 1
|
||||
PREEMPTED = 2
|
||||
@ -119,7 +148,7 @@ class RequestStatus(enum.IntEnum):
|
||||
|
||||
|
||||
# Mapping of finished statuses to their finish reasons.
|
||||
# NOTE: The ignored sequences are the sequences whose prompt lengths
|
||||
# NOTE: The ignored requests are the requests whose prompt lengths
|
||||
# are longer than the model's length cap. Therefore, the stop
|
||||
# reason should also be "length" as in OpenAI API.
|
||||
_FINISHED_REASON_MAP = {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -14,9 +14,10 @@ from vllm.compilation.config import CompilationConfig
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.plugins import set_compilation_config
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
||||
@ -27,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.base import PlaceholderRange
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -37,8 +39,8 @@ class GPUModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
):
|
||||
# TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
@ -75,10 +77,16 @@ class GPUModelRunner:
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
|
||||
# Lazy initialization
|
||||
# self.model: nn.Module # Set after load_model
|
||||
self.kv_caches: List[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
@ -96,18 +104,28 @@ class GPUModelRunner:
|
||||
and not self.model_config.enforce_eager)
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
@ -156,7 +174,8 @@ class GPUModelRunner:
|
||||
req_id=req_id,
|
||||
prompt_token_ids=req_data.prompt_token_ids,
|
||||
prompt=req_data.prompt,
|
||||
multi_modal_data=req_data.multi_modal_data,
|
||||
mm_inputs=req_data.mm_inputs,
|
||||
mm_positions=req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=req_data.block_ids,
|
||||
@ -285,11 +304,9 @@ class GPUModelRunner:
|
||||
seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
|
||||
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
|
||||
non_blocking=True)
|
||||
input_ids = input_ids.to(self.device, non_blocking=True)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(positions,
|
||||
non_blocking=True)
|
||||
|
||||
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
|
||||
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
|
||||
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
|
||||
@ -308,7 +325,7 @@ class GPUModelRunner:
|
||||
# token from the partial request.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return attn_metadata, logits_indices
|
||||
return input_ids, attn_metadata, logits_indices
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
@ -325,13 +342,91 @@ class GPUModelRunner:
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_inputs: List[MultiModalKwargs] = []
|
||||
req_input_ids: List[Tuple[int, int]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
for input_id in encoder_input_ids:
|
||||
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||
req_input_ids.append((req_id, input_id))
|
||||
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||
device=self.device)
|
||||
|
||||
# Run the encoder.
|
||||
# `encoder_outputs` is either of the following:
|
||||
# 1. A tensor of shape [num_images, feature_size, hidden_size]
|
||||
# in case when feature_size is fixed across all images.
|
||||
# 2. A list (length: num_images) of tensors, each of shape
|
||||
# [feature_size, hidden_size] in case when the feature size is
|
||||
# dynamic depending on input images.
|
||||
encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs)
|
||||
|
||||
# Cache the encoder outputs.
|
||||
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
self.encoder_cache[req_id][input_id] = output
|
||||
|
||||
def _gather_encoder_outputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> List[torch.Tensor]:
|
||||
encoder_outputs: List[torch.Tensor] = []
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info["offset"]
|
||||
num_encoder_tokens = pos_info["length"]
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens,
|
||||
# num_computed_tokens + num_scheduled_tokens) and
|
||||
# [start_pos, start_pos + num_encoder_tokens)
|
||||
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
||||
# The encoder output is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens)
|
||||
assert start_idx < end_idx
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||
return encoder_outputs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self._update_states(scheduler_output)
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
|
||||
# Run the encoder.
|
||||
self._execute_encoder(scheduler_output)
|
||||
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
input_ids, attn_metadata, logits_indices = self._prepare_inputs(
|
||||
scheduler_output)
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
@ -343,12 +438,26 @@ class GPUModelRunner:
|
||||
# Eager mode.
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
# Get the inputs embeds.
|
||||
if encoder_outputs:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, encoder_outputs)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision embeddings),
|
||||
# always use embeddings (rather than token ids) as input to the model.
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata):
|
||||
hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
input_ids=None,
|
||||
positions=self.positions[:num_input_tokens],
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=self.inputs_embeds[:num_input_tokens],
|
||||
)
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
@ -440,13 +549,16 @@ class GPUModelRunner:
|
||||
with set_forward_context(None): # noqa: SIM117
|
||||
with set_compile_context(self.cudagraph_batch_sizes):
|
||||
# Trigger compilation for general shape.
|
||||
model(self.input_ids,
|
||||
self.positions,
|
||||
dummy_kv_caches,
|
||||
attn_metadata=None)
|
||||
model(input_ids=None,
|
||||
positions=self.positions,
|
||||
kv_caches=dummy_kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=self.inputs_embeds)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# TODO(woosuk): Profile the max memory usage of the encoder and
|
||||
# the encoder cache.
|
||||
self._dummy_run(self.model, self.max_num_tokens)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -468,10 +580,11 @@ class GPUModelRunner:
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
self.model(
|
||||
self.input_ids[:num_tokens],
|
||||
self.positions[:num_tokens],
|
||||
input_ids=None,
|
||||
positions=self.positions[:num_tokens],
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
inputs_embeds=self.inputs_embeds[:num_tokens],
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
@ -506,7 +619,8 @@ class CachedRequestState:
|
||||
req_id: str
|
||||
prompt_token_ids: List[int]
|
||||
prompt: Optional[str]
|
||||
multi_modal_data: Optional["MultiModalDataDict"]
|
||||
mm_inputs: List[MultiModalKwargs]
|
||||
mm_positions: List["PlaceholderRange"]
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user