[V1] Support VLMs with fine-grained scheduling (#9871)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Woosuk Kwon 2024-11-12 20:53:13 -08:00 committed by GitHub
parent 0d4ea3fb5c
commit bbd3e86926
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 542 additions and 96 deletions

View File

@ -216,9 +216,11 @@ class GPT2Model(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
@ -263,6 +265,9 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -270,9 +275,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(

View File

@ -538,6 +538,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
normalize=False,
softmax=False)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -545,9 +548,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -448,6 +449,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def process_mm_inputs(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -455,6 +475,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LLaVA-1.5.
@ -494,24 +515,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
"""
if intermediate_tensors is not None:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -360,6 +360,9 @@ class OPTForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -367,9 +370,11 @@ class OPTForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(

View File

@ -39,6 +39,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of
@ -500,15 +501,20 @@ def input_processor_for_phi3v(ctx: InputContext,
# TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_idx = 0
while merged_token_ids:
token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID:
new_token_ids.extend(
repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
))
replacement_ids = repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_idx += 1
else:
new_token_ids.append(token_id)
@ -516,7 +522,8 @@ def input_processor_for_phi3v(ctx: InputContext,
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@ -669,32 +676,42 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds
def process_mm_inputs(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object):
if intermediate_tensors is not None:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -441,6 +441,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -448,9 +451,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(

View File

@ -0,0 +1,48 @@
from typing import Dict, List, Set, Tuple
from vllm.v1.request import Request
class EncoderCacheManager:
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: Dict[str, Set[int]] = {}
# List of [req_id, input_id]
self.freed: List[Tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool:
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
def can_allocate(self, request: Request, input_id: int) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
def get_cached_input_ids(self, request: Request) -> Set[int]:
return self.cached.get(request.request_id, set())
def free(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
return
self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed

View File

@ -1,16 +1,21 @@
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Union
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
logger = init_logger(__name__)
@ -61,12 +66,20 @@ class Scheduler:
# Request id -> RunningRequestData
self.running_reqs_data: Dict[str, RunningRequestData] = {}
def schedule(self) -> "SchedulerOutput":
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
scheduled_running_reqs: List[Request] = []
preempted_reqs: List[Request] = []
# Encoder-related.
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
self.max_num_encoder_input_tokens = 2048
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=2048)
def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens,
@ -74,23 +87,45 @@ class Scheduler:
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills,
# prefix caching, and the "jump forward" optimization in the future.
# prefix caching, and the "jump decoding" optimization in the future.
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
scheduled_running_reqs: List[Request] = []
preempted_reqs: List[Request] = []
req_to_new_block_ids: Dict[str, List[int]] = {}
num_scheduled_tokens: Dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: Dict[str, List[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# First, schedule the RUNNING requests.
# NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
# in the "partial" state, where the request has some tokens computed
# but not all. The constraint is due to the persistent batch in the
# V1 model runner.
# TODO(woosuk): Remove this constraint after refactoring model runner.
has_partial_request = False
req_index = 0
while req_index < len(self.running):
if token_budget == 0:
break
# Only the last request in the RUNNING queue can be "partial".
assert not has_partial_request
assert token_budget > 0
request = self.running[req_index]
num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = (
self._try_schedule_encoder_inputs(request,
request.num_computed_tokens,
num_new_tokens,
encoder_budget))
assert num_new_tokens > 0
while True:
new_blocks = self.kv_cache_manager.append_slots(
request, num_new_tokens)
@ -106,22 +141,40 @@ class Scheduler:
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
can_schedule = True
break
if not can_schedule:
break
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
has_partial_request = (request.num_computed_tokens + num_new_tokens
< request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting:
if has_partial_request:
break
if len(self.running) == self.max_num_running_reqs:
break
if token_budget == 0:
@ -149,12 +202,21 @@ class Scheduler:
computed_blocks.pop()
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks)
if new_blocks is None:
# The request cannot be scheduled.
break
request.num_computed_tokens = num_computed_tokens
self.waiting.popleft()
self.running.append(request)
@ -172,6 +234,18 @@ class Scheduler:
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
has_partial_request = (num_computed_tokens + num_new_tokens <
request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@ -205,12 +279,14 @@ class Scheduler:
scheduled_running_reqs=running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
)
self.finished_req_ids = set()
@ -234,6 +310,72 @@ class Scheduler:
self.running_reqs_data[request.request_id] = req_data
return req_data
def _try_schedule_encoder_inputs(
self,
request: Request,
num_computed_tokens: int,
num_new_tokens: int,
encoder_budget: int,
) -> Tuple[List[int], int, int]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
An encoder input will be scheduled if:
- Its output tokens overlap with the range of tokens being computed
in this step, i.e.,
[num_computed_tokens, num_computed_tokens + num_new_tokens).
- It is not already computed and stored in the encoder cache.
- There is sufficient encoder token budget to process it.
- The encoder cache has space to store it.
If an encoder input cannot be scheduled due to cache or budget
limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input.
"""
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: List[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
assert len(mm_positions) > 0
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_new_tokens:
# The encoder input is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue
if self.encoder_cache_manager.has_cache(request, i):
# The encoder input is already computed and cached.
continue
if not self.encoder_cache_manager.can_allocate(request, i):
# The encoder cache is full. We can only schedule the decoder
# tokens just before the encoder input.
num_new_tokens = start_pos - num_computed_tokens
break
if num_encoder_tokens > encoder_budget:
# The encoder budget is exhausted. We can only schedule the
# decoder tokens up until the encoder input.
# NOTE(woosuk): We assume that the encoder tokens should be
# processed altogether, as the encoder usually uses
# bidirectional attention.
num_new_tokens = start_pos - num_computed_tokens
break
encoder_budget -= num_encoder_tokens
encoder_inputs_to_schedule.append(i)
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
@ -251,6 +393,17 @@ class Scheduler:
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
assert request.num_computed_tokens <= request.num_tokens
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"]
num_tokens = request.mm_positions[input_id]["length"]
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free(request, input_id)
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request
@ -355,7 +508,8 @@ class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional[MultiModalDataDict]
mm_inputs: List["MultiModalKwargs"]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
@ -369,9 +523,10 @@ class NewRequestData:
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.inputs["prompt_token_ids"],
prompt=request.inputs.get("prompt"),
multi_modal_data=request.inputs.get("multi_modal_data"),
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
@ -429,6 +584,8 @@ class SchedulerOutput:
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
preempted_req_ids: Set[str]
finished_req_ids: Set[str]
free_encoder_input_ids: List[Tuple[str, int]]

View File

@ -17,6 +17,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
@ -65,6 +66,9 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
@ -93,6 +97,12 @@ class EngineCore:
"""Add request to the scheduler."""
req = Request.from_engine_core_request(request)
# FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may
# take 10-50 ms, which can cause a spike in the latency. We should
# consider moving this to a separate thread.
if req.mm_data:
req.mm_inputs = self.mm_input_mapper.process_inputs(
req.mm_data, req.mm_processor_kwargs)
self.scheduler.add_request(req)
def abort_requests(self, request_ids: List[str]):

View File

@ -0,0 +1,39 @@
from typing import Any, Dict, List, Optional
from vllm.config import ModelConfig
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
class MMInputMapper:
def __init__(
self,
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]],
) -> List[MultiModalKwargs]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
# Process each image input separately so that later we can schedule
# them in a fine-grained manner.
mm_inputs: List[MultiModalKwargs] = []
num_images = len(image_inputs)
for i in range(num_images):
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)
mm_inputs.append(mm_input)
return mm_inputs

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs.data import DecoderOnlyInputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
@ -47,14 +48,30 @@ class Request:
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Raw multimodal data before the mm input mapper (e.g., PIL images).
self.mm_data = inputs.get("multi_modal_data")
self.mm_processor_kwargs = inputs.get("mm_processor_kwargs")
mm_positions = inputs.get("multi_modal_placeholders")
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
request_id=request.request_id,
inputs=DecoderOnlyInputs(type="token",
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt),
inputs=DecoderOnlyInputs(
type="token",
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
multi_modal_data=request.mm_data,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=request.mm_processor_kwargs,
),
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
@ -96,9 +113,21 @@ class Request:
def get_finished_reason(self) -> Union[str, None]:
return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return self.mm_data is not None
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_positions)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
class RequestStatus(enum.IntEnum):
"""Status of a sequence."""
"""Status of a request."""
WAITING = 0
RUNNING = 1
PREEMPTED = 2
@ -119,7 +148,7 @@ class RequestStatus(enum.IntEnum):
# Mapping of finished statuses to their finish reasons.
# NOTE: The ignored sequences are the sequences whose prompt lengths
# NOTE: The ignored requests are the requests whose prompt lengths
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {

View File

@ -1,7 +1,7 @@
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import numpy as np
import torch
@ -14,9 +14,10 @@ from vllm.compilation.config import CompilationConfig
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal import MultiModalKwargs
from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
@ -27,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.multimodal.base import PlaceholderRange
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
@ -37,8 +39,8 @@ class GPUModelRunner:
def __init__(
self,
vllm_config: VllmConfig,
input_registry: InputRegistry = INPUT_REGISTRY,
):
# TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
@ -75,10 +77,16 @@ class GPUModelRunner:
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = input_registry
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
@ -96,18 +104,28 @@ class GPUModelRunner:
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
@ -156,7 +174,8 @@ class GPUModelRunner:
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
multi_modal_data=req_data.multi_modal_data,
mm_inputs=req_data.mm_inputs,
mm_positions=req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=req_data.block_ids,
@ -285,11 +304,9 @@ class GPUModelRunner:
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
non_blocking=True)
input_ids = input_ids.to(self.device, non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
@ -308,7 +325,7 @@ class GPUModelRunner:
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
return input_ids, attn_metadata, logits_indices
def _prepare_sampling(
self,
@ -325,13 +342,91 @@ class GPUModelRunner:
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = []
req_input_ids: List[Tuple[int, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
# Run the encoder.
# `encoder_outputs` is either of the following:
# 1. A tensor of shape [num_images, feature_size, hidden_size]
# in case when feature_size is fixed across all images.
# 2. A list (length: num_images) of tensors, each of shape
# [feature_size, hidden_size] in case when the feature size is
# dynamic depending on input images.
encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> List[torch.Tensor]:
encoder_outputs: List[torch.Tensor] = []
num_reqs = self.input_batch.num_reqs
for req_id in self.input_batch.req_ids[:num_reqs]:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
# Run the encoder.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
# Prepare the decoder inputs.
input_ids, attn_metadata, logits_indices = self._prepare_inputs(
scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -343,12 +438,26 @@ class GPUModelRunner:
# Eager mode.
num_input_tokens = num_scheduled_tokens
# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# NOTE(woosuk): To unify token ids and soft tokens (vision embeddings),
# always use embeddings (rather than token ids) as input to the model.
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata):
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
input_ids=None,
positions=self.positions[:num_input_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_input_tokens],
)
hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices]
@ -440,13 +549,16 @@ class GPUModelRunner:
with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
model(self.input_ids,
self.positions,
dummy_kv_caches,
attn_metadata=None)
model(input_ids=None,
positions=self.positions,
kv_caches=dummy_kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds)
@torch.inference_mode()
def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
self._dummy_run(self.model, self.max_num_tokens)
torch.cuda.synchronize()
@ -468,10 +580,11 @@ class GPUModelRunner:
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
input_ids=None,
positions=self.positions[:num_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens],
)
end_time = time.perf_counter()
@ -506,7 +619,8 @@ class CachedRequestState:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalDataDict"]
mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
generator: Optional[torch.Generator]