From dc48ba0c750e176f6314504cc0e8370a46ed01a8 Mon Sep 17 00:00:00 2001 From: Bram Wasti Date: Fri, 26 Sep 2025 19:59:09 -0400 Subject: [PATCH 01/20] Kernel-override Determinism [1/n] (#25603) Signed-off-by: Bram Wasti --- csrc/core/batch_invariant.hpp | 16 + csrc/layernorm_kernels.cu | 8 +- csrc/layernorm_quant_kernels.cu | 5 +- csrc/moe/topk_softmax_kernels.cu | 4 +- tests/v1/generation/test_batch_invariance.py | 290 +++++++++ vllm/model_executor/layers/batch_invariant.py | 561 ++++++++++++++++++ vllm/v1/attention/backends/flex_attention.py | 7 + vllm/v1/worker/gpu_model_runner.py | 3 + 8 files changed, 890 insertions(+), 4 deletions(-) create mode 100644 csrc/core/batch_invariant.hpp create mode 100644 tests/v1/generation/test_batch_invariance.py create mode 100644 vllm/model_executor/layers/batch_invariant.py diff --git a/csrc/core/batch_invariant.hpp b/csrc/core/batch_invariant.hpp new file mode 100644 index 0000000000000..19e422e4b80cd --- /dev/null +++ b/csrc/core/batch_invariant.hpp @@ -0,0 +1,16 @@ +#pragma once +#include +#include +#include + +namespace vllm { + +// vllm_kernel_override_batch_invariant(); returns true +// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1 +inline bool vllm_kernel_override_batch_invariant() { + std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"; + const char* val = std::getenv(env_key.c_str()); + return (val && std::atoi(val) != 0) ? 1 : 0; +} + +} // namespace vllm diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 93c73d58390e1..6c3685f6f7cdc 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,6 +1,7 @@ #include "type_convert.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" +#include "core/batch_invariant.hpp" #include #include @@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] wt_ptr % req_alignment_bytes == 0; bool offsets_are_multiple_of_vector_width = hidden_size % vector_width == 0 && input_stride % vector_width == 0; - if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && offsets_are_multiple_of_vector_width && + !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); @@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size] auto inp_ptr = reinterpret_cast(input.data_ptr()); auto out_ptr = reinterpret_cast(out.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) { LAUNCH_FUSED_POLY_NORM(8); } else { LAUNCH_FUSED_POLY_NORM(0); diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index be134089bd6d4..58c3d9c0981a0 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -9,6 +9,7 @@ #include "quantization/fp8/common.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" +#include "core/batch_invariant.hpp" #include #include @@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant( auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 && + !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 53573ada86ba9..eca021f1c1863 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -21,6 +21,7 @@ #include #include "../cuda_compat.h" #include "../cub_helpers.h" +#include "../core/batch_invariant.hpp" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py new file mode 100644 index 0000000000000..b864f9a318363 --- /dev/null +++ b/tests/v1/generation/test_batch_invariance.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +import random +import string + +import pytest +import torch + +from vllm import LLM, SamplingParams + + +def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: + # Lightweight random prompt generator to vary prompt lengths and content. + vocab = [ + "alpha", + "bravo", + "charlie", + "delta", + "echo", + "foxtrot", + "golf", + "hotel", + "india", + "juliet", + "kilo", + "lima", + "mike", + "november", + "oscar", + "papa", + "quebec", + "romeo", + "sierra", + "tango", + "uniform", + "victor", + "whiskey", + "xray", + "yankee", + "zulu", + ] + n = random.randint(min_words, max_words) + words = random.choices(vocab, k=n) + + # Add some noise and punctuation variability + if random.random() < 0.5: + words[0] = words[0].capitalize() + if random.random() < 0.2: + words.append("".join(random.choices(string.ascii_lowercase, k=5))) + punct = random.choice([".", "?", "!", "...", ""]) + return " ".join(words) + punct + + +@pytest.mark.timeout(1000) +def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): + """ + Ensures that the same request (the 'needle' prompt) yields identical output + whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64), + using the high-level v1 LLM() API only (no manual batching). + + Strategy: + - Create two LLM engines with identical config except max_num_seqs: 1 vs N. + - Compute a baseline output for the needle prompt with the bs=1 engine. + - For many trials, generate a batch (size N) where the needle appears at a + random position among random filler prompts using the bs=N engine. + - Track how many trials match vs mismatch, and report totals at the end. + The test fails if any mismatches occur, but we still dump pass/fail + counts. + + Notes: + - Use seeded stochastic sampling with a fixed seed to test determinism. + - Outputs are intentionally longer and sampled at higher temperature/top_p + to produce a more random-sounding phrase, yet remain deterministic by + seed. + - Keep max_tokens and max_model_len bounded for speed and memory use. + """ + random.seed(12345) + + # Allow overrides from environment (useful for CI tuning) + # "facebook/opt-125m" is too small, doesn't reliably test determinism + model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) + batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64")) + assert batch_size >= 2, "Batch size should be >= 2 to mix needle." + + # Keep GPU memory usage low to avoid startup allocation failures. + gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3")) + max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096")) + swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) + + # Sampling parameters: longer outputs with a more random-sounding + # continuation,but still deterministic due to fixed seed. + temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0")) + top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95")) + max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128")) + + sampling = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=20240919, + ) + + needle_prompt = ("There once was a ") + + llm_bs1 = None + llm_bsN = None + try: + # Engine with bs=1 behavior + llm_bs1 = LLM_with_max_seqs( + model=model, + max_num_seqs=1, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + swap_space=swap_space_gb, + ) + + # Baseline generation for the needle prompt alone. + baseline_out = llm_bs1.generate([needle_prompt], sampling) + assert len(baseline_out) == 1 + assert len(baseline_out[0].outputs) >= 1 + baseline_text = baseline_out[0].outputs[0].text + + # Engine with larger batch limit (e.g., 64) + llm_bsN = LLM_with_max_seqs( + model=model, + max_num_seqs=batch_size, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + swap_space=swap_space_gb, + ) + + mismatches = 0 + + for trial in range(num_trials): + # Create a batch of size `batch_size` and insert the needle at + # a random index + prompts: list[str] = [] + needle_pos = random.randint(0, batch_size - 1) + for i in range(batch_size): + if i == needle_pos: + prompts.append(needle_prompt) + else: + prompts.append(_random_prompt()) + + # Generate with the larger-batch engine + outputs = llm_bsN.generate(prompts, sampling) + # Find the needle output by position + needle_output = outputs[needle_pos] + assert needle_output.prompt == needle_prompt + assert len(needle_output.outputs) >= 1 + text = needle_output.outputs[0].text + + if text != baseline_text: + mismatches += 1 + + passes = num_trials - mismatches + # Dump how many passed vs failed + print(f"[determinism] total={num_trials}, passed={passes}, " + f"failed={mismatches}, batch_size={batch_size}") + + if mismatches > 0: + pytest.fail( + f"Nondeterministic outputs detected: {mismatches} failed out " + f"of {num_trials} trials (batch_size={batch_size}).") + + finally: + # Ensure engines are shutdown to free GPU/VRAM across test sessions + if llm_bs1 is not None: + with contextlib.suppress(Exception): + llm_bs1.shutdown() + if llm_bsN is not None: + with contextlib.suppress(Exception): + llm_bsN.shutdown() + + +def _extract_step_logprobs(request_output): + if getattr(request_output, "outputs", None): + inner = request_output.outputs[0] + if hasattr(inner, "logprobs") and inner.logprobs is not None: + t = torch.tensor( + [ + inner.logprobs[i][tid].logprob + for i, tid in enumerate(inner.token_ids) + ], + dtype=torch.float32, + ) + return t + + return None + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Requires CUDA to match production inference path.", +) +def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2(): + + #model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m") + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + # Force float32 to avoid precision-induced differences. + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enforce_eager=True, # helps reduce nondeterminism from some backends + ) + + prompts = [ + "The capital of France is", + "The capital of Germany is", + ] + + sp = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=8, + # Seed shouldn't matter at temperature=0, but keeping it stable anyway. + seed=1234, + logprobs=5, + ) + + # BS=1: run prompts individually and collect logprobs per step. + bs1_logprobs_per_prompt = [] + for p in prompts: + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bs1_logprobs_per_prompt.append(step_logprobs) + + # BS=2: run prompts in a batch and collect logprobs per step for each + # prompt. + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bs2_logprobs_per_prompt = [] + for o in outs_batched: + step_logprobs = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bs2_logprobs_per_prompt.append(step_logprobs) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs. + for i, (logprobs_bs1, logprobs_bs2) in enumerate( + zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)): + assert len(logprobs_bs1) == len(logprobs_bs2), ( + f"Different number of generation steps for prompt index {i}: " + f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)") + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)): + assert a.shape == b.shape, ( + f"Logits shape mismatch at prompt {i}, step {t}: " + f"{a.shape} vs {b.shape}") + # Bitwise exact equality. + assert torch.equal( + a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} " + f"(dtype={a.dtype}, shape={a.shape}).") + + +def LLM_with_max_seqs( + model: str, + max_num_seqs: int, + gpu_memory_utilization: float, + max_model_len: int, + swap_space: int, +) -> LLM: + """ + Helper to construct an LLM with a specific max_num_seqs (batch-size limit) + using the high-level v1 LLM API, while constraining memory usage. + """ + return LLM( + model=model, + max_num_seqs=max_num_seqs, + # Constrain GPU memory pool so test can run even on busy GPUs. + gpu_memory_utilization=gpu_memory_utilization, + # Keep KV cache footprint small while allowing longer outputs. + max_model_len=max_model_len, + # Allow some CPU offload if needed. + swap_space=swap_space, + # Keep things lean and CI-friendly. + dtype="float16", + # Single-GPU by default; override externally if desired. + tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1", + ) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py new file mode 100644 index 0000000000000..ae2c842af698b --- /dev/null +++ b/vllm/model_executor/layers/batch_invariant.py @@ -0,0 +1,561 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +from collections import namedtuple +from collections.abc import Callable +from typing import Any, Union + +import torch +import triton +import triton.language as tl + + +def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, + args: dict[str, Any]) -> dict[str, Any]: + ret = {} + m, n, k = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" + if "tiles_per_update" in args: + ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, " + f"tiles_per_update={args['tiles_per_update']:02}]") + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k + ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) + return ret + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, # + bias_ptr, + M, + N, + K, # + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + A_LARGE: tl.constexpr, + B_LARGE: tl.constexpr, + C_LARGE: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, + GROUP_SIZE_M, NUM_SMS) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + if A_LARGE: + offs_am = offs_am.to(tl.int64) + if B_LARGE: + offs_bn = offs_bn.to(tl.int64) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), + BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), + BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + if A_LARGE or B_LARGE: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to( + tl.int64) + else: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, + mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, b, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, + GROUP_SIZE_M, NUM_SMS) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if C_LARGE: + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_cn + bias = tl.load(bias_ptrs, mask=offs_cn < N, + other=0.0).to(tl.float32) + accumulator += bias + if c_ptr.dtype.element_ty == tl.float8e4nv: + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul_persistent(a: torch.Tensor, + b: torch.Tensor, + bias: Union[torch.Tensor, None] = None): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + assert bias is None or bias.dim() == 1, ( + "Currently assuming bias is 1D, let Horace know if you run into this") + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + + configs = { + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float32: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + # print(a.device, b.device, c.device) + matmul_kernel_persistent[grid]( + a, + b, + c, # + bias, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + NUM_SMS=NUM_SMS, # + A_LARGE=a.numel() > 2**31, + B_LARGE=b.numel() > 2**31, + C_LARGE=c.numel() > 2**31, + HAS_BIAS=bias is not None, + **configs[dtype], + ) + return c + + +@triton.jit +def _log_softmax_kernel( + input_ptr, + output_ptr, + input_row_stride, + output_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute log_softmax along the last dimension of a 2D tensor. + Each block handles one row of the input tensor. + """ + # Get the row index for this block + row_idx = tl.program_id(0).to(tl.int64) + + # Compute base pointers for input and output rows + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Find maximum value in the row for numerical stability + max_val = -float("inf") + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) + + # Update maximum + max_val = tl.max(tl.maximum(vals, max_val)) + + # Step 2: Compute sum of exp(x - max_val) + sum_exp = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + + # Compute exp(x - max_val) and accumulate + exp_vals = tl.exp(vals - max_val) + sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) + + # Compute log(sum_exp) + log_sum_exp = tl.log(sum_exp) + + # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask) + + # Compute log_softmax + output = vals - max_val - log_sum_exp + + # Store results + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Compute log_softmax using Triton kernel. + + Args: + input: Input tensor + dim: Dimension along which to compute log_softmax + (only -1 or last dim supported) + >> Stashed changes + Returns: + Tensor with log_softmax applied along the specified dimension + """ + if dim != -1 and dim != input.ndim - 1: + raise ValueError("This implementation only supports log_softmax along " + "the last dimension") + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + + n_rows, n_cols = input_2d.shape + + # Allocate output tensor + output = torch.empty_like(input_2d) + + # Choose block size based on the number of columns + BLOCK_SIZE = 1024 + + # Launch kernel with one block per row + grid = (n_rows, ) + _log_softmax_kernel[grid]( + input_2d, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Reshape output back to original shape + return output.reshape(original_shape) + + +@triton.jit +def mean_kernel( + input_ptr, + output_ptr, + input_stride0, + input_stride1, + input_stride2, + output_stride0, + output_stride1, + M, # size before reduction dim + N, # size of reduction dim + K, # size after reduction dim + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for computing mean along a single dimension. + Input is viewed as (M, N, K) where N is the dimension being reduced. + """ + # Program ID gives us which output element we're computing + pid = tl.program_id(0) + + # Compute output indices + m_idx = pid // K + k_idx = pid % K + + # Bounds check + if m_idx >= M or k_idx >= K: + return + + # Accumulate sum across reduction dimension + acc = 0.0 + for n_start in range(0, N, BLOCK_SIZE): + n_offsets = n_start + tl.arange(0, BLOCK_SIZE) + mask = n_offsets < N + + # Calculate input indices + input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \ + + k_idx * input_stride2 + + # Load and accumulate + vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) + acc += tl.sum(vals) + + # Compute mean and store + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 + tl.store(output_ptr + output_idx, mean_val) + + +def mean_dim(input: torch.Tensor, + dim: int, + keepdim: bool = False, + dtype: Union[torch.dtype, None] = None) -> torch.Tensor: + """ + Triton implementation of torch.mean with single dimension reduction. + + Args: + input: Input tensor + dim: Single dimension along which to compute mean + keepdim: Whether to keep the reduced dimension + dtype: Output dtype. If None, uses input dtype + (or float32 for integer inputs) + + Returns: + Tensor with mean values along specified dimension + """ + # Validate inputs + assert input.is_cuda, "Input must be a CUDA tensor" + assert -input.ndim <= dim < input.ndim, ( + f"Invalid dimension {dim} for tensor with {input.ndim} dimensions") + + # Handle negative dim + if dim < 0: + dim = dim + input.ndim + + # Handle dtype + if dtype is None: + if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + dtype = torch.float32 + else: + dtype = input.dtype + + # Convert input to appropriate dtype if needed + if input.dtype != dtype: + input = input.to(dtype) + + # Get input shape and strides + shape = list(input.shape) + + # Calculate dimensions for kernel + M = 1 + for i in range(dim): + M *= shape[i] + + N = shape[dim] + + K = 1 + for i in range(dim + 1, len(shape)): + K *= shape[i] + + # Reshape input to 3D view (M, N, K) + input_3d = input.reshape(M, N, K) + + # Create output shape + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1:] + + # Create output tensor + output = torch.empty(output_shape, dtype=dtype, device=input.device) + + # Reshape output for kernel + if keepdim: + output_2d = output.reshape(M, 1, K).squeeze(1) + else: + output_2d = output.reshape(M, K) + + # Launch kernel + grid = (M * K, ) + BLOCK_SIZE = 1024 + + mean_kernel[grid]( + input_3d, + output_2d, + input_3d.stride(0), + input_3d.stride(1), + input_3d.stride(2), + output_2d.stride(0), + output_2d.stride(1) if output_2d.ndim > 1 else 0, + M, + N, + K, + BLOCK_SIZE, + ) + + return output + + +def mm_batch_invariant(a, b): + return matmul_persistent(a, b) + + +def addmm_batch_invariant(bias, a, b): + return matmul_persistent(a, b, bias=bias) + + +def _log_softmax_batch_invariant(input, dim, _half_to_float): + assert not _half_to_float, "not implemented" + return log_softmax(input, dim=dim) + + +def mean_batch_invariant(input, + dim, + keepdim=False, + dtype: Union[torch.dtype, None] = None): + assert dtype is None or dtype == torch.float32, \ + f"unsupported dtype: {dtype}" + + result = input.to(torch.float32) + + # Sort dimensions to reduce from largest to smallest to handle shifting dims + # during iterative reduction. + sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) + + # Iteratively apply a deterministic mean. + for d in sorted_dims: + result = mean_dim(result, dim=d, keepdim=True) + + if not keepdim: + # Squeeze the reduced dimensions. + for d in sorted_dims: + result = result.squeeze(d) + + return result + + +_batch_invariant_MODE = False +_batch_invariant_LIB = None + + +def is_batch_invariant_mode_enabled(): + return _batch_invariant_MODE + + +def enable_batch_invariant_mode(): + global _batch_invariant_MODE, _batch_invariant_LIB + if _batch_invariant_MODE: + return + + _batch_invariant_MODE = True + _batch_invariant_LIB = torch.library.Library("aten", "IMPL") + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::_log_softmax", + _log_softmax_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") + + +def disable_batch_invariant_mode(): + global _batch_invariant_MODE, _batch_invariant_LIB + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + _batch_invariant_MODE = False + _batch_invariant_LIB = None + + +@contextlib.contextmanager +def set_batch_invariant_mode(enabled: bool = True): + global _batch_invariant_MODE, _batch_invariant_LIB + old_data = (_batch_invariant_MODE, _batch_invariant_LIB) + if enabled: + enable_batch_invariant_mode() + else: + disable_batch_invariant_mode() + yield + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + _batch_invariant_MODE, _batch_invariant_LIB = old_data + + +AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) + + +def get_batch_invariant_attention_block_size() -> AttentionBlockSize: + return AttentionBlockSize(block_m=16, block_n=16) + + +def vllm_kernel_override_batch_invariant(): + env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT" + is_overridden = False + val = os.getenv(env_key, "0") + try: + is_overridden = int(val) != 0 + except ValueError: + is_overridden = False + return is_overridden + + +def init_batch_invariance(): + # this will hit all the csrc overrides as well + if vllm_kernel_override_batch_invariant(): + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" + enable_batch_invariant_mode() diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index c3358bfa74e91..807b8d987a2d9 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -18,6 +18,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, is_quantized_kv_cache) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant) from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) @@ -839,6 +841,11 @@ def get_kernel_options(query, block_m, block_n, kernel_options: dict[str, Union[int, bool]] = { "FORCE_USE_FLEX_ATTENTION": True, } + if vllm_kernel_override_batch_invariant(): + kernel_options["BLOCK_M"] = 16 + kernel_options["BLOCK_N"] = 16 + kernel_options["IS_DIVISIBLE"] = False + return kernel_options if use_direct_build: kernel_options["BLOCK_M"] = block_m kernel_options["BLOCK_N"] = block_n diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4fd4f9128c6eb..f87a327d02a50 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -192,6 +192,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) + from vllm.model_executor.layers.batch_invariant import ( + init_batch_invariance) + init_batch_invariance() model_config = self.model_config cache_config = self.cache_config From 4e33a7ea85cb702090c07fb7a8ebdbf44c472f5c Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Fri, 26 Sep 2025 17:07:36 -0700 Subject: [PATCH 02/20] [Bugfix] Optimize CpuGpuBuffer initialization (#25447) Signed-off-by: Naman Lalit --- vllm/v1/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index ec4417290f611..ee0c1168f3cd0 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -117,7 +117,7 @@ class CpuGpuBuffer: dtype=dtype, device="cpu", pin_memory=pin_memory) - self.gpu = self.cpu.to(device) + self.gpu = torch.zeros_like(self.cpu, device=device) self.np: np.ndarray # To keep type hints simple (avoiding generics and subclasses), we # only conditionally create the numpy array attribute. This can cause From 6f5c0931c1f618b2ca8668d114fec2d26cecfd8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20M=2E=20K=C3=BCbler?= <44084297+jmkuebler@users.noreply.github.com> Date: Sat, 27 Sep 2025 02:10:21 +0200 Subject: [PATCH 03/20] [Spec decode] automatically disable mm for text-only draft models (#25667) Signed-off-by: Jonas Kuebler --- tests/v1/e2e/test_spec_decode.py | 126 ++++++++++++++++--------------- vllm/v1/spec_decode/eagle.py | 14 ++++ 2 files changed, 78 insertions(+), 62 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index c4efd7548b81b..ea8d94722859b 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -8,7 +8,7 @@ from typing import Any, Union import pytest import torch -from tests.utils import get_attn_backend_list_based_on_platform +from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR @@ -88,69 +88,66 @@ def test_ngram_correctness( Compare the outputs of an original LLM and a speculative LLM should be the same when using ngram speculative decoding. ''' - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - test_prompts = get_test_prompts(mm_enabled=False) + test_prompts = get_test_prompts(mm_enabled=False) - ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) - del ref_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + ref_llm = LLM(model=model_name, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - spec_llm = LLM( - model=model_name, - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, - max_model_len=1024, - ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 66% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches >= int(0.66 * len(ref_outputs)) - del spec_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches >= int(0.66 * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() -@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), -], - ids=[ - "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", - "llama4_eagle", "llama4_eagle_mm", - "deepseek_eagle" - ]) +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], + [ + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=large_gpu_mark(min_gb=80)), # works on 4x H100 + pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=large_gpu_mark(min_gb=80)), # works on 4x H100 + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), + ], + ids=[ + "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle", + "llama4_eagle_mm", "deepseek_eagle" + ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( @@ -174,9 +171,14 @@ def test_eagle_correctness( model_setup: (method, model_name, eagle_model_name, tp_size) ''' with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - m.setenv("VLLM_MLA_DISABLE", "1") - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": + # Scout requires default backend selection + # because vision encoder has head_dim 88 being incompatible + # with FLASH_ATTN and needs to fall back to Flex Attn + pass + else: + m.setenv("VLLM_MLA_DISABLE", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): pytest.skip("TRITON_ATTN does not support " diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 57da8346f497f..394df48b4153f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -804,6 +804,20 @@ class EagleProposer: self.attn_layer_names = list(draft_attn_layer_names) + if self.is_multimodal_model: + # Even if the target model is multimodal, we can also use + # text-only draft models + try: + dummy_input_ids = torch.tensor([[1]], + device=self.input_ids.device) + self.model.get_input_embeddings(dummy_input_ids, + multimodal_embeddings=None) + except (NotImplementedError, AttributeError, TypeError): + logger.warning( + "Draft model does not support multimodal inputs, " + "falling back to text-only mode") + self.is_multimodal_model = False + if supports_multimodal(target_model): # handle multimodality self.model.config.image_token_index = ( From 8bf8f4582208ac7af230512ff5f3ac1dc36d5222 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 26 Sep 2025 17:16:40 -0700 Subject: [PATCH 04/20] [Core] Don't count preempted tokens in prefix cache hit rate (#25787) Signed-off-by: Zhuohan Li --- vllm/v1/core/kv_cache_manager.py | 24 ++++++++---- vllm/v1/core/sched/scheduler.py | 64 ++++++++++++++++---------------- vllm/v1/metrics/stats.py | 8 +++- vllm/v1/request.py | 3 ++ 4 files changed, 59 insertions(+), 40 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 401327f727a4a..0af98e7ba2d89 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -27,8 +27,8 @@ class KVCacheBlocks: `blocks[i][j]` refers to the i-th kv_cache_group and the j-th block of tokens.We don't use block of tokens as the outer dimension because it assumes all - kv_cache_groups have the same number of blocks, which is true for now but - will be broken if we want to give different block_size to different + kv_cache_groups have the same number of blocks, which is true for now but + will be broken if we want to give different block_size to different kv_cache_groups in the future. """ @@ -184,9 +184,17 @@ class KVCacheManager: if self.log_stats: assert self.prefix_cache_stats is not None - self.prefix_cache_stats.requests += 1 - self.prefix_cache_stats.queries += request.num_tokens - self.prefix_cache_stats.hits += num_new_computed_tokens + if request.num_preemptions > 0: + # Previously preempted request + self.prefix_cache_stats.preempted_requests += 1 + self.prefix_cache_stats.preempted_queries += request.num_tokens + self.prefix_cache_stats.preempted_hits += ( + num_new_computed_tokens) + else: + # New request + self.prefix_cache_stats.requests += 1 + self.prefix_cache_stats.queries += request.num_tokens + self.prefix_cache_stats.hits += num_new_computed_tokens return KVCacheBlocks(computed_blocks), num_new_computed_tokens @@ -209,10 +217,10 @@ class KVCacheManager: already been computed locally (i.e. new_computed_blocks). num_new_computed_tokens: The number of new computed tokens just hitting the prefix caching, excluding external tokens. - new_computed_blocks: The cached blocks for the above new computed + new_computed_blocks: The cached blocks for the above new computed tokens. num_lookahead_tokens: The number of speculative tokens to allocate. - This is used by spec decode proposers with kv-cache such + This is used by spec decode proposers with kv-cache such as eagle. delay_cache_blocks: Whether to skip caching the blocks. This is used by P/D when allocating blocks used in a KV transfer @@ -365,7 +373,7 @@ class KVCacheManager: requests in the current step. Returns: - list[int]: The number of common prefix blocks for each kv cache + list[int]: The number of common prefix blocks for each kv cache group. """ assert request.status == RequestStatus.RUNNING diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7fc4776b02611..10d8f6bbda5cc 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -251,46 +251,48 @@ class Scheduler(SchedulerInterface): req_index += 1 continue + # Schedule newly needed KV blocks for the request. while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens) - if new_blocks is None: - # The request cannot be scheduled. - # Preempt the lowest-priority request. - if self.policy == SchedulingPolicy.PRIORITY: - preempted_req = max( - self.running, - key=lambda r: (r.priority, r.arrival_time), - ) - self.running.remove(preempted_req) - if preempted_req in scheduled_running_reqs: - scheduled_running_reqs.remove(preempted_req) - else: - preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp) - - self.waiting.prepend_request(preempted_req) - preempted_reqs.append(preempted_req) - if preempted_req == request: - # No more request to preempt. - can_schedule = False - break - else: + if new_blocks is not None: # The request can be scheduled. - can_schedule = True break - if not can_schedule: + + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event(EngineCoreEventType.PREEMPTED, + scheduled_timestamp) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break + + if new_blocks is None: + # Cannot schedule this request. break - assert new_blocks is not None # Schedule the request. scheduled_running_reqs.append(request) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 296c39e8cdb5c..a0d571318ba0d 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -17,13 +17,19 @@ class PrefixCacheStats: """Stores prefix cache hit statistics.""" # Whether reset_prefix_cache was invoked. reset: bool = False - # The number of requests in this update. + # The number of new requests in this update. requests: int = 0 # The number of queries in these requests. Note that "queries" here # means the number of tokens that were queried from the cache. queries: int = 0 # The number of hits in these requests. hits: int = 0 + # The number of previously preempted requests in this update. + preempted_requests: int = 0 + # The `queries` number for preempted requests. + preempted_queries: int = 0 + # The `hits` number for preempted requests. + preempted_hits: int = 0 @dataclass diff --git a/vllm/v1/request.py b/vllm/v1/request.py index ff10fa00c1cf6..dd0aea645d742 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -115,6 +115,9 @@ class Request: # indicates that the output is corrupted self.num_nans_in_logits = 0 + # The number of requests being preempted by the scheduler + self.num_preemptions = 0 + self.block_hashes: list[BlockHash] = [] self.get_hash_new_full_blocks: Optional[Callable[ [], list[BlockHash]]] = None From 3958b96bf5f771560053b752424b1e7caba04a61 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 26 Sep 2025 21:23:52 -0400 Subject: [PATCH 05/20] Add option to restrict media domains (#25783) Signed-off-by: Chenheli Hua Signed-off-by: Russell Bryant Co-authored-by: Chenheli Hua --- docs/features/multimodal_inputs.md | 4 +++ docs/usage/security.md | 6 ++++ .../entrypoints/openai/test_lora_resolvers.py | 1 + tests/entrypoints/openai/test_serving_chat.py | 1 + tests/multimodal/test_utils.py | 33 ++++++++++++++++++- vllm/config/model.py | 3 ++ vllm/config/speculative.py | 2 ++ vllm/engine/arg_utils.py | 5 +++ vllm/entrypoints/chat_utils.py | 6 ++++ vllm/entrypoints/llm.py | 4 +++ vllm/multimodal/utils.py | 16 +++++++++ 11 files changed, 80 insertions(+), 1 deletion(-) diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 7fb0337235005..bcc48e7560462 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. +!!! tip + When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com` + This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks. + ## Offline Inference To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: diff --git a/docs/usage/security.md b/docs/usage/security.md index d54e2bb37ec07..5d85e889c80cc 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -60,6 +60,12 @@ Key points from the PyTorch security guide: - Implement proper authentication and authorization for management interfaces - Follow the principle of least privilege for all system components +### 4. **Restrict Domains Access for Media URLs:** + +Restrict domains that vLLM can access for media URLs by setting +`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks. +(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`) + ## Security and Firewalls: Protecting Exposed vLLM Systems While vLLM is designed to allow unsafe network services to be isolated to diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index 9d5ee84a19567..0561158dcf65a 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -45,6 +45,7 @@ class MockModelConfig: logits_processor_pattern: Optional[str] = None diff_sampling_param: Optional[dict] = None allowed_local_media_path: str = "" + allowed_media_domains: Optional[list[str]] = None encoder_config = None generation_config: str = "auto" skip_tokenizer_init: bool = False diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index bfed760822cdb..07f39fe2b9bd0 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -240,6 +240,7 @@ class MockModelConfig: logits_processor_pattern = None diff_sampling_param: Optional[dict] = None allowed_local_media_path: str = "" + allowed_media_domains: Optional[list[str]] = None encoder_config = None generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index f6a93bae2afce..d1a7882a4c376 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.parametrize("suffix", get_supported_suffixes()) async def test_fetch_image_base64(url_images: dict[str, Image.Image], raw_image_url: str, suffix: str): - connector = MediaConnector() + connector = MediaConnector( + # Domain restriction should not apply to data URLs. + allowed_media_domains=[ + "www.bogotobogo.com", + "github.com", + ]) url_image = url_images[raw_image_url] try: @@ -387,3 +392,29 @@ def test_argsort_mm_positions(case): modality_idxs = argsort_mm_positions(mm_positions) assert modality_idxs == expected_modality_idxs + + +@pytest.mark.asyncio +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +@pytest.mark.parametrize("num_frames", [-1, 32, 1800]) +async def test_allowed_media_domains(video_url: str, num_frames: int): + connector = MediaConnector( + media_io_kwargs={"video": { + "num_frames": num_frames, + }}, + allowed_media_domains=[ + "www.bogotobogo.com", + "github.com", + ]) + + video_sync, metadata_sync = connector.fetch_video(video_url) + video_async, metadata_async = await connector.fetch_video_async(video_url) + assert np.array_equal(video_sync, video_async) + assert metadata_sync == metadata_async + + disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png" + with pytest.raises(ValueError): + _, _ = connector.fetch_video(disallowed_url) + + with pytest.raises(ValueError): + _, _ = await connector.fetch_video_async(disallowed_url) diff --git a/vllm/config/model.py b/vllm/config/model.py index da01d6d4480c5..b2b68abd2c1d3 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -137,6 +137,9 @@ class ModelConfig: """Allowing API requests to read local images or videos from directories specified by the server file system. This is a security risk. Should only be enabled in trusted environments.""" + allowed_media_domains: Optional[list[str]] = None + """If set, only media URLs that belong to this domain can be used for + multi-modal inputs. """ revision: Optional[str] = None """The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8b80ce13f96ed..cb4f0ae2cee05 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -288,6 +288,8 @@ class SpeculativeConfig: trust_remote_code, allowed_local_media_path=self.target_model_config. allowed_local_media_path, + allowed_media_domains=self.target_model_config. + allowed_media_domains, dtype=self.target_model_config.dtype, seed=self.target_model_config.seed, revision=self.revision, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8757f4b8b7ba7..6bb794177db84 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -297,6 +297,8 @@ class EngineArgs: tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path + allowed_media_domains: Optional[ + list[str]] = ModelConfig.allowed_media_domains download_dir: Optional[str] = LoadConfig.download_dir safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format @@ -531,6 +533,8 @@ class EngineArgs: **model_kwargs["hf_config_path"]) model_group.add_argument("--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]) + model_group.add_argument("--allowed-media-domains", + **model_kwargs["allowed_media_domains"]) model_group.add_argument("--revision", **model_kwargs["revision"]) model_group.add_argument("--code-revision", **model_kwargs["code_revision"]) @@ -997,6 +1001,7 @@ class EngineArgs: tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, allowed_local_media_path=self.allowed_local_media_path, + allowed_media_domains=self.allowed_media_domains, dtype=self.dtype, seed=self.seed, revision=self.revision, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index df49119d86420..4e1ecb9ed4c51 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -632,6 +632,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def allowed_local_media_path(self): return self._model_config.allowed_local_media_path + @property + def allowed_media_domains(self): + return self._model_config.allowed_media_domains + @property def mm_registry(self): return MULTIMODAL_REGISTRY @@ -832,6 +836,7 @@ class MultiModalContentParser(BaseMultiModalContentParser): self._connector = MediaConnector( media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, + allowed_media_domains=tracker.allowed_media_domains, ) def parse_image( @@ -916,6 +921,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): self._connector = MediaConnector( media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, + allowed_media_domains=tracker.allowed_media_domains, ) def parse_image( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index dfe535b959179..862f383e4ecb2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -86,6 +86,8 @@ class LLM: or videos from directories specified by the server file system. This is a security risk. Should only be enabled in trusted environments. + allowed_media_domains: If set, only media URLs that belong to this + domain can be used for multi-modal inputs. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, @@ -169,6 +171,7 @@ class LLM: skip_tokenizer_init: bool = False, trust_remote_code: bool = False, allowed_local_media_path: str = "", + allowed_media_domains: Optional[list[str]] = None, tensor_parallel_size: int = 1, dtype: ModelDType = "auto", quantization: Optional[QuantizationMethods] = None, @@ -264,6 +267,7 @@ class LLM: skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, allowed_local_media_path=allowed_local_media_path, + allowed_media_domains=allowed_media_domains, tensor_parallel_size=tensor_parallel_size, dtype=dtype, quantization=quantization, diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 9b158267040af..1f1eea6bfee75 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -50,6 +50,7 @@ class MediaConnector: connection: HTTPConnection = global_http_connection, *, allowed_local_media_path: str = "", + allowed_media_domains: Optional[list[str]] = None, ) -> None: """ Args: @@ -82,6 +83,9 @@ class MediaConnector: allowed_local_media_path_ = None self.allowed_local_media_path = allowed_local_media_path_ + if allowed_media_domains is None: + allowed_media_domains = [] + self.allowed_media_domains = allowed_media_domains def _load_data_url( self, @@ -115,6 +119,14 @@ class MediaConnector: return media_io.load_file(filepath) + def _assert_url_in_allowed_media_domains(self, url_spec) -> None: + if self.allowed_media_domains and url_spec.hostname not in \ + self.allowed_media_domains: + raise ValueError( + f"The URL must be from one of the allowed domains: " + f"{self.allowed_media_domains}. Input URL domain: " + f"{url_spec.hostname}") + def load_from_url( self, url: str, @@ -125,6 +137,8 @@ class MediaConnector: url_spec = urlparse(url) if url_spec.scheme.startswith("http"): + self._assert_url_in_allowed_media_domains(url_spec) + connection = self.connection data = connection.get_bytes(url, timeout=fetch_timeout) @@ -150,6 +164,8 @@ class MediaConnector: loop = asyncio.get_running_loop() if url_spec.scheme.startswith("http"): + self._assert_url_in_allowed_media_domains(url_spec) + connection = self.connection data = await connection.async_get_bytes(url, timeout=fetch_timeout) future = loop.run_in_executor(global_thread_pool, From 92da847cf5f4eedf0bc9fed45d7c076be78b8c1f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 26 Sep 2025 21:54:09 -0400 Subject: [PATCH 06/20] Add flashinfer-build.sh and register precompiled cu128 wheel in Dockerfile (#25782) Signed-off-by: mgoin --- docker/Dockerfile | 30 ++++++++++++------- tools/flashinfer-build.sh | 63 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 10 deletions(-) create mode 100644 tools/flashinfer-build.sh diff --git a/docker/Dockerfile b/docker/Dockerfile index c0f55a7eeba07..fad62be798a1e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -391,18 +391,28 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' git clone --depth 1 --recursive --shallow-submodules \ --branch ${FLASHINFER_GIT_REF} \ ${FLASHINFER_GIT_REPO} flashinfer + # Exclude CUDA arches for older versions (11.x and 12.0-12.7) + # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. + if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" + elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" + else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" + fi pushd flashinfer - if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then - # Exclude CUDA arches for older versions (11.x and 12.0-12.7) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - if [[ "${CUDA_VERSION}" == 11.* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - else - # CUDA 12.8+ supports 10.0a and 12.0 - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" + if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then + # NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh + echo "🏗️ Installing FlashInfer from pre-compiled wheel" + uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then + # Download pre-compiled cubins + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." fi + elif [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" # HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future diff --git a/tools/flashinfer-build.sh b/tools/flashinfer-build.sh new file mode 100644 index 0000000000000..6c14d87348c3a --- /dev/null +++ b/tools/flashinfer-build.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# This script is used to build FlashInfer wheels with AOT kernels + +set -ex + +# FlashInfer configuration +FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" +FLASHINFER_GIT_REF="${FLASHINFER_GIT_REF}" +CUDA_VERSION="${CUDA_VERSION}" +BUILD_WHEEL="${BUILD_WHEEL:-true}" + +if [[ -z "${FLASHINFER_GIT_REF}" ]]; then + echo "❌ FLASHINFER_GIT_REF must be specified" >&2 + exit 1 +fi + +if [[ -z "${CUDA_VERSION}" ]]; then + echo "❌ CUDA_VERSION must be specified" >&2 + exit 1 +fi + +echo "🏗️ Building FlashInfer ${FLASHINFER_GIT_REF} for CUDA ${CUDA_VERSION}" + +# Clone FlashInfer +git clone --depth 1 --recursive --shallow-submodules \ + --branch ${FLASHINFER_GIT_REF} \ + ${FLASHINFER_GIT_REPO} flashinfer + +# Set CUDA arch list based on CUDA version +# Exclude CUDA arches for older versions (11.x and 12.0-12.7) +if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" +elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" +else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" +fi + +echo "🏗️ Building FlashInfer AOT for arches: ${FI_TORCH_CUDA_ARCH_LIST}" + +pushd flashinfer + # Make sure the wheel is built for the correct CUDA version + export UV_TORCH_BACKEND=cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + + # Build AOT kernels + export TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" + export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" + python3 -m flashinfer.aot + + if [[ "${BUILD_WHEEL}" == "true" ]]; then + # Build wheel for distribution + uv build --no-build-isolation --wheel --out-dir ../flashinfer-dist . + echo "✅ FlashInfer wheel built successfully in flashinfer-dist/" + else + # Install directly (for Dockerfile) + uv pip install --system --no-build-isolation --force-reinstall . + echo "✅ FlashInfer installed successfully" + fi +popd + +# Cleanup +rm -rf flashinfer \ No newline at end of file From f1d53d150c5cd9c7d94db296793fc25f955ea8a9 Mon Sep 17 00:00:00 2001 From: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Date: Sat, 27 Sep 2025 11:35:47 +0800 Subject: [PATCH 07/20] [Multimodal][Speculative Decoding]Eagle Eagle3 mm support, enablement on qwen2.5vl (#22872) Signed-off-by: Junhong Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> --- tests/models/registry.py | 3 + tests/v1/e2e/test_spec_decode.py | 9 +- vllm/benchmarks/datasets.py | 80 +++++++++++++++ vllm/model_executor/models/llama_eagle3.py | 27 ++++-- vllm/model_executor/models/qwen2_5_vl.py | 11 ++- vllm/model_executor/models/registry.py | 1 + vllm/v1/spec_decode/eagle.py | 108 +++++++++++++++------ vllm/v1/worker/gpu_model_runner.py | 16 ++- 8 files changed, 210 insertions(+), 45 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 6a6e2538559ff..e321acc873c62 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -651,6 +651,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL"), + "Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo( + "Qwen/Qwen2.5-VL-7B-Instruct", + speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"), "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"), } diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index ea8d94722859b..8f048775352e6 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -129,6 +129,11 @@ def test_ngram_correctness( ["model_setup", "mm_enabled"], [ (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), + False, + marks=pytest.mark.skip(reason="Skipping due to its " \ + "head_dim not being a a multiple of 32")), (("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", @@ -145,8 +150,8 @@ def test_ngram_correctness( "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ - "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle", - "llama4_eagle_mm", "deepseek_eagle" + "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", + "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 68a937d5750ec..f0c0d829a393b 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ): dataset_class = MLPerfDataset args.hf_split = "train" + elif ( + args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MMStarDataset + args.hf_split = "val" + args.hf_subset = None else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -2721,3 +2728,76 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): random.shuffle(requests) return requests + + +# ----------------------------------------------------------------------------- +# MMStar Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MMStarDataset(HuggingFaceDataset): + """ + Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar + refer to: https://github.com/sgl-project/SpecForge/pull/106 + """ + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"} + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list[SampleRequest]: + # If --hf-output-len is not set, use the default output length. + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests: list[SampleRequest] = [] + + for ind, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + # Split the question text from options + # (keep only the part before "Options:"). + full_q: str = item.get("question", "") + question_text = full_q.split("Options:", 1)[0].strip() + + # Multimodal image content. + mm_content = process_image(item["image"]) + + # Compute prompt token length (note: this is plain text length + # if enable_multimodal_chat is False). + prompt_len = len(tokenizer(question_text).input_ids) + + if enable_multimodal_chat: + # If multimodal content should be embedded in the chat message, + # convert to [{"role":"user","content":[...]}] + prompt = self.apply_multimodal_chat_transformation( + question_text, mm_content + ) + mm_for_request = None # Already embedded in chat content. + else: + # Default: prompt is plain text, + # image is in mm_content for the bench to assemble. + prompt = question_text + mm_for_request = mm_content + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_for_request, + request_id=request_id_prefix + str(ind), + ) + ) + + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index b99a1547918ee..55b6ae6ee0e9c 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn from transformers import LlamaConfig -from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -19,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) @@ -102,7 +102,6 @@ class LlamaDecoderLayer(LlamaDecoderLayer): return hidden_states, residual -@support_torch_compile class LlamaModel(nn.Module): def __init__( @@ -145,13 +144,21 @@ class LlamaModel(nn.Module): eps=self.config.rms_norm_eps, ) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - input_embeds = self.embed_tokens(input_ids) + if input_embeds is None: + input_embeds = self.get_input_embeddings(input_ids) assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None @@ -239,11 +246,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if inputs_embeds is not None: - raise NotImplementedError( - f"{type(self).__name__} does not support multimodal inputs yet." - ) - return self.model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states, inputs_embeds) def compute_logits( self, @@ -299,3 +302,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): skip_substrs=skip_substrs, ) loader.load_weights(model_weights.items()) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + return inputs_embeds diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6af6faa2b2964..3199f53a0539e 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -68,7 +68,7 @@ from vllm.transformers_utils.config import uses_mrope from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, +from .interfaces import (MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, SupportsMultiModal, SupportsMultiModalPruning, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder @@ -965,7 +965,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, - SupportsQuant, + SupportsQuant, SupportsEagle3, SupportsMultiModalPruning): packed_modules_mapping = { @@ -1028,6 +1028,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.language_model.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.language_model.model.layers) + return (2, num_layers // 2, num_layers - 3) + def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 10e9aa4db0781..0471164ab8a6b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -286,6 +286,7 @@ _SPECULATIVE_DECODING_MODELS = { "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 394df48b4153f..51e54e0dc337f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -80,9 +80,17 @@ class EagleProposer: self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) + self.uses_mrope = self.vllm_config.model_config.uses_mrope + if self.uses_mrope: + # M-RoPE need (3, max_num_tokens) + self.mrope_positions = torch.zeros((3, self.max_num_tokens), + dtype=torch.int64, + device=device) + else: + # RoPE need (max_num_tokens,) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=device) self.hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -143,11 +151,22 @@ class EagleProposer: dtype=torch.int32, ).repeat(max_batch_size, 1) + def _get_positions(self, num_tokens: int): + if self.uses_mrope: + return self.mrope_positions[:, :num_tokens] + return self.positions[:num_tokens] + + def _set_positions(self, num_tokens: int, positions: torch.Tensor): + if self.uses_mrope: + self.mrope_positions[:, :num_tokens] = positions + else: + self.positions[:num_tokens] = positions + def propose( self, # [num_tokens] target_token_ids: torch.Tensor, - # [num_tokens] + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, @@ -198,7 +217,7 @@ class EagleProposer: else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions + self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] @@ -218,7 +237,7 @@ class EagleProposer: num_tokens=num_input_tokens): ret_hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_input_tokens], + positions=self._get_positions(num_input_tokens), hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) @@ -235,7 +254,10 @@ class EagleProposer: draft_token_ids = logits.argmax(dim=-1) return draft_token_ids.view(-1, 1) - positions = target_positions[last_token_indices] + if self.uses_mrope: + positions = target_positions[:, last_token_indices] + else: + positions = target_positions[last_token_indices] if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"): hidden_states = self.hidden_states[last_token_indices] else: @@ -282,25 +304,34 @@ class EagleProposer: # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + if self.uses_mrope: + positions += 1 + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. + # Since it is complex to remove such requests from the batch, + # we keep them in the batch but adjust the position ids + # and slot mappings to avoid the + # out-of-range access during the model execution. + # The draft tokens generated with this adjustment + # should be ignored. + exceeds_max_model_len = positions[0] >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where\ + (exceeds_max_model_len.unsqueeze(0), \ + torch.zeros_like(positions), positions) + else: + positions += 1 + exceeds_max_model_len = positions >= self.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) @@ -308,13 +339,22 @@ class EagleProposer: common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = clamped_positions[0] // self.block_size + else: + block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions % self.block_size) + if self.uses_mrope: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions[0] % self.block_size) + else: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -330,7 +370,7 @@ class EagleProposer: # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions + self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.is_multimodal_model: inputs_embeds = self.model.get_input_embeddings(input_ids) @@ -347,7 +387,7 @@ class EagleProposer: num_tokens=input_batch_size): ret_hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:input_batch_size], + positions=self._get_positions(input_batch_size), hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) @@ -787,6 +827,11 @@ class EagleProposer: return spec_common_attn_metadata, token_indices + def get_model_name(self, model: nn.Module) -> str: + if hasattr(model, 'module'): # multi-GPU + model = model.module + return model.__class__.__name__ + def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -820,8 +865,13 @@ class EagleProposer: if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( - target_model.config.image_token_index) + if (self.get_model_name(target_model) == + "Qwen2_5_VLForConditionalGeneration"): + self.model.config.image_token_index = ( + target_model.config.image_token_id) + else: + self.model.config.image_token_index = ( + target_model.config.image_token_index) target_language_model = target_model.get_language_model() else: target_language_model = target_model @@ -892,7 +942,7 @@ class EagleProposer: self.model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=self._get_positions(num_tokens), hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f87a327d02a50..22a177dd7cc7a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -442,6 +442,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) + def _get_positions(self, num_tokens: Any): + if isinstance(num_tokens, int): + if self.uses_mrope: + return self.mrope_positions.gpu[:, :num_tokens] + return self.positions.gpu[:num_tokens] + else: + if self.uses_mrope: + return self.mrope_positions.gpu[:, num_tokens] + return self.positions.gpu[num_tokens] + def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, @@ -2544,8 +2554,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( @@ -2570,8 +2579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[token_indices] + target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( From c242c98031b87d00999e07dbb4aa9b2a70798c6c Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 26 Sep 2025 23:44:52 -0400 Subject: [PATCH 08/20] [Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (#25788) --- vllm/model_executor/models/qwen2_5_vl.py | 73 ++++++++++++------------ vllm/model_executor/models/qwen3_vl.py | 53 ++++++++++++----- 2 files changed, 75 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3199f53a0539e..adb21373056c5 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. @@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.proj", disable_tp=use_data_parallel) - - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype()) - self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True - - if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA - }: - raise RuntimeError( - f"Qwen2.5-VL does not support {self.attn_backend} backend now." - ) + self.attn_backend = attn_backend + self.use_upstream_fa = use_upstream_fa self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } @@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module): projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, @@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module): head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + use_upstream_fa = False + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Qwen2.5-VL does not support {self.attn_backend} backend now." + ) + self.blocks = nn.ModuleList([ - Qwen2_5_VisionBlock(dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=get_act_and_mul_fn( - vision_config.hidden_act), - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(depth) + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( d_model=vision_config.out_hidden_size, @@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module): prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f3f11438eeeea..f1aeb99a4d373 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.utils import is_list_of @@ -158,6 +158,8 @@ class Qwen3_VisionBlock(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -170,7 +172,9 @@ class Qwen3_VisionBlock(nn.Module): projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa) self.mlp = Qwen3_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, @@ -287,19 +291,6 @@ class Qwen3_VisionTransformer(nn.Module): head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(vision_config.depth) - ]) - self.merger = Qwen3_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, @@ -325,10 +316,42 @@ class Qwen3_VisionTransformer(nn.Module): self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) + use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Qwen3-VL does not support {self.attn_backend} backend now.") + if current_platform.is_device_capability( + 100) and self.attn_backend != _Backend.TORCH_SDPA: + # TODO(Roger/Wentao): remove this after FA + # or XFORMERS's issue fixed on Blackwell + logger.info_once("Qwen3-VL vision attention does not support " + f"{self.attn_backend} backend on Blackwell now. " + "Vision attention backend is set to TORCH_SDPA.") + self.attn_backend = _Backend.TORCH_SDPA + + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa) + for layer_idx in range(vision_config.depth) + ]) @property def dtype(self) -> torch.dtype: From d346ec695ef5dc74cde338a6bc3857e91c311ab2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 27 Sep 2025 12:45:20 +0800 Subject: [PATCH 09/20] [CI/Build] Consolidate model loader tests and requirements (#25765) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 19 ++----- .github/mergify.yml | 2 +- docker/Dockerfile | 2 +- requirements/nightly_torch_test.txt | 3 +- requirements/rocm.txt | 5 +- requirements/test.in | 3 +- requirements/test.txt | 10 ++-- setup.py | 5 +- tests/model_executor/conftest.py | 52 ------------------- .../fastsafetensors_loader/__init__.py | 0 .../test_fastsafetensors_loader.py | 0 .../test_weight_utils.py | 0 .../runai_model_streamer}/__init__.py | 0 .../test_runai_model_streamer_loader.py | 0 .../runai_model_streamer}/test_runai_utils.py | 0 .../test_weight_utils.py | 0 .../tensorizer_loader/__init__.py | 0 .../tensorizer_loader/conftest.py | 0 .../tensorizer_loader/test_tensorizer.py | 2 +- .../model_loader/weight_utils.py | 35 +++++++++++-- 20 files changed, 48 insertions(+), 90 deletions(-) delete mode 100644 tests/model_executor/conftest.py rename tests/{ => model_executor/model_loader}/fastsafetensors_loader/__init__.py (100%) rename tests/{ => model_executor/model_loader}/fastsafetensors_loader/test_fastsafetensors_loader.py (100%) rename tests/{ => model_executor/model_loader}/fastsafetensors_loader/test_weight_utils.py (100%) rename tests/{runai_model_streamer_test => model_executor/model_loader/runai_model_streamer}/__init__.py (100%) rename tests/{runai_model_streamer_test => model_executor/model_loader/runai_model_streamer}/test_runai_model_streamer_loader.py (100%) rename tests/{runai_model_streamer_test => model_executor/model_loader/runai_model_streamer}/test_runai_utils.py (100%) rename tests/{runai_model_streamer_test => model_executor/model_loader/runai_model_streamer}/test_weight_utils.py (100%) rename tests/{ => model_executor/model_loader}/tensorizer_loader/__init__.py (100%) rename tests/{ => model_executor/model_loader}/tensorizer_loader/conftest.py (100%) rename tests/{ => model_executor/model_loader}/tensorizer_loader/test_tensorizer.py (99%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c178fd372bcbf..82a3b2fc199e5 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -465,29 +465,18 @@ steps: commands: - pytest -v -s kernels/mamba -- label: Tensorizer Test # 14min - timeout_in_minutes: 25 - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/model_executor/model_loader - - tests/tensorizer_loader - - tests/entrypoints/openai/test_tensorizer_entrypoint.py - commands: - - apt-get update && apt-get install -y curl libsodium23 - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s tensorizer_loader - - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py - -- label: Model Executor Test # 7min - timeout_in_minutes: 20 +- label: Model Executor Test # ??? + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor - tests/model_executor + - tests/entrypoints/openai/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor + - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py - label: Benchmarks # 11min timeout_in_minutes: 20 diff --git a/.github/mergify.yml b/.github/mergify.yml index 75ee3e3c55b46..923f708ea10c6 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -274,7 +274,7 @@ pull_request_rules: - files~=^vllm/model_executor/model_loader/tensorizer.py - files~=^vllm/model_executor/model_loader/tensorizer_loader.py - files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py - - files~=^tests/tensorizer_loader/ + - files~=^tests/model_executor/model_loader/tensorizer_loader/ actions: assign: users: diff --git a/docker/Dockerfile b/docker/Dockerfile index fad62be798a1e..c2b855be4403a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -546,7 +546,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ else \ BITSANDBYTES_VERSION="0.46.1"; \ fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3] + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3]>=0.14.0' ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index a529bf4504e40..790a18f28b7f5 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -43,7 +43,6 @@ tritonclient==2.51.0 numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding numba == 0.61.2; python_version > '3.9' numpy -runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 +runai-model-streamer[s3]==0.14.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index c129dd345c81a..c4aabe2a73144 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -5,8 +5,6 @@ numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Req numba == 0.61.2; python_version > '3.9' # Dependencies for AMD GPUs -boto3 -botocore datasets ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. peft @@ -15,7 +13,6 @@ tensorizer==2.10.1 packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 +runai-model-streamer[s3]==0.14.0 conch-triton-kernels==1.2.1 timm>=1.0.17 \ No newline at end of file diff --git a/requirements/test.in b/requirements/test.in index 451bd73879107..c9496c61a7e4f 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -51,8 +51,7 @@ tritonclient==2.51.0 numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding numba == 0.61.2; python_version > '3.9' numpy -runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 +runai-model-streamer[s3]==0.14.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 decord==0.6.0 diff --git a/requirements/test.txt b/requirements/test.txt index 3519aa524f418..912e04b2606c5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -72,7 +72,9 @@ blobfile==3.0.0 bm25s==0.2.13 # via mteb boto3==1.35.57 - # via tensorizer + # via + # runai-model-streamer-s3 + # tensorizer botocore==1.35.57 # via # boto3 @@ -925,10 +927,10 @@ rsa==4.9.1 # via google-auth rtree==1.4.0 # via torchgeo -runai-model-streamer==0.11.0 - # via -r requirements/test.in -runai-model-streamer-s3==0.11.0 +runai-model-streamer==0.14.0 # via -r requirements/test.in +runai-model-streamer-s3==0.14.0 + # via runai-model-streamer s3transfer==0.10.3 # via boto3 sacrebleu==2.4.3 diff --git a/setup.py b/setup.py index e4c40d22b928d..a8fec8a028d0d 100644 --- a/setup.py +++ b/setup.py @@ -654,10 +654,7 @@ setup( "bench": ["pandas", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": [ - "runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs", - "google-cloud-storage", "runai-model-streamer-s3", "boto3" - ], + "runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"], "audio": ["librosa", "soundfile", "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility diff --git a/tests/model_executor/conftest.py b/tests/model_executor/conftest.py deleted file mode 100644 index c6d89d849e9f9..0000000000000 --- a/tests/model_executor/conftest.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - - -@pytest.fixture -def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") - - -@pytest.fixture -def sample_json_schema(): - return { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, - "skills": { - "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 - }, - "work_history": { - "type": "array", - "items": { - "type": "object", - "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } - }, - "required": ["company", "position"] - } - } - }, - "required": ["name", "age", "skills", "work_history"] - } diff --git a/tests/fastsafetensors_loader/__init__.py b/tests/model_executor/model_loader/fastsafetensors_loader/__init__.py similarity index 100% rename from tests/fastsafetensors_loader/__init__.py rename to tests/model_executor/model_loader/fastsafetensors_loader/__init__.py diff --git a/tests/fastsafetensors_loader/test_fastsafetensors_loader.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py similarity index 100% rename from tests/fastsafetensors_loader/test_fastsafetensors_loader.py rename to tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py diff --git a/tests/fastsafetensors_loader/test_weight_utils.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py similarity index 100% rename from tests/fastsafetensors_loader/test_weight_utils.py rename to tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py diff --git a/tests/runai_model_streamer_test/__init__.py b/tests/model_executor/model_loader/runai_model_streamer/__init__.py similarity index 100% rename from tests/runai_model_streamer_test/__init__.py rename to tests/model_executor/model_loader/runai_model_streamer/__init__.py diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py similarity index 100% rename from tests/runai_model_streamer_test/test_runai_model_streamer_loader.py rename to tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py diff --git a/tests/runai_model_streamer_test/test_runai_utils.py b/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py similarity index 100% rename from tests/runai_model_streamer_test/test_runai_utils.py rename to tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py diff --git a/tests/runai_model_streamer_test/test_weight_utils.py b/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py similarity index 100% rename from tests/runai_model_streamer_test/test_weight_utils.py rename to tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py diff --git a/tests/tensorizer_loader/__init__.py b/tests/model_executor/model_loader/tensorizer_loader/__init__.py similarity index 100% rename from tests/tensorizer_loader/__init__.py rename to tests/model_executor/model_loader/tensorizer_loader/__init__.py diff --git a/tests/tensorizer_loader/conftest.py b/tests/model_executor/model_loader/tensorizer_loader/conftest.py similarity index 100% rename from tests/tensorizer_loader/conftest.py rename to tests/model_executor/model_loader/tensorizer_loader/conftest.py diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py similarity index 99% rename from tests/tensorizer_loader/test_tensorizer.py rename to tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py index e00d7c2f80c67..f50f046967383 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py @@ -14,6 +14,7 @@ import pytest import torch import vllm.model_executor.model_loader.tensorizer +from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs # yapf: disable @@ -27,7 +28,6 @@ from vllm.model_executor.model_loader.tensorizer_loader import ( # yapf: enable from vllm.utils import PlaceholderModule -from ..utils import VLLM_PATH, RemoteOpenAIServer from .conftest import DummyExecutor, assert_from_collective_rpc try: diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index cad32fee1d0f6..f52d9dd2f5348 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -639,6 +639,19 @@ def runai_safetensors_weights_iterator( yield from tensor_iter +def _init_loader( + pg: torch.distributed.ProcessGroup, + device: torch.device, + f_list: list[str], + *, + nogds: bool = False, +): + loader = SafeTensorsFileLoader(pg, device, nogds=nogds) + rank_file_map = {i: [f] for i, f in enumerate(f_list)} + loader.add_filenames(rank_file_map) + return loader + + def fastsafetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, @@ -656,17 +669,31 @@ def fastsafetensors_weights_iterator( for i in range(0, len(hf_weights_files), pg.size()) ] + nogds = False + for f_list in tqdm( weight_files_sub_lists, desc="Loading safetensors using Fastsafetensor loader", disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): - loader = SafeTensorsFileLoader(pg, device) - rank_file_map = {i: [f] for i, f in enumerate(f_list)} - loader.add_filenames(rank_file_map) + loader = _init_loader(pg, device, f_list, nogds=nogds) try: - fb = loader.copy_files_to_device() + try: + fb = loader.copy_files_to_device() + except RuntimeError as e: + if "gds" not in str(e): + raise + + loader.close() + nogds = True + logger.warning_once( + "GDS not enabled, setting `nogds=True`.\n" + "For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages" + ) + loader = _init_loader(pg, device, f_list, nogds=nogds) + fb = loader.copy_files_to_device() + try: keys = list(fb.key_to_rank_lidx.keys()) for k in keys: From b3613e3acece6502c553901fe4433e3f783363b7 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Fri, 26 Sep 2025 21:57:27 -0700 Subject: [PATCH 10/20] [CI/Build] Add timing to Model Executor Test (#25799) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 82a3b2fc199e5..c6c4e2a2309fc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -465,8 +465,8 @@ steps: commands: - pytest -v -s kernels/mamba -- label: Model Executor Test # ??? - timeout_in_minutes: 60 +- label: Model Executor Test # 23min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor From cd87bfbf37f2300b7076b496366cd69048819777 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 27 Sep 2025 13:51:15 +0800 Subject: [PATCH 11/20] [CI/Build] Reorganize root-level V1 tests (#25767) Signed-off-by: DarkLight1337 --- .../scripts/hardware_ci/run-xpu-test.sh | 3 +- .buildkite/test-pipeline.yaml | 23 +++---- tests/v1/{ => core}/test_kv_sharing.py | 0 tests/v1/distributed/__init__.py | 0 .../v1/{ => distributed}/test_async_llm_dp.py | 0 .../{ => distributed}/test_external_lb_dp.py | 0 .../v1/{ => distributed}/test_hybrid_lb_dp.py | 2 +- .../{ => distributed}/test_internal_lb_dp.py | 2 +- .../openai/test_multi_api_servers.py | 2 +- tests/v1/{ => metrics}/test_metrics_reader.py | 0 tests/v1/{test_utils.py => utils.py} | 61 ------------------ tests/v1/worker/test_utils.py | 63 +++++++++++++++++++ 12 files changed, 75 insertions(+), 81 deletions(-) rename tests/v1/{ => core}/test_kv_sharing.py (100%) create mode 100644 tests/v1/distributed/__init__.py rename tests/v1/{ => distributed}/test_async_llm_dp.py (100%) rename tests/v1/{ => distributed}/test_external_lb_dp.py (100%) rename tests/v1/{ => distributed}/test_hybrid_lb_dp.py (99%) rename tests/v1/{ => distributed}/test_internal_lb_dp.py (99%) rename tests/v1/{ => metrics}/test_metrics_reader.py (100%) rename tests/v1/{test_utils.py => utils.py} (67%) create mode 100644 tests/v1/worker/test_utils.py diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 1fc3dbd8c21f4..6b9c0121c4aa8 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -44,7 +44,6 @@ docker run \ pytest -v -s v1/structured_output pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py + pytest -v -s v1/test_metrics pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_metrics_reader.py ' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c6c4e2a2309fc..e603c1582e1fb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -159,10 +159,7 @@ steps: - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py - - tests/v1/test_async_llm_dp.py - - tests/v1/test_external_lb_dp.py - - tests/v1/test_internal_lb_dp.py - - tests/v1/test_hybrid_lb_dp.py + - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py - tests/distributed/test_symm_mem_allreduce.py commands: @@ -180,10 +177,10 @@ steps: - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py - - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py @@ -300,12 +297,9 @@ steps: - pytest -v -s v1/spec_decode - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/metrics - - pytest -v -s v1/test_kv_sharing.py - - pytest -v -s v1/test_metrics_reader.py - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py - pytest -v -s v1/test_serial_utils.py - - pytest -v -s v1/test_utils.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine @@ -895,14 +889,13 @@ steps: - tests/compile/test_wrapper.py - tests/distributed/ - tests/entrypoints/llm/test_collective_rpc.py - - tests/v1/test_async_llm_dp.py - - tests/v1/test_external_lb_dp.py + - tests/v1/distributed - tests/v1/entrypoints/openai/test_multi_api_servers.py - tests/v1/shutdown - tests/v1/worker/test_worker_memory_snapshot.py commands: - - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py diff --git a/tests/v1/test_kv_sharing.py b/tests/v1/core/test_kv_sharing.py similarity index 100% rename from tests/v1/test_kv_sharing.py rename to tests/v1/core/test_kv_sharing.py diff --git a/tests/v1/distributed/__init__.py b/tests/v1/distributed/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/distributed/test_async_llm_dp.py similarity index 100% rename from tests/v1/test_async_llm_dp.py rename to tests/v1/distributed/test_async_llm_dp.py diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/distributed/test_external_lb_dp.py similarity index 100% rename from tests/v1/test_external_lb_dp.py rename to tests/v1/distributed/test_external_lb_dp.py diff --git a/tests/v1/test_hybrid_lb_dp.py b/tests/v1/distributed/test_hybrid_lb_dp.py similarity index 99% rename from tests/v1/test_hybrid_lb_dp.py rename to tests/v1/distributed/test_hybrid_lb_dp.py index 552436f818d77..21d8009a6dbb7 100644 --- a/tests/v1/test_hybrid_lb_dp.py +++ b/tests/v1/distributed/test_hybrid_lb_dp.py @@ -12,7 +12,7 @@ import pytest_asyncio import requests from tests.utils import RemoteOpenAIServer -from tests.v1.test_utils import check_request_balancing +from tests.v1.utils import check_request_balancing from vllm.platforms import current_platform MODEL_NAME = "ibm-research/PowerMoE-3b" diff --git a/tests/v1/test_internal_lb_dp.py b/tests/v1/distributed/test_internal_lb_dp.py similarity index 99% rename from tests/v1/test_internal_lb_dp.py rename to tests/v1/distributed/test_internal_lb_dp.py index e965645711ee6..3f9defd13dead 100644 --- a/tests/v1/test_internal_lb_dp.py +++ b/tests/v1/distributed/test_internal_lb_dp.py @@ -13,7 +13,7 @@ import pytest_asyncio import requests from tests.utils import RemoteOpenAIServer -from tests.v1.test_utils import check_request_balancing +from tests.v1.utils import check_request_balancing from vllm.platforms import current_platform MODEL_NAME = "ibm-research/PowerMoE-3b" diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index f7c31b0c43778..35f75191d9c8d 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -8,7 +8,7 @@ import pytest import pytest_asyncio from tests.utils import RemoteOpenAIServer -from tests.v1.test_utils import check_request_balancing +from tests.v1.utils import check_request_balancing MODEL_NAME = "ibm-research/PowerMoE-3b" diff --git a/tests/v1/test_metrics_reader.py b/tests/v1/metrics/test_metrics_reader.py similarity index 100% rename from tests/v1/test_metrics_reader.py rename to tests/v1/metrics/test_metrics_reader.py diff --git a/tests/v1/test_utils.py b/tests/v1/utils.py similarity index 67% rename from tests/v1/test_utils.py rename to tests/v1/utils.py index 00d98a873a310..b3f560c11e8f5 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/utils.py @@ -1,71 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import pytest import regex as re import requests -import torch from tests.utils import RemoteOpenAIServer -from vllm.v1.worker.utils import bind_kv_cache - - -def test_bind_kv_cache(): - from vllm.attention import Attention - - ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), - } - kv_cache = { - 'layers.0.self_attn': torch.zeros((1, )), - 'layers.1.self_attn': torch.zeros((1, )), - 'layers.2.self_attn': torch.zeros((1, )), - 'layers.3.self_attn': torch.zeros((1, )), - } - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ - 'layers.0.self_attn'] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ - 'layers.1.self_attn'] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ - 'layers.2.self_attn'] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ - 'layers.3.self_attn'] - - assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] - assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] - assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] - assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] - - -def test_bind_kv_cache_non_attention(): - from vllm.attention import Attention - - # example from Jamba PP=2 - ctx = { - 'model.layers.20.attn': Attention(32, 128, 0.1), - 'model.layers.28.attn': Attention(32, 128, 0.1), - } - kv_cache = { - 'model.layers.20.attn': torch.zeros((1, )), - 'model.layers.28.attn': torch.zeros((1, )), - } - - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache(kv_cache, ctx, runner_kv_caches) - - assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ - 'model.layers.20.attn'] - assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ - 'model.layers.28.attn'] - - assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] - assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] - # Prometheus metrics utilities for testing diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py new file mode 100644 index 0000000000000..fd0e630ce178a --- /dev/null +++ b/tests/v1/worker/test_utils.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.v1.worker.utils import bind_kv_cache + + +def test_bind_kv_cache(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = { + 'layers.0.self_attn': torch.zeros((1, )), + 'layers.1.self_attn': torch.zeros((1, )), + 'layers.2.self_attn': torch.zeros((1, )), + 'layers.3.self_attn': torch.zeros((1, )), + } + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ + 'layers.0.self_attn'] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ + 'layers.1.self_attn'] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ + 'layers.2.self_attn'] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ + 'layers.3.self_attn'] + + assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] + assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] + assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] + assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] + + +def test_bind_kv_cache_non_attention(): + from vllm.attention import Attention + + # example from Jamba PP=2 + ctx = { + 'model.layers.20.attn': Attention(32, 128, 0.1), + 'model.layers.28.attn': Attention(32, 128, 0.1), + } + kv_cache = { + 'model.layers.20.attn': torch.zeros((1, )), + 'model.layers.28.attn': torch.zeros((1, )), + } + + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + + assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ + 'model.layers.20.attn'] + assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ + 'model.layers.28.attn'] + + assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] + assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] From 39391520698e8b1d699ea2ccec571a6c6416ba9d Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:47:29 -0700 Subject: [PATCH 12/20] [Misc] Fix codeowners override for v1 sample and attention (#25037) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- .github/CODEOWNERS | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 9d749fe8d3238..0b9c054b968aa 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -12,8 +12,6 @@ /vllm/model_executor/layers/mamba @tdoublep /vllm/model_executor/model_loader @22quinn /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche -/vllm/v1/attention @LucasWilkinson -/vllm/v1/sample @22quinn @houseroad /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee /vllm/reasoning @aarnphm @chaunceyjiang @@ -28,11 +26,13 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat -/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett -/vllm/v1/spec_decode @benchislett @luccafong +/vllm/v1/attention @LucasWilkinson /vllm/v1/attention/backends/flashinfer.py @mgoin /vllm/v1/attention/backends/triton_attn.py @tdoublep /vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC +/vllm/v1/sample @22quinn @houseroad @njhill +/vllm/v1/spec_decode @benchislett @luccafong +/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett /vllm/v1/kv_cache_interface.py @heheda12345 /vllm/v1/offloading @ApostaC @@ -54,7 +54,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee /tests/models/language/generation/test_hybrid.py @tdoublep -/tests/v1/kv_connector/nixl_integration @NickLucche +/tests/v1/kv_connector/nixl_integration @NickLucche /tests/v1/kv_connector @ApostaC /tests/v1/offloading @ApostaC From 23b8ee672d7ce4c383ed1527a7f268c0ca33c16c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 27 Sep 2025 00:57:07 -0700 Subject: [PATCH 13/20] [Misc] Update openai client example file for multimodal (#25795) Signed-off-by: Roger Wang Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ...i_chat_completion_client_for_multimodal.py | 63 +++++++++++-------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index 37216a5cfe574..5d515fbfb6716 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -38,11 +38,13 @@ client = OpenAI( base_url=openai_api_base, ) +headers = {"User-Agent": "vLLM Example Client"} + def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" - with requests.get(content_url) as response: + with requests.get(content_url, headers=headers) as response: response.raise_for_status() result = base64.b64encode(response.content).decode("utf-8") @@ -50,19 +52,19 @@ def encode_base64_content_from_url(content_url: str) -> str: # Text-only inference -def run_text_only(model: str) -> None: +def run_text_only(model: str, max_completion_tokens: int) -> None: chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": "What's the capital of France?"}], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion.choices[0].message.content - print("Chat completion output:", result) + print("Chat completion output:\n", result) # Single-image input inference -def run_single_image(model: str) -> None: +def run_single_image(model: str, max_completion_tokens: int) -> None: ## Use image url in the payload image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" chat_completion_from_url = client.chat.completions.create( @@ -79,11 +81,11 @@ def run_single_image(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output from image url:", result) + print("Chat completion output from image url:\n", result) ## Use base64 encoded image in the payload image_base64 = encode_base64_content_from_url(image_url) @@ -101,7 +103,7 @@ def run_single_image(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content @@ -109,7 +111,7 @@ def run_single_image(model: str) -> None: # Multi-image input inference -def run_multi_image(model: str) -> None: +def run_multi_image(model: str, max_completion_tokens: int) -> None: image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" chat_completion_from_url = client.chat.completions.create( @@ -130,15 +132,15 @@ def run_multi_image(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output:", result) + print("Chat completion output:\n", result) # Video input inference -def run_video(model: str) -> None: +def run_video(model: str, max_completion_tokens: int) -> None: video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4" video_base64 = encode_base64_content_from_url(video_url) @@ -157,11 +159,11 @@ def run_video(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output from image url:", result) + print("Chat completion output from video url:\n", result) ## Use base64 encoded video in the payload chat_completion_from_base64 = client.chat.completions.create( @@ -178,15 +180,15 @@ def run_video(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content - print("Chat completion output from base64 encoded image:", result) + print("Chat completion output from base64 encoded video:\n", result) # Audio input inference -def run_audio(model: str) -> None: +def run_audio(model: str, max_completion_tokens: int) -> None: from vllm.assets.audio import AudioAsset audio_url = AudioAsset("winning_call").url @@ -211,11 +213,11 @@ def run_audio(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content - print("Chat completion output from input audio:", result) + print("Chat completion output from input audio:\n", result) # HTTP URL chat_completion_from_url = client.chat.completions.create( @@ -235,11 +237,11 @@ def run_audio(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output from audio url:", result) + print("Chat completion output from audio url:\n", result) # base64 URL chat_completion_from_base64 = client.chat.completions.create( @@ -259,14 +261,14 @@ def run_audio(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content - print("Chat completion output from base64 encoded audio:", result) + print("Chat completion output from base64 encoded audio:\n", result) -def run_multi_audio(model: str) -> None: +def run_multi_audio(model: str, max_completion_tokens: int) -> None: from vllm.assets.audio import AudioAsset # Two different audios to showcase batched inference. @@ -300,11 +302,11 @@ def run_multi_audio(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content - print("Chat completion output from input audio:", result) + print("Chat completion output from input audio:\n", result) example_function_map = { @@ -330,13 +332,20 @@ def parse_args(): choices=list(example_function_map.keys()), help="Conversation type with multimodal data.", ) + parser.add_argument( + "--max-completion-tokens", + "-n", + type=int, + default=128, + help="Maximum number of tokens to generate for each completion.", + ) return parser.parse_args() def main(args) -> None: chat_type = args.chat_type model = get_first_model(client) - example_function_map[chat_type](model) + example_function_map[chat_type](model, args.max_completion_tokens) if __name__ == "__main__": From 176173989a4c5d9c3a4dca8c788d3492ac27a2e0 Mon Sep 17 00:00:00 2001 From: Xiaohan Zou Date: Sat, 27 Sep 2025 03:59:22 -0400 Subject: [PATCH 14/20] [Bugfix] Add missing `image_size` for phi4_multimodal (#25796) --- vllm/model_executor/models/phi4_multimodal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index d2a3a8cc04969..bdc831354c11f 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -786,6 +786,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): target_ratios, orig_width, orig_height, + image_size, ) # calculate the target width and height From 27d7638b9476062931a6770ed90714792e77cc83 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 27 Sep 2025 16:15:12 +0800 Subject: [PATCH 15/20] [Bugfix] Merge MM embeddings by index instead of token IDs (#16229) Signed-off-by: DarkLight1337 Signed-off-by: NickLucche Signed-off-by: Roger Wang Co-authored-by: NickLucche Co-authored-by: Roger Wang --- docs/contributing/model/multimodal.md | 33 +----- vllm/config/model.py | 7 +- vllm/model_executor/models/aria.py | 25 ++--- vllm/model_executor/models/aya_vision.py | 27 ++--- vllm/model_executor/models/bert.py | 12 +++ vllm/model_executor/models/bert_with_rope.py | 6 ++ vllm/model_executor/models/blip2.py | 22 ++-- vllm/model_executor/models/chameleon.py | 24 ++--- vllm/model_executor/models/chatglm.py | 3 + vllm/model_executor/models/cohere2_vision.py | 27 ++--- vllm/model_executor/models/deepseek_eagle.py | 6 ++ vllm/model_executor/models/deepseek_mtp.py | 6 ++ vllm/model_executor/models/deepseek_vl2.py | 25 ++--- vllm/model_executor/models/dots_ocr.py | 44 +++----- vllm/model_executor/models/ernie45_vl.py | 27 +++-- vllm/model_executor/models/ernie_mtp.py | 6 ++ vllm/model_executor/models/fuyu.py | 26 ++--- vllm/model_executor/models/gemma3_mm.py | 26 ++--- vllm/model_executor/models/gemma3n_mm.py | 23 ++-- vllm/model_executor/models/glm4_1v.py | 17 --- vllm/model_executor/models/glm4_moe_mtp.py | 6 ++ vllm/model_executor/models/glm4v.py | 35 ++---- vllm/model_executor/models/granite_speech.py | 34 +++--- vllm/model_executor/models/hunyuan_v1.py | 3 + .../models/hyperclovax_vision.py | 35 ++---- vllm/model_executor/models/idefics3.py | 31 ++---- vllm/model_executor/models/interfaces.py | 83 ++++++++++++-- vllm/model_executor/models/interfaces_base.py | 24 ++++- vllm/model_executor/models/interns1.py | 47 ++++---- vllm/model_executor/models/internvl.py | 46 ++++---- vllm/model_executor/models/keye.py | 18 ---- vllm/model_executor/models/kimi_vl.py | 29 +---- vllm/model_executor/models/lfm2.py | 3 + vllm/model_executor/models/llama4_eagle.py | 31 ++---- vllm/model_executor/models/llama_eagle.py | 6 ++ vllm/model_executor/models/llama_eagle3.py | 17 +-- vllm/model_executor/models/llava.py | 26 ++--- vllm/model_executor/models/llava_next.py | 31 +++--- .../model_executor/models/llava_next_video.py | 23 ++-- vllm/model_executor/models/llava_onevision.py | 13 --- vllm/model_executor/models/midashenglm.py | 25 ++--- vllm/model_executor/models/mimo_mtp.py | 6 ++ vllm/model_executor/models/minicpmv.py | 27 ++--- vllm/model_executor/models/minimax_text_01.py | 10 +- vllm/model_executor/models/minimax_vl_01.py | 25 ++--- vllm/model_executor/models/mistral3.py | 26 ++--- vllm/model_executor/models/mllama4.py | 30 ++---- vllm/model_executor/models/modernbert.py | 9 ++ vllm/model_executor/models/molmo.py | 32 ++---- .../model_executor/models/nano_nemotron_vl.py | 43 +++----- vllm/model_executor/models/nemotron_vl.py | 37 ++++--- vllm/model_executor/models/olmo2.py | 6 ++ vllm/model_executor/models/ovis.py | 21 +--- vllm/model_executor/models/ovis2_5.py | 18 +--- vllm/model_executor/models/paligemma.py | 23 ++-- vllm/model_executor/models/phi3v.py | 44 +++++--- vllm/model_executor/models/phi4_multimodal.py | 18 +--- vllm/model_executor/models/phi4mm.py | 14 --- vllm/model_executor/models/pixtral.py | 26 ++--- .../models/qwen2_5_omni_thinker.py | 26 ++--- vllm/model_executor/models/qwen2_5_vl.py | 13 --- vllm/model_executor/models/qwen2_audio.py | 23 ++-- vllm/model_executor/models/qwen2_vl.py | 13 --- vllm/model_executor/models/qwen3_vl.py | 89 +++++++++------ vllm/model_executor/models/qwen_vl.py | 25 ++--- vllm/model_executor/models/roberta.py | 3 + vllm/model_executor/models/skyworkr1v.py | 36 ++++--- vllm/model_executor/models/solar.py | 3 + vllm/model_executor/models/step3_text.py | 3 + vllm/model_executor/models/step3_vl.py | 52 ++++----- vllm/model_executor/models/tarsier.py | 25 ++--- vllm/model_executor/models/terratorch.py | 3 + vllm/model_executor/models/transformers.py | 62 ++++++++--- vllm/model_executor/models/ultravox.py | 36 ++++--- vllm/model_executor/models/utils.py | 102 +++++++----------- vllm/model_executor/models/voxtral.py | 29 ++--- vllm/model_executor/models/whisper.py | 10 +- vllm/v1/spec_decode/eagle.py | 46 ++++---- vllm/v1/worker/gpu_model_runner.py | 52 ++++++--- vllm/v1/worker/tpu_model_runner.py | 81 +++++++++++--- 80 files changed, 966 insertions(+), 1139 deletions(-) diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 87d34d207cde3..1d72fe97b9665 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -66,35 +66,12 @@ Further update the model as follows: !!! important The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. -- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. +!!! note + By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in + [PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing. + This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings]. - ??? code - - ```python - from .utils import merge_multimodal_embeddings - - class YourModelForImage2Seq(nn.Module): - ... - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index) - - return inputs_embeds - ``` + You may override this method if additional logic is required for your model when merging embeddings. - Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model. diff --git a/vllm/config/model.py b/vllm/config/model.py index b2b68abd2c1d3..3fb448ebbf364 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -509,9 +509,14 @@ class ModelConfig: else: # task == "auto" pass else: + debug_info = { + "architectures": architectures, + "is_generative_model": is_generative_model, + "is_pooling_model": is_pooling_model, + } raise AssertionError("The model should be a generative or " "pooling model when task is set to " - f"{self.task!r}.") + f"{self.task!r}. Found: {debug_info}") self.runner = runner self.convert = convert diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 35c1adbdd00b6..6cef5e134a4bc 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -38,8 +38,7 @@ from .idefics2_vision_model import ( from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - is_pp_missing_parameter, maybe_prefix, - merge_multimodal_embeddings) + is_pp_missing_parameter, maybe_prefix) class AriaImagePixelInputs(TensorSchema): @@ -605,19 +604,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): multimodal_embeddings = self._process_image_input(image_input) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -628,10 +614,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ) -> Union[torch.Tensor, IntermediateTensors]: if inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 6fd8c2fb5c561..eab996e9ba22b 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) class AyaVisionImagePixelInputs(TensorSchema): @@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model( diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index ee32587f6b1b4..c984845204c4f 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant): self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): prefix=maybe_prefix(prefix, "model")) self.pooler = self._build_pooler(pooler_config) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, ), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loaded_params = loader.load_weights(weights) @@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module): Pooler.for_encode(pooler_config), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loaded_params = loader.load_weights(weights) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index bfc1408ddf880..4e1eba32d2594 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant): prefix=f"{prefix}.encoder") self.pooler = BertPooler(self.config) if add_pooling_layer else None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): loaded_params = loader.load_weights(weights) return loaded_params + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.new.get_input_embeddings(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index b7455fba62c02..4d1850d07b28e 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -27,7 +27,7 @@ from .blip import BlipVisionModel from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - _IMAGE_TOKEN_ID) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == _IMAGE_TOKEN_ID, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 79d648d749c6a..f9740adb151b5 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) logger = init_logger(__name__) @@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = self.model.get_input_embeddings(image_tokens) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.model.vocabulary_mapping.image_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + image_token_id = self.model.vocabulary_mapping.image_token_id + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == image_token_id, + ) input_ids = None hidden_states = self.model(input_ids, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 879508400222f..c182201fe2567 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 6d67eb68d51a8..99edcba4d874a 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) class Cohere2VisionImagePixelInputs(TensorSchema): @@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_id, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.language_model.model( diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index ed7e7614800fc..c42a66d86912c 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module): self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 92f311ab465b5..a4623ff13cec4 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): prefix=maybe_prefix( prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index c8ed759d2e972..b98008c83bdcc 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) # The image token id may be various _IMAGE_TOKEN = "" @@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): model_config = vllm_config.model_config tokenizer = cached_tokenizer_from_config(model_config) - self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN] self.vision = self._init_vision_module(self.vision_config, quant_config, @@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_token_id, + ) input_ids = None hidden_states = self.language_model(input_ids, diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 2db350c892ae7..4845f19bcbc42 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, Qwen2VLProcessingInfo) from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) + maybe_prefix) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict @@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_id, - ) - - return inputs_embeds - def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -830,17 +813,14 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None and kwargs.get("pixel_values") is not None: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - inputs_embeds = None - else: - assert input_ids is not None - inputs_embeds = self.get_multimodal_embeddings( - input_ids, - image_input=image_input, - ) - input_ids = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) + input_ids = None hidden_states = self.language_model( input_ids=input_ids, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 74b358034ef3d..a73ec4f88ffe4 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .vision import get_vit_attn_backend logger = init_logger(__name__) @@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: + self._set_visual_token_mask(input_ids) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - if multimodal_embeddings is None: - return inputs_embeds - - self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, - multimodal_embeddings, - [self.config.im_patch_id]) - return inputs_embeds + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 288fbe736c32f..3b24bf2f1ef8f 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -116,6 +116,9 @@ class ErnieMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -160,6 +163,9 @@ class ErnieMTP(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 53e9e6fe6e460..b99fe33a1dcce 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - _IMAGE_TOKEN_ID, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == _IMAGE_TOKEN_ID, + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 0630ee07c347e..be75e36fe23b5 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) if (vision_embeddings is not None) and len(vision_embeddings) != 0: kwargs = self.prepare_attn_masks( input_ids, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 2acdba54a257d..b23437a08e5ab 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # them here, as the model forward has only access to the input_embeds. if input_ids is not None: @@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_( per_layer_inputs) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - # NOTE: this order of processing mm items is important - [self.config.image_token_id, self.config.audio_token_id]) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b088e0c0dd241..dbb5431ae4919 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0 - and all(embed.numel() > 0 for embed in multimodal_embeddings)): - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id], - ) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index c572978e62206..826d541e571bd 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP): prefix=maybe_prefix( prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index bf33575859aea..ace9c05daf15a 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .chatglm import ChatGLMBaseModel, ChatGLMModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import flatten_bn, merge_multimodal_embeddings +from .utils import flatten_bn, isin_list class GLMVImagePixelInputs(TensorSchema): @@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=[ - self.config.boi_token_id, - self.config.pad_token_id, - self.config.eoi_token_id, - ], - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, [ + self.config.boi_token_id, + self.config.pad_token_id, + self.config.eoi_token_id, + ]), + ) input_ids = None hidden_states = self.transformer(input_ids, positions, diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index a5849184339b1..8a02da58ea0b9 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip2 import Blip2QFormerModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, embed_multimodal, - init_vllm_registered_model, maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix ### Audio Input @@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration( # Split variable length features into a tuple return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object, @@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration( audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] - return None + audio_features = self._process_audio_input(audio_input) return audio_features @@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - """Compute the merged LLM / audio embeddings.""" - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.audio_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration( # condition is for v0 compatibility. elif inputs_embeds is None: audio_embeds = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) + inputs_embeds = self.get_input_embeddings( + input_ids, + audio_embeds, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None model_output = self.language_model(input_ids, positions, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 8a23a6b45bc70..d28c971167902 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 4d39ff9ae79ee..f851688bf7bab 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list, + maybe_prefix) from .vision import get_vision_encoder_info EOT = "<|endofturn|>" @@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_multimodal_embeddings( self, **kwargs: Unpack[HCXVisionMultimodalInputs], - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings: multimodal_embeddings = list() if kwargs.get("pixel_values_images") is not None: @@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): multimodal_embeddings.append(_multimodal_embeddings_videos) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - placeholder_token_id=[ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=isin_list( + input_ids, + [self.config.image_token_id, self.config.video_token_id]), + ) input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 79e130119ae83..3334ee2242531 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -52,8 +52,7 @@ from .idefics2_vision_model import ( # yapf: enable from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .llama import LlamaModel -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix class Idefics3ImagePixelInputs(TensorSchema): @@ -539,10 +538,7 @@ class Idefics3Model(nn.Module): return image_hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.text_model.get_input_embeddings(input_ids) def forward( @@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.model.text_model(input_ids, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index f13e590cd243b..d40df9b43dd43 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, MutableSequence -from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, - Union, overload, runtime_checkable) +from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional, + Protocol, Union, overload, runtime_checkable) import numpy as np import torch @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.utils import supports_kw -from .interfaces_base import is_pooling_model +from .interfaces_base import VllmModel, is_pooling_model if TYPE_CHECKING: from vllm.config import VllmConfig @@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol): """ ... - def get_language_model(self) -> torch.nn.Module: + def get_language_model(self) -> VllmModel: """ Returns the underlying language model used for text generation. @@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol): """ ... + @overload + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + ... + + @overload + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: MultiModalEmbeddings, + *, + is_multimodal: torch.Tensor, + handle_oov_mm_token: bool = False, + ) -> Tensor: + ... + + def _get_text_embeddings( + self, + input_ids: Tensor, + get_input_embeddings: Callable[[Tensor], Tensor], + *, + is_multimodal: Optional[Tensor], + handle_oov_mm_token: bool, + ) -> Tensor: + if handle_oov_mm_token and is_multimodal is not None: + is_text = ~is_multimodal + text_embeds = get_input_embeddings(input_ids[is_text]) + + return torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) + + return get_input_embeddings(input_ids) + def get_input_embeddings( self, input_ids: Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[Tensor] = None, + handle_oov_mm_token: bool = False, ) -> Tensor: """ - Returns the input embeddings merged from the text embeddings from - input_ids and the multimodal embeddings generated from multimodal - kwargs. + Apply token embeddings to `input_ids`. + + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. Note however that doing so increases memory usage + as an additional buffer is needed to hold the input embeddings. """ - ... + from .utils import _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.get_language_model().get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) @runtime_checkable diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 8fdf70e35a2b8..84146db0943c6 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]): ) -> None: ... + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + """Apply token embeddings to `input_ids`.""" + ... + def forward( self, input_ids: torch.Tensor, @@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool: return supports_kw(model_init, "vllm_config") +def _check_vllm_model_get_input_embeddings( + model: Union[type[object], object]) -> bool: + model_get_input_embeddings = getattr(model, "get_input_embeddings", None) + if not callable(model_get_input_embeddings): + logger.warning( + "The model (%s) is missing the `get_input_embeddings` method.", + model, + ) + return False + + return True + + def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): @@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: def is_vllm_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: - return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + return (_check_vllm_model_init(model) + and _check_vllm_model_get_input_embeddings(model) + and _check_vllm_model_forward(model)) @runtime_checkable diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 197d629b906fe..545dad1a96f5e 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, isin_list, maybe_prefix) class InternS1MultiModalProjector(nn.Module): @@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index f4004e518e3ba..78aac85414344 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + isin_list, maybe_prefix) IMG_START = '' IMG_END = '' @@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 3b6fdba225122..62a71b7b1fa85 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module): multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 503627865c4a5..db032736f9148 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -66,7 +66,6 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model from vllm.model_executor.models.interfaces import (SupportsMultiModal, SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel -from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -424,26 +423,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.media_placeholder_token_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -462,14 +441,12 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, if image_input is None: inputs_embeds = None else: - inputs_embeds = self.get_input_embeddings(input_ids) image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( + inputs_embeds = self.get_input_embeddings( input_ids, - inputs_embeds, image_embeds, - placeholder_token_id=self.config. - media_placeholder_token_id, + is_multimodal=input_ids == + self.config.media_placeholder_token_id, ) input_ids = None diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 53c36e4e52d81..f9def222a1ec9 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -522,6 +522,9 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index a203af53205cd..235275c0940a1 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -37,9 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, Llama4ForCausalLM) from vllm.model_executor.models.utils import extract_layer_index -from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal +from .utils import AutoWeightsLoader, maybe_prefix logger = init_logger(__name__) @@ -79,10 +79,7 @@ class LlamaModel(nn.Module): self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -194,6 +191,11 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_language_model(self) -> torch.nn.Module: + return self.model + + get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore + def forward( self, input_ids: torch.Tensor, @@ -220,20 +222,3 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): skip_prefixes=(["lm_head."]), ) loader.load_weights(map(transform, weights)) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 2ff2d54a83aa8..d6e6fd3fcfe9c 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -73,6 +73,9 @@ class LlamaModel(nn.Module): self.config.hidden_size, bias=False) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -149,6 +152,9 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 55b6ae6ee0e9c..34b8ea0ca5360 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) @@ -144,10 +143,7 @@ class LlamaModel(nn.Module): eps=self.config.rms_norm_eps, ) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -239,6 +235,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): requires_grad=False, ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -302,11 +301,3 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): skip_substrs=skip_substrs, ) loader.load_weights(model_weights.items()) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - return inputs_embeds diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4d8ed95b6cc8f..6f3cfd88aee23 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -41,8 +41,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -676,22 +675,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -744,8 +727,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c9133fde14552..e132389c4f061 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -25,8 +25,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, - flatten_bn, init_vllm_registered_model, maybe_prefix) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix) class LlavaNextImagePixelInputs(TensorSchema): @@ -474,19 +474,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) - - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.image_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -549,8 +551,11 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens]. # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 610fb188d57d2..2642d8c77cf3b 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -30,8 +30,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -415,19 +414,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = self._process_video_pixels(video_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.video_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -449,8 +435,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.video_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index cee9ddaf94cc4..906858f4e2f47 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -850,19 +850,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_index, self.config.video_token_index]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 82648ba668ca5..0bf04e0e7e2fa 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -54,8 +54,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.midashenglm import DashengConfig from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix _Tuple2 = Union[int, tuple[int, int], Sequence[int]] @@ -744,21 +743,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): return [] return self._process_audio_input(audio_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.decoder.get_input_embeddings(input_ids) - if multimodal_embeddings and len(multimodal_embeddings) > 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.audio_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -771,8 +755,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): inputs_embeds = None elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_id, + ) input_ids = None return self.decoder.model(input_ids, diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index b4abe458e4771..9c1e36094c4a3 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -117,6 +117,9 @@ class MiMoMultiTokenPredictor(nn.Module): self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -158,6 +161,9 @@ class MiMoMTP(nn.Module): self.config.hidden_size, prefix=maybe_prefix(prefix, "lm_head")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a17c4f004d75c..bffc9a0c125ea 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -71,8 +71,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, isin_list, maybe_prefix # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -1144,23 +1143,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): return self._process_multimodal_inputs(modalities) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert len(self.mm_token_ids) > 0 - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - list(self.mm_token_ids), - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1178,8 +1160,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, list(self.mm_token_ids)), + ) input_ids = None hidden_states = self.llm.model( diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index cc9a959f63313..a92890c9f7b55 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -592,10 +592,7 @@ class MiniMaxText01Model(nn.Module): dtype=torch.long) minimax_cache_tensors[:, slots_tensor, ...] = 0 - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward(self, @@ -687,10 +684,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( batch_size) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward(self, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index d81ac8c704e79..d41b9d3f14fe2 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -28,7 +28,7 @@ from .llava_next import LlavaNextProcessingInfo from .pixtral import PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) class MiniMaxVL01ImagePixelInputs(TensorSchema): @@ -218,22 +218,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -403,8 +387,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index ba6da4403ae16..31571ce962d18 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -38,8 +38,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -524,22 +523,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -592,8 +575,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 79e315f794893..3af5267928cde 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -56,8 +56,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llama4 import Llama4ForCausalLM -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -813,24 +812,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -846,8 +827,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # this condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None return self.language_model(input_ids, positions, intermediate_tensors, diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 1d5da3139de92..e4a51b3697370 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -43,6 +43,9 @@ class ModernBertEmbeddings(nn.Module): eps=config.layer_norm_eps, bias=config.norm_bias) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -220,6 +223,9 @@ class ModernBertModel(nn.Module): eps=config.norm_eps, bias=config.norm_bias) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) @@ -333,6 +339,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 201bf83cac581..054caee9e8a4f 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -58,7 +58,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -819,10 +819,7 @@ class MolmoModel(nn.Module, SupportsQuant): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1481,24 +1478,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_patch_id is not None - - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_patch_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.LongTensor, @@ -1515,8 +1494,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_patch_id, + ) input_ids = None hidden_states = self.model(input_ids, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 2b68d40cf2c67..505806a15c891 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -35,8 +35,7 @@ from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.utils import (flatten_bn, init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) + isin_list, maybe_prefix) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, @@ -1096,8 +1095,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, return modalities - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: # Validate the multimodal input keyword arguments modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if modalities is None: @@ -1121,30 +1120,6 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0): - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - - return inputs_embeds - def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -1163,9 +1138,17 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 3abbff8c717d4..2627a262e9582 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -38,7 +38,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) IMG_START = '' IMG_END = '' @@ -576,20 +576,24 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [self.img_context_token_id] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -608,8 +612,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_context_token_id, + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 2e0b1fb2a13f7..e7e30ee8df0ff 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -295,6 +295,9 @@ class Olmo2Model(nn.Module): make_empty_intermediate_tensors_factory(["hidden_states"], self.config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -408,6 +411,9 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index bd525b6780e0d..8503d3f71d1c9 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -48,7 +48,6 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. IMAGE_TOKEN = "" @@ -501,19 +500,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): return image_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_pad_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -529,8 +515,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_pad_token_id, + ) input_ids = None # up until here we have an inputs_embeds 100% numerical identity diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f18e38ce154d2..2ecc7bff07e07 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -585,17 +585,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - tmp = torch.concat(multimodal_embeddings, dim=0) - inputs_embeds[input_ids == self.image_pad_token_id] = tmp - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -612,8 +601,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_pad_token_id, + ) input_ids = None # up until here we have a inputs_embeds 100% numerical identity diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index aef5102304614..f07f444819f4c 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -26,8 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info logger = init_logger(__name__) @@ -362,19 +361,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -388,8 +374,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a2b201fe4228d..ea34c8d92f136 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -51,9 +51,9 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, + _merge_multimodal_embeddings, flatten_bn, + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -643,14 +643,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds + inputs_embeds = self._get_text_embeddings( + input_ids, + self.embed_tokens, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) def forward(self, input_ids: torch.Tensor, @@ -666,8 +683,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, # condition is for v0 compatibility elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=self.image_token_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index bdc831354c11f..e8b79717d75d0 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -1342,12 +1342,12 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): image_attention_mask) return image_embeds - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: - return None + return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). @@ -1371,18 +1371,6 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 47b5ad55ab2d0..15b09c7ae2bc9 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1151,7 +1151,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] - return None # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). @@ -1175,19 +1174,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.embed_tokens(input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 7b197844c8b63..2c04b6f0f4f90 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -50,8 +50,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs try: @@ -433,22 +432,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.vision_args.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -465,8 +448,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.vision_args.image_token_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 5f27230c913b4..bfa398ee43b56 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -865,24 +865,26 @@ class Qwen2_5OmniThinkerForConditionalGeneration( multimodal_embeddings += audio_embeddings return multimodal_embeddings + # TODO (ywang96): support overlapping modality embeddings so that + # `use_audio_in_video` will work on V1. def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - # TODO (ywang96): support overlapping modality embeddings so that - # `use_audio_in_video` will work on V1. - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, [ - self.config.image_token_index, - self.config.video_token_index, - self.config.audio_token_index - ]) - return inputs_embeds + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def get_multimodal_embeddings_v0( self, **kwargs: object) -> Optional[NestedTensors]: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index adb21373056c5..5b092b42205fa 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1365,19 +1365,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 762ab42e5929e..9dfa29eef5ce7 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -49,8 +49,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix # # === Audio Inputs === # @@ -438,19 +437,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, masked_audio_features = self._process_audio_input(audio_input) return masked_audio_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -467,8 +453,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d4e195246bf13..8192c3ce05dd2 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1459,19 +1459,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f1aeb99a4d373..5d0b66f91aced 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -79,7 +79,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention, from .qwen2_vl import Qwen2VLProcessingInfo from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - maybe_prefix, merge_multimodal_embeddings) + _merge_multimodal_embeddings, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -1324,17 +1325,22 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings def _compute_deepstack_embeds( - self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor: - visual_lens = [ - x.shape[0] if isinstance(x, torch.Tensor) else len(x) - for x in multimodal_embeddings - ] + self, + inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: torch.Tensor, + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: + visual_lens = [len(x) for x in multimodal_embeddings] multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) - multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 - multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim], - dim=-1) + ( + multimodal_embeddings_main, + multimodal_embeddings_multiscale, + ) = torch.split( + multimodal_embeddings_cat, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, @@ -1346,39 +1352,62 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)) - deepstack_input_embeds = merge_multimodal_embeddings( - input_ids, - deepstack_input_embeds, - multimodal_embeddings_multiscale, - placeholder_token_id=[ - self.config.image_token_id, self.config.video_token_id - ], + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, ) deepstack_input_embeds = deepstack_input_embeds.view( inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) + return deepstack_input_embeds, multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - deepstack_input_embeds = None - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - if self.use_deepstack: - deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 - input_ids, inputs_embeds, multimodal_embeddings) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) + inputs_embeds = self._get_text_embeddings( + input_ids, + self.language_model.get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") if self.use_deepstack: - if deepstack_input_embeds is None: - deepstack_input_embeds = torch.zeros_like( - inputs_embeds).unsqueeze(0).repeat( - self.deepstack_num_level, 1, 1).contiguous() + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + else: + deepstack_input_embeds = None + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + if deepstack_input_embeds is not None: + deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze( + 0).repeat(self.deepstack_num_level, 1, 1).contiguous() self._set_deepstack_input_embeds(deepstack_input_embeds) return inputs_embeds diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 90200f319464b..dc11b60604a91 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .qwen import QWenBaseModel, QWenModel -from .utils import flatten_bn, merge_multimodal_embeddings +from .utils import flatten_bn class QwenImagePixelInputs(TensorSchema): @@ -756,21 +756,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.transformer.visual.image_pad_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -786,8 +771,12 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == + self.transformer.visual.image_pad_id, + ) input_ids = None hidden_states = self.transformer(input_ids, positions, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ba405be416876..53e698c4fa806 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -218,6 +218,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.roberta.get_input_embeddings(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 893ce4497c319..f9a107c06085b 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -38,7 +38,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) IMG_START = '' IMG_END = '' @@ -842,19 +842,24 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_context_token_id is not None + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_context_token_id, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -873,8 +878,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_context_token_id, + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index c774171b9dcd2..c5b82b0ca4a07 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -483,6 +483,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 0cce0c78f8dc6..0fe723d594839 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 5f6ad58850439..ad295ef447325 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import run_dp_sharded_vision_model @@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, 1 else cur_feature[0]) return merged_image_features - def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings @@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - if multimodal_embeddings is None: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - else: - is_text = input_ids != self.config.image_token_id - text_ids = input_ids[is_text] - text_embeds = self.language_model.model.get_input_embeddings( - text_ids) - inputs_embeds = torch.empty(input_ids.shape[0], - text_embeds.shape[-1], - dtype=text_embeds.dtype, - device=text_embeds.device) - inputs_embeds[is_text] = text_embeds - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_id) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.language_model(input_ids, diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 3660efdc079aa..1145bea414808 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -40,7 +40,7 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) from .vision import VisionEncoderInfo, get_vision_encoder_info @@ -589,22 +589,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -617,8 +601,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index b9dfa8e9b6f51..938b02e3e04b3 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -233,6 +233,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: # We do not really use any input tokens and therefore no embeddings # to be calculated. However, due to the mandatory token ids in diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 19dd242f16eb6..3d7b06633f342 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, - SupportsQuant) +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, flatten_bn, make_empty_intermediate_tensors_factory, maybe_prefix) @@ -797,6 +797,9 @@ class TransformersForCausalLM(TransformersBase): else: self.lm_head = PPMissingLayer() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings()(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, @@ -873,13 +876,19 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): multimodal_embeds = self.get_multimodal_embeddings(**kwargs) if multimodal_embeds is not None: inputs_embeds = self.get_input_embeddings( - input_ids, multimodal_embeds) + input_ids, + multimodal_embeds, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None model_output = super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output + def get_language_model(self) -> torch.nn.Module: + return self.model + def get_multimodal_embeddings(self, **kwargs): pixel_values = kwargs.pop("pixel_values", None) pixel_values = pixel_values if pixel_values is not None else kwargs.pop( @@ -934,15 +943,42 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings=None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings()(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0): - mask = (input_ids == self.config.image_token_id) - mask = mask.unsqueeze(-1).expand_as(inputs_embeds) - multimodal_embeddings = torch.cat(multimodal_embeddings) + """ + Apply token embeddings to `input_ids`. - inputs_embeds = inputs_embeds.masked_scatter( - mask, multimodal_embeddings) - return inputs_embeds + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. + """ + from .utils import _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.model.get_input_embeddings(), + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 12ae9487ad9dc..77e886c22e634 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 @@ -555,19 +554,21 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - # The audio token index is not included in the embedding table - # We need to remove it before embedding lookup - safe_input_ids = input_ids.clone() - safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0 - inputs_embeds = self.language_model.get_input_embeddings( - safe_input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward(self, input_ids: torch.Tensor, @@ -601,8 +602,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None language_model = self.language_model diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 51cd41c864f09..7b3f20c6b28a1 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Optional, Protocol, Union, overload +from typing import Any, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn @@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, - is_multimodal: torch.Tensor, multimodal_embeddings: NestedTensors, + is_multimodal: torch.Tensor, ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -402,63 +402,37 @@ def _merge_multimodal_embeddings( Note: This updates ``inputs_embeds`` in place. """ - flattened = _flatten_embeddings(multimodal_embeddings) - try: - # This is equivalent to: inputs_embeds[is_multimodal] = flattened. - inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), - flattened.to(dtype=inputs_embeds.dtype)) - except RuntimeError as e: - num_expected_tokens = is_multimodal.sum().item() - assert isinstance(num_expected_tokens, int) + if len(multimodal_embeddings) == 0: + return inputs_embeds - if flattened.shape[0] != num_expected_tokens: + mm_embeds_flat = _flatten_embeddings(multimodal_embeddings) + input_dtype = inputs_embeds.dtype + + try: + # For debugging + # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype) + + # NOTE: This can avoid D2H sync (#22105), but fails to + # raise an error if is_multimodal.sum() < len(mm_embeds_flat) + inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), + mm_embeds_flat.to(dtype=input_dtype)) + except RuntimeError as e: + num_actual_tokens = len(mm_embeds_flat) + num_expected_tokens = is_multimodal.sum().item() + + if num_actual_tokens != num_expected_tokens: expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( - f"Attempted to assign {expr} = {flattened.shape[0]} " + f"Attempted to assign {expr} = {num_actual_tokens} " f"multimodal tokens to {num_expected_tokens} placeholders" ) from e - else: - raise ValueError("Error during masked scatter operation") from e + + raise ValueError("Error during masked scatter operation") from e return inputs_embeds -def embed_multimodal( - input_ids: torch.Tensor, - multimodal_token_id: int, - get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - multimodal_embeds: NestedTensors, -) -> torch.Tensor: - """ - Embed token IDs and multimodal inputs and combine their embeddings. - - ``multimodal_token_id`` is used to determine whether a token ID should - be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``. - - Compared to ``merge_multimodal_embeddings`, this avoids running - ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]`` - which causes issues when the placeholder token ID exceeds the - vocabulary size of the language model. - """ - is_multimodal = input_ids == multimodal_token_id - is_text = ~is_multimodal - - text_embeds = get_text_embeds(input_ids[is_text]) - merged_embeds = torch.empty( - (input_ids.shape[0], text_embeds.shape[1]), - dtype=text_embeds.dtype, - device=text_embeds.device, - ) - - merged_embeds[is_text] = text_embeds - - return _merge_multimodal_embeddings( - merged_embeds, - is_multimodal, - multimodal_embeds, - ) - - def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, @@ -491,23 +465,29 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = torch.tensor( - placeholder_token_id, - pin_memory=is_pin_memory_available()).to(device=input_ids.device, - non_blocking=True) - return _merge_multimodal_embeddings( - inputs_embeds, - torch.isin(input_ids, placeholder_token_id), - multimodal_embeddings, - ) + is_multimodal = isin_list(input_ids, placeholder_token_id) + else: + is_multimodal = (input_ids == placeholder_token_id) return _merge_multimodal_embeddings( inputs_embeds, - (input_ids == placeholder_token_id), - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, ) +def isin_list( + elements: torch.Tensor, + test_elements_list: list[int], +) -> torch.Tensor: + test_elements = torch.tensor( + test_elements_list, + pin_memory=is_pin_memory_available(), + ).to(device=elements.device, non_blocking=True) + + return torch.isin(elements, test_elements) + + class LayerFn(Protocol): def __call__(self, prefix: str) -> torch.nn.Module: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index b33e8d09c4be1..f93e7ccfd06ff 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -45,10 +45,8 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsTranscription) -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription +from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix logger = init_logger(__name__) @@ -376,9 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + audio_encoder = self.tokenizer.instruct.audio_encoder + audio_tok_id = audio_encoder.audio_token audio_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - audio_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + audio_embeddings, + is_multimodal=input_ids == audio_tok_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, @@ -421,20 +424,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, return audio_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - audio_encoder = self.tokenizer.instruct.audio_encoder - audio_tok_id = audio_encoder.audio_token - - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id) - return inputs_embeds - def _parse_and_validate_audio_arrays( self, **kwargs: object) -> Union[list[torch.Tensor], None]: audio_arrays = kwargs.pop("audio_arrays", None) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index de3e4f0592a62..7beeeddf988fe 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -579,10 +579,7 @@ class WhisperDecoder(nn.Module): hidden_states = self.layer_norm(hidden_states) return hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -916,7 +913,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: # This method just returns the decoder sequence embeddings since # Whisper does not have encoder text tokens. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 51e54e0dc337f..1b5bafb9ca1b1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -18,6 +18,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -64,8 +65,10 @@ class EagleProposer: # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() - self.is_multimodal_model = vllm_config.model_config \ - .is_multimodal_model + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + vllm_config.model_config) self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None @@ -175,7 +178,8 @@ class EagleProposer: last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embeds: Optional[list[torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -219,18 +223,21 @@ class EagleProposer: # copy inputs to buffer for cudagraph self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states - if self.is_multimodal_model: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = self.model.get_input_embeddings( - input_ids, - multimodal_embeddings=mm_embeds or None, + + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + + self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) - self.inputs_embeds[:num_tokens] = inputs_embeds - inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] else: - inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None with set_forward_context(per_layer_attn_metadata, self.vllm_config, @@ -372,14 +379,15 @@ class EagleProposer: self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states - if self.is_multimodal_model: - inputs_embeds = self.model.get_input_embeddings(input_ids) - self.inputs_embeds[:batch_size] = inputs_embeds - inputs_embeds = self.inputs_embeds[:input_batch_size] + if self.supports_mm_inputs: + self.inputs_embeds[:batch_size] = \ + self.model.get_input_embeddings(input_ids) + input_ids = None + inputs_embeds = self.inputs_embeds[:input_batch_size] else: - inputs_embeds = None input_ids = self.input_ids[:input_batch_size] + inputs_embeds = None # Run the model. with set_forward_context(per_layer_attn_metadata, @@ -849,7 +857,7 @@ class EagleProposer: self.attn_layer_names = list(draft_attn_layer_names) - if self.is_multimodal_model: + if self.supports_mm_inputs: # Even if the target model is multimodal, we can also use # text-only draft models try: @@ -861,7 +869,7 @@ class EagleProposer: logger.warning( "Draft model does not support multimodal inputs, " "falling back to text-only mode") - self.is_multimodal_model = False + self.supports_mm_inputs = False if supports_multimodal(target_model): # handle multimodality @@ -933,7 +941,7 @@ class EagleProposer: ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - if self.is_multimodal_model: + if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 22a177dd7cc7a..1bae0d4ce4d1f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -368,6 +368,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -1627,9 +1632,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self, scheduler_output: "SchedulerOutput", shift_computed_tokens: int = 0, - ) -> list[torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + mm_embeds = list[torch.Tensor]() + is_mm_embed = self.is_mm_embed.cpu + is_mm_embed[:total_num_scheduled_tokens] = False + + req_start_idx = 0 should_sync_mrope_positions = False - mm_embeds: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] @@ -1638,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens + for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position start_pos = pos_info.offset @@ -1670,6 +1683,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ + = True if is_embed is None else is_embed + mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], is_embed=is_embed, @@ -1677,6 +1694,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: + assert req_state.mrope_positions is not None should_sync_mrope_positions = True mm_embeds_req, new_mrope_positions, new_delta = ( self.model.recompute_mrope_positions( @@ -1685,18 +1703,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, )) - assert req_state.mrope_positions is not None req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta mm_embeds.extend(mm_embeds_req) + req_start_idx += num_scheduled_tokens + + is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) if should_sync_mrope_positions: self._calc_mrope_positions(scheduler_output) - self.mrope_positions.copy_to_gpu( - scheduler_output.total_num_scheduled_tokens) + self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) - return mm_embeds + return mm_embeds, is_mm_embed def _extract_encoder_inputs( self, @@ -1990,14 +2009,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings( + scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids.gpu[:num_scheduled_tokens], - multimodal_embeddings=mm_embeds or None, + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) # TODO(woosuk): Avoid the copy. Optimize. @@ -2586,10 +2607,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - mm_embeds = None + if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(scheduler_output, - shift_computed_tokens=1) + mm_embed_inputs = self._gather_mm_embeddings( + scheduler_output, + shift_computed_tokens=1, + ) + else: + mm_embed_inputs = None draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -2599,8 +2624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, - mm_embeds=mm_embeds, + mm_embed_inputs=mm_embed_inputs, ) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index a330f50875a89..2405f978ca73f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -263,6 +263,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory) + # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens # Keep in int64 to avoid overflow with long context @@ -879,13 +886,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", - ) -> list[torch.Tensor]: - mm_embeds: list[torch.Tensor] = [] + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + padded_total_num_scheduled_tokens = _get_padded_token_len( + self.num_tokens_paddings, total_num_scheduled_tokens) + + is_mm_embed = self.is_mm_embed_cpu + is_mm_embed[:padded_total_num_scheduled_tokens] = False + mm_embeds = list[torch.Tensor]() + req_start_idx = 0 + for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens + # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid @@ -906,26 +922,53 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The encoder output is already processed and stored # in the decoder's KV cache. continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." + assert pos_info.is_embed is None, "Expected all positions to"\ " be contiguous and embeddings." - encoder_output = self.encoder_cache[mm_hash] - mm_embeds.append(encoder_output) - return mm_embeds - def _get_model_inputs(self, input_ids: torch.Tensor, - mm_embeds: list[torch.Tensor]): + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ + = True + + # Only whole mm items are processed + mm_embeds.append(encoder_output) + + req_start_idx += num_scheduled_tokens + + is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \ + .to(self.device) + + return mm_embeds, is_mm_embed + + def _get_model_inputs( + self, + input_ids: torch.Tensor, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]], + ): if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds = self.model.get_input_embeddings( - input_ids=input_ids, + input_ids, multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) + return None, inputs_embeds else: # For text-only models, we use token ids as input. @@ -953,9 +996,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) else: - mm_embeds = [] + mm_embed_inputs = None + torch_xla.sync(wait=False) # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. @@ -972,7 +1016,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ end_index = self._prepare_inputs(scheduler_output, start_index) input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embeds) + self.input_ids, mm_embed_inputs) torch_xla.sync(wait=False) # Run the decoder with set_forward_context( @@ -1325,9 +1369,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hf_config.image_token_index placeholders_ids = placeholders_ids.to(self.device) + + mm_mask = torch.tensor([False] * num_tokens) + mm_mask[:items_size] = True + mm_mask = mm_mask.to(self.device) # Assign outputs or the graph will be cut short. - a, b = self._get_model_inputs(placeholders_ids, - [mm_embeds]) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=([mm_embeds], mm_mask), + ) assert a is None torch_xla.sync(wait=False) @@ -1338,7 +1388,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32, device="cpu") placeholders_ids = placeholders_ids.to(self.device) - a, b = self._get_model_inputs(placeholders_ids, []) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=None, + ) assert a is None torch_xla.sync(wait=False) From 3f5d902d2a6fa752925d49b8c219b9515b37c0a6 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sat, 27 Sep 2025 06:09:26 -0400 Subject: [PATCH 16/20] Validate API tokens in constant time (#25781) Signed-off-by: rentianyue-jk Signed-off-by: Russell Bryant Co-authored-by: rentianyue-jk --- vllm/entrypoints/openai/api_server.py | 28 +++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 97cbda63bf426..d054e2826744f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,12 +3,14 @@ import asyncio import gc +import hashlib import importlib import inspect import json import multiprocessing import multiprocessing.forkserver as forkserver import os +import secrets import signal import socket import tempfile @@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: class AuthenticationMiddleware: """ Pure ASGI middleware that authenticates each request by checking - if the Authorization header exists and equals "Bearer {api_key}". + if the Authorization Bearer token exists and equals anyof "{api_key}". Notes ----- @@ -1263,7 +1265,26 @@ class AuthenticationMiddleware: def __init__(self, app: ASGIApp, tokens: list[str]) -> None: self.app = app - self.api_tokens = {f"Bearer {token}" for token in tokens} + self.api_tokens = [ + hashlib.sha256(t.encode("utf-8")).digest() for t in tokens + ] + + def verify_token(self, headers: Headers) -> bool: + authorization_header_value = headers.get("Authorization") + if not authorization_header_value: + return False + + scheme, _, param = authorization_header_value.partition(" ") + if scheme.lower() != "bearer": + return False + + param_hash = hashlib.sha256(param.encode("utf-8")).digest() + + token_match = False + for token_hash in self.api_tokens: + token_match |= secrets.compare_digest(param_hash, token_hash) + + return token_match def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: @@ -1276,8 +1297,7 @@ class AuthenticationMiddleware: url_path = URL(scope=scope).path.removeprefix(root_path) headers = Headers(scope=scope) # Type narrow to satisfy mypy. - if url_path.startswith("/v1") and headers.get( - "Authorization") not in self.api_tokens: + if url_path.startswith("/v1") and not self.verify_token(headers): response = JSONResponse(content={"error": "Unauthorized"}, status_code=401) return response(scope, receive, send) From 7977e5027c2250a4abc1f474c5619c40b4e5682f Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sat, 27 Sep 2025 06:46:49 -0400 Subject: [PATCH 17/20] Add filtering for chat template kwargs (#25794) Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- tests/entrypoints/test_chat_utils.py | 85 +++++++++++++++++++++++++ vllm/entrypoints/chat_utils.py | 54 +++++++++++++++- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/cli_args.py | 10 ++- vllm/entrypoints/openai/serving_chat.py | 14 +++- 5 files changed, 158 insertions(+), 6 deletions(-) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 78370d199b566..a268f573ef905 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, parse_chat_messages, parse_chat_messages_futures, resolve_chat_template_content_format, + resolve_chat_template_kwargs, resolve_hf_chat_template) from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, @@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" +QWEN3_MODEL_ID = "Qwen/Qwen3-8B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): assert isinstance(chat_template, str) +@pytest.mark.parametrize( + "model, expected_kwargs", + [ + ( + QWEN2VL_MODEL_ID, + { + "add_vision_id", "add_generation_prompt", + "continue_final_message", "tools" + }, + ), + ( + QWEN3_MODEL_ID, + { + "enable_thinking", "add_generation_prompt", + "continue_final_message", "tools" + }, + ), + ], +) +def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, + expected_kwargs): + """checks that chat_template is a dict type for HF models.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + tools = ([{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + }]) + + chat_template_kwargs = { + # both unused + "unsed_kwargs_1": 123, + "unsed_kwargs_2": "abc", + # should not appear + "chat_template": "{% Hello world! %}", + # used by tokenizer + "continue_final_message": True, + "tools": tools, + # both used by Qwen2-VL and Qwen3 + "add_generation_prompt": True, + # only used by Qwen2-VL + "add_vision_id": True, + # only used by Qwen3 + "enable_thinking": True, + } + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) + + # Build the tokenizer + tokenizer = get_tokenizer( + model, + trust_remote_code=model_config.trust_remote_code, + ) + + # Test detecting the tokenizer's chat_template + chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=None, + tools=tools, + model_config=model_config, + ) + resolved_chat_template_kwargs = resolve_chat_template_kwargs( + tokenizer, + chat_template=chat_template, + chat_template_kwargs=chat_template_kwargs, + ) + assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs + + # NOTE: Qwen2-Audio default chat template is specially defined inside # processor class instead of using `tokenizer_config.json` # yapf: disable diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4e1ecb9ed4c51..6b0ed23277d36 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -11,7 +11,12 @@ from pathlib import Path from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast) +import jinja2 +import jinja2.ext +import jinja2.meta import jinja2.nodes +import jinja2.parser +import jinja2.sandbox import transformers.utils.chat_template_utils as hf_chat_utils # yapf conflicts with isort for this block # yapf: disable @@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import ( # yapf: enable from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid +from vllm.utils import random_uuid, supports_kw logger = init_logger(__name__) @@ -1554,6 +1559,46 @@ def parse_chat_messages_futures( return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() +# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412 +# only preserve the parse function used to resolve chat template kwargs +class AssistantTracker(jinja2.ext.Extension): + tags = {"generation"} + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + call = self.call_method("_generation_support") + call_block = jinja2.nodes.CallBlock(call, [], [], body) + return call_block.set_lineno(lineno) + + +def resolve_chat_template_kwargs( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + chat_template: str, + chat_template_kwargs: dict[str, Any], +) -> dict[str, Any]: + fn_kw = { + k for k in chat_template_kwargs + if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) + } + + env = jinja2.sandbox.ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[AssistantTracker, jinja2.ext.loopcontrols], + ) + parsed_content = env.parse(chat_template) + template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + + # We exclude chat_template from kwargs here, because + # chat template has been already resolved at this stage + unexpected_vars = {"chat_template"} + accept_vars = (fn_kw | template_vars) - unexpected_vars + return { + k: v for k, v in chat_template_kwargs.items() if k in accept_vars + } + + def apply_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: list[ConversationMessage], @@ -1579,12 +1624,17 @@ def apply_hf_chat_template( ) try: + resolved_kwargs = resolve_chat_template_kwargs( + tokenizer=tokenizer, + chat_template=hf_chat_template, + chat_template_kwargs=kwargs, + ) return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type] chat_template=hf_chat_template, tokenize=tokenize, - **kwargs, + **resolved_kwargs, ) # External library exceptions can sometimes occur despite the framework's diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d054e2826744f..15844d3162fe9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1716,6 +1716,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, exclude_tools_when_tool_choice_none=args. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 1c2a6f58197d8..a306c2bb7cb5b 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -103,9 +103,13 @@ class FrontendArgs: chat_template_content_format: ChatTemplateContentFormatOption = "auto" """The format to render message content within a chat template. -* "string" will render the content as a string. Example: `"Hello World"` -* "openai" will render the content as a list of dictionaries, similar to OpenAI -schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" + * "string" will render the content as a string. Example: `"Hello World"` + * "openai" will render the content as a list of dictionaries, similar to + OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" + trust_request_chat_template: bool = False + """Whether to trust the chat template provided in the request. If False, + the server will always use the chat template specified by `--chat-template` + or the ones from tokenizer.""" response_role: str = "assistant" """The role name to return if `request.add_generation_prompt=true`.""" ssl_keyfile: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0780448ad7332..ab4bf75102f43 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, return_tokens_as_token_ids: bool = False, reasoning_parser: str = "", enable_auto_tools: bool = False, @@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing): self.response_role = response_role self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template self.enable_log_outputs = enable_log_outputs # set up tool use @@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing): if not self.use_harmony: # Common case. + request_chat_template = request.chat_template + chat_template_kwargs = request.chat_template_kwargs + if not self.trust_request_chat_template and ( + request_chat_template is not None or + (chat_template_kwargs and + chat_template_kwargs.get("chat_template") is not None)): + return self.create_error_response( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template.") ( conversation, request_prompts, @@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing): request, tokenizer, request.messages, - chat_template=request.chat_template or self.chat_template, + chat_template=request_chat_template or self.chat_template, chat_template_content_format=self. chat_template_content_format, add_generation_prompt=request.add_generation_prompt, From ec152c8748d0b37da157fa6a99a75920822dc30d Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 27 Sep 2025 13:18:20 +0100 Subject: [PATCH 18/20] Fix GPTQ model loading in Transformers backend (#25770) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py --- tests/models/test_transformers.py | 10 +++++++--- vllm/model_executor/models/transformers.py | 22 +++++++++++++++++----- vllm/model_executor/models/utils.py | 7 +++++-- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 1817d4aeee9f9..e4b5e7c244539 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -100,10 +100,9 @@ def test_distributed( kwargs_test=kwargs) -@pytest.mark.skipif( - current_platform.is_rocm(), - reason="bitsandbytes quantization is currently not supported in rocm.") @pytest.mark.parametrize("model, quantization_kwargs", [ + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}), + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}), ( "meta-llama/Llama-3.2-1B-Instruct", { @@ -121,6 +120,11 @@ def test_quantization( max_tokens: int, num_logprobs: int, ) -> None: + if (current_platform.is_rocm() + and quantization_kwargs.get("quantization", "") == "bitsandbytes"): + pytest.skip( + "bitsandbytes quantization is currently not supported in rocm.") + with vllm_runner( model, model_impl="auto", enforce_eager=True, **quantization_kwargs) as vllm_model: # type: ignore[arg-type] diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 3d7b06633f342..7cfb639f675d5 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -447,7 +447,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): self.device_config: DeviceConfig = vllm_config.device_config self.model_config: ModelConfig = vllm_config.model_config self.parallel_config: ParallelConfig = vllm_config.parallel_config - self.quant_config: QuantizationConfig = vllm_config.quant_config + self.quant_config: Optional[ + QuantizationConfig] = vllm_config.quant_config self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size @@ -456,7 +457,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] + """Skip loading weights whose qualname starts with these prefixes.""" self.skip_substrs: list[str] = [] + """Skip loading weights whose qualname contains these substrings.""" + self.ignore_unexpected_prefixes: list[str] = [] + """Ignore unexpected weights whose qualname starts with these prefixes. + """ + self.ignore_unexpected_suffixes: list[str] = [] + """Ignore unexpected weights whose qualname ends with these suffixes.""" + + # Skip loading extra bias for GPTQ models. + if self.quant_config and "gptq" in self.quant_config.get_name(): + self.ignore_unexpected_suffixes.append(".bias") # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` @@ -563,9 +575,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): raise ValueError( f"{type(self.model)} does not support tensor parallel. {tip}") - def _tensor_parallel(module: nn.Module, - prefix: str = "", - tp_plan=None): + def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None): tp_plan = tp_plan or {} # If the current module is a PreTrainedModel, set the tp_plan for @@ -597,7 +607,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): prefix=qual_name, tp_plan=tp_plan) - _tensor_parallel(self.model) + _tensor_parallel(self.model, prefix="model") def create_attention_instances( self, @@ -696,6 +706,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): self, skip_prefixes=self.skip_prefixes, skip_substrs=self.skip_substrs, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 7b3f20c6b28a1..bb6a0bd022021 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -109,6 +109,7 @@ class AutoWeightsLoader: skip_prefixes: Optional[list[str]] = None, skip_substrs: Optional[list[str]] = None, ignore_unexpected_prefixes: Optional[list[str]] = None, + ignore_unexpected_suffixes: Optional[list[str]] = None, ) -> None: super().__init__() @@ -116,6 +117,7 @@ class AutoWeightsLoader: self.skip_prefixes = skip_prefixes or [] self.skip_substrs = skip_substrs or [] self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or [] # update default skip_substrs self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS @@ -149,8 +151,9 @@ class AutoWeightsLoader: or any(substr in qualname for substr in self.skip_substrs)) def _can_ignore_unexpected(self, qualname: str) -> bool: - return any( - qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes) + return any(iup) or any(ius) def _load_param( self, From f9df8b4ad77a933659c93cb1b923c8e09d76ea3a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 27 Sep 2025 10:13:11 -0400 Subject: [PATCH 19/20] [Bugfix] Fix triton import precommit failure (#25803) Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/batch_invariant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index ae2c842af698b..c025d509d8626 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -7,8 +7,8 @@ from collections.abc import Callable from typing import Any, Union import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, From a5354b3ed24723ac6e351896cc11e16dcee0b701 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 27 Sep 2025 10:22:28 -0400 Subject: [PATCH 20/20] [Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982) Signed-off-by: Tyler Michael Smith --- vllm/config/parallel.py | 18 +++ .../device_communicators/all2all.py | 120 ++++++++++++------ .../base_device_communicator.py | 28 ++-- .../device_communicators/cuda_communicator.py | 22 ++-- .../device_communicators/xpu_communicator.py | 16 ++- vllm/distributed/parallel_state.py | 17 ++- vllm/forward_context.py | 97 +++++++++----- vllm/model_executor/layers/fused_moe/layer.py | 119 +++++++++-------- vllm/model_executor/models/aria.py | 16 +-- vllm/model_executor/models/deepseek_v2.py | 58 +-------- vllm/model_executor/models/ernie_mtp.py | 16 +-- vllm/model_executor/models/glm4.py | 16 ++- vllm/model_executor/models/gpt_oss.py | 42 ++++-- vllm/model_executor/models/granitemoe.py | 38 ++++-- vllm/model_executor/models/llama.py | 25 ++-- vllm/model_executor/models/llama4.py | 44 ++++--- vllm/model_executor/models/llama4_eagle.py | 4 +- vllm/model_executor/models/llama_eagle.py | 8 +- vllm/model_executor/models/llama_eagle3.py | 27 ++-- vllm/model_executor/models/qwen3_moe.py | 60 +++++---- vllm/model_executor/models/qwen3_next.py | 70 +++++----- vllm/model_executor/models/qwen3_next_mtp.py | 6 +- vllm/model_executor/models/utils.py | 48 ++++++- 23 files changed, 540 insertions(+), 375 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index f80eb1adc7fd3..8b980458ddaff 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -279,6 +279,24 @@ class ParallelConfig: assert last_exc is not None raise last_exc + # The all_reduce at the end of attention (during o_proj) means that + # inputs are replicated across each rank of the tensor parallel group. + # If using expert-parallelism with DeepEP All2All ops, replicated + # tokens results in useless duplicate computation and communication. + # + # In this case, ensure the input to the experts is sequence parallel + # to avoid the excess work. + # + # Not needed for pplx-kernels as it can handle duplicate input tokens. + @property + def use_sequence_parallel_moe(self) -> bool: + return (envs.VLLM_ALL2ALL_BACKEND + in ("allgather_reducescatter", "naive", + "deepep_high_throughput", "deepep_low_latency") + and self.enable_expert_parallel + and self.tensor_parallel_size > 1 + and self.data_parallel_size > 1) + @staticmethod def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool: diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 661ed939608a0..bb3fd657facde 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import vllm.envs as envs -from vllm.distributed import get_dp_group +from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import has_deep_ep, has_pplx @@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase): super().__init__(cpu_group) def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): + cu_tokens_across_sp_cpu: torch.Tensor, + is_sequence_parallel: bool) -> torch.Tensor: assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] + rank = self.rank if is_sequence_parallel else self.dp_rank + world_size = (self.world_size + if is_sequence_parallel else self.dp_world_size) + + start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1] + end = cu_tokens_across_sp_cpu[rank] buffer[start:end, :].copy_(x) - for idx in range(self.dp_world_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - self.dp_group.broadcast(buffer[start:end, :], idx) + for idx in range(world_size): + start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1] + end = cu_tokens_across_sp_cpu[idx] + get_ep_group().broadcast(buffer[start:end, :], idx) return buffer - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - sizes = get_forward_context( - ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states, router_logits = get_dp_group().all_gatherv( - [hidden_states, router_logits], - dim=0, - sizes=sizes, - ) + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: + sp_size = self.tp_group.world_size if is_sequence_parallel else 1 + dp_metadata = get_forward_context().dp_metadata + cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_sp_cpu, + is_sequence_parallel) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_sp_cpu, + is_sequence_parallel) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - sizes = get_forward_context( - ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states = get_dp_group().reduce_scatterv(hidden_states, - dim=0, - sizes=sizes) + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: + + ep_rank = self.rank if is_sequence_parallel else self.dp_rank + + dp_metadata = get_forward_context().dp_metadata + sp_size = self.tp_group.world_size if is_sequence_parallel else 1 + cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) + + start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1] + end = cu_tokens_across_sp_cpu[ep_rank] + + all_hidden_states = get_ep_group().all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] return hidden_states def destroy(self): @@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: """ Gather hidden_states and router_logits from all dp ranks. """ sizes = get_forward_context( ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states, router_logits = get_dp_group().all_gatherv( + + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] + hidden_states, router_logits = dist_group.all_gatherv( [hidden_states, router_logits], dim=0, sizes=sizes, ) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: """ Reduce-scatter hidden_states across all dp ranks. """ sizes = get_forward_context( ).dp_metadata.get_chunk_sizes_across_dp_rank() - hidden_states = get_dp_group().reduce_scatterv(hidden_states, - dim=0, - sizes=sizes) + + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + hidden_states = dist_group.reduce_scatterv(hidden_states, + dim=0, + sizes=sizes) return hidden_states def destroy(self): @@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase): kwargs, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode) - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): def get_handle(self, kwargs): raise NotImplementedError - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase): self.workspace_tensor = None self.prepare_workspace_tensor = None self.mapping = None - self.initialized = False \ No newline at end of file + self.initialized = False diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 586441c917830..a42081fb0c158 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -28,6 +28,8 @@ class Cache: class All2AllManagerBase: + rank: int + world_size: int def __init__(self, cpu_group): self.cpu_group = cpu_group @@ -40,6 +42,7 @@ class All2AllManagerBase: # all2all lives in ep group, which is merged from dp and tp group self.dp_group = get_dp_group() self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction # when we create this object self.dp_rank = self.dp_group.rank_in_group @@ -60,17 +63,21 @@ class All2AllManagerBase: # and reuse it for the same config. raise NotImplementedError + def dispatch(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False): + raise NotImplementedError + def set_num_sms(self, num_sms: int): pass def max_sms_used(self) -> Optional[int]: return None # None means it could use the whole GPU - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - raise NotImplementedError - - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False): raise NotImplementedError def destroy(self): @@ -267,15 +274,20 @@ class DeviceCommunicatorBase: module.quant_method.init_prepare_finalize(module) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: """ Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class. """ return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: """ Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class. diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index bab372b722dbb..30d1bf10138bb 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase): use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM - # ep does not use pynccl - use_pynccl = "ep" not in unique_name - - self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem @@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase): SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None - if use_pynccl and self.world_size > 1: + if self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, @@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase): return output_list def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) + hidden_states, router_logits, is_sequence_parallel) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) + hidden_states = self.all2all_manager.combine(hidden_states, + is_sequence_parallel) return hidden_states diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index b236bae261e03..27bd176554afa 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase): dist.broadcast(input_, src=src, group=self.device_group) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) + hidden_states, router_logits, is_sequence_parallel) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False) -> torch.Tensor: assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) + hidden_states = self.all2all_manager.combine(hidden_states, + is_sequence_parallel) return hidden_states diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 69f98eb54f36c..638170963e2b1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -871,17 +871,24 @@ class GroupCoordinator: model) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: if self.device_communicator is not None: return self.device_communicator.dispatch(hidden_states, - router_logits) + router_logits, + is_sequence_parallel) else: return hidden_states, router_logits - def combine(self, hidden_states) -> torch.Tensor: + def combine(self, + hidden_states, + is_sequence_parallel: bool = False) -> torch.Tensor: if self.device_communicator is not None: - return self.device_communicator.combine(hidden_states) + return self.device_communicator.combine(hidden_states, + is_sequence_parallel) else: return hidden_states diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 2bf4e18045211..09defade00dc1 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple): return BatchDescriptor(self.num_tokens, uniform_decode=False) -def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], +def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, + sequence_parallel_size: int) -> list[int]: + sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) // + sequence_parallel_size) + + sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size) + return sp_tokens.tolist() + + +def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, + sequence_parallel_size: int, max_num_tokens: int, chunk_idx: int) -> list[int]: - dp_size = len(num_tokens_across_dp_cpu) - local_size = [-1] * dp_size - for i in range(dp_size): - dp_tokens = num_tokens_across_dp_cpu[i] + sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, + sequence_parallel_size) + sp_size = len(sp_tokens) + + local_size = [-1] * sp_size + for i in range(sp_size): + # Take into account sharding if MoE activation is sequence parallel. local_size[i] = min(max_num_tokens, - dp_tokens - (max_num_tokens * chunk_idx)) + sp_tokens[i] - (max_num_tokens * chunk_idx)) if local_size[i] <= 0: local_size[i] = 1 # ensure lockstep even if done return local_size @@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor - cu_tokens_across_dp_cpu: torch.Tensor + num_tokens_across_dp_cpu: torch.Tensor + + # NOTE: local_sizes should only be set by the chunked_sizes context manager local_sizes: Optional[list[int]] = None @staticmethod @@ -98,6 +113,17 @@ class DPMetadata: dist.all_reduce(num_tokens_tensor, group=group) return num_tokens_tensor.cpu() + # Get the cumulative tokens across sequence parallel ranks. + # In this case the input to the MoEs will be distributed w.r.t both + # DP and TP rank. + # When sp_size==1, this is just the cummulative num tokens across DP. + def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: + num_tokens_across_sp_cpu = ( + (self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size) + num_tokens_across_sp_cpu = ( + num_tokens_across_sp_cpu.repeat_interleave(sp_size)) + return torch.cumsum(num_tokens_across_sp_cpu, dim=0) + @staticmethod def should_ubatch_across_dp( should_ubatch: bool, orig_num_tokens_per_ubatch: int, @@ -147,10 +173,10 @@ class DPMetadata: @staticmethod def make( - parallel_config: ParallelConfig, - attn_metadata: Any, - num_tokens: int, - num_tokens_across_dp: Optional[torch.Tensor] = None + parallel_config: ParallelConfig, + attn_metadata: Any, + num_tokens: int, + num_tokens_across_dp_cpu: Optional[torch.Tensor] = None ) -> "DPMetadata": assert parallel_config.data_parallel_size > 1 @@ -167,18 +193,18 @@ class DPMetadata: # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] - == batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}" - if num_tokens_across_dp is None: - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + assert (num_tokens_across_dp_cpu is None + or num_tokens_across_dp_cpu[dp_rank] == batchsize + ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" + if num_tokens_across_dp_cpu is None: + num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp( batchsize, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) - cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) - return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu, - num_tokens_across_dp) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) + return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) @contextmanager - def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): + def chunked_sizes(self, sequence_parallel_size: int, + max_chunk_size_per_rank: int, chunk_idx: int): """ Context manager to compute and temporarily set the per-rank local token sizes for a specific chunk during chunked forward execution. @@ -192,31 +218,40 @@ class DPMetadata: `chunk_idx`, this context manager sets `self.local_sizes` to the number of tokens to process in that chunk on each rank. - It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the - number of tokens per rank, and calls `_compute_chunked_local_num_tokens` - to determine the chunk-wise split. - `self.local_sizes` is only valid inside the context. Args: + sequence_parallel_size: When Attn is TP and MoE layers are EP, + we use SP between the layers to avoid + redundant ops. We need this value to + compute the chunked sizes. max_chunk_size_per_rank: The max number of tokens each rank is allowed to process in this chunk. chunk_idx: The index of the chunk to compute sizes for. """ - cu_sizes = self.cu_tokens_across_dp_cpu - num_tokens_across_dp_cpu = [ - (cu_sizes[i] - - cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item() - for i in range(len(cu_sizes)) - ] self.local_sizes = _compute_chunked_local_num_tokens( - num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) + self.num_tokens_across_dp_cpu, sequence_parallel_size, + max_chunk_size_per_rank, chunk_idx) + try: + yield self.local_sizes + finally: + self.local_sizes = None + + @contextmanager + def sp_local_sizes(self, sequence_parallel_size: int): + """ + Context mamager for setting self.local_sizes. Same as self.chunked_sizes + but without any chunking. + """ + self.local_sizes = _compute_sp_num_tokens( + self.num_tokens_across_dp_cpu, sequence_parallel_size) try: yield self.local_sizes finally: self.local_sizes = None def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: + assert self.local_sizes is not None return self.local_sizes diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index eccae8b2a7af4..8de1d14d46b33 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,6 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable +from contextlib import nullcontext from enum import Enum from typing import Callable, Literal, Optional, Union, get_args, overload @@ -983,8 +984,7 @@ class FusedMoE(CustomOp): if dp_size is not None else get_dp_group().world_size) self.is_sequence_parallel = is_sequence_parallel - if self.is_sequence_parallel: - self.sp_size = tp_size_ + self.sp_size = tp_size_ if is_sequence_parallel else 1 self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( @@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp): # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, + with ctx.dp_metadata.chunked_sizes(self.sp_size, + moe_dp_chunk_size_per_rank, chunk_idx): process_chunk(chunk_start, chunk_end, @@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp): else: shared_output = None - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) + ctx = get_forward_context() + sp_ctx = ctx.dp_metadata.sp_local_sizes( + self.sp_size) if ctx.dp_metadata else nullcontext() - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.e_score_correction_bias, - activation=self.activation, - apply_router_weight_on_input=self.apply_router_weight_on_input, - enable_eplb=self.enable_eplb, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, - ) + with sp_ctx: + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits, self.is_sequence_parallel) - if shared_output is not None: - assert not isinstance(final_hidden_states, tuple) - assert self.shared_experts is not None - final_hidden_states = ( - shared_output, - final_hidden_states, + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, tuple) - final_hidden_states, zero_expert_result = final_hidden_states - def reduce_output(states: torch.Tensor, - do_combine: bool = True) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, tuple) + final_hidden_states, zero_expert_result = final_hidden_states - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - states = self.maybe_all_reduce_tensor_model_parallel(states) + def reduce_output(states: torch.Tensor, + do_combine: bool = True) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, + self.is_sequence_parallel) - return states + if (not self.is_sequence_parallel and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1)): + states = self.maybe_all_reduce_tensor_model_parallel( + states) - if self.shared_experts is not None: - return ( - reduce_output(final_hidden_states[0], do_combine=False), - reduce_output(final_hidden_states[1]), - ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, torch.Tensor) - return reduce_output(final_hidden_states) + zero_expert_result - else: - return reduce_output(final_hidden_states) + return states + + if self.shared_experts is not None: + return ( + reduce_output(final_hidden_states[0], do_combine=False), + reduce_output(final_hidden_states[1]), + ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, torch.Tensor) + return reduce_output(final_hidden_states) + zero_expert_result + else: + return reduce_output(final_hidden_states) @classmethod def make_expert_params_mapping( diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 6cef5e134a4bc..e0d7af0b1c3e6 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.processing_aria import AriaProcessor -from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.config import QuantizationConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE @@ -297,14 +297,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer): Experts (MoE) Layer. """ - def __init__( - self, - config: AriaTextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, cache_config, quant_config, prefix) + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config, prefix) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.mlp = AriaTextMoELayer(config, quant_config=quant_config, prefix=f"{prefix}.mlp") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index aab522390a7a8..2e0bcbe5d2e57 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,7 +32,6 @@ import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config -import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig @@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors -from vllm.utils import cdiv, direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, @@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module): return x -# Chunk x along the num_tokens axis for sequence parallelism -# NOTE: This is wrapped in a torch custom op to work around the following issue: -# The output tensor can have a sequence length 0 at small input sequence lengths -# even though we explicitly pad to avoid this. -def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - # all_gather needs the sequence length to be divisible by tp_size - seq_len = x.size(0) - remainder = seq_len % tp_size - if remainder != 0: - pad_len = tp_size - remainder - x = nn.functional.pad(x, (0, 0, 0, pad_len)) - - chunk = x.shape[0] // tp_size - start = tp_rank * chunk - return torch.narrow(x, 0, start, chunk) - - -def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor: - tp_size = get_tensor_model_parallel_world_size() - seq_len = cdiv(x.size(0), tp_size) - shape = list(x.shape) - shape[0] = seq_len - out = torch.empty(shape, dtype=x.dtype, device=x.device) - return out - - -direct_register_custom_op( - op_name="sequence_parallel_chunk", - op_func=sequence_parallel_chunk, - fake_impl=sequence_parallel_chunk_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - class DeepseekV2MoE(nn.Module): def __init__( @@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module): self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts - # The all_reduce at the end of attention (during o_proj) means that - # inputs are replicated across each rank of the tensor parallel group. - # If using expert-parallelism with DeepEP All2All ops, replicated - # tokens results in useless duplicate computation and communication. - # - # In this case, ensure the input to the experts is sequence parallel - # to avoid the excess work. - # - # Not needed for pplx-kernels as it can handle duplicate input tokens. - self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND - in ("deepep_high_throughput", - "deepep_low_latency") - and parallel_config.enable_expert_parallel - and self.tp_size > 1) + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module): # TODO: We can replace the all_reduce at the end of attn with a # reduce_scatter instead of chunking here. if self.is_sequence_parallel: - hidden_states = torch.ops.vllm.sequence_parallel_chunk( - hidden_states) + hidden_states = sequence_parallel_chunk(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 3b24bf2f1ef8f..2e6ef2d476a67 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -29,10 +29,9 @@ import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + vllm_config: VllmConfig, prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + config = vllm_config.model_config.hf_config self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module): self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) - self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, - prefix) + self.mtp_block = LlamaDecoderLayer(vllm_config, prefix) def forward( self, @@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module): self.layers = torch.nn.ModuleDict({ str(idx): ErnieMultiTokenPredictorLayer( - config, + vllm_config, f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, ) for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index b9d5e24e9f6fa..f49f21a40f824 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -136,14 +136,16 @@ class Glm4Attention(nn.Module): class Glm4DecoderLayer(nn.Module): - def __init__( - self, - config: Glm4Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[Glm4Config] = None) -> None: super().__init__() + + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7c755a00e1c98..47ba5084d6083 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from vllm.utils import cdiv @@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module): def __init__( self, - config: GptOssConfig, + vllm_config: VllmConfig, layer_idx: int, - quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + self.layer_idx = layer_idx self.num_experts = config.num_local_experts self.experts_per_token = config.num_experts_per_tok @@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module): prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, - activation="swigluoai") + activation="swigluoai", + is_sequence_parallel=self.is_sequence_parallel) def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + if self.is_sequence_parallel: + x = sequence_parallel_chunk(x) + g = self.router(x) x = self.experts(hidden_states=x, router_logits=g) + + if self.is_sequence_parallel: + x = tensor_model_parallel_all_gather(x.contiguous(), 0) + x = x[:num_tokens] return x @@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module): def __init__( self, - config: GptOssConfig, - cache_config: CacheConfig, - quant_config: QuantizationConfig, + vllm_config: VllmConfig, prefix: str = "", ): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + self.layer_idx = extract_layer_index(prefix) self.attn = OAIAttention(config, prefix=f"{prefix}.attn", cache_config=cache_config) - self.mlp = MLPBlock(config, + self.mlp = MLPBlock(vllm_config, self.layer_idx, - quant_config=quant_config, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) @@ -216,8 +234,6 @@ class GptOssModel(nn.Module): ): super().__init__() self.config = vllm_config.model_config.hf_config - self.cache_config = vllm_config.cache_config - self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( @@ -227,9 +243,7 @@ class GptOssModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, lambda prefix: TransformerBlock( - self.config, - cache_config=self.cache_config, - quant_config=self.quant_config, + vllm_config, prefix=prefix, ), prefix=f"{prefix}.layers", diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 47ac22c4aeaa5..76a5745a4f51b 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -29,12 +29,13 @@ from typing import Any, Optional import torch from torch import nn -from transformers.models.granitemoe import GraniteMoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + is_sequence_parallel=False, prefix: str = ""): super().__init__() self.hidden_size = hidden_size + self.is_sequence_parallel = is_sequence_parallel # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(hidden_size, @@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module): renormalize=True, quant_config=quant_config, tp_size=tp_size, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + is_sequence_parallel=self.is_sequence_parallel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + num_tokens = orig_shape[0] + final_hidden_states = final_hidden_states[:num_tokens] + return final_hidden_states.view(orig_shape) @@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module): def __init__( self, - config: GraniteMoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) @@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + is_sequence_parallel=parallel_config.use_sequence_parallel_moe, prefix=f"{prefix}.block_sparse_moe") self.input_layernorm = RMSNorm(config.hidden_size, @@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GraniteMoeDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), + lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1b03cbef501b3..c7dd134ea47e9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -68,6 +68,7 @@ class LlamaMLP(nn.Module): bias: bool = False, prefix: str = "", reduce_results: bool = True, + disable_tp: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -75,6 +76,7 @@ class LlamaMLP(nn.Module): output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, + disable_tp=disable_tp, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( @@ -83,6 +85,7 @@ class LlamaMLP(nn.Module): bias=bias, quant_config=quant_config, reduce_results=reduce_results, + disable_tp=disable_tp, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": @@ -237,14 +240,16 @@ class LlamaAttention(nn.Module): class LlamaDecoderLayer(nn.Module): - def __init__( - self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[LlamaConfig] = None) -> None: super().__init__() + + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -335,7 +340,6 @@ class LlamaModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -357,10 +361,7 @@ class LlamaModel(nn.Module): self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index ddd7e6a5936e3..32d4f69c6bf18 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -28,7 +28,8 @@ from vllm.attention import Attention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, @@ -59,13 +61,16 @@ class Llama4MoE(nn.Module): router_scores = torch.sigmoid(router_scores.float()) return (router_scores, router_indices.to(torch.int32)) - def __init__(self, - config: Llama4TextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.top_k = config.num_experts_per_tok + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe intermediate_size_moe = config.intermediate_size self.router = ReplicatedLinear(config.hidden_size, @@ -82,6 +87,7 @@ class Llama4MoE(nn.Module): bias=False, prefix=f"{prefix}.shared_expert", reduce_results=False, + disable_tp=self.is_sequence_parallel, ) self.experts = SharedFusedMoE( @@ -96,9 +102,14 @@ class Llama4MoE(nn.Module): renormalize=False, quant_config=quant_config, prefix=f"{prefix}.experts", + is_sequence_parallel=self.is_sequence_parallel, ) def forward(self, hidden_states): + num_tokens = hidden_states.shape[0] + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + router_logits, _ = self.router(hidden_states) shared_out, routed_out = self.experts( @@ -107,7 +118,10 @@ class Llama4MoE(nn.Module): ) experts_out = routed_out + shared_out - if self.tp_size > 1: + if self.is_sequence_parallel: + experts_out = tensor_model_parallel_all_gather(experts_out, 0) + experts_out = experts_out[:num_tokens] + elif self.tp_size > 1: experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( experts_out) @@ -257,15 +271,16 @@ class Llama4Attention(nn.Module): class Llama4DecoderLayer(nn.Module): - def __init__( - self, - config: Llama4TextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[Llama4TextConfig] = None) -> None: super().__init__() + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.layer_idx = extract_layer_index(prefix) self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.hidden_size = config.hidden_size @@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module): self.layer_idx + 1) % config.interleave_moe_layer_step == 0 if is_moe_layer: self.feed_forward = Llama4MoE( - config=config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.feed_forward", ) else: diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 235275c0940a1..0768edd083155 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -68,9 +68,9 @@ class LlamaModel(nn.Module): self.layers = nn.ModuleList([ Llama4DecoderLayer( - self.config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, ) for i in range(self.config.num_hidden_layers) ]) self.fc = torch.nn.Linear(self.config.hidden_size * 2, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index d6e6fd3fcfe9c..d7d6b1745fc8d 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def __init__( self, - config: LlamaConfig, + vllm_config: VllmConfig, disable_input_layernorm: bool, prefix: str = "", + config: Optional[LlamaConfig] = None, ) -> None: - super().__init__(config, prefix=prefix) + super().__init__(vllm_config, prefix=prefix, config=config) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 @@ -64,9 +65,10 @@ class LlamaModel(nn.Module): self.layers = nn.ModuleList([ LlamaDecoderLayer( - self.config, + vllm_config, i == 0, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, ) for i in range(self.config.num_hidden_layers) ]) self.fc = torch.nn.Linear(self.config.hidden_size * 2, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 34b8ea0ca5360..7192a76c87498 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -8,13 +8,11 @@ import torch import torch.nn as nn from transformers import LlamaConfig -from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -28,17 +26,14 @@ logger = init_logger(__name__) class LlamaDecoderLayer(LlamaDecoderLayer): - def __init__( - self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[LlamaConfig] = None) -> None: + super().__init__(vllm_config, prefix=prefix, config=config) + + config = config or vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config # override qkv self.self_attn.qkv_proj = QKVParallelLinear( @@ -125,9 +120,9 @@ class LlamaModel(nn.Module): self.layers = nn.ModuleList([ LlamaDecoderLayer( - config=self.config, - cache_config=current_vllm_config.cache_config, + current_vllm_config, prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + config=self.config, ) ]) if hasattr(self.config, "target_hidden_size"): diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index cb2ff97a5df25..45b9c656a4bbd 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -29,13 +29,13 @@ from typing import Any, Optional, Union import torch from torch import nn -from transformers import Qwen3MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, - config: Qwen3MoeConfig, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", - enable_eplb: bool = False, ): super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group @@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts @@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module): quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel) self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, @@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): assert hidden_states.dim( ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" is_input_1d = hidden_states.dim() == 1 - hidden_dim = hidden_states.shape[-1] + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + # return to 1d if input is 1d return final_hidden_states.squeeze(0) if is_input_1d else \ final_hidden_states @@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module): class Qwen3MoeDecoderLayer(nn.Module): - def __init__( - self, - config: Qwen3MoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - enable_eplb: bool = False, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module): if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen3MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb) + self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -362,10 +373,8 @@ class Qwen3MoeModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config.get_text_config() - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module): prefix=f"{prefix}.embed_tokens") self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen3MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - enable_eplb=enable_eplb), + lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, + prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index dc3153fcc826b..14d19874a51e0 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config) from vllm.distributed import (divide, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fla.ops import ( @@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor] class Qwen3NextSparseMoeBlock(nn.Module): - def __init__( - self, - config: Qwen3NextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - enable_eplb: bool = False, - ): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group @@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module): self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts @@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module): quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel) self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, @@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + shared_output = None if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) @@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module): if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) @@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module): def __init__( self, - config: Qwen3NextConfig, + vllm_config: VllmConfig, layer_type: str, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, prefix: str = "", - enable_eplb: bool = False, ) -> None: super().__init__() - self.config = config + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config self.layer_type = layer_type self.layer_idx = extract_layer_index(prefix) @@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module): config.num_experts > 0 and (self.layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen3NextSparseMoeBlock( - config=config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, ) else: self.mlp = Qwen3NextMLP( @@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module): torch.zeros( 1, 1, - self.config.hidden_size, + config.hidden_size, dtype=config.torch_dtype, ), ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, - self.config.hidden_size, + config.hidden_size, dtype=config.torch_dtype, ), ) @@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module): super().__init__() config: Qwen3NextConfig = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config lora_config = vllm_config.lora_config - speculative_config = vllm_config.speculative_config - enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module): def get_layer(prefix: str): return Qwen3NextDecoderLayer( - config, + vllm_config, layer_type=config.layer_types[extract_layer_index(prefix)], - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, prefix=prefix, - enable_eplb=enable_eplb, ) self.start_layer, self.end_layer, self.layers = make_layers( diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index c054339842e64..e950699a0c499 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module): super().__init__() model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config config: Qwen3NextConfig = model_config.hf_config @@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module): self.layers = torch.nn.ModuleList( Qwen3NextDecoderLayer( - config, + vllm_config, layer_type="full_attention", - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, prefix=f'{prefix}.layers.{idx}', ) for idx in range(self.num_mtp_layers)) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index bb6a0bd022021..4bf151fbf62d1 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -13,11 +13,14 @@ from transformers import PretrainedConfig import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors -from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, +from vllm.utils import (cdiv, direct_register_custom_op, + get_cuda_view_from_cpu_tensor, is_pin_memory_available, is_uva_available) logger = init_logger(__name__) @@ -743,3 +746,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int: return hf_config.hidden_size text_config = hf_config.get_text_config() return text_config.hidden_size + + +# Chunk x along the num_tokens axis for sequence parallelism +# NOTE: This is wrapped in a torch custom op to work around the following issue: +# The output tensor can have a sequence length 0 at small input sequence lengths +# even though we explicitly pad to avoid this. +def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.sequence_parallel_chunk_impl(x) + + +def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # all_gather needs the sequence length to be divisible by tp_size + seq_len = x.size(0) + remainder = seq_len % tp_size + if remainder != 0: + pad_len = tp_size - remainder + y = nn.functional.pad(x, (0, 0, 0, pad_len)) + else: + y = x + + chunk = y.shape[0] // tp_size + start = tp_rank * chunk + return torch.narrow(y, 0, start, chunk) + + +def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + seq_len = cdiv(x.size(0), tp_size) + shape = list(x.shape) + shape[0] = seq_len + out = torch.empty(shape, dtype=x.dtype, device=x.device) + return out + + +direct_register_custom_op( + op_name="sequence_parallel_chunk_impl", + op_func=sequence_parallel_chunk_impl, + fake_impl=sequence_parallel_chunk_impl_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +)