From 94846416166c731939892350d7ab26dcbcb2982d Mon Sep 17 00:00:00 2001 From: Song <44120206+Oliver-ss@users.noreply.github.com> Date: Thu, 31 Jul 2025 23:19:06 +0800 Subject: [PATCH 1/9] [Model] Add step3 vl (#21998) Signed-off-by: oliveryuan Co-authored-by: oliveryuan --- docs/models/supported_models.md | 1 + tests/models/registry.py | 6 + .../openai/tool_parsers/__init__.py | 2 + .../openai/tool_parsers/step3_tool_parser.py | 296 +++++ vllm/model_executor/models/registry.py | 2 + vllm/model_executor/models/step3_text.py | 521 ++++++++ vllm/model_executor/models/step3_vl.py | 1052 +++++++++++++++++ vllm/reasoning/__init__.py | 2 + vllm/reasoning/step3_reasoning_parser.py | 109 ++ vllm/transformers_utils/config.py | 5 +- vllm/transformers_utils/configs/__init__.py | 6 + vllm/transformers_utils/configs/step3_vl.py | 123 ++ 12 files changed, 2124 insertions(+), 1 deletion(-) create mode 100644 vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py create mode 100644 vllm/model_executor/models/step3_text.py create mode 100644 vllm/model_executor/models/step3_vl.py create mode 100644 vllm/reasoning/step3_reasoning_parser.py create mode 100644 vllm/transformers_utils/configs/step3_vl.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5a9823bb6bae7..f5d9e3b22f2a6 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | +| `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | | `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 8fcff5a8c5113..b9e7de4e9fd11 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), + "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", + trust_remote_code=True, + is_available_online=False), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct", trust_remote_code=True), "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", @@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", trust_remote_code=True), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 + "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3", + trust_remote_code=True, + is_available_online=False), "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501 diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 88c8aa929b78d..099e456aa486f 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser __all__ = [ @@ -40,4 +41,5 @@ __all__ = [ "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "Step3ToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py new file mode 100644 index 0000000000000..a20d18eb52544 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import json +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["step3"]) +class Step3ToolParser(ToolParser): + """ + Tool parser for a model that uses a specific XML-like format for tool calls. + This version uses a robust, stateful, cursor-based streaming parser and + consolidates tool arguments into a single message. + """ + + TOOL_CALLS_BEGIN = "<|tool_calls_begin|>" + TOOL_CALLS_END = "<|tool_calls_end|>" + TOOL_CALL_BEGIN = "<|tool_call_begin|>" + TOOL_CALL_END = "<|tool_call_end|>" + TOOL_SEP = "<|tool_sep|>" + SPECIAL_TOKENS = [ + TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END + ] + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + # Explicit state flags for robust streaming + self.tool_block_started = False + self.tool_block_finished = False + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + request.skip_special_tokens = False + return request + + @staticmethod + def _parse_steptml_invoke( + action_text: str + ) -> tuple[Optional[str], Optional[dict[str, str]]]: + func_name_match = re.search(r'', + action_text) + if not func_name_match: + return None, None + func_name = func_name_match.group(1) + + params: dict[str, str] = {} + param_matches = re.findall( + r'([^<]*)', + action_text) + for name, value in param_matches: + params[name] = value.strip() + return func_name, params + + def _cast_arguments( + self, + func_name: str, + params: dict[str, Any], + request: ChatCompletionRequest, + ) -> dict[str, Any]: + for tool in request.tools or []: + if tool.function.name == func_name: + schema = tool.function.parameters or {} + properties = schema.get("properties", {}) + for key, value in params.items(): + if not isinstance(value, str): + continue + prop = properties.get(key, {}) + typ = prop.get("type") + if typ == "string": + params[key] = value.strip() + elif typ == "integer": + with contextlib.suppress(ValueError): + params[key] = int(value) + elif typ == "number": + with contextlib.suppress(ValueError): + params[key] = float(value) + elif typ == "boolean": + lower_val = value.lower() + params[key] = lower_val == "true" if lower_val in ( + "true", "false") else value + elif typ == "null": + params[key] = None if value.lower( + ) == "null" else value + break + return params + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # The main loop processes the stream from the last known position. + while True: + if self.position >= len(current_text): + return None # We've processed the entire stream. + + unprocessed_text = current_text[self.position:] + + # STATE: After all tools are done, all subsequent text is content. + if self.tool_block_finished: + self.position = len(current_text) + return DeltaMessage(content=unprocessed_text) + + # STATE: Before the tool block has started. + if not self.tool_block_started: + if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN): + self.position += len(self.TOOL_CALLS_BEGIN) + self.tool_block_started = True + continue # Token consumed, re-loop. + + start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN) + if start_pos == -1: + if self.TOOL_CALLS_BEGIN.startswith( + unprocessed_text.strip()) and unprocessed_text: + return None # It's a prefix, wait. + self.position = len(current_text) + return DeltaMessage(content=unprocessed_text) + else: + content = unprocessed_text[:start_pos] + self.position += len(content) + return DeltaMessage(content=content) + + # STATE: Inside the main tool block. + offset = len(unprocessed_text) - len(unprocessed_text.lstrip()) + unprocessed_text = unprocessed_text.lstrip() + self.position += offset + + if unprocessed_text.startswith(self.TOOL_CALLS_END): + self.position += len(self.TOOL_CALLS_END) + self.tool_block_finished = True + self.current_tool_id = -1 + continue + + # Check if we are between tool calls. + tool_finished = ( + self.current_tool_id != -1 and + self.prev_tool_call_arr[self.current_tool_id].get("finished")) + if self.current_tool_id == -1 or tool_finished: + if unprocessed_text.startswith(self.TOOL_CALL_BEGIN): + self.position += len(self.TOOL_CALL_BEGIN) + if self.current_tool_id == -1: + self.current_tool_id = 0 + else: + self.current_tool_id += 1 + self.current_tool_name_sent = False + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + self.prev_tool_call_arr[ + self.current_tool_id]["finished"] = False + continue + + if self.TOOL_CALL_BEGIN.startswith(unprocessed_text): + return None + + # STATE: Parsing an active tool call. + if self.current_tool_id != -1 and not self.prev_tool_call_arr[ + self.current_tool_id].get("finished", False): + end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END) + if end_tool_pos == -1: + tool_body = unprocessed_text + else: + tool_body = unprocessed_text[:end_tool_pos] + + if end_tool_pos == -1 and self.TOOL_CALL_END.startswith( + tool_body): + return None + + function_name, arguments = self._parse_steptml_invoke( + tool_body) + if not function_name: + return None + + tool_call_arr = { + "name": function_name, + "parameters": arguments or {} + } + + # Send the function name as soon as it's parsed. + if not self.current_tool_name_sent: + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id].update( + tool_call_arr) + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name)) + ]) + + # Update our internal state with the latest parsed arguments. + self.prev_tool_call_arr[ + self.current_tool_id].update( # noqa: E501 + tool_call_arr) + + # Only send arguments when the tool call is complete. + if end_tool_pos != -1: + self.position += end_tool_pos + len(self.TOOL_CALL_END) + self.prev_tool_call_arr[ + self.current_tool_id]["finished"] = True + + final_args = self._cast_arguments( + function_name, + tool_call_arr.get("parameters", {}), # type: ignore + request) + if final_args: + final_args_json = json.dumps(final_args, + ensure_ascii=False) + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=final_args_json)) + ]) + + # If tool is not finished, return None to wait for more tokens. + return None + + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + if self.TOOL_CALLS_BEGIN not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1) + if self.TOOL_CALLS_END not in rest: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1) + content = (pre_text + post_text).strip() + + tool_calls: list[ToolCall] = [] + call_parts = tool_block.split(self.TOOL_CALL_BEGIN) + + for part in call_parts: + if not part or self.TOOL_CALL_END not in part: + continue + + call_content = part.split(self.TOOL_CALL_END, 1)[0] + if self.TOOL_SEP not in call_content: + continue + + type_part, invoke_part = call_content.split(self.TOOL_SEP, 1) + if type_part.strip() != "function": + continue + + function_name, params_dict = self._parse_steptml_invoke( + invoke_part) + + if function_name and params_dict is not None: + params_dict = self._cast_arguments(function_name, params_dict, + request) + params_str = json.dumps(params_dict, ensure_ascii=False) + tool_calls.append( + ToolCall(function=FunctionCall(name=function_name, + arguments=params_str))) + if tool_calls: + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 51831a770347a..848c04b9b32f7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = { "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), @@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = { "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), + "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py new file mode 100644 index 0000000000000..47d2af5c2a140 --- /dev/null +++ b/vllm/model_executor/models/step3_text.py @@ -0,0 +1,521 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Jurassic model.""" +from collections.abc import Iterable +from typing import Any, Optional + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +logger = init_logger(__name__) + + +class FusedMoEBlock(nn.Module): + + def __init__(self, + config: ModelConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.moe_num_experts}.") + + self.experts = FusedMoE(num_experts=config.moe_num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_expert_weight, + quant_config=quant_config, + prefix=f"{prefix}.experts") + self.gate = ReplicatedLinear(config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits, _ = self.gate(hidden_states) + + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Step3TextMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + self.hidden_size = hidden_size + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + intermediate_act = self.act_fn(gate_up) + output, _ = self.down_proj(intermediate_act) + return output + + +class Step3TextAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + norm_eps: float, + rope_theta: int, + share_q_dim: Optional[int] = None, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embedding: int = 8192, + head_dim: int = 256, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + if num_kv_heads != 1: + raise ValueError(f"Step3TextAttention num_kv_heads must be 1, " + f"but got {num_kv_heads}.") + self.num_kv_heads = num_kv_heads + + self.head_dim = head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.q_size = share_q_dim if share_q_dim else self.head_dim + + self.qkv_proj = ReplicatedLinear( + hidden_size, + self.q_size + self.kv_size * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.inter_norm = RMSNorm(self.q_size, eps=norm_eps) + self.wq = ColumnParallelLinear( + self.q_size, + self.head_dim * self.total_num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq", + ) + self.rotary_emb = get_rope(self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embedding, + base=rope_theta, + rope_scaling=rope_scaling) + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + def forward(self, positions: torch.Tensor, + hidden_states: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.inter_norm(q) + q = self.wq(q)[0] + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + residual, _ = self.o_proj(attn_output) + return residual + + +class Step3TextDecoderLayer(nn.Module): + + def __init__(self, + config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + config = config.hf_config + self.hidden_size = config.hidden_size + rope_scaling = getattr(config, "rope_scaling", None) + + self.self_attn = Step3TextAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + norm_eps=config.rms_norm_eps, + max_position_embedding=config.max_position_embedding, + head_dim=config.head_dim, + share_q_dim=config.share_q_dim, + rope_theta=config.rope_theta, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn") + + layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [ + int(i) for i in moe_layers_enum.strip().split(',') + ] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + if layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.moe") + self.share_expert = Step3TextMLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.share_expert") + self.use_moe = True + else: + self.mlp = Step3TextMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.use_moe = False + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + hidden_states = share_output + moe_output + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Step3TextModel(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = config.vocab_size + self.config = config + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Step3TextDecoderLayer(config=vllm_config. + model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Step3TextForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + self.config = config + self.vllm_config = vllm_config + + self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None): + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + qkv_params_mapping = [ + # (param_name, shard_name, relative_start_idx, relative_end_idx) + (".qkv_proj", ".q_proj", 0, self.config.share_q_dim / + (self.config.share_q_dim + self.config.head_dim * 2)), + (".qkv_proj", ".k_proj", self.config.share_q_dim / + (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim) / + (self.config.share_q_dim + self.config.head_dim * 2)), + (".qkv_proj", ".v_proj", + (self.config.share_q_dim + self.config.head_dim) / + (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim * 2) / + (self.config.share_q_dim + self.config.head_dim * 2)), + ] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + ] + + disable_moe_stacked_params = [ + data[1] for data in expert_params_mapping + ] + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if any(disable_moe_stacked_param in name + for disable_moe_stacked_param in + disable_moe_stacked_params): + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + for expert_id in range(loaded_weight.shape[0]): + loaded_weight_expert = loaded_weight[expert_id] + weight_loader(param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id) + loaded_params.add(name) + break + else: + for (param_name, weight_name, start_idx, + end_idx) in qkv_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + dim = param.shape[param.output_dim] + begin_idx = int(start_idx * dim) + end_idx = int(end_idx * dim) + param_slice = param.narrow(param.output_dim, begin_idx, + end_idx - begin_idx) + param_slice.copy_(loaded_weight) + loaded_params.add(name) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py new file mode 100644 index 0000000000000..363c12a4bf2b8 --- /dev/null +++ b/vllm/model_executor/models/step3_vl.py @@ -0,0 +1,1052 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from itertools import product +from math import ceil, sqrt +from typing import Any, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from transformers import BatchFeature, PretrainedConfig, TensorType + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Step3VisionEncoderConfig +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + + +class Step3VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + patch_pixel_values: Optional[torch.Tensor] + num_patches: list[int] + + +class Step3VLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + image_embeds: torch.Tensor + + +Step3VLImageInputs = Union[Step3VLImagePixelInputs, + Step3VLImageEmbeddingInputs] + +ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] + +MAX_IMAGE_SIZE: int = 3024 + + +class Step3VisionProcessor: + + def __init__(self, size, interpolation_mode="bicubic", patch_size=None): + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + patch_size = patch_size if patch_size is not None else size + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (size, size), + interpolation=InterpolationMode.BICUBIC if interpolation_mode + == "bicubic" else InterpolationMode.BILINEAR, + antialias=True), + ]) + + self.patch_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (patch_size, patch_size), + interpolation=InterpolationMode.BICUBIC if interpolation_mode + == "bicubic" else InterpolationMode.BILINEAR, + antialias=True), + ]) if patch_size is not None else None + + def __call__(self, image, is_patch=False): + if is_patch: + return {"pixel_values": self.patch_transform(image).unsqueeze(0)} + else: + return {"pixel_values": self.transform(image).unsqueeze(0)} + + +class ImagePatcher: + + def determine_window_size(self, long: int, short: int) -> int: + if long <= 728: + return short if long / short > 1.5 else 0 + return min(short, 504) if long / short > 4 else 504 + + def slide_window( + self, + width: int, + height: int, + sizes: list[tuple[int, int]], + steps: list[tuple[int, int]], + img_rate_thr: float = 0.6, + ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]: + assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1" + windows = [] + # Sliding windows. + for size, step in zip(sizes, steps): + size_w, size_h = size + step_w, step_h = step + + x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + + 1) + x_start = [step_w * i for i in range(x_num)] + if len(x_start) > 1 and x_start[-1] + size_w > width: + x_start[-1] = width - size_w + + y_num = 1 if height <= size_h else ceil((height - size_h) / + step_h + 1) + y_start = [step_h * i for i in range(y_num)] + if len(y_start) > 1 and y_start[-1] + size_h > height: + y_start[-1] = height - size_h + + start = np.array(list(product(y_start, x_start)), dtype=int) + start[:, [0, 1]] = start[:, [1, 0]] + windows.append(np.concatenate([start, start + size], axis=1)) + windows = np.concatenate(windows, axis=0) + + return [(int(box[0]), int(box[1]), int(box[2] - box[0]), + int(box[3] - box[1])) for box in windows], (x_num, y_num) + + def square_pad(self, img: Image.Image) -> Image.Image: + w, h = img.size + if w == h: + return img + size = max(w, h) + padded = Image.new(img.mode, (size, size), 0) + padded.paste(img, (0, 0)) + return padded + + def get_image_size_for_padding(self, img_width: int, + img_height: int) -> tuple[int, int]: + ratio = img_width / img_height + if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4): + new_size = max(img_height, img_width) + return new_size, new_size + return img_width, img_height + + def get_image_size_for_preprocess(self, img_width: int, + img_height: int) -> tuple[int, int]: + + if max(img_height, img_width) > MAX_IMAGE_SIZE: + scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width) + img_width = int(img_width * scale_factor) + img_height = int(img_height * scale_factor) + return img_width, img_height + + def get_image_size_for_crop(self, img_width: int, img_height: int, + window_size: int): + w_ratio = img_width / window_size + h_ratio = img_height / window_size + + if w_ratio < 1: + width_new = img_width + else: + decimal_w = w_ratio - img_width // window_size + w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio) + width_new = window_size * w_ratio + if h_ratio < 1: + height_new = img_height + else: + decimal_h = h_ratio - img_height // window_size + h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio) + height_new = window_size * h_ratio + return int(width_new), int(height_new) + + def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int): + target = img.crop((j, i, j + tw, i + th)) + return target + + def get_num_patches(self, img_width: int, + img_height: int) -> tuple[int, int]: + img_width, img_height = self.get_image_size_for_padding( + img_width, img_height) + img_width, img_height = self.get_image_size_for_preprocess( + img_width, img_height) + window_size = self.determine_window_size(max(img_height, img_width), + min(img_height, img_width)) + if window_size == 0: + return 0, 0 + else: + img_width, img_height = self.get_image_size_for_crop( + img_width, img_height, window_size) + center_list, (x_num, y_num) = self.slide_window( + img_width, img_height, [(window_size, window_size)], + [(window_size, window_size)]) + full_rows = (len(center_list) - 1) // x_num + 1 + if len(center_list) > 0 and len(center_list) % x_num == 0: + full_rows -= 1 + return len(center_list), full_rows + + def __call__( + self, img: Image.Image + ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: + img_width, img_height = img.size + new_img_width, new_img_height = self.get_image_size_for_padding( + img_width, img_height) + if new_img_width != img_width or new_img_height != img_height: + img = self.square_pad(img) + img_width, img_height = img.size + + new_img_width, new_img_height = self.get_image_size_for_preprocess( + img_width, img_height) + img = img.resize((new_img_width, new_img_height), + Image.Resampling.BILINEAR) + window_size = self.determine_window_size( + max(new_img_height, new_img_width), + min(new_img_height, new_img_width)) + + if window_size == 0: + return img, [], None + else: + new_img_width, new_img_height = self.get_image_size_for_crop( + new_img_width, new_img_height, window_size) + if (new_img_width, new_img_height) != (img_width, img_height): + img_for_crop = img.resize((new_img_width, new_img_height), + Image.Resampling.BILINEAR) + else: + img_for_crop = img + + patches = [] + newlines = [] + center_list, (x_num, y_num) = self.slide_window( + new_img_width, new_img_height, [(window_size, window_size)], + [(window_size, window_size)]) + for patch_id, center_lf_point in enumerate(center_list): + x, y, patch_w, patch_h = center_lf_point + big_patch = self.patch_crop(img_for_crop, y, x, patch_h, + patch_w) + patches.append(big_patch) + if (patch_id + 1) % x_num == 0: + newlines.append(patch_id) + + if newlines and newlines[-1] == len(patches) - 1: + newlines.pop() + + return img, patches, [i in newlines for i in range(len(patches)) + ] if len(patches) > 0 else None + + +class Step3VLProcessor: + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + self.image_size = 728 + self.patch_size = 504 + self.image_preprocessor = Step3VisionProcessor(self.image_size, + "bilinear", + self.patch_size) + + self.num_image_feature_size = 169 + self.num_patch_feature_size = 81 + self.image_token = "" + self.image_feature_placeholder = (self.image_token * + self.num_image_feature_size) + self.patch_feature_placeholder = (self.image_token * + self.num_patch_feature_size) + + self.patcher = ImagePatcher() + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[self.image_token] + + def get_num_image_tokens(self, img_width: int, img_height: int) -> int: + num_patches, num_newlines = self.patcher.get_num_patches( + img_width, img_height) + + return num_patches * ( + self.num_patch_feature_size + + 2) + self.num_image_feature_size + 2 + num_newlines + + def _split_images(self, + images: list[Image.Image]) -> list[ImageWithPatches]: + result = [] + for img in images: + result.append(self.patcher(img)) + return result + + def _convert_images_to_pixel_values( + self, + images: list[Image.Image], + is_patch: bool = False, + ) -> list[torch.Tensor]: + return [ + self.image_preprocessor(img, is_patch=is_patch)["pixel_values"] + for img in images + ] + + def _get_patch_repl( + self, + num_patches: int, + patch_newline_mask: list[bool] | None, + ) -> tuple[str, list[int]]: + text = "" + token_ids = [] + for i in range(num_patches): + assert len(patch_newline_mask) == num_patches + text += f"{self.patch_feature_placeholder}" + token_ids.extend( + [self.tokenizer.convert_tokens_to_ids("")] + + [self.image_token_id] * self.num_patch_feature_size + + [self.tokenizer.convert_tokens_to_ids("")]) + if patch_newline_mask and patch_newline_mask[i]: + text += "" + token_ids.append( + self.tokenizer.convert_tokens_to_ids("")) + return text, token_ids + + def _get_image_repl( + self, + num_images: int, + ) -> tuple[str, list[int]]: + text = f"{self.image_feature_placeholder}" + token_ids = [ + self.tokenizer.convert_tokens_to_ids("") + ] + [self.image_token_id] * self.num_image_feature_size + [ + self.tokenizer.convert_tokens_to_ids("") + ] + return text * num_images, token_ids * num_images + + def _get_image_repl_features( + self, + num_images: int, + num_patches: int, + patch_new_line_idx: Optional[list[bool]], + ) -> tuple[str, list[int]]: + if num_patches > 0: + patch_repl, patch_repl_ids = self._get_patch_repl( + num_patches, patch_new_line_idx) + else: + patch_repl = "" + patch_repl_ids = [] + image_repl, image_repl_ids = self._get_image_repl(num_images) + return patch_repl + image_repl, patch_repl_ids + image_repl_ids + + def replace_placeholder(self, text: str, placeholder: str, + repls: list[str]) -> str: + parts = text.split(placeholder) + + if len(parts) - 1 != len(repls): + raise ValueError( + "The number of placeholders does not match the number of replacements." # noqa: E501 + ) + + result = [parts[0]] + for i, repl in enumerate(repls): + result.append(repl) + result.append(parts[i + 1]) + + return "".join(result) + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if len(images) == 0: + image_inputs = {} + text_inputs = self.tokenizer(text) + else: + splitted_images_data = self._split_images(images) + pixel_values_lst = [] + patch_pixel_values_lst = [] + patch_newline_mask_lst = [] + image_repl_str_lst = [] + image_repl_ids_lst = [] + num_patches = [] + for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501 + pixel_values_lst.extend( + self._convert_images_to_pixel_values([raw_img])) + + if len(img_patches) > 0: + patch_pixel_values_lst.extend( + self._convert_images_to_pixel_values(img_patches, + is_patch=True)) + num_patches.append(len(img_patches)) + + image_repl_str, image_repl_ids = self._get_image_repl_features( + 1, len(img_patches), patch_newline_mask) + image_repl_str_lst.append(image_repl_str) + image_repl_ids_lst.extend(image_repl_ids) + + if patch_newline_mask is not None: + patch_newline_mask_lst.extend(patch_newline_mask) + + image_inputs = { + "pixel_values": torch.cat(pixel_values_lst), + "num_patches": num_patches, + } + if patch_pixel_values_lst: + image_inputs["patch_pixel_values"] = torch.cat( + patch_pixel_values_lst) + if patch_newline_mask_lst: + image_inputs["patch_newline_mask"] = torch.tensor( + patch_newline_mask_lst, dtype=torch.bool) + + text = [ + self.replace_placeholder(t, self.image_token, + image_repl_str_lst) for t in text + ] + text_inputs = self.tokenizer(text) + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +class Step3VLProcessingInfo(BaseProcessingInfo): + + def get_hf_processor(self) -> Step3VLProcessor: + return Step3VLProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_max_image_tokens(self) -> int: + hf_processor = self.get_hf_processor() + return hf_processor.get_num_image_tokens( + self.get_image_size_with_most_features().width, + self.get_image_size_with_most_features().height) + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_image_size_with_most_features(self) -> ImageSize: + return ImageSize(3024, 3024) + + def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int: + if len(mm_data) != 1 or "image" not in mm_data: + raise ValueError( + "mm_data could only contain one key 'image' for steo1o") + + image_data = mm_data["image"] + if not isinstance(image_data, (list, tuple)): + image_data = [image_data] + + return sum(self.get_hf_processor().get_num_image_tokens( + img.width, img.height) for img in image_data) + + +class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + return "" * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] + ): + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_placeholder_token_id = hf_processor.image_token_id + batch_num_patches = out_mm_kwargs["num_patches"].tolist() + + def get_replacement_step1o(item_idx: int): + img_out = out_mm_kwargs.get_item("image", item_idx) + num_patches = batch_num_patches[item_idx] + if num_patches > 0: + patch_newline_mask = img_out["patch_newline_mask"].data.tolist( + ) + image_repl_ids = hf_processor._get_image_repl_features( + 1, num_patches, patch_newline_mask)[1] + else: + image_repl_ids = hf_processor._get_image_repl_features( + 1, 0, None)[1] + return PromptUpdateDetails.select_token_id( + seq=image_repl_ids, + embed_token_id=image_placeholder_token_id, + ) + + return [ + PromptReplacement( + modality="image", + target=[image_placeholder_token_id], + replacement=get_replacement_step1o, + ) + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + patch_pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), + patch_newline_mask=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + ) + + +def get_abs_pos(abs_pos, tgt_size): + dim = abs_pos.size(-1) + abs_pos_new = abs_pos.squeeze(0) + cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] + + src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + old_pos_embed = old_pos_embed.view(1, src_size, src_size, + dim).permute(0, 3, 1, + 2).contiguous() + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode='bicubic', + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) + vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, + dim) + return vision_pos_embed + else: + return abs_pos + + +class Step3VisionEmbeddings(nn.Module): + + def __init__(self, config: Step3VisionEncoderConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=True, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.pad_tp_size = 4 # hard code for padding + # To load the pretrained weights, we still use P+1 as the seqlen + self.position_embedding = torch.nn.Embedding(self.num_patches + 1, + self.embed_dim) + self.register_buffer("position_ids", + torch.arange(self.num_patches + 1).expand( + (1, -1)), + persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + # pad + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + get_abs_pos( + self.position_embedding(self.position_ids), patch_embeds.size(1)) + embeddings = torch.cat([ + embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, + 1), embeddings + ], + dim=1) + return embeddings + + +class Step3VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.total_num_heads + + self.scale = self.head_dim**-0.5 + + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.qkv_proj = QKVParallelLinear(self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=prefix) + self.out_proj = RowParallelLinear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=prefix) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) + k = k.view(bsz, tgt_len, self.num_heads, self.head_dim) + v = v.view(bsz, tgt_len, self.num_heads, self.head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + scale=self.scale, + is_causal=False) + attn_output = attn_output.transpose(1, 2).reshape( + bsz, tgt_len, self.num_heads * self.head_dim) + + attn_output, _ = self.out_proj(attn_output) + + return attn_output + + +class Step3VisionMLP(nn.Module): + + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=prefix) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=prefix) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Step3VisionEncoderLayer(nn.Module): + + def __init__(self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Step3VisionAttention(config, + quant_config, + prefix=f"{prefix}.self_attn") + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.FloatTensor: + hidden_states = hidden_states + self.layer_norm1( + self.self_attn(hidden_states)) + hidden_states = hidden_states + self.layer_norm2( + self.mlp(hidden_states)) + return hidden_states + + +class Step3VisionEncoder(nn.Module): + + def __init__(self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Step3VisionEncoderLayer(config, + quant_config, + prefix=f"{prefix}.layers.{i}") + for i in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + return hidden_states + + +class Step3VisionTransformer(nn.Module): + + def __init__(self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.image_size = config.image_size + self.embeddings = Step3VisionEmbeddings(config) + self.transformer = Step3VisionEncoder(config, + quant_config, + prefix=f"{prefix}.transformer") + + def forward( + self, + pixel_values: torch.Tensor, + ): + hidden_states = self.embeddings(pixel_values) + hidden_states = self.transformer(inputs_embeds=hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor, + info=Step3VLProcessingInfo, + dummy_inputs=Step3VLDummyInputsBuilder) +class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_model = Step3VisionTransformer(config.vision_config, + None, + prefix=maybe_prefix( + prefix, "vision_model")) + self.vit_downsampler = nn.Conv2d( + config.vision_config.hidden_size, + config.vision_config.output_hidden_size, + kernel_size=2, + stride=config.understand_projector_stride) + self.vit_downsampler2 = nn.Conv2d( + config.vision_config.output_hidden_size, + config.vision_config.output_hidden_size * 2, + kernel_size=3, + stride=2, + padding=1, + ) + self.vit_large_projector = nn.Linear( + config.vision_config.output_hidden_size * 2, + config.hidden_size, + bias=config.projector_bias, + ) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Step3VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + patch_pixel_values = kwargs.pop("patch_pixel_values", None) + num_patches = kwargs.pop("num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = flatten_bn(pixel_values, concat=True) + if pixel_values.dim() >= 3: + pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) + if patch_pixel_values is not None: + patch_pixel_values = flatten_bn(patch_pixel_values, + concat=True) + patch_pixel_values = patch_pixel_values.view( + -1, *patch_pixel_values.shape[-3:]) + # Handle empty patch_pixel_values by setting to None + if patch_pixel_values.shape[0] == 0: + patch_pixel_values = None + num_patches = flatten_bn(num_patches, concat=True).tolist() + + return Step3VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values.to(self.dtype).to(self.device), + patch_pixel_values=patch_pixel_values.to(self.dtype).to( + self.device) if patch_pixel_values is not None else None, + num_patches=num_patches, + ) + + if image_embeds is not None: + if image_embeds.dim() == 2 or image_embeds.dim() >= 3: + image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) + else: + raise ValueError( + f"Unexpected shape for image_embeds: {image_embeds.shape}") + + return Step3VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds.to(self.dtype).to(self.device), + ) + return None + + def _process_image_features(self, + image_features: torch.Tensor) -> torch.Tensor: + B, P = image_features.shape[:2] + HW = int(sqrt(P)) + image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) + image_features = self.vit_downsampler(image_features) + image_features = self.vit_downsampler2(image_features) + n_dim = image_features.size(1) + image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1) + image_features = self.vit_large_projector(image_features) + return image_features + + def _get_vision_model_output(self, + input_tensor: torch.Tensor) -> torch.Tensor: + return self.vision_model(input_tensor)[:, 4:] + + def _process_image_input( + self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]: + + if image_input["type"] == "image_embeds": + image_features = image_input["image_embeds"] + else: + image_features = self._get_vision_model_output( + image_input["pixel_values"]) + patch_image_features = self._get_vision_model_output( + image_input["patch_pixel_values"] + ) if image_input["patch_pixel_values"] is not None else None + num_patches = image_input["num_patches"] + + image_features = self._process_image_features(image_features) + patch_image_features = self._process_image_features( + patch_image_features) if patch_image_features is not None else None + + merged_image_features = [] + cur_patch_idx = 0 + for i, num_patch in enumerate(num_patches): + cur_feature = [] + if num_patch > 0: + patch_slice = patch_image_features[ + cur_patch_idx:cur_patch_idx + num_patch] + cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) + cur_feature.append(image_features[i].view( + -1, image_features.shape[-1])) + cur_patch_idx += num_patch + merged_image_features.append( + torch.cat(cur_feature) if len(cur_feature) > + 1 else cur_feature[0]) + return merged_image_features + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + if multimodal_embeddings is None: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + else: + is_text = input_ids != self.config.image_token_id + text_ids = input_ids[is_text] + text_embeds = self.language_model.model.get_input_embeddings( + text_ids) + inputs_embeds = torch.empty(input_ids.shape[0], + text_embeds.shape[-1], + dtype=text_embeds.dtype, + device=text_embeds.device) + inputs_embeds[is_text] = text_embeds + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + return loaded_weights diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index d61e4f11dfa29..1c3f78f2edbfb 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser +from .step3_reasoning_parser import Step3ReasoningParser __all__ = [ "ReasoningParser", @@ -18,4 +19,5 @@ __all__ = [ "Qwen3ReasoningParser", "Glm4MoeModelReasoningParser", "MistralReasoningParser", + "Step3ReasoningParser", ] diff --git a/vllm/reasoning/step3_reasoning_parser.py b/vllm/reasoning/step3_reasoning_parser.py new file mode 100644 index 0000000000000..f642ea977c580 --- /dev/null +++ b/vllm/reasoning/step3_reasoning_parser.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +import regex as re +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("step3") +class Step3ReasoningParser(ReasoningParser): + """ + Reasoning parser for Step3 model. + + The Step3 model uses token to denote the end of reasoning + text. This parser extracts all content before as reasoning content. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_end_token = "" + + self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", + re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_end_token_id = self.vocab.get(self.think_end_token) + if self.think_end_token_id is None: + raise RuntimeError( + "Step3 reasoning parser could not locate think end " + "token in the tokenizer!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text "abcxyz": + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special token + if len(delta_token_ids + ) == 1 and delta_token_ids[0] == self.think_end_token_id: + return None + + if self.think_end_token_id in delta_token_ids: + # in delta, extract reasoning content and remaining content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # already seen in previous text, everything is content + return DeltaMessage(content=delta_text) + else: + # No seen yet, everything is reasoning + return DeltaMessage(reasoning_content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + + # Check if the model output contains the token + if self.think_end_token not in model_output: + # If no token, everything is reasoning content + return model_output, None + else: + # Find the first occurrence of + end_index = model_output.find(self.think_end_token) + reasoning_content = model_output[:end_index] + + # Content after token + content = model_output[end_index + len(self.think_end_token):] + + if len(content) == 0: + content = None + + return reasoning_content, content + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.think_end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.think_end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.think_end_token_id) + 1:] diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4ce56cb3a6aac..fcaa48c1392a3 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, MllamaConfig, MLPSpeculatorConfig, Nemotron_Nano_VL_Config, NemotronConfig, NVLM_D_Config, - RWConfig, UltravoxConfig) + RWConfig, Step3TextConfig, + Step3VLConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file @@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, "ultravox": UltravoxConfig, + "step3_vl": Step3VLConfig, + "step3_text": Step3TextConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 7c7d859e4a325..96733da726181 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, + Step3VisionEncoderConfig, + Step3VLConfig) from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -42,4 +45,7 @@ __all__ = [ "Nemotron_Nano_VL_Config", "NVLM_D_Config", "UltravoxConfig", + "Step3VLConfig", + "Step3VisionEncoderConfig", + "Step3TextConfig", ] diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py new file mode 100644 index 0000000000000..fe3c72de69d28 --- /dev/null +++ b/vllm/transformers_utils/configs/step3_vl.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional, Union + +from transformers.configuration_utils import PretrainedConfig + + +class Step3VisionEncoderConfig(PretrainedConfig): + model_type = "step3_vision_encoder" + + def __init__( + self, + hidden_size=1792, + intermediate_size=3072, + output_hidden_size=4096, + num_hidden_layers=63, + num_attention_heads=16, + num_channels=3, + image_size=728, + patch_size=14, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + **kwargs, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.output_hidden_size = output_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + super().__init__(**kwargs) + + +class Step3TextConfig(PretrainedConfig): + model_type = "step3_text" + architectures = ["Step3TextForCausalLM"] + + def __init__( + self, + hidden_size: int = 7168, + intermediate_size: int = 18432, + num_attention_heads: int = 64, + num_attention_groups: int = 1, + num_hidden_layers: int = 61, + max_seq_len: int = 65536, + vocab_size: int = 128815, + rms_norm_eps: float = 1e-5, + moe_intermediate_size: int = 5120, + moe_num_experts: int = 48, + moe_top_k: int = 3, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embedding: int = 65536, + share_expert_dim: int = 5120, + share_q_dim: int = 2048, + head_dim: int = 256, + norm_expert_weight: bool = False, + moe_layers_enum: tuple[int, + ...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59), + **kwargs, + ) -> None: + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_attention_groups = num_attention_groups + self.num_hidden_layers = num_hidden_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.rms_norm_eps = rms_norm_eps + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.max_position_embedding = max_position_embedding + self.share_expert_dim = share_expert_dim + self.share_q_dim = share_q_dim + self.head_dim = head_dim + self.norm_expert_weight = norm_expert_weight + self.moe_layers_enum = moe_layers_enum + + super().__init__(**kwargs) + + +class Step3VLConfig(PretrainedConfig): + model_type = "step3_vl" + + def __init__( + self, + vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None, + text_config: Optional[Union[dict, Step3TextConfig]] = None, + understand_projector_stride: int = 1, + projector_bias: bool = True, + image_token_id: int = 128001, + **kwargs, + ) -> None: + if vision_config is None: + vision_config = Step3VisionEncoderConfig() + elif isinstance(vision_config, dict): + vision_config = Step3VisionEncoderConfig(**vision_config) + self.vision_config = vision_config + + if text_config is None: + text_config = Step3TextConfig() + elif isinstance(text_config, dict): + text_config = Step3TextConfig(**text_config) + self.text_config = text_config + + self.understand_projector_stride = understand_projector_stride + self.projector_bias = projector_bias + self.hidden_size = text_config.hidden_size + self.image_token_id = image_token_id + + super().__init__(**kwargs) From 7349d5268bf70b7a530c1e649884e4f926615f8e Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Thu, 31 Jul 2025 12:46:07 -0400 Subject: [PATCH 2/9] [ez] Remove a trailing space from compilation/decorators.py (#22028) --- vllm/compilation/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index f3592324d8cfa..1370862d580a5 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -108,7 +108,7 @@ def support_torch_compile( During runtime, when we actually mark dimensions of tensors, it depends on the value of arguments: - - if it is a single integer (can be negative), the corresponding dimension + - if it is a single integer (can be negative), the corresponding dimension of the argument will be marked as dynamic. - if it is `None`, ignored. - if it is `IntermediateTensors`, all the tensors in the intermediate From 58bb902186a87007deeeef2d2af02ed2b13bb182 Mon Sep 17 00:00:00 2001 From: Doug Smith Date: Thu, 31 Jul 2025 12:52:48 -0400 Subject: [PATCH 3/9] fix(setup): improve precompiled wheel setup for Docker builds (#22025) Signed-off-by: dougbtv --- docker/Dockerfile | 1 + requirements/test.txt | 24 +++-- setup.py | 203 ++++++++++++++++++------------------------ 3 files changed, 104 insertions(+), 124 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 43522ef8fb8dd..69aeee67a4300 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -370,6 +370,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ fi # Install vllm wheel first, so that torch etc will be installed. +# !bang RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/uv \ uv pip install --system dist/*.whl --verbose \ diff --git a/requirements/test.txt b/requirements/test.txt index d45048aae5809..4aaca2afea266 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -22,9 +22,7 @@ aiohttp==3.10.11 aiohttp-cors==0.8.1 # via ray aiosignal==1.3.1 - # via - # aiohttp - # ray + # via aiohttp albucore==0.0.16 # via terratorch albumentations==1.4.6 @@ -139,7 +137,7 @@ contourpy==1.3.0 # via matplotlib cramjam==2.9.0 # via fastparquet -cupy-cuda12x==13.3.0 +cupy-cuda12x==13.5.1 # via ray cycler==0.12.1 # via matplotlib @@ -226,7 +224,6 @@ frozenlist==1.5.0 # via # aiohttp # aiosignal - # ray fsspec==2024.9.0 # via # datasets @@ -603,10 +600,18 @@ opencv-python-headless==4.11.0.86 opentelemetry-api==1.35.0 # via # mlflow-skinny + # opentelemetry-exporter-prometheus # opentelemetry-sdk # opentelemetry-semantic-conventions +opentelemetry-exporter-prometheus==0.56b0 + # via ray +opentelemetry-proto==1.36.0 + # via ray opentelemetry-sdk==1.35.0 - # via mlflow-skinny + # via + # mlflow-skinny + # opentelemetry-exporter-prometheus + # ray opentelemetry-semantic-conventions==0.56b0 # via opentelemetry-sdk packaging==24.2 @@ -697,7 +702,9 @@ pqdm==0.2.0 pretrainedmodels==0.7.4 # via segmentation-models-pytorch prometheus-client==0.22.0 - # via ray + # via + # opentelemetry-exporter-prometheus + # ray propcache==0.2.0 # via yarl proto-plus==1.26.1 @@ -707,6 +714,7 @@ protobuf==5.28.3 # google-api-core # googleapis-common-protos # mlflow-skinny + # opentelemetry-proto # proto-plus # ray # tensorboardx @@ -854,7 +862,7 @@ rasterio==1.4.3 # rioxarray # terratorch # torchgeo -ray==2.43.0 +ray==2.48.0 # via -r requirements/test.in redis==5.2.0 # via tensorizer diff --git a/setup.py b/setup.py index bf3391e2db19e..6d615d122d69e 100644 --- a/setup.py +++ b/setup.py @@ -282,10 +282,69 @@ class cmake_build_ext(build_ext): self.copy_file(file, dst_file) -class repackage_wheel(build_ext): +class precompiled_wheel_utils: """Extracts libraries and other files from an existing wheel.""" - def get_base_commit_in_main_branch(self) -> str: + @staticmethod + def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: + import tempfile + import zipfile + + temp_dir = None + try: + if not os.path.isfile(wheel_url_or_path): + wheel_filename = wheel_url_or_path.split("/")[-1] + temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") + wheel_path = os.path.join(temp_dir, wheel_filename) + print(f"Downloading wheel from {wheel_url_or_path} " + f"to {wheel_path}") + from urllib.request import urlretrieve + urlretrieve(wheel_url_or_path, filename=wheel_path) + else: + wheel_path = wheel_url_or_path + print(f"Using existing wheel at {wheel_path}") + + package_data_patch = {} + + with zipfile.ZipFile(wheel_path) as wheel: + files_to_copy = [ + "vllm/_C.abi3.so", + "vllm/_moe_C.abi3.so", + "vllm/_flashmla_C.abi3.so", + "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", + "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", + "vllm/cumem_allocator.abi3.so", + ] + + compiled_regex = re.compile( + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + file_members = list( + filter(lambda x: x.filename in files_to_copy, + wheel.filelist)) + file_members += list( + filter(lambda x: compiled_regex.match(x.filename), + wheel.filelist)) + + for file in file_members: + print(f"[extract] {file.filename}") + target_path = os.path.join(".", file.filename) + os.makedirs(os.path.dirname(target_path), exist_ok=True) + with wheel.open(file.filename) as src, open( + target_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + pkg = os.path.dirname(file.filename).replace("/", ".") + package_data_patch.setdefault(pkg, []).append( + os.path.basename(file.filename)) + + return package_data_patch + finally: + if temp_dir is not None: + print(f"Removing temporary directory {temp_dir}") + shutil.rmtree(temp_dir) + + @staticmethod + def get_base_commit_in_main_branch() -> str: # Force to use the nightly wheel. This is mainly used for CI testing. if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: return "nightly" @@ -334,115 +393,6 @@ class repackage_wheel(build_ext): "wheel may not be compatible with your dev branch: %s", err) return "nightly" - def run(self) -> None: - assert _is_cuda( - ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" - - wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) - if wheel_location is None: - base_commit = self.get_base_commit_in_main_branch() - wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" - # Fallback to nightly wheel if latest commit wheel is unavailable, - # in this rare case, the nightly release CI hasn't finished on main. - if not is_url_available(wheel_location): - wheel_location = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" - - import zipfile - - if os.path.isfile(wheel_location): - wheel_path = wheel_location - print(f"Using existing wheel={wheel_path}") - else: - # Download the wheel from a given URL, assume - # the filename is the last part of the URL - wheel_filename = wheel_location.split("/")[-1] - - import tempfile - - # create a temporary directory to store the wheel - temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") - wheel_path = os.path.join(temp_dir, wheel_filename) - print(f"Downloading wheel from {wheel_location} to {wheel_path}") - from urllib.request import urlretrieve - try: - urlretrieve(wheel_location, filename=wheel_path) - except Exception as e: - from setuptools.errors import SetupError - raise SetupError( - f"Failed to get vLLM wheel from {wheel_location}") from e - - # Set the dist_dir for Docker build context - dist_dir = ("/workspace/dist" - if envs.VLLM_DOCKER_BUILD_CONTEXT else "dist") - os.makedirs(dist_dir, exist_ok=True) - - # Extract only necessary compiled .so files from precompiled wheel - with zipfile.ZipFile(wheel_path) as wheel: - # Get version from METADATA (optional, mostly useful for logging) - metadata_file = next((n for n in wheel.namelist() - if n.endswith(".dist-info/METADATA")), None) - if not metadata_file: - raise RuntimeError( - "Could not find METADATA in precompiled wheel.") - metadata = wheel.read(metadata_file).decode() - version_line = next((line for line in metadata.splitlines() - if line.startswith("Version: ")), None) - if not version_line: - raise RuntimeError( - "Could not determine version from METADATA.") - version = version_line.split(": ")[1].strip() - - print(f"Extracting precompiled kernels from vLLM wheel version: " - f"{version}") - - # List of compiled shared objects to extract - files_to_copy = [ - "vllm/_C.abi3.so", - "vllm/_moe_C.abi3.so", - "vllm/_flashmla_C.abi3.so", - "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", - "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", - "vllm/cumem_allocator.abi3.so", - ] - - file_members = list( - filter(lambda x: x.filename in files_to_copy, wheel.filelist)) - compiled_regex = re.compile( - r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") - file_members += list( - filter(lambda x: compiled_regex.match(x.filename), - wheel.filelist)) - - for file in file_members: - print(f"Extracting and including {file.filename} " - "from existing wheel") - package_name = os.path.dirname(file.filename).replace("/", ".") - file_name = os.path.basename(file.filename) - - if package_name not in package_data: - package_data[package_name] = [] - - output_base = (dist_dir - if envs.VLLM_DOCKER_BUILD_CONTEXT else ".") - target_path = os.path.join(output_base, file.filename) - os.makedirs(os.path.dirname(target_path), exist_ok=True) - with wheel.open(file.filename) as src, open(target_path, - "wb") as dst: - shutil.copyfileobj(src, dst) - - package_data[package_name].append(file_name) - - # Copy wheel into dist dir for Docker to consume (e.g., via --mount) - if envs.VLLM_DOCKER_BUILD_CONTEXT: - arch_tag = "cp38-abi3-manylinux1_x86_64" - corrected_wheel_name = f"vllm-{version}-{arch_tag}.whl" - final_wheel_path = os.path.join(dist_dir, corrected_wheel_name) - - print( - "Docker build context detected, copying precompiled wheel to " - f"{final_wheel_path}") - shutil.copy2(wheel_path, final_wheel_path) - def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" @@ -676,16 +626,37 @@ package_data = { ] } +# If using precompiled, extract and patch package_data (in advance of setup) +if envs.VLLM_USE_PRECOMPILED: + assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) + if wheel_location is not None: + wheel_url = wheel_location + else: + base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() + wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + from urllib.request import urlopen + try: + with urlopen(wheel_url) as resp: + if resp.status != 200: + wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + except Exception as e: + print(f"[warn] Falling back to nightly wheel: {e}") + wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + + patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( + wheel_url) + for pkg, files in patch.items(): + package_data.setdefault(pkg, []).extend(files) + if _no_device(): ext_modules = [] -if not ext_modules: +if not ext_modules or envs.VLLM_USE_PRECOMPILED: + # Disable build_ext when using precompiled wheel cmdclass = {} else: - cmdclass = { - "build_ext": - repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext - } + cmdclass = {"build_ext": cmake_build_ext} setup( # static metadata should rather go in pyproject.toml From 0780bb57835dcd9ee666aaf807c37086de67422b Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Thu, 31 Jul 2025 11:53:27 -0500 Subject: [PATCH 4/9] Removing amdproduction Tests (#22027) Signed-off-by: Alexei V. Ivanov --- .buildkite/test-pipeline.yaml | 46 +++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2bf0b6fd9a169..a7fe200559305 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -82,7 +82,7 @@ steps: - bash standalone_tests/python_only_compile.sh - label: Basic Correctness Test # 30min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] fast_check: true torch_nightly: true source_file_dependencies: @@ -99,7 +99,7 @@ steps: - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Chunked Prefill Test - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/basic_correctness/test_chunked_prefill @@ -108,7 +108,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] fast_check: true source_file_dependencies: - vllm/core @@ -209,7 +209,7 @@ steps: - pytest -v -s distributed/test_eplb_execute.py - label: Metrics, Tracing Test # 10min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] num_gpus: 2 source_file_dependencies: - vllm/ @@ -228,7 +228,7 @@ steps: ##### 1 GPU test ##### - label: Regression Test # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/test_regression @@ -280,7 +280,7 @@ steps: - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" source_file_dependencies: - vllm/entrypoints @@ -305,7 +305,7 @@ steps: - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/prefix_caching @@ -314,7 +314,7 @@ steps: - label: Platform Tests (CUDA) - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/cuda @@ -355,7 +355,7 @@ steps: - pytest -v -s compile/test_async_tp.py - label: PyTorch Fullgraph Smoke Test # 9min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ @@ -368,7 +368,7 @@ steps: - pytest -v -s compile/piecewise/test_full_cudagraph.py - label: PyTorch Fullgraph Test # 18min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ @@ -377,7 +377,7 @@ steps: - pytest -v -s compile/test_full_graph.py - label: Kernels Core Operation Test - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - tests/kernels/core @@ -416,7 +416,7 @@ steps: parallelism: 2 - label: Kernels Mamba Test - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ - tests/kernels/mamba @@ -424,7 +424,7 @@ steps: - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader @@ -437,7 +437,7 @@ steps: - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py - label: Model Executor Test - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor - tests/model_executor @@ -447,7 +447,7 @@ steps: - pytest -v -s model_executor - label: Benchmarks # 9min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite" source_file_dependencies: - benchmarks/ @@ -455,7 +455,7 @@ steps: - bash scripts/run-benchmarks.sh - label: Benchmarks CLI Test # 10min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/benchmarks/ @@ -494,7 +494,7 @@ steps: - pytest -s entrypoints/openai/correctness/ - label: Encoder Decoder tests # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/encoder_decoder @@ -502,7 +502,7 @@ steps: - pytest -v -s encoder_decoder - label: OpenAI-Compatible Tool Use # 20 min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] fast_check: false source_file_dependencies: - vllm/ @@ -623,7 +623,7 @@ steps: # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] optional: true commands: - echo 'Testing custom models...' @@ -658,7 +658,7 @@ steps: ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -755,7 +755,7 @@ steps: - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins - label: Multi-step Tests (4 GPUs) # 36min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -776,7 +776,7 @@ steps: - pytest -v -s multi_step/test_correctness_llm.py - label: Pipeline Parallelism Test # 45min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -790,7 +790,7 @@ steps: - pytest -v -s distributed/test_pipeline_parallel.py - label: LoRA TP Test (Distributed) - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] num_gpus: 4 source_file_dependencies: - vllm/lora From 53c21e492e0acd140a9984c8ec7cc3a7123efee5 Mon Sep 17 00:00:00 2001 From: XiongfeiWei Date: Thu, 31 Jul 2025 10:26:43 -0700 Subject: [PATCH 5/9] Update torch_xla pin to 20250730 (#21956) Signed-off-by: Xiongfei Wei --- docker/Dockerfile.tpu | 2 +- requirements/tpu.txt | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index b9fc9def88190..2190151369761 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20250724" +ARG NIGHTLY_DATE="20250730" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 2d0d8bd8457e3..7bb77c4a99636 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -19,8 +19,8 @@ nixl==0.3.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.9.0.dev20250724 -torchvision==0.24.0.dev20250724 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" +torch==2.9.0.dev20250730 +torchvision==0.24.0.dev20250730 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" From 9e0726e5bfd201fa2c9209e3997d24c72ecc3b13 Mon Sep 17 00:00:00 2001 From: zhiweiz Date: Thu, 31 Jul 2025 10:35:07 -0700 Subject: [PATCH 6/9] [Meta] Official Eagle mm support, first enablement on llama4 (#20788) Signed-off-by: morgendave Co-authored-by: Roger Wang --- examples/offline_inference/spec_decode.py | 64 ++++++++++++++++++++-- tests/v1/e2e/test_spec_decode.py | 61 +++++++++++++++------ vllm/model_executor/models/llama4.py | 1 + vllm/model_executor/models/llama4_eagle.py | 35 ++++++++++-- vllm/model_executor/models/llama_eagle.py | 6 ++ vllm/model_executor/models/llama_eagle3.py | 5 ++ vllm/v1/spec_decode/eagle.py | 59 +++++++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 10 +++- 8 files changed, 205 insertions(+), 36 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index ce735f3b27dfe..184c30891eca7 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -13,6 +13,38 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", + "https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG", + "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg", + "https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg", + "https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg", + "https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg", +] + + +def get_custom_mm_prompts(num_prompts): + prompts = [] + for url in IMAGE_URLS: + prompts.append( + [ + {"type": "image_url", "image_url": {"url": url}}, + {"type": "text", "text": QUESTION}, + ] + ) + if num_prompts > len(IMAGE_URLS): + prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) + + return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] + + def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) @@ -35,6 +67,7 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--custom-mm-prompts", action="store_true") return parser.parse_args() @@ -44,14 +77,26 @@ def main(): model_dir = args.model_dir if args.model_dir is None: + if args.custom_mm_prompts: + raise ValueError( + "custom_mm_prompts requires mm based models" + "default llama3.1-8b-instruct is not mm based" + "please specify model_dir to give a mm based model" + ) model_dir = "meta-llama/Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_dir) + args.custom_skip_chat_template = True - prompts = get_samples(args, tokenizer) - # add_special_tokens is False to avoid adding bos twice when using chat templates - prompt_ids = [ - tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts - ] + if not args.custom_mm_prompts: + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) + for prompt in prompts + ] + else: + prompts = get_custom_mm_prompts(args.num_prompts) if args.method == "eagle" or args.method == "eagle3": eagle_dir = args.eagle_dir @@ -85,10 +130,17 @@ def main(): speculative_config=speculative_config, disable_log_stats=False, max_model_len=16384, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) - outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not args.custom_mm_prompts: + outputs = llm.generate( + prompt_token_ids=prompt_ids, sampling_params=sampling_params + ) + else: + outputs = llm.chat(prompts, sampling_params=sampling_params) # print the generated text if args.print_output: diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 2423f966acfab..31f25e94c5b4b 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -3,29 +3,34 @@ from __future__ import annotations import random -from typing import Any +from typing import Any, Union import pytest import torch from vllm import LLM, SamplingParams +from vllm.assets.base import VLLM_S3_BUCKET_URL +from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory -@pytest.fixture -def test_prompts(): +def get_test_prompts(mm_enabled: bool): prompt_types = ["repeat", "sentence"] + if mm_enabled: + prompt_types.append("mm") num_prompts = 100 prompts = [] random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + print(f"Prompt types: {random_prompt_type_choices}") # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. for kind in random_prompt_type_choices: word_choices = ["test", "temp", "hello", "where"] word = random.choice(word_choices) + prompt: Union[str, list[dict[str, Any]]] = "" if kind == "repeat": prompt = f""" please repeat the word '{word}' 10 times. @@ -38,6 +43,21 @@ def test_prompts(): uses the word {word} at least once. give no other output than that simple sentence without quotes. """ + elif kind == "mm": + placeholders = [{ + "type": "image_url", + "image_url": { + "url": + f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + }] + prompt = [ + *placeholders, + { + "type": "text", + "text": "The meaning of the image is" + }, + ] else: raise ValueError(f"Unknown prompt type: {kind}") prompts.append([{"role": "user", "content": prompt}]) @@ -57,7 +77,6 @@ def model_name(): def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, - test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, ): @@ -67,6 +86,7 @@ def test_ngram_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + test_prompts = get_test_prompts(mm_enabled=False) ref_llm = LLM(model=model_name, max_model_len=1024) ref_outputs = ref_llm.chat(test_prompts, sampling_config) @@ -103,23 +123,32 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize("model_setup", [ - ("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), - ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), -], - ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], [ + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + ], + ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, - test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], + mm_enabled: bool, ): + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 470e701d98013..60098209c39ac 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module): super().__init__() self.layer_idx = extract_layer_index(prefix) + self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.hidden_size = config.hidden_size rope_theta = config.rope_theta rope_scaling = config.rope_scaling diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 222ab5dfaee4a..ece490ff2f2a8 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, Llama4ForCausalLM) from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings logger = init_logger(__name__) @@ -78,15 +79,23 @@ class LlamaModel(nn.Module): self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - input_embeds = self.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) + torch.cat((inputs_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( @@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - return self.model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states, inputs_embeds) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: @@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): model_weights[name] = loaded_weight loader.load_weights(model_weights.items()) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_index, + ) + + return inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index c7690604c1d09..a4933b77e3a53 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7fc9fe2ebb6f6..71275f0d58579 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def compute_logits( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 63f6fc276189d..302126dbe3d5f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import numpy as np import torch import torch.nn as nn @@ -51,6 +53,9 @@ class EagleProposer: # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + self.is_multimodal_model = vllm_config.model_config \ + .is_multimodal_model + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) @@ -76,6 +81,11 @@ class EagleProposer: device=device, dtype=torch.int32) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + def propose( self, # [num_tokens] @@ -88,6 +98,7 @@ class EagleProposer: next_token_ids: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -128,14 +139,27 @@ class EagleProposer: # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = self.model.get_input_embeddings( + input_ids, + multimodal_embeddings=mm_embeds or None, + ) + self.inputs_embeds[:num_tokens] = inputs_embeds + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:num_input_tokens] with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): ret_hidden_states = self.model( - self.input_ids[:num_input_tokens], - self.positions[:num_input_tokens], - self.hidden_states[:num_input_tokens], + input_ids=input_ids, + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, ) if self.method == "deepseek_mtp": last_hidden_states = ret_hidden_states @@ -218,15 +242,24 @@ class EagleProposer: self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states + if self.is_multimodal_model: + inputs_embeds = self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = inputs_embeds + inputs_embeds = self.inputs_embeds[:input_batch_size] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:input_batch_size] # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( - self.input_ids[:input_batch_size], - self.positions[:input_batch_size], - self.hidden_states[:input_batch_size], + input_ids=input_ids, + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], @@ -391,10 +424,18 @@ class EagleProposer: ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + self.model( - self.input_ids[:num_tokens], - self.positions[:num_tokens], - self.hidden_states[:num_tokens], + input_ids=input_ids, + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, ) def validate_same_kv_cache_group(self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 987ef22a1b7fb..29cda4d837bf3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1205,13 +1205,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", + shift_computed_tokens: int = 0, ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens + num_computed_tokens = \ + req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset @@ -1858,6 +1860,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] + mm_embeds = None + if self.is_multimodal_model: + mm_embeds = self._gather_mm_embeddings(scheduler_output, + shift_computed_tokens=1) + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1865,6 +1872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): next_token_ids=next_token_ids, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, + mm_embeds=mm_embeds, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids From 71470bc4afdab89eccc232b668a69571ffede1dc Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Thu, 31 Jul 2025 11:39:16 -0700 Subject: [PATCH 7/9] [Misc] Add unit tests for chunked local attention (#21692) Signed-off-by: Yong Hoon Shin --- .../attention/test_chunked_local_attention.py | 196 ++++++++++++++++++ tests/v1/attention/utils.py | 36 ++-- 2 files changed, 219 insertions(+), 13 deletions(-) create mode 100644 tests/v1/attention/test_chunked_local_attention.py diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py new file mode 100644 index 0000000000000..8c5a63653db9f --- /dev/null +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import numpy as np +import pytest +import torch + +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata +from vllm.v1.attention.backends.utils import ( + make_local_attention_virtual_batches) + + +@dataclass +class LocalAttentionTestData: + # Input parameters + batch_spec: BatchSpec + attn_chunk_size: int + block_size: int + # Expected return values + expected_q_seqlens: list[int] + expected_k_seqlens: list[int] + expected_local_block_table: list[list[int]] + + +test_data_list = [ + # Same as example in docstring of make_local_attention_virtual_batches + # except block table has 9 columns instead of 10 + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[4, 10, 5], + seq_lens=[6, 17, 9], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[2, 2, 1, 4, 4, 1, 4, 1], + expected_k_seqlens=[4, 2, 4, 4, 4, 1, 4, 1], + # 2 pages per local branch + # (chunk size 4 // block size 2) + expected_local_block_table=[ + [0, 1], # local-batch 0, (batch 0, starting from k[0]) + [2, 3], # local-batch 1, (batch 0, starting from k[4]) + [11, 12], # local-batch 2, (batch 1, starting from k[4]) + [13, 14], # local-batch 3, (batch 1, starting from k[8]) + [15, 16], # local-batch 4, (batch 1, starting from k[12]) + [17, 17], # local-batch 5, (batch 1, starting from k[16]) + [20, 21], # local-batch 6, (batch 2, starting from k[4]) + [22, 23], # local-batch 7, (batch 2, starting from k[8]) + ]), + # Case where block indices are not clipped to block table ncols-1 + # because tokens_in_last_block == attn_chunk_size + LocalAttentionTestData(batch_spec=BatchSpec( + query_lens=[8], + seq_lens=[12], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[4, 4], + expected_k_seqlens=[4, 4], + expected_local_block_table=[ + [2, 3], + [4, 5], + ]), + # Case where all kv_seq positions are involved in attn + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[7], + # 10 - 7 = 3 previously computed tokens + seq_lens=[10], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[1, 4, 2], + expected_k_seqlens=[4, 4, 2], + expected_local_block_table=[ + [0, 1], + [2, 3], + [4, 4], + ]), + # Case where attn_chunk_size > kv_seq_len + # so no extra mini virtual batches are created + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[4], + seq_lens=[6], + ), + # Larger than kv_seq_len + attn_chunk_size=10, + block_size=2, + # No change to q_seqlens and k_seqlens + expected_q_seqlens=[4], + expected_k_seqlens=[6], + # In this case, we only need a block-table like: + # block_table = [ [0, 1, 2] ] # 1 batch, 3 pages + # But we need to pad it to 5 pages per local batch + # because currently the pages_per_local_batch + # is calculated as (attn_chunk_size // block_size) + expected_local_block_table=[ + [0, 1, 2, 2, 2], + ]), + # Block size equal to chunk size + # Expect single page per batch in local batch table + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[6, 6], + seq_lens=[8, 8], + ), + attn_chunk_size=4, + block_size=4, + expected_q_seqlens=[2, 4, 2, 4], + expected_k_seqlens=[4, 4, 4, 4], + # Initial block table = [ + # [0, 1], < batch 0 + # [2, 3], < batch 1 + # ] + expected_local_block_table=[ + [0], # local-batch 0, (batch 0, starting from k[0]) + [1], # local-batch 1, (batch 0, starting from k[4]) + [2], # local-batch 1, (batch 0, starting from k[0]) + [3], # local-batch 1, (batch 0, starting from k[4]) + ]), + # Case where query falls in the second attention chunk + # k_toks > 0 1 2 3 4 + # q_toks v _____________ + # 0 | 1 + # 1 | 1 1 + # 2 | 1 1 1 + # 3 | 1 1 1 1 + # 4 | 1 + # where tokens 0,1,2,3 have been pre-computed + LocalAttentionTestData(batch_spec=BatchSpec( + query_lens=[1], + seq_lens=[5], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[1], + expected_k_seqlens=[1], + expected_local_block_table=[ + [2, 2], + ]), +] + + +@pytest.mark.parametrize("test_data", test_data_list) +def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): + device = torch.device("cuda:0") + batch_spec = test_data.batch_spec + attn_chunk_size = test_data.attn_chunk_size + block_size = test_data.block_size + expected_q_seqlens = test_data.expected_q_seqlens + expected_k_seqlens = test_data.expected_k_seqlens + expected_local_block_table = test_data.expected_local_block_table + + # Create common attention metadata + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size, + device, + # Use torch.arange instead of torch.randint so we can assert on + # block table tensor values. The block table will have shape + # (num_batches, cdiv(max_seq_len, block_size)) and the values will be + # aranged from 0 to cdiv(max_seq_len, block_size)-1 + arange_block_indices=True, + ) + + # Call the function + result = make_local_attention_virtual_batches(attn_chunk_size, + common_attn_metadata, + block_size) + + # Convert to numpy for easier comparison + actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy()) + actual_k_seqlens = result.seq_lens_cpu.numpy() + + # Check that all query lengths are less than or equal to attn_chunk_size + assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens) + # Check that all key lengths are less than or equal to attn_chunk_size + assert all(k_len <= attn_chunk_size for k_len in actual_k_seqlens) + # Check that the total number of query tokens is preserved + assert sum(actual_q_seqlens) == sum(batch_spec.query_lens) + + # Verify results + np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens) + np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens) + + expected_block_table_tensor =\ + torch.tensor(expected_local_block_table, + dtype=torch.int32, + device=device) + + print(f"Expected block table:\n{expected_block_table_tensor}") + print(f"Actual block table:\n{result.block_table_tensor}") + + torch.testing.assert_close(result.block_table_tensor, + expected_block_table_tensor) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index ae2ab6e6413c0..be6cfce6fba8a 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -40,7 +40,8 @@ def create_common_attn_metadata( batch_spec: BatchSpec, block_size: int, device: torch.device, - max_block_idx: int = 1000) -> CommonAttentionMetadata: + max_block_idx: int = 1000, + arange_block_indices: bool = False) -> CommonAttentionMetadata: """Create CommonAttentionMetadata from a BatchSpec and ModelParams.""" # Create query start locations query_start_loc = torch.zeros(batch_spec.batch_size + 1, @@ -65,19 +66,28 @@ def create_common_attn_metadata( ] num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) - # Create block table (random for testing) + # Create block table and slot mapping max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size - block_table_tensor = torch.randint(0, - max_block_idx, - (batch_spec.batch_size, max_blocks), - dtype=torch.int32, - device=device) - - # Create slot mapping - slot_mapping = torch.randint(0, - max_block_idx, (num_tokens, ), - dtype=torch.int64, - device=device) + if arange_block_indices: + num_blocks = batch_spec.batch_size * max_blocks + block_table_tensor = torch.arange(num_blocks, + dtype=torch.int32, + device=device).view( + batch_spec.batch_size, + max_blocks) + slot_mapping = torch.arange(num_tokens, + dtype=torch.int64, + device=device).view(num_tokens) + else: + block_table_tensor = torch.randint(0, + max_block_idx, + (batch_spec.batch_size, max_blocks), + dtype=torch.int32, + device=device) + slot_mapping = torch.randint(0, + max_block_idx, (num_tokens, ), + dtype=torch.int64, + device=device) # Calculate max query length max_query_len = max(batch_spec.query_lens) From 2dff2e21d928129e985b23897e9f326abe3f1417 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 31 Jul 2025 16:33:53 -0400 Subject: [PATCH 8/9] [Bugfix] Fix MTP weight loading (#21941) --- vllm/model_executor/models/deepseek_mtp.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 911f0036c2dd6..2e026d582a6de 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP): stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] expert_params_mapping = FusedMoE.make_expert_params_mapping( @@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP): if (("mlp.experts." in name) and name not in params_dict): continue name = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + if ((param_name == "fused_qkv_a_proj") + and name not in params_dict): + continue + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue From 6e672daf62e7b03ff1dcf74e4206dad07d39d4ec Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Thu, 31 Jul 2025 22:58:38 +0200 Subject: [PATCH 9/9] Add FlashInfer allreduce RMSNorm Quant fusion (#21069) Signed-off-by: ilmarkov Signed-off-by: ilmarkov Co-authored-by: ilmarkov --- .buildkite/test-pipeline.yaml | 1 + tests/compile/test_fusion_all_reduce.py | 126 +++++- tests/utils.py | 12 + vllm/compilation/collective_fusion.py | 533 ++++++++++++++++++++++-- vllm/config.py | 2 +- 5 files changed, 606 insertions(+), 68 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a7fe200559305..2f6cc45be77e6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -353,6 +353,7 @@ steps: - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py + - pytest -v -s compile/test_fusion_all_reduce.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental] diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index b8d64247f6beb..b394e0035c689 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -7,22 +7,26 @@ import torch import vllm.envs as envs from vllm.compilation.collective_fusion import AllReduceFusionPass +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, ModelConfig, PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + GroupShape, QuantFP8) from vllm.platforms import current_platform from vllm.utils import update_environment_variables -from ..utils import multi_gpu_test +from ..utils import has_module_attribute, multi_gpu_test from .backend import TestBackend class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, eps=1e-6): + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps @@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, eps=1e-6): + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps @@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] +class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): + + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm = RMSNorm(hidden_size, eps) + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.scale = torch.rand(1, dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), + dtype=torch.float32) + + def forward(self, hidden_states, residual): + view = hidden_states.reshape(-1, self.hidden_size) + all_reduce = tensor_model_parallel_all_reduce(view) + norm_output, residual_output = self.norm(all_reduce, residual) + torch.ops._C.static_scaled_fp8_quant(self.output, + norm_output.contiguous(), + self.scale) + return self.output, residual_output + + def ops_in_model_after(self): + return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + + def ops_in_model_before(self): + return [ + torch.ops.vllm.all_reduce.default, + torch.ops._C.static_scaled_fp8_quant.default + ] + + +class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): + + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm = RMSNorm(hidden_size, eps) + self.scale = torch.rand(1, dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), + dtype=torch.float32) + + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(token_num, 128) + scale_n = hidden_size // 16 + rounded_n = round_up(scale_n, 4) + self.output_scale = torch.empty((rounded_m, rounded_n // 4), + dtype=torch.int32) + + def forward(self, hidden_states, residual): + view = hidden_states.reshape(-1, self.hidden_size) + all_reduce = tensor_model_parallel_all_reduce(view) + norm_output, residual_output = self.norm(all_reduce, residual) + norm_output = norm_output.reshape(-1, norm_output.shape[-1]) + torch.ops._C.scaled_fp4_quant(self.output, norm_output, + self.output_scale, self.scale) + return self.output, residual_output, self.output_scale + + def ops_in_model_after(self): + return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + + def ops_in_model_before(self): + return [ + torch.ops.vllm.all_reduce.default, + torch.ops._C.scaled_fp4_quant.default + ] + + @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "test_model", - [TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel]) +@pytest.mark.parametrize("test_model", [ + TestAllReduceRMSNormModel, + TestAllReduceFusedAddRMSNormModel, + TestAllReduceFusedAddRMSNormStaticQuantFP8Model, + TestAllReduceFusedAddRMSNormStaticQuantFP4Model, +]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) -@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") -@pytest.mark.skipif(not find_spec("flashinfer"), - reason="flashinfer is not installed") -@pytest.mark.skipif(not current_platform.is_device_capability(100), - reason="Only test on SM100") +@pytest.mark.skipif( + not find_spec("flashinfer") + or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), + reason="flashinfer is not found or flashinfer " + "is not compiled with trtllm_allreduce_fusion") def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): num_processes = 2 + if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model + and not current_platform.has_device_capability(100)): + pytest.skip("Skip as nvfp4 is only supported on " + "devices with compute capability 10.0 (Blackwell)") def run_torch_spawn(fn, nprocs): torch.multiprocessing.spawn(fn, @@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) - vllm_config = VllmConfig( - compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm"], - compile_sizes=[2, 4, 8])) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"])) vllm_config.compilation_config.pass_config = PassConfig( - enable_fi_allreduce_fusion=True) + enable_fi_allreduce_fusion=True, enable_noop=True) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config @@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, seed=42) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - backend = TestBackend(all_reduce_fusion_pass) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) - model = test_model_cls(hidden_size) + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - requires_grad=False) - residual = torch.randn((batch_size * seq_len, hidden_size), - requires_grad=False) + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) + + hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) + residual = torch.randn((token_num, hidden_size), requires_grad=False) compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states, residual) diff --git a/tests/utils.py b/tests/utils.py index f4317e6bdb406..1c1a1cc6014ec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ import asyncio import copy import functools +import importlib import os import signal import subprocess @@ -974,3 +975,14 @@ def get_client_text_logprob_generations( return [(text_generations, text, (None if x.logprobs is None else x.logprobs.top_logprobs)) for completion in completions for x in completion.choices] + + +def has_module_attribute(module_name, attribute_name): + """ + Helper function to check if a module has a specific attribute. + """ + try: + module = importlib.import_module(module_name) + return hasattr(module, attribute_name) + except ImportError: + return False diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cb99fe8310e73..6ae50245ed3a8 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -37,6 +37,8 @@ logger = init_logger(__name__) ALLREDUCE_OP = torch.ops.vllm.all_reduce.default RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default +STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default class BasePattern: @@ -394,7 +396,7 @@ if flashinfer_comm is not None: # Max size of the input tensor per world size # to use flashinfer fused allreduce _FI_MAX_SIZES = { - 2: MiB, # 1MB + 2: 64 * MiB, # 64MB 4: MiB, # 1MB 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB @@ -414,9 +416,13 @@ if flashinfer_comm is not None: trigger_completion_at_end: bool, fp32_acc: bool, max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None, ) -> None: - num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size @@ -425,7 +431,6 @@ if flashinfer_comm is not None: _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), max_fusion_size, ) - if use_flashinfer: assert (_FI_WORKSPACE_TENSOR is not None ), "Flashinfer must be enabled when using flashinfer" @@ -455,37 +460,65 @@ if flashinfer_comm is not None: use_oneshot=True, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNorm, + pattern_code=pattern_code, allreduce_out=None, - quant_out=None, - scale_out=None, - layout_code=None, - scale_factor=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED, + scale_factor=scale_factor, ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if norm_out is None: - torch.ops._C.fused_add_rms_norm(allreduce_out, residual, - rms_gamma, rms_eps) + if (scale_factor is not None and scale_out is None + and fuse_rms_quant): + # Do fused rms norm static fp8 quant fused op + if norm_out is None: + torch.ops._C.fused_add_rms_norm_static_fp8_quant( + quant_out, allreduce_out, residual, rms_gamma, + scale_factor, rms_eps) + else: + torch.ops._C.rms_norm_static_fp8_quant( + quant_out, allreduce_out, rms_gamma, scale_factor, + rms_eps) else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, - rms_eps) - allreduce_in.copy_(allreduce_out) + if norm_out is None: + torch.ops._C.fused_add_rms_norm(allreduce_out, residual, + rms_gamma, rms_eps) + norm_out = allreduce_out + else: + torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, + rms_eps) + if scale_factor is not None: + if scale_out is not None: + torch.ops._C.scaled_fp4_quant(quant_out, norm_out, + scale_out, scale_factor) + else: + torch.ops._C.static_scaled_fp8_quant( + quant_out, norm_out, scale_factor) + if scale_factor is None or norm_out is not None: + # we need to return allreduce outpput + # in cases of non quant fused AR + RMS norm + # and fused AR + RMS norm + quant without fused add + allreduce_in.copy_(allreduce_out) def call_trtllm_fused_allreduce_norm_fake( - allreduce_in: torch.Tensor, - residual: torch.Tensor, - rms_gamma: torch.Tensor, - rms_eps: float, - world_rank: int, - world_size: int, - launch_with_pdl: bool, - trigger_completion_at_end: bool, - fp32_acc: bool, - max_token_num: int, - norm_out: Optional[torch.Tensor] = None, - ) -> None: + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None) -> None: pass direct_register_custom_op( @@ -495,6 +528,8 @@ if flashinfer_comm is not None: "allreduce_in", "residual", "norm_out", + "quant_out", + "scale_out", ], fake_impl=call_trtllm_fused_allreduce_norm_fake, dispatch_key=current_platform.dispatch_key, @@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams: world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, + fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams: self.fp32_acc = True self.use_oneshot = False self.max_token_num = max_token_num + self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams: "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, + "fuse_rms_quant": self.fuse_rms_quant, } -class AllReduceRMSNORMPattern(BasePattern): +class AllReduceRMSNormPattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (without residual) + with fused flashinfer implementation. + Applies to allreduce + rmsnorm before attn in the first Transformer block. + """ def __init__( self, @@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern): def pattern(input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor): - all_reduce_output = tensor_model_parallel_all_reduce(input) + allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_OP, result=rms_result, - input=all_reduce_output, + input=allreduce_output, weight=weight, epsilon=self.epsilon, ) - return rms[1], all_reduce_output + # rms_result, allreduce_output + return rms[1], allreduce_output def replacement(input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor): residual = torch.zeros_like(input) allreduce = auto_functionalized( - torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=rms_result, + quant_out=None, + scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) - + # rms_result, allreduce_in return allreduce[3], allreduce[1] pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern): class AllReduceFusedAddRMSNormPattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. + """ def __init__( self, @@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): - all_reduce_output = tensor_model_parallel_all_reduce(input) + allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_ADD_OP, - input=all_reduce_output, + input=allreduce_output, residual=residual, weight=weight, epsilon=self.epsilon, ) + # input, residual return rms[1], rms[2] def replacement(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce = auto_functionalized( - torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, + norm_out=None, + quant_out=None, + scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - norm_out=None, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) + # allreduce_in, residual return allreduce[1], allreduce[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) +class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (without residual) + + static fp8 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_dtype = torch.float8_e4m3fn + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) + rmsnorm_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.quant_dtype) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [input, rmsnorm_result, quant_result, weight, scale] + + def pattern( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out_tuple = auto_functionalized(RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon) + + quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, + result=quant_result, + input=rmsnorm_out_tuple[1], + scale=scale) + + # quant_out, allreduce_output + return quant_out_tuple[1], all_reduce + + def replacement(input: torch.Tensor, result_rms: torch.Tensor, + quant_result: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + residual = torch.zeros_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=result_rms, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + # quant_out, allreduce_output + return allreduce[4], allreduce[1] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + + static fp8 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_dtype = torch.float8_e4m3fn + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + quant_result = torch.empty([4, 4], + device=self.device, + dtype=self.quant_dtype) + scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + + return [ + quant_result, + residual, + input, + weight, + scale, + ] + + def pattern( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + allreduce_output = tensor_model_parallel_all_reduce(input) + + fused_add_rmsnorm_out_tuple = \ + auto_functionalized( + RMS_ADD_OP, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon) + quant_out_tuple = auto_functionalized( + STATIC_FP8_QUANT_OP, + result=quant_result, + input=fused_add_rmsnorm_out_tuple[1], + scale=scale) + + # quant_out, allreduce_output + return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + + def replacement(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # # quant_out, rms_norm_residual + return allreduce[4], allreduce[2] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (without residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.empty([1, 16, 16], + device=self.device, + dtype=self.dtype) + + rmsnorm_result = torch.empty([1, 16, 16], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + weight = torch.empty([16], device=self.device, dtype=self.dtype) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + + return [ + input, rmsnorm_result, quant_result, weight, + input_global_scale, output_scale + ] + + def pattern( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor, + ): + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out_tuple = auto_functionalized(RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon) + + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=rmsnorm_out_tuple[1], + output_scale=output_scale, + input_scale=input_global_scale) + + # quant_out, allreduce_output, output_scale + return quant_out_tuple[1], all_reduce, quant_out_tuple[2] + + def replacement(input: torch.Tensor, result_rms: torch.Tensor, + quant_result: torch.Tensor, weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor): + residual = torch.zeros_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=result_rms, + quant_out=quant_result, + scale_out=output_scale, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + scale_factor=input_global_scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + # quant_out, allreduce_output, output_scale + return allreduce[4], allreduce[1], allreduce[5] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def register(self, pm_pass: PatternMatcherPass): + + def get_inputs(): + input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + + residual = torch.empty([16, 16], + device=self.device, + dtype=self.dtype) + weight = torch.empty([16, 16], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + + return [ + quant_result, + residual, + input, + output_scale, + weight, + input_global_scale, + ] + + def pattern(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, output_scale: torch.Tensor, + weight: torch.Tensor, input_global_scale: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + + fused_add_rmsnorm_out_tuple = \ + auto_functionalized( + RMS_ADD_OP, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon) + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=fused_add_rmsnorm_out_tuple[1], + output_scale=output_scale, + input_scale=input_global_scale) + + # quant_out, allreduce_output, output_scale + return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[ + 2], quant_out_tuple[2] + + def replacement(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, output_scale: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=quant_result, + scale_out=output_scale, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + scale_factor=input_global_scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # quant_out, rms_norm_residual, output_scale + return allreduce[4], allreduce[2], allreduce[5] + + pm.register_replacement(pattern, replacement, get_inputs(), + pm.fwd_only, pm_pass) + + class AllReduceFusionPass(VllmInductorPass): def __init__(self, config: VllmConfig): @@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass): self.tp_size, ) return - + max_num_token = min( + _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // + (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), + config.compilation_config.pass_config. + fi_allreduce_fusion_max_token_num) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num, + max_token_num=max_num_token, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num, - ) + max_token_num=max_num_token, + # fuse rms norm static fp8 quant fused op + # in fallback path, when we don't use flashinfer + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) for epsilon in [1e-5, 1e-6]: - AllReduceRMSNORMPattern( + AllReduceFusedRMSNormStaticQuantFP8Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormStaticQuantFP8Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + if current_platform.has_device_capability(100): + AllReduceFusedRMSNormStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceRMSNormPattern( epsilon, self.model_dtype, self.device, @@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass): self.allreduce_params, ).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + self.disabled = False def __call__(self, graph: fx.Graph): @@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass): if self.disabled: return if flashinfer_comm is not None: - flashinfer_comm.trtllm_destroy_ipc_workspace( + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( self.ipc_handles, self.group) diff --git a/vllm/config.py b/vllm/config.py index 27dde5f1b1f6f..edad5dd0406bf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4051,7 +4051,7 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 1024 + fi_allreduce_fusion_max_token_num: int = 16384 """Max number of tokens to used in flashinfer allreduce fusion.""" # TODO(luka) better pass enabling system.