mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:36:02 +08:00
[TPU][V1] Capture multimodal encoder during model compilation (#15051)
Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Siyuan Liu <lsiyuan@google.com>
This commit is contained in:
parent
71eda0bb76
commit
210207525e
@ -17,7 +17,7 @@ source /etc/environment
|
|||||||
docker run --privileged --net host --shm-size=16G -it \
|
docker run --privileged --net host --shm-size=16G -it \
|
||||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||||
&& python3 -m pip install pytest tpu-info \
|
&& python3 -m pip install pytest pytest-asyncio tpu-info \
|
||||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||||
&& export VLLM_USE_V1=1 \
|
&& export VLLM_USE_V1=1 \
|
||||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||||
@ -42,6 +42,8 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
&& echo TEST_8 \
|
&& echo TEST_8 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
|
||||||
&& echo TEST_9 \
|
&& echo TEST_9 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \
|
||||||
|
&& echo TEST_10 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
91
tests/v1/tpu/test_multimodal.py
Normal file
91
tests/v1/tpu/test_multimodal.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
pytest.skip(
|
||||||
|
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def base64_encoded_image() -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
image_url: encode_image_base64(fetch_image(image_url))
|
||||||
|
for image_url in TEST_IMAGE_URLS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This test needs a TPU")
|
||||||
|
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
|
||||||
|
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str,
|
||||||
|
str]):
|
||||||
|
|
||||||
|
def whats_in_this_image_msg(b64):
|
||||||
|
return [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{b64}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
server_args = [
|
||||||
|
"--max-model-len",
|
||||||
|
"1024",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"16",
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.95",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--max-num-batched-tokens",
|
||||||
|
"576",
|
||||||
|
# NOTE: max-num-batched-tokens>=mm_item_size
|
||||||
|
"--disable_chunked_mm_input",
|
||||||
|
"--chat-template",
|
||||||
|
"examples/template_llava.jinja"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Server will pre-compile on first startup (takes a long time).
|
||||||
|
with RemoteOpenAIServer(model_name, server_args,
|
||||||
|
max_wait_seconds=600) as remote_server:
|
||||||
|
client: openai.AsyncOpenAI = remote_server.get_async_client()
|
||||||
|
|
||||||
|
# Other requests now should be much faster
|
||||||
|
for image_url in TEST_IMAGE_URLS:
|
||||||
|
image_base64 = base64_encoded_image[image_url]
|
||||||
|
chat_completion_from_base64 = await client.chat.completions\
|
||||||
|
.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=whats_in_this_image_msg(image_base64),
|
||||||
|
max_completion_tokens=24,
|
||||||
|
temperature=0.0)
|
||||||
|
result = chat_completion_from_base64
|
||||||
|
assert result
|
||||||
|
choice = result.choices[0]
|
||||||
|
assert choice.finish_reason == "length"
|
||||||
|
|
||||||
|
message = choice.message
|
||||||
|
message = result.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import bisect
|
import bisect
|
||||||
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Optional, cast
|
from typing import TYPE_CHECKING, Optional, cast
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -21,7 +22,8 @@ from vllm.forward_context import set_forward_context
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||||
|
PlaceholderRange)
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
@ -37,8 +39,7 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
|||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
from .utils import sanity_check_mm_encoder_outputs
|
||||||
scatter_mm_placeholders)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -198,7 +199,7 @@ class TPUModelRunner:
|
|||||||
device="cpu")
|
device="cpu")
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
self.block_table_cpu = torch.zeros(
|
self.block_table_cpu = torch.zeros(
|
||||||
(self.max_num_tokens, self.max_num_blocks_per_req),
|
(self.max_num_reqs, self.max_num_blocks_per_req),
|
||||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
|
||||||
@ -220,6 +221,37 @@ class TPUModelRunner:
|
|||||||
self.num_reqs_paddings = _get_req_paddings(
|
self.num_reqs_paddings = _get_req_paddings(
|
||||||
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
|
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
|
||||||
|
|
||||||
|
# Get maximum number of mm items per modality (batch size).
|
||||||
|
self.max_num_mm_items_by_modality = dict()
|
||||||
|
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
||||||
|
and self.encoder_cache_size > 0):
|
||||||
|
max_tokens_by_modality_dict = (
|
||||||
|
MULTIMODAL_REGISTRY.
|
||||||
|
get_max_tokens_per_item_by_nonzero_modality(self.model_config))
|
||||||
|
for modality, max_tokens in max_tokens_by_modality_dict.items():
|
||||||
|
# Check how many items of this modality can be supported by
|
||||||
|
# the encoder budget.
|
||||||
|
encoder_budget = min(self.max_num_encoder_input_tokens,
|
||||||
|
self.encoder_cache_size)
|
||||||
|
|
||||||
|
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
|
||||||
|
max_tokens)
|
||||||
|
|
||||||
|
# Check how many items of this modality can be supported by
|
||||||
|
# the decoder budget.
|
||||||
|
max_mm_items_per_req = self.mm_registry.\
|
||||||
|
get_mm_limits_per_prompt(self.model_config)[modality]
|
||||||
|
|
||||||
|
# NOTE: We do not consider max_num_batched_tokens on purpose
|
||||||
|
# because the multimodal embeddings can be generated in advance
|
||||||
|
# and chunked prefilled.
|
||||||
|
max_num_mm_items_decoder_budget = self.max_num_reqs * \
|
||||||
|
max_mm_items_per_req
|
||||||
|
|
||||||
|
max_num_mm_items = min(max_num_mm_items_encoder_budget,
|
||||||
|
max_num_mm_items_decoder_budget)
|
||||||
|
self.max_num_mm_items_by_modality[modality] = max_num_mm_items
|
||||||
|
|
||||||
def _update_num_xla_graphs(self, case_str):
|
def _update_num_xla_graphs(self, case_str):
|
||||||
check_comp = self.check_recompilation and not self.enforce_eager
|
check_comp = self.check_recompilation and not self.enforce_eager
|
||||||
if not check_comp:
|
if not check_comp:
|
||||||
@ -606,29 +638,36 @@ class TPUModelRunner:
|
|||||||
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||||
# depending on the input multimodal items.
|
# depending on the input multimodal items.
|
||||||
|
xm.mark_step()
|
||||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||||
**batched_mm_inputs)
|
**batched_mm_inputs)
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
sanity_check_mm_encoder_outputs(
|
||||||
curr_group_outputs,
|
curr_group_outputs,
|
||||||
expected_num_items=len(grouped_mm_inputs),
|
expected_num_items=len(grouped_mm_inputs),
|
||||||
)
|
)
|
||||||
|
|
||||||
for output in curr_group_outputs:
|
if isinstance(curr_group_outputs, torch.Tensor):
|
||||||
encoder_outputs.append(output)
|
encoder_outputs.append(curr_group_outputs)
|
||||||
|
else:
|
||||||
|
assert isinstance(curr_group_outputs, (list, tuple))
|
||||||
|
for output in curr_group_outputs:
|
||||||
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
# Cache the encoder outputs.
|
# Cache the encoder outputs.
|
||||||
|
# NOTE (NickLucche) here we diverge from logic in other runners, as we
|
||||||
|
# assume to only have whole mm items to process. Hence we avoid the
|
||||||
|
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
|
||||||
for (req_id, input_id, pos_info), output in zip(
|
for (req_id, input_id, pos_info), output in zip(
|
||||||
req_ids_pos,
|
req_ids_pos,
|
||||||
encoder_outputs,
|
encoder_outputs,
|
||||||
):
|
):
|
||||||
if req_id not in self.encoder_cache:
|
if req_id not in self.encoder_cache:
|
||||||
self.encoder_cache[req_id] = {}
|
self.encoder_cache[req_id] = {}
|
||||||
|
assert pos_info.is_embed is None, "Expected all positions to be"\
|
||||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
" contiguous and embeddings."
|
||||||
output,
|
self.encoder_cache[req_id][input_id] = output
|
||||||
is_embed=pos_info.is_embed,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _gather_mm_embeddings(
|
def _gather_mm_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -641,6 +680,10 @@ class TPUModelRunner:
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
num_computed_tokens = req_state.num_computed_tokens
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
|
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
|
||||||
|
# NOTE (NickLucche) here we diverge from logic in other runners, as
|
||||||
|
# we assume to only have whole mm items to process. Hence we avoid
|
||||||
|
# the intrinsic dynamism that `gather_mm_placeholders` introduces.
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info.offset
|
||||||
num_encoder_tokens = pos_info.length
|
num_encoder_tokens = pos_info.length
|
||||||
@ -657,25 +700,33 @@ class TPUModelRunner:
|
|||||||
# in the decoder's KV cache.
|
# in the decoder's KV cache.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
||||||
end_idx = min(
|
|
||||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
||||||
num_encoder_tokens)
|
|
||||||
assert start_idx < end_idx
|
|
||||||
assert req_id in self.encoder_cache
|
assert req_id in self.encoder_cache
|
||||||
assert i in self.encoder_cache[req_id]
|
assert i in self.encoder_cache[req_id]
|
||||||
|
assert pos_info.is_embed is None, "Expected all positions to"\
|
||||||
|
" be contiguous and embeddings."
|
||||||
encoder_output = self.encoder_cache[req_id][i]
|
encoder_output = self.encoder_cache[req_id][i]
|
||||||
|
mm_embeds.append(encoder_output)
|
||||||
if (is_embed := pos_info.is_embed) is not None:
|
|
||||||
is_embed = is_embed[start_idx:end_idx]
|
|
||||||
|
|
||||||
mm_embeds_item = gather_mm_placeholders(
|
|
||||||
encoder_output[start_idx:end_idx],
|
|
||||||
is_embed=is_embed,
|
|
||||||
)
|
|
||||||
mm_embeds.append(mm_embeds_item)
|
|
||||||
return mm_embeds
|
return mm_embeds
|
||||||
|
|
||||||
|
def _get_model_inputs(self, input_ids: torch.Tensor,
|
||||||
|
mm_embeds: list[torch.Tensor]):
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
|
# as input to the multimodal model, even when the input is text.
|
||||||
|
if mm_embeds:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
|
input_ids, mm_embeds)
|
||||||
|
else:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
|
return None, inputs_embeds
|
||||||
|
else:
|
||||||
|
# For text-only models, we use token ids as input.
|
||||||
|
# While it is possible to use embeddings as input just like the
|
||||||
|
# multimodal models, it is not desirable for performance since
|
||||||
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
|
return input_ids, None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -694,27 +745,13 @@ class TPUModelRunner:
|
|||||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||||
else:
|
else:
|
||||||
mm_embeds = []
|
mm_embeds = []
|
||||||
|
xm.mark_step()
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
|
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
|
||||||
scheduler_output)
|
scheduler_output)
|
||||||
if self.is_multimodal_model:
|
input_ids, inputs_embeds = self._get_model_inputs(
|
||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
self.input_ids, mm_embeds)
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
xm.mark_step()
|
||||||
# as input to the multimodal model, even when the input is text.
|
|
||||||
if mm_embeds:
|
|
||||||
inputs_embeds = self.model.get_input_embeddings(
|
|
||||||
self.input_ids, mm_embeds)
|
|
||||||
else:
|
|
||||||
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
|
|
||||||
input_ids = None
|
|
||||||
else:
|
|
||||||
# For text-only models, we use token ids as input.
|
|
||||||
# While it is possible to use embeddings as input just like the
|
|
||||||
# multimodal models, it is not desirable for performance since
|
|
||||||
# then the embedding layer is not included in the CUDA graph.
|
|
||||||
input_ids = self.input_ids
|
|
||||||
inputs_embeds = None
|
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
# Run the decoder
|
# Run the decoder
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
@ -890,9 +927,70 @@ class TPUModelRunner:
|
|||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
self._hidden_states_dtype = out.dtype
|
self._hidden_states_dtype = out.dtype
|
||||||
|
|
||||||
|
def _precompile_mm_encoder(self) -> None:
|
||||||
|
# Pre-compile MM encoder for all supported data modalities.
|
||||||
|
hf_config = self.vllm_config.model_config.hf_config
|
||||||
|
for mode, max_items_by_mode in \
|
||||||
|
self.max_num_mm_items_by_modality.items():
|
||||||
|
logger.info(
|
||||||
|
"Compiling Multimodal %s Encoder with different input"
|
||||||
|
" shapes.", mode)
|
||||||
|
start = time.perf_counter()
|
||||||
|
# No padding for MM encoder just yet.
|
||||||
|
for num_items in range(1, max_items_by_mode + 1):
|
||||||
|
logger.info(" -- mode: %s items: %d", mode, num_items)
|
||||||
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
||||||
|
mode, num_items)
|
||||||
|
# Run multimodal encoder.
|
||||||
|
xm.mark_step()
|
||||||
|
mm_embeds = self.model.\
|
||||||
|
get_multimodal_embeddings(**batched_dummy_mm_inputs)
|
||||||
|
xm.mark_step()
|
||||||
|
num_patches = mm_embeds[0].shape[0]
|
||||||
|
items_size = num_patches * num_items
|
||||||
|
|
||||||
|
# NOTE (NickLucche) pre-compile `get_input_embeddings` when mm
|
||||||
|
# embeddings are present. We assume `--disable-mm-chunked`,
|
||||||
|
# hence only whole items can be scheduled. This implies we just
|
||||||
|
# need to compile when `num_items` fit the (padded) `input_ids`
|
||||||
|
for num_tokens in self.num_tokens_paddings:
|
||||||
|
if num_tokens >= items_size:
|
||||||
|
# XLA Workaround: if torch.zeros(..device) is used, XLA
|
||||||
|
# compiles a scalar+expansion op, which won't match
|
||||||
|
# the graph generated at runtime. CPU->TPU must be used
|
||||||
|
placeholders_ids = torch.zeros(num_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
# Align placeholders and actual num mm_embeddings.
|
||||||
|
placeholders_ids[:items_size] = \
|
||||||
|
hf_config.image_token_index
|
||||||
|
|
||||||
|
placeholders_ids = placeholders_ids.to(self.device)
|
||||||
|
# Assign outputs or the graph will be cut short.
|
||||||
|
a, b = self._get_model_inputs(placeholders_ids,
|
||||||
|
[mm_embeds])
|
||||||
|
assert a is None
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
# Pre-compile `get_input_embeddings` when mm_embeddings are not
|
||||||
|
# present. Chunk is only made of text, no mm_placeholders.
|
||||||
|
for num_tokens in self.num_tokens_paddings:
|
||||||
|
placeholders_ids = torch.zeros(num_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
placeholders_ids = placeholders_ids.to(self.device)
|
||||||
|
a, b = self._get_model_inputs(placeholders_ids, [])
|
||||||
|
assert a is None
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
xm.wait_device_ops()
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"Multimodal %s Encoder compilation finished in in %.2f "
|
||||||
|
"[secs].", mode, end - start)
|
||||||
|
|
||||||
def _precompile_backbone(self) -> None:
|
def _precompile_backbone(self) -> None:
|
||||||
logger.info("Compiling the model with different input shapes.")
|
logger.info("Compiling the model with different input shapes.")
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for num_tokens in self.num_tokens_paddings:
|
for num_tokens in self.num_tokens_paddings:
|
||||||
logger.info(" -- num_tokens: %d", num_tokens)
|
logger.info(" -- num_tokens: %d", num_tokens)
|
||||||
@ -962,11 +1060,70 @@ class TPUModelRunner:
|
|||||||
"""
|
"""
|
||||||
Precompile all the subgraphs with possible input shapes.
|
Precompile all the subgraphs with possible input shapes.
|
||||||
"""
|
"""
|
||||||
# TODO: precompile encoder
|
self._precompile_mm_encoder()
|
||||||
self._precompile_backbone()
|
self._precompile_backbone()
|
||||||
self._precompile_select_hidden_states()
|
self._precompile_select_hidden_states()
|
||||||
self._precompile_sample_from_hidden()
|
self._precompile_sample_from_hidden()
|
||||||
|
|
||||||
|
def profile_run(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
# Profile with multimodal encoder & encoder cache.
|
||||||
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
|
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
||||||
|
and self.encoder_cache_size > 0):
|
||||||
|
|
||||||
|
# NOTE: Currently model is profiled with a single non-text
|
||||||
|
# modality with the max possible input tokens even when
|
||||||
|
# it supports multiple.
|
||||||
|
dummy_data_modality, max_num_mm_items = max(
|
||||||
|
self.max_num_mm_items_by_modality.items(), key=lambda t: t[1])
|
||||||
|
|
||||||
|
encoder_budget = min(self.max_num_encoder_input_tokens,
|
||||||
|
self.encoder_cache_size)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Encoder cache will be initialized with a budget of %d tokens,"
|
||||||
|
" and profiled with %s %s items of the maximum feature size.",
|
||||||
|
encoder_budget, max_num_mm_items, dummy_data_modality)
|
||||||
|
|
||||||
|
# Create dummy batch of multimodal inputs.
|
||||||
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
||||||
|
dummy_data_modality, max_num_mm_items)
|
||||||
|
|
||||||
|
# Run multimodal encoder.
|
||||||
|
# Isolate encoder graph from post-processing to minimize
|
||||||
|
# impact of recompilation until it's fixed.
|
||||||
|
start = time.perf_counter()
|
||||||
|
xm.mark_step()
|
||||||
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||||
|
**batched_dummy_mm_inputs)
|
||||||
|
xm.mark_step()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"Multimodal Encoder profiling finished in in %.2f [secs].",
|
||||||
|
end - start)
|
||||||
|
|
||||||
|
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
||||||
|
"Expected dimension 0 of encoder outputs to match the number "
|
||||||
|
f"of multimodal data items: {max_num_mm_items}, got "
|
||||||
|
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
||||||
|
"due to the 'get_multimodal_embeddings' method of the model "
|
||||||
|
"not implemented correctly.")
|
||||||
|
|
||||||
|
# Cache the dummy encoder outputs.
|
||||||
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
|
# Trigger compilation for general shape.
|
||||||
|
self._dummy_run(num_tokens)
|
||||||
|
|
||||||
|
xm.mark_step()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
self.encoder_cache.clear()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
@ -1045,6 +1202,36 @@ class TPUModelRunner:
|
|||||||
def get_input_embeddings(self, *args, **kwargs):
|
def get_input_embeddings(self, *args, **kwargs):
|
||||||
return self.model.get_input_embeddings(*args, **kwargs)
|
return self.model.get_input_embeddings(*args, **kwargs)
|
||||||
|
|
||||||
|
def _get_mm_dummy_batch(self, modality: str,
|
||||||
|
batch_size: int) -> BatchedTensorInputs:
|
||||||
|
# Dummy data for pre-compiling multimodal models.
|
||||||
|
dummy_request_data = self.mm_registry.get_decoder_dummy_data(
|
||||||
|
model_config=self.model_config,
|
||||||
|
seq_len=self.max_num_tokens,
|
||||||
|
)
|
||||||
|
dummy_mm_data = dummy_request_data.multi_modal_data
|
||||||
|
|
||||||
|
# Dummy data definition in V0 may contain multiple multimodal items
|
||||||
|
# (e.g, multiple images) for a single request, therefore here we
|
||||||
|
# always replicate first item by max_num_mm_items times since in V1
|
||||||
|
# they are scheduled to be processed separately.
|
||||||
|
assert isinstance(dummy_mm_data, MultiModalKwargs), (
|
||||||
|
"Expected dummy multimodal data to be of type "
|
||||||
|
f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. "
|
||||||
|
"This is most likely due to the model not having a merged "
|
||||||
|
"processor.")
|
||||||
|
|
||||||
|
# When models have a merged processor, their dummy data is
|
||||||
|
# already batched `MultiModalKwargs`, therefore we take the first
|
||||||
|
# `MultiModalKwargsItem` from the desired modality to profile on.
|
||||||
|
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
|
||||||
|
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||||
|
|
||||||
|
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
||||||
|
batch_size)
|
||||||
|
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
|
||||||
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
||||||
logger.info("Preparing request paddings:")
|
logger.info("Preparing request paddings:")
|
||||||
@ -1088,7 +1275,6 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
|
|||||||
if num >= max_token_size:
|
if num >= max_token_size:
|
||||||
break
|
break
|
||||||
num *= 2
|
num *= 2
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.info("Using incremental token paddings:")
|
logger.info("Using incremental token paddings:")
|
||||||
while num <= padding_gap:
|
while num <= padding_gap:
|
||||||
|
|||||||
@ -157,7 +157,7 @@ class TPUWorker:
|
|||||||
runner_kv_caches)
|
runner_kv_caches)
|
||||||
|
|
||||||
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
# `max_num_tokens >= max_num_batched_tokens` due to padding.
|
||||||
self.model_runner._dummy_run(self.model_runner.max_num_tokens)
|
self.model_runner.profile_run(self.model_runner.max_num_tokens)
|
||||||
|
|
||||||
# Synchronize before measuring the memory usage.
|
# Synchronize before measuring the memory usage.
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user