From bc0a5a0c089844b17cb93f3294348f411e523586 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Wed, 24 Dec 2025 05:21:50 +0400 Subject: [PATCH 1/7] [CI] Add Qwen3-Next-FP8 to Blackwell model tests (#31049) Signed-off-by: Vadim Gimpelson --- tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml | 11 +++++++++++ tests/evals/gsm8k/configs/models-blackwell.txt | 1 + tests/evals/gsm8k/test_gsm8k_correctness.py | 1 + tests/utils.py | 1 + 4 files changed, 14 insertions(+) create mode 100644 tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml diff --git a/tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml b/tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml new file mode 100644 index 0000000000000..9fae32734d753 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml @@ -0,0 +1,11 @@ +model_name: "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" +accuracy_threshold: 0.85 +num_questions: 1319 +num_fewshot: 5 +server_args: >- + --max-model-len 4096 + --tensor-parallel-size 2 + --enable-expert-parallel + --async-scheduling +env: + VLLM_USE_FLASHINFER_MOE_FP8: "1" diff --git a/tests/evals/gsm8k/configs/models-blackwell.txt b/tests/evals/gsm8k/configs/models-blackwell.txt index 39978aa6ffbe9..c27031d25fb8c 100644 --- a/tests/evals/gsm8k/configs/models-blackwell.txt +++ b/tests/evals/gsm8k/configs/models-blackwell.txt @@ -4,3 +4,4 @@ Qwen1.5-MoE-W4A16-CT.yaml DeepSeek-V2-Lite-Instruct-FP8.yaml Qwen3-30B-A3B-NVFP4.yaml Qwen3-Next-80B-A3B-NVFP4-EP2.yaml +Qwen3-Next-FP8-EP2.yaml diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index ea6715f5cb532..dd0d3ae0cca47 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -71,6 +71,7 @@ def test_gsm8k_correctness(config_filename): print(f"Number of questions: {eval_config['num_questions']}") print(f"Number of few-shot examples: {eval_config['num_fewshot']}") print(f"Server args: {' '.join(server_args)}") + print(f"Environment variables: {env_dict}") # Launch server and run evaluation with RemoteOpenAIServer( diff --git a/tests/utils.py b/tests/utils.py index d8102331b3612..1b338e93182a5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -106,6 +106,7 @@ class RemoteOpenAIServer: env.update(env_dict) serve_cmd = ["vllm", "serve", model, *vllm_serve_args] print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}") + print(f"Environment variables: {env}") self.proc: subprocess.Popen = subprocess.Popen( serve_cmd, env=env, From ca6a95ba259aaa4c89eccdd254fa7922e31eddc2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 24 Dec 2025 10:15:16 +0800 Subject: [PATCH 2/7] [Chore] Simplify logic of `_execute_mm_encoder` (#31222) Signed-off-by: DarkLight1337 --- vllm/v1/worker/gpu_model_runner.py | 48 ++++++++++++++---------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00c585aaaacbb..16fc9fd7cb4d8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -61,6 +61,7 @@ from vllm.model_executor.layers.rotary_embedding import ( ) from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal, SupportsXDRoPE, @@ -78,11 +79,7 @@ from vllm.model_executor.models.interfaces_base import ( is_text_generation_model, ) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - BatchedTensorInputs, - MultiModalKwargsItem, - PlaceholderRange, -) +from vllm.multimodal.inputs import BatchedTensorInputs, MultiModalKwargsItem from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType @@ -2097,28 +2094,27 @@ class GPUModelRunner( ] return logits_indices_padded - def _batch_mm_kwargs_from_scheduler( + def _batch_mm_inputs_from_scheduler( self, scheduler_output: "SchedulerOutput", - ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: - """Batch multimodal kwargs from scheduled encoder inputs. + ) -> tuple[list[str], list[MultiModalKwargsItem]]: + """Batch multimodal inputs from scheduled encoder inputs. Args: scheduler_output: The scheduler output containing scheduled encoder inputs. Returns: - A tuple of (mm_kwargs, req_ids_pos) where: - - mm_kwargs: List of multimodal kwargs items to be batched - - mm_hashes_pos: List of (mm_hash, position_info) tuples + A tuple of (mm_hashes, mm_kwargs) where: + - mm_hashes: List of multimodal hashes for each item + - mm_kwargs: List of multimodal kwargs for each item """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return [], [] - # Batch the multi-modal inputs. + + mm_hashes = list[str]() mm_kwargs = list[MultiModalKwargsItem]() - # list of tuple (mm_hash, position_info) - mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] @@ -2126,19 +2122,16 @@ class GPUModelRunner( mm_feature = req_state.mm_features[mm_input_id] if mm_feature.data is None: continue - mm_hash = mm_feature.identifier - mm_kwargs.append(mm_feature.data) - mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) - return mm_kwargs, mm_hashes_pos + mm_hashes.append(mm_feature.identifier) + mm_kwargs.append(mm_feature.data) + + return mm_hashes, mm_kwargs def _execute_mm_encoder( self, scheduler_output: "SchedulerOutput" ) -> list[torch.Tensor]: - # Batch the multi-modal inputs using the helper method. - mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output - ) + mm_hashes, mm_kwargs = self._batch_mm_inputs_from_scheduler(scheduler_output) if not mm_kwargs: return [] @@ -2157,7 +2150,7 @@ class GPUModelRunner( device=self.device, pin_memory=self.pin_memory, ): - curr_group_outputs: list[torch.Tensor] = [] + curr_group_outputs: MultiModalEmbeddings # EVS-related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when @@ -2173,6 +2166,7 @@ class GPUModelRunner( and modality == "video" and num_items > 1 ): + curr_group_outputs_lst = list[torch.Tensor]() for video_mm_kwargs_item in filter( lambda item: item.modality == "video", mm_kwargs ): @@ -2188,7 +2182,9 @@ class GPUModelRunner( **micro_batch_mm_inputs ) - curr_group_outputs.extend(micro_batch_outputs) + curr_group_outputs_lst.extend(micro_batch_outputs) + + curr_group_outputs = curr_group_outputs_lst else: # Run the encoder. # `curr_group_outputs` is either of the following: @@ -2197,7 +2193,7 @@ class GPUModelRunner( # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment] + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -2206,7 +2202,7 @@ class GPUModelRunner( encoder_outputs.extend(curr_group_outputs) # Cache the encoder outputs by mm_hash - for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + for mm_hash, output in zip(mm_hashes, encoder_outputs): self.encoder_cache[mm_hash] = output logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) From dd424571c8edaf7c68fe4ff400da8ef9f26a1e48 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 24 Dec 2025 10:15:47 +0800 Subject: [PATCH 3/7] [Bugfix] Enable `dynamic_dims` for different embeds shape (#31223) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/audioflamingo3.py | 2 +- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/qwen2_audio.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/audioflamingo3.py b/vllm/model_executor/models/audioflamingo3.py index 0ca5f2c4e0a75..3609cc26a4c6b 100644 --- a/vllm/model_executor/models/audioflamingo3.py +++ b/vllm/model_executor/models/audioflamingo3.py @@ -111,7 +111,7 @@ class AudioFlamingo3EmbeddingInputs(TensorSchema): audio_embeds: Annotated[ list[torch.Tensor], - TensorShape("bn", "naf", "hs"), + TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}), ] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index c45bdf95e7487..930ff737bcdac 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -139,7 +139,7 @@ class MiniCPMVImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] image_embeds: Annotated[ torch.Tensor | list[torch.Tensor], - TensorShape("bn", "ns", "hs"), + TensorShape("bn", "ns", "hs", dynamic_dims={"ns"}), ] diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index f84ddfa84f6ab..c97e6873e0d17 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -101,7 +101,7 @@ class Qwen2AudioEmbeddingInputs(TensorSchema): audio_embeds: Annotated[ list[torch.Tensor], - TensorShape("bn", "naf", "hs"), + TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}), ] From 4ed11105d7b81fe9005b119c86ef42f198ced2cf Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Wed, 24 Dec 2025 10:22:35 +0800 Subject: [PATCH 4/7] [Misc] Remove unused custom ops `copy_blocks` and `copy_blocks_mla` (#30967) Signed-off-by: rongfu.leng --- csrc/cache.h | 10 -- csrc/cache_kernels.cu | 88 --------------- csrc/torch_bindings.cpp | 10 -- tests/kernels/attention/test_cache.py | 154 -------------------------- vllm/_custom_ops.py | 12 -- vllm/_ipex_ops.py | 12 -- 6 files changed, 286 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index cbe44c09eb624..42ccb589683a9 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -9,16 +9,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping); - -void copy_blocks_mla(std::vector const& kv_caches, - const torch::Tensor& block_mapping); - void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f11c5f24c12ec..cf26ae544deaa 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -119,94 +119,6 @@ __global__ void copy_blocks_mla_kernel( } // namespace vllm -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { - int num_layers = key_caches.size(); - TORCH_CHECK(num_layers == value_caches.size()); - if (num_layers == 0) { - return; - } - torch::Device cache_device = key_caches[0].device(); - TORCH_CHECK(cache_device.is_cuda()); - - // Create data structures for the kernel. - // Create an array of pointers to the key and value caches. - int64_t key_cache_ptrs[num_layers]; - int64_t value_cache_ptrs[num_layers]; - for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = - reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = - reinterpret_cast(value_caches[layer_idx].data_ptr()); - } - - // block_mapping is a 2D tensor with shape (num_pairs, 2). - int num_pairs = block_mapping.size(0); - - // Move the data structures to the GPU. - // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = - torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) - .to(cache_device); - torch::Tensor value_cache_ptrs_tensor = - torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) - .to(cache_device); - - // Launch the kernel. - const int numel_per_block = key_caches[0][0].numel(); - dim3 grid(num_layers, num_pairs); - dim3 block(std::min(1024, numel_per_block)); - const at::cuda::OptionalCUDAGuard device_guard(cache_device); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), numel_per_block); - })); -} - -// copy blocks kernel for MLA (assumes a joint KV-cache) -void copy_blocks_mla(std::vector const& kv_caches, - const torch::Tensor& block_mapping) { - int num_layers = kv_caches.size(); - if (num_layers == 0) { - return; - } - torch::Device cache_device = kv_caches[0].device(); - TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA"); - - std::vector cache_ptrs(num_layers); - for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - cache_ptrs[layer_idx] = - reinterpret_cast(kv_caches[layer_idx].data_ptr()); - } - torch::Tensor cache_ptrs_tensor = - torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64) - .to(cache_device); - - int num_pairs = block_mapping.size(0); - // We use the stride instead of numel in case the cache is padded for memory - // alignment reasons, we assume the blocks data (inclusive of any padding) - // is contiguous in memory - int mem_footprint_per_block = kv_caches[0].stride(0); - dim3 grid(num_layers, num_pairs); - dim3 block(std::min(1024, mem_footprint_per_block)); - const at::cuda::OptionalCUDAGuard device_guard(cache_device); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] { - vllm::copy_blocks_mla_kernel<<>>( - cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), mem_footprint_per_block); - })); -} - namespace vllm { // Used to copy/convert one element diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 461f74ca184fd..6f2c8e915b5cb 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -685,16 +685,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); - // Copy the cache blocks from src to dst. - cache_ops.def( - "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " - "Tensor block_mapping) -> ()"); - cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); - - cache_ops.def( - "copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()"); - cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla); - // Reshape the key and value tensors and cache them. cache_ops.def( "reshape_and_cache(Tensor key, Tensor value," diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index acf46d75d62eb..3f76033254d32 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -40,93 +40,6 @@ KV_CACHE_DTYPE = ["auto", "fp8"] RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"] -@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) -@pytest.mark.parametrize("num_layers", NUM_LAYERS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@torch.inference_mode() -def test_copy_blocks( - kv_cache_factory, - num_mappings: int, - num_layers: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, - kv_cache_dtype: str, - device: str, -) -> None: - if kv_cache_dtype == "fp8" and head_size % 16: - pytest.skip() - current_platform.seed_everything(seed) - torch.set_default_device(device) - torch.cuda.set_device(device) - # Generate random block mappings where each source block is mapped to two - # destination blocks. - assert 2 * num_mappings <= num_blocks - src_blocks = random.sample(range(num_blocks), num_mappings) - remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) - dst_blocks = random.sample(remaining_blocks, 2 * num_mappings) - block_mapping: list[tuple[int, int]] = [] - for i in range(num_mappings): - src = src_blocks[i] - dst1 = dst_blocks[2 * i] - dst2 = dst_blocks[2 * i + 1] - block_mapping.append((src, dst1)) - block_mapping.append((src, dst2)) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory( - num_blocks, - block_size, - num_layers, - num_heads, - head_size, - kv_cache_dtype, - dtype, - seed, - device, - ) - - # Clone the KV caches. - cloned_key_caches = [key_cache.clone() for key_cache in key_caches] - cloned_value_caches = [value_cache.clone() for value_cache in value_caches] - - # Call the copy blocks kernel. - block_mapping_tensor = torch.tensor( - block_mapping, dtype=torch.int64, device=device - ).view(-1, 2) - - opcheck( - torch.ops._C_cache_ops.copy_blocks, - (key_caches, value_caches, block_mapping_tensor), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(head_size == HEAD_SIZES[0]), - ) - ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) - - # Run the reference implementation. - for src, dst in block_mapping: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) - - # Compare the results. - for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): - torch.testing.assert_close(key_cache, cloned_key_cache) - for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): - torch.testing.assert_close(value_cache, cloned_value_cache) - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -763,73 +676,6 @@ def test_concat_and_cache_ds_mla( torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) -@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) -@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) -@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) -@pytest.mark.parametrize("num_layers", NUM_LAYERS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@torch.inference_mode() -def test_copy_blocks_mla( - kv_lora_rank: int, - qk_rope_head_dim: int, - block_size: int, - num_blocks: int, - num_layers: int, - dtype: torch.dtype, - seed: int, - device: str, - kv_cache_dtype: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - torch.cuda.set_device(device) - - entry_size = kv_lora_rank + qk_rope_head_dim - - kv_caches = [] - for _ in range(num_layers): - kv_cache = _create_mla_cache( - num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device - ) - _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) - kv_caches.append(kv_cache) - - ref_caches = [kv_cache.clone() for kv_cache in kv_caches] - - num_mappings = min(2, num_blocks // 2) - src_blocks = random.sample(range(num_blocks), num_mappings) - remaining = list(set(range(num_blocks)) - set(src_blocks)) - dst_blocks = random.sample(remaining, 2 * num_mappings) - block_mapping = [] - for i in range(num_mappings): - src = src_blocks[i] - dst1 = dst_blocks[2 * i] - dst2 = dst_blocks[2 * i + 1] - block_mapping.append((src, dst1)) - block_mapping.append((src, dst2)) - block_mapping_tensor = torch.tensor( - block_mapping, dtype=torch.int64, device=device - ).view(-1, 2) - - for src, dst in block_mapping: - for ref_cache in ref_caches: - ref_cache[dst].copy_(ref_cache[src]) - - opcheck( - torch.ops._C_cache_ops.copy_blocks_mla, - (kv_caches, block_mapping_tensor), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - ) - ops.copy_blocks_mla(kv_caches, block_mapping_tensor) - - for kv_cache, ref_cache in zip(kv_caches, ref_caches): - torch.testing.assert_close(kv_cache, ref_cache) - - @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 78bd8d4e64115..c1519fc177250 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2328,18 +2328,6 @@ def concat_and_cache_mla( ) -def copy_blocks( - key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) - - -def copy_blocks_mla(kv_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: - torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) - - def swap_blocks( src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor ) -> None: diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 95c17cb331f67..239f5376eb462 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -383,18 +383,6 @@ class ipex_ops: ) return None - @staticmethod - def copy_blocks( - key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor, - ) -> None: - torch.xpu.copy_blocks( # type: ignore - key_caches, - value_caches, - block_mapping, - ) - @staticmethod def swap_blocks( src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor From 538e830caab8d0e7c2557adb975dca3c5af296be Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 23 Dec 2025 18:23:43 -0800 Subject: [PATCH 5/7] [KVEvent] User request.block_hash for parent block_hash (#30544) Signed-off-by: Chen Zhang Signed-off-by: Yifan Qiao Co-authored-by: Yifan Qiao --- tests/v1/core/test_prefix_caching.py | 63 ++++++++++++++++++++++++++++ vllm/v1/core/block_pool.py | 4 +- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 0880a17c78d40..977ec71bcbecf 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1356,6 +1356,69 @@ def test_kv_cache_events(blocks_to_cache: int): assert len(manager.block_pool.cached_block_hash_to_block) == 0 +def test_null_parent_block_hash(): + block_size = 1 + num_cached_blocks = 2 + num_full_blocks = 4 + + pool = BlockPool( + num_gpu_blocks=8, + enable_caching=True, + hash_block_size=block_size, + enable_kv_cache_events=True, + ) + + req = make_request( + "req_null_parent", + prompt_token_ids=[10, 11, 12, 13], + block_size=block_size, + hash_fn=sha256, + ) + assert len(req.block_hashes) == num_full_blocks + + # Physical parent is `null_block` (no hash), while the logical parent hash + # still exists in `request.block_hashes[num_cached_blocks - 1]`. + assert pool.null_block.block_hash is None + new_blocks = pool.get_new_blocks(num_full_blocks - 1) + blocks = [ + new_blocks[: num_cached_blocks - 1], + pool.null_block, # physical parent + *new_blocks[num_cached_blocks - 1 :], + ] + + pool.cache_full_blocks( + request=req, + blocks=blocks, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=block_size, + kv_cache_group_id=0, + ) + + events = pool.take_events() + assert len(events) == 1 + event = events[0] + assert isinstance(event, BlockStored) + + expected_parent = kv_cache_utils.maybe_convert_block_hash( + req.block_hashes[num_cached_blocks - 1] + ) + assert event.parent_block_hash == expected_parent + assert event.parent_block_hash is not None + + expected_new_hashes = [ + kv_cache_utils.maybe_convert_block_hash(h) + for h in req.block_hashes[num_cached_blocks:num_full_blocks] + ] + assert event.block_hashes == expected_new_hashes + + # Ensure we didn't accidentally assign a hash to the null block. + assert pool.null_block.block_hash is None + # Sanity check: newly cached physical blocks should have hashes assigned. + assert blocks[num_cached_blocks].block_hash is not None + assert blocks[num_full_blocks - 1].block_hash is not None + + @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) def test_kv_cache_events_with_lora(blocks_to_cache: int): """Test BlockStored events contain correct lora_id when using LoRA requests.""" diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index c779e3d34b3ed..a6f06d1b16a34 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -270,10 +270,8 @@ class BlockPool: if num_cached_blocks == 0: parent_block_hash: ExternalBlockHash | None = None else: - parent_block = blocks[num_cached_blocks - 1] - assert parent_block.block_hash is not None parent_block_hash = maybe_convert_block_hash( - get_block_hash(parent_block.block_hash) + block_hashes[num_cached_blocks - 1] ) self.kv_event_queue.append( From 8b59753cdb5eb2a672fb44e3a281cebea197355b Mon Sep 17 00:00:00 2001 From: Chao Lei Date: Wed, 24 Dec 2025 10:24:07 +0800 Subject: [PATCH 6/7] [P/D] Mooncake connector support more protocols (#30133) Signed-off-by: LCAIZJ --- .../kv_transfer/kv_connector/v1/mooncake_connector.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py index 705960aebe2da..9a15d3fa6ed09 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py @@ -408,7 +408,13 @@ class MooncakeConnectorWorker: self.engine = TransferEngine() self.hostname = get_ip() - ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "") + protocol = self.vllm_config.kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr] + "mooncake_protocol", "rdma" + ) + logger.info( + "The Mooncake Transfer Engine is using %s as its protocol.", protocol + ) + ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", protocol, "") if ret_value != 0: raise RuntimeError("Mooncake Transfer Engine initialization failed.") From 76e6a951925bf37c49f88ad155dc9fcec01a3faf Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:41:09 -0500 Subject: [PATCH 7/7] [Bug] Fix `Number of dimensions of tensors must match.` for Deepseek V3.2 (#31160) Signed-off-by: yewentao256 --- vllm/model_executor/models/deepseek_v2.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 22d43a4bae18a..4899f5476f955 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -878,11 +878,14 @@ class Indexer(nn.Module): ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - # `rotary_emb` is shape-preserving; `q_pe` is already - # [num_tokens, n_head, rope_dim]. + # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation + # so we need to reshape back to token-flattened shapes + q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) + k_pe = k_pe.reshape(-1, 1, self.rope_dim) + q = torch.cat([q_pe, q_nope], dim=-1) # `k_pe` is [num_tokens, 1, rope_dim] (MQA). - k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) + k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim)