fix pre-commit

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-24 03:55:48 +00:00
commit ac8afb6ae8
18 changed files with 123 additions and 322 deletions

View File

@ -9,16 +9,6 @@
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping);
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping);
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,

View File

@ -119,94 +119,6 @@ __global__ void copy_blocks_mla_kernel(
} // namespace vllm
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());
// Create data structures for the kernel.
// Create an array of pointers to the key and value caches.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
// block_mapping is a 2D tensor with shape (num_pairs, 2).
int num_pairs = block_mapping.size(0);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor =
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
// Launch the kernel.
const int numel_per_block = key_caches[0][0].numel();
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), numel_per_block);
}));
}
// copy blocks kernel for MLA (assumes a joint KV-cache)
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping) {
int num_layers = kv_caches.size();
if (num_layers == 0) {
return;
}
torch::Device cache_device = kv_caches[0].device();
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");
std::vector<int64_t> cache_ptrs(num_layers);
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
}
torch::Tensor cache_ptrs_tensor =
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
.to(cache_device);
int num_pairs = block_mapping.size(0);
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
int mem_footprint_per_block = kv_caches[0].stride(0);
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, mem_footprint_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
}));
}
namespace vllm {
// Used to copy/convert one element

View File

@ -685,16 +685,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
cache_ops.def(
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"

View File

@ -0,0 +1,11 @@
model_name: "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
accuracy_threshold: 0.85
num_questions: 1319
num_fewshot: 5
server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--enable-expert-parallel
--async-scheduling
env:
VLLM_USE_FLASHINFER_MOE_FP8: "1"

View File

@ -4,3 +4,4 @@ Qwen1.5-MoE-W4A16-CT.yaml
DeepSeek-V2-Lite-Instruct-FP8.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Qwen3-Next-FP8-EP2.yaml

View File

@ -71,6 +71,7 @@ def test_gsm8k_correctness(config_filename):
print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
print(f"Server args: {' '.join(server_args)}")
print(f"Environment variables: {env_dict}")
# Launch server and run evaluation
with RemoteOpenAIServer(

View File

@ -40,93 +40,6 @@ KV_CACHE_DTYPE = ["auto", "fp8"]
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks(
kv_cache_factory,
num_mappings: int,
num_layers: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
kv_cache_dtype: str,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
current_platform.seed_everything(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
src_blocks = random.sample(range(num_blocks), num_mappings)
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
block_mapping: list[tuple[int, int]] = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(
num_blocks,
block_size,
num_layers,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
opcheck(
torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]),
)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation.
for src, dst in block_mapping:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
torch.testing.assert_close(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
torch.testing.assert_close(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -763,73 +676,6 @@ def test_concat_and_cache_ds_mla(
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
block_size: int,
num_blocks: int,
num_layers: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
entry_size = kv_lora_rank + qk_rope_head_dim
kv_caches = []
for _ in range(num_layers):
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
kv_caches.append(kv_cache)
ref_caches = [kv_cache.clone() for kv_cache in kv_caches]
num_mappings = min(2, num_blocks // 2)
src_blocks = random.sample(range(num_blocks), num_mappings)
remaining = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining, 2 * num_mappings)
block_mapping = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
for src, dst in block_mapping:
for ref_cache in ref_caches:
ref_cache[dst].copy_(ref_cache[src])
opcheck(
torch.ops._C_cache_ops.copy_blocks_mla,
(kv_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.copy_blocks_mla(kv_caches, block_mapping_tensor)
for kv_cache, ref_cache in zip(kv_caches, ref_caches):
torch.testing.assert_close(kv_cache, ref_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)

View File

@ -106,6 +106,7 @@ class RemoteOpenAIServer:
env.update(env_dict)
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
print(f"Environment variables: {env}")
self.proc: subprocess.Popen = subprocess.Popen(
serve_cmd,
env=env,

View File

@ -1356,6 +1356,69 @@ def test_kv_cache_events(blocks_to_cache: int):
assert len(manager.block_pool.cached_block_hash_to_block) == 0
def test_null_parent_block_hash():
block_size = 1
num_cached_blocks = 2
num_full_blocks = 4
pool = BlockPool(
num_gpu_blocks=8,
enable_caching=True,
hash_block_size=block_size,
enable_kv_cache_events=True,
)
req = make_request(
"req_null_parent",
prompt_token_ids=[10, 11, 12, 13],
block_size=block_size,
hash_fn=sha256,
)
assert len(req.block_hashes) == num_full_blocks
# Physical parent is `null_block` (no hash), while the logical parent hash
# still exists in `request.block_hashes[num_cached_blocks - 1]`.
assert pool.null_block.block_hash is None
new_blocks = pool.get_new_blocks(num_full_blocks - 1)
blocks = [
new_blocks[: num_cached_blocks - 1],
pool.null_block, # physical parent
*new_blocks[num_cached_blocks - 1 :],
]
pool.cache_full_blocks(
request=req,
blocks=blocks,
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=block_size,
kv_cache_group_id=0,
)
events = pool.take_events()
assert len(events) == 1
event = events[0]
assert isinstance(event, BlockStored)
expected_parent = kv_cache_utils.maybe_convert_block_hash(
req.block_hashes[num_cached_blocks - 1]
)
assert event.parent_block_hash == expected_parent
assert event.parent_block_hash is not None
expected_new_hashes = [
kv_cache_utils.maybe_convert_block_hash(h)
for h in req.block_hashes[num_cached_blocks:num_full_blocks]
]
assert event.block_hashes == expected_new_hashes
# Ensure we didn't accidentally assign a hash to the null block.
assert pool.null_block.block_hash is None
# Sanity check: newly cached physical blocks should have hashes assigned.
assert blocks[num_cached_blocks].block_hash is not None
assert blocks[num_full_blocks - 1].block_hash is not None
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
def test_kv_cache_events_with_lora(blocks_to_cache: int):
"""Test BlockStored events contain correct lora_id when using LoRA requests."""

View File

@ -2328,18 +2328,6 @@ def concat_and_cache_mla(
)
def copy_blocks(
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def copy_blocks_mla(kv_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
def swap_blocks(
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
) -> None:

View File

@ -383,18 +383,6 @@ class ipex_ops:
)
return None
@staticmethod
def copy_blocks(
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor,
) -> None:
torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def swap_blocks(
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor

View File

@ -408,7 +408,13 @@ class MooncakeConnectorWorker:
self.engine = TransferEngine()
self.hostname = get_ip()
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "")
protocol = self.vllm_config.kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr]
"mooncake_protocol", "rdma"
)
logger.info(
"The Mooncake Transfer Engine is using %s as its protocol.", protocol
)
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", protocol, "")
if ret_value != 0:
raise RuntimeError("Mooncake Transfer Engine initialization failed.")

View File

@ -111,7 +111,7 @@ class AudioFlamingo3EmbeddingInputs(TensorSchema):
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs"),
TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
]

View File

@ -878,11 +878,14 @@ class Indexer(nn.Module):
)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
# Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
# so we need to reshape back to token-flattened shapes
q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
k_pe = k_pe.reshape(-1, 1, self.rope_dim)
q = torch.cat([q_pe, q_nope], dim=-1)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)

View File

@ -139,7 +139,7 @@ class MiniCPMVImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"]
image_embeds: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "ns", "hs"),
TensorShape("bn", "ns", "hs", dynamic_dims={"ns"}),
]

View File

@ -101,7 +101,7 @@ class Qwen2AudioEmbeddingInputs(TensorSchema):
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs"),
TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
]

View File

@ -270,10 +270,8 @@ class BlockPool:
if num_cached_blocks == 0:
parent_block_hash: ExternalBlockHash | None = None
else:
parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None
parent_block_hash = maybe_convert_block_hash(
get_block_hash(parent_block.block_hash)
block_hashes[num_cached_blocks - 1]
)
self.kv_event_queue.append(

View File

@ -62,6 +62,7 @@ from vllm.model_executor.layers.rotary_embedding import (
)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsMRoPE,
SupportsMultiModal,
SupportsXDRoPE,
@ -2098,35 +2099,35 @@ class GPUModelRunner(
]
return logits_indices_padded
def _batch_mm_kwargs_from_scheduler(
def _batch_mm_inputs_from_scheduler(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[
list[str],
list[MultiModalKwargsItem],
list[tuple[str, PlaceholderRange]],
list[str],
]:
"""Batch multimodal kwargs from scheduled encoder inputs.
"""Batch multimodal inputs from scheduled encoder inputs.
Args:
scheduler_output: The scheduler output containing scheduled encoder
inputs.
Returns:
A tuple of (mm_kwargs, mm_hashes_pos, req_ids) where:
- mm_kwargs: List of multimodal kwargs items to be batched
- mm_hashes_pos: List of (mm_hash, position_info) tuples
- req_ids: List of request IDs for each encoder input
A tuple of (mm_hashes, mm_kwargs, mm_lora_refs) where:
- mm_hashes: List of multimodal hashes for each item
- mm_kwargs: List of multimodal kwargs for each item
- mm_lora_refs: List of (req_id, placeholder_range) for each item
"""
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return [], [], []
# Batch the multi-modal inputs.
mm_hashes = list[str]()
mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
# list of request IDs for each encoder input
req_ids = list[str]()
# Multimodal LoRA reference info to map each multimodal item
# back to its request & position
mm_lora_refs = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
@ -2134,19 +2135,18 @@ class GPUModelRunner(
mm_feature = req_state.mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
req_ids.append(req_id)
return mm_kwargs, mm_hashes_pos, req_ids
mm_hashes.append(mm_feature.identifier)
mm_kwargs.append(mm_feature.data)
mm_lora_refs.append((req_id, mm_feature.mm_position))
return mm_hashes, mm_kwargs, mm_lora_refs
def _execute_mm_encoder(
self, scheduler_output: "SchedulerOutput"
) -> list[torch.Tensor]:
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos, encoder_req_ids = (
self._batch_mm_kwargs_from_scheduler(scheduler_output)
mm_hashes, mm_kwargs, mm_lora_refs = self._batch_mm_inputs_from_scheduler(
scheduler_output
)
if not mm_kwargs:
@ -2168,7 +2168,7 @@ class GPUModelRunner(
token_lora_mapping = []
lora_requests = set()
for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
for req_id, pos_info in mm_lora_refs:
req_idx = self.input_batch.req_id_to_index[req_id]
lora_id = int(self.input_batch.request_lora_mapping[req_idx])
@ -2196,7 +2196,7 @@ class GPUModelRunner(
if hasattr(self.model, "get_num_mm_connector_tokens"):
num_post_op_tokens = []
for _, pos_info in mm_hashes_pos:
for _, pos_info in mm_lora_refs:
mm_token_count = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.length
)
@ -2232,7 +2232,7 @@ class GPUModelRunner(
device=self.device,
pin_memory=self.pin_memory,
):
curr_group_outputs: list[torch.Tensor] = []
curr_group_outputs: MultiModalEmbeddings
# EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
@ -2248,6 +2248,7 @@ class GPUModelRunner(
and modality == "video"
and num_items > 1
):
curr_group_outputs_lst = list[torch.Tensor]()
for video_mm_kwargs_item in filter(
lambda item: item.modality == "video", mm_kwargs
):
@ -2263,7 +2264,9 @@ class GPUModelRunner(
**micro_batch_mm_inputs
)
curr_group_outputs.extend(micro_batch_outputs)
curr_group_outputs_lst.extend(micro_batch_outputs)
curr_group_outputs = curr_group_outputs_lst
else:
# Run the encoder.
# `curr_group_outputs` is either of the following:
@ -2272,7 +2275,7 @@ class GPUModelRunner(
# 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items.
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment]
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
@ -2281,7 +2284,7 @@ class GPUModelRunner(
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
for mm_hash, output in zip(mm_hashes, encoder_outputs):
self.encoder_cache[mm_hash] = output
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)