diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5a9823bb6bae7..f5d9e3b22f2a6 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | +| `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | | `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 8fcff5a8c5113..b9e7de4e9fd11 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), + "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", + trust_remote_code=True, + is_available_online=False), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct", trust_remote_code=True), "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", @@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", trust_remote_code=True), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 + "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3", + trust_remote_code=True, + is_available_online=False), "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501 diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 88c8aa929b78d..099e456aa486f 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser __all__ = [ @@ -40,4 +41,5 @@ __all__ = [ "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "Step3ToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py new file mode 100644 index 0000000000000..a20d18eb52544 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import json +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["step3"]) +class Step3ToolParser(ToolParser): + """ + Tool parser for a model that uses a specific XML-like format for tool calls. + This version uses a robust, stateful, cursor-based streaming parser and + consolidates tool arguments into a single message. + """ + + TOOL_CALLS_BEGIN = "<|tool_calls_begin|>" + TOOL_CALLS_END = "<|tool_calls_end|>" + TOOL_CALL_BEGIN = "<|tool_call_begin|>" + TOOL_CALL_END = "<|tool_call_end|>" + TOOL_SEP = "<|tool_sep|>" + SPECIAL_TOKENS = [ + TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END + ] + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + # Explicit state flags for robust streaming + self.tool_block_started = False + self.tool_block_finished = False + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + request.skip_special_tokens = False + return request + + @staticmethod + def _parse_steptml_invoke( + action_text: str + ) -> tuple[Optional[str], Optional[dict[str, str]]]: + func_name_match = re.search(r'', + action_text) + if not func_name_match: + return None, None + func_name = func_name_match.group(1) + + params: dict[str, str] = {} + param_matches = re.findall( + r'([^<]*)', + action_text) + for name, value in param_matches: + params[name] = value.strip() + return func_name, params + + def _cast_arguments( + self, + func_name: str, + params: dict[str, Any], + request: ChatCompletionRequest, + ) -> dict[str, Any]: + for tool in request.tools or []: + if tool.function.name == func_name: + schema = tool.function.parameters or {} + properties = schema.get("properties", {}) + for key, value in params.items(): + if not isinstance(value, str): + continue + prop = properties.get(key, {}) + typ = prop.get("type") + if typ == "string": + params[key] = value.strip() + elif typ == "integer": + with contextlib.suppress(ValueError): + params[key] = int(value) + elif typ == "number": + with contextlib.suppress(ValueError): + params[key] = float(value) + elif typ == "boolean": + lower_val = value.lower() + params[key] = lower_val == "true" if lower_val in ( + "true", "false") else value + elif typ == "null": + params[key] = None if value.lower( + ) == "null" else value + break + return params + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # The main loop processes the stream from the last known position. + while True: + if self.position >= len(current_text): + return None # We've processed the entire stream. + + unprocessed_text = current_text[self.position:] + + # STATE: After all tools are done, all subsequent text is content. + if self.tool_block_finished: + self.position = len(current_text) + return DeltaMessage(content=unprocessed_text) + + # STATE: Before the tool block has started. + if not self.tool_block_started: + if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN): + self.position += len(self.TOOL_CALLS_BEGIN) + self.tool_block_started = True + continue # Token consumed, re-loop. + + start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN) + if start_pos == -1: + if self.TOOL_CALLS_BEGIN.startswith( + unprocessed_text.strip()) and unprocessed_text: + return None # It's a prefix, wait. + self.position = len(current_text) + return DeltaMessage(content=unprocessed_text) + else: + content = unprocessed_text[:start_pos] + self.position += len(content) + return DeltaMessage(content=content) + + # STATE: Inside the main tool block. + offset = len(unprocessed_text) - len(unprocessed_text.lstrip()) + unprocessed_text = unprocessed_text.lstrip() + self.position += offset + + if unprocessed_text.startswith(self.TOOL_CALLS_END): + self.position += len(self.TOOL_CALLS_END) + self.tool_block_finished = True + self.current_tool_id = -1 + continue + + # Check if we are between tool calls. + tool_finished = ( + self.current_tool_id != -1 and + self.prev_tool_call_arr[self.current_tool_id].get("finished")) + if self.current_tool_id == -1 or tool_finished: + if unprocessed_text.startswith(self.TOOL_CALL_BEGIN): + self.position += len(self.TOOL_CALL_BEGIN) + if self.current_tool_id == -1: + self.current_tool_id = 0 + else: + self.current_tool_id += 1 + self.current_tool_name_sent = False + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + self.prev_tool_call_arr[ + self.current_tool_id]["finished"] = False + continue + + if self.TOOL_CALL_BEGIN.startswith(unprocessed_text): + return None + + # STATE: Parsing an active tool call. + if self.current_tool_id != -1 and not self.prev_tool_call_arr[ + self.current_tool_id].get("finished", False): + end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END) + if end_tool_pos == -1: + tool_body = unprocessed_text + else: + tool_body = unprocessed_text[:end_tool_pos] + + if end_tool_pos == -1 and self.TOOL_CALL_END.startswith( + tool_body): + return None + + function_name, arguments = self._parse_steptml_invoke( + tool_body) + if not function_name: + return None + + tool_call_arr = { + "name": function_name, + "parameters": arguments or {} + } + + # Send the function name as soon as it's parsed. + if not self.current_tool_name_sent: + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id].update( + tool_call_arr) + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name)) + ]) + + # Update our internal state with the latest parsed arguments. + self.prev_tool_call_arr[ + self.current_tool_id].update( # noqa: E501 + tool_call_arr) + + # Only send arguments when the tool call is complete. + if end_tool_pos != -1: + self.position += end_tool_pos + len(self.TOOL_CALL_END) + self.prev_tool_call_arr[ + self.current_tool_id]["finished"] = True + + final_args = self._cast_arguments( + function_name, + tool_call_arr.get("parameters", {}), # type: ignore + request) + if final_args: + final_args_json = json.dumps(final_args, + ensure_ascii=False) + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=final_args_json)) + ]) + + # If tool is not finished, return None to wait for more tokens. + return None + + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + if self.TOOL_CALLS_BEGIN not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1) + if self.TOOL_CALLS_END not in rest: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1) + content = (pre_text + post_text).strip() + + tool_calls: list[ToolCall] = [] + call_parts = tool_block.split(self.TOOL_CALL_BEGIN) + + for part in call_parts: + if not part or self.TOOL_CALL_END not in part: + continue + + call_content = part.split(self.TOOL_CALL_END, 1)[0] + if self.TOOL_SEP not in call_content: + continue + + type_part, invoke_part = call_content.split(self.TOOL_SEP, 1) + if type_part.strip() != "function": + continue + + function_name, params_dict = self._parse_steptml_invoke( + invoke_part) + + if function_name and params_dict is not None: + params_dict = self._cast_arguments(function_name, params_dict, + request) + params_str = json.dumps(params_dict, ensure_ascii=False) + tool_calls.append( + ToolCall(function=FunctionCall(name=function_name, + arguments=params_str))) + if tool_calls: + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 51831a770347a..848c04b9b32f7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = { "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), @@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = { "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), + "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py new file mode 100644 index 0000000000000..47d2af5c2a140 --- /dev/null +++ b/vllm/model_executor/models/step3_text.py @@ -0,0 +1,521 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Jurassic model.""" +from collections.abc import Iterable +from typing import Any, Optional + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +logger = init_logger(__name__) + + +class FusedMoEBlock(nn.Module): + + def __init__(self, + config: ModelConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.moe_num_experts}.") + + self.experts = FusedMoE(num_experts=config.moe_num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_expert_weight, + quant_config=quant_config, + prefix=f"{prefix}.experts") + self.gate = ReplicatedLinear(config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits, _ = self.gate(hidden_states) + + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Step3TextMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + self.hidden_size = hidden_size + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + intermediate_act = self.act_fn(gate_up) + output, _ = self.down_proj(intermediate_act) + return output + + +class Step3TextAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + norm_eps: float, + rope_theta: int, + share_q_dim: Optional[int] = None, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embedding: int = 8192, + head_dim: int = 256, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + if num_kv_heads != 1: + raise ValueError(f"Step3TextAttention num_kv_heads must be 1, " + f"but got {num_kv_heads}.") + self.num_kv_heads = num_kv_heads + + self.head_dim = head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.q_size = share_q_dim if share_q_dim else self.head_dim + + self.qkv_proj = ReplicatedLinear( + hidden_size, + self.q_size + self.kv_size * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.inter_norm = RMSNorm(self.q_size, eps=norm_eps) + self.wq = ColumnParallelLinear( + self.q_size, + self.head_dim * self.total_num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq", + ) + self.rotary_emb = get_rope(self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embedding, + base=rope_theta, + rope_scaling=rope_scaling) + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + def forward(self, positions: torch.Tensor, + hidden_states: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.inter_norm(q) + q = self.wq(q)[0] + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + residual, _ = self.o_proj(attn_output) + return residual + + +class Step3TextDecoderLayer(nn.Module): + + def __init__(self, + config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + config = config.hf_config + self.hidden_size = config.hidden_size + rope_scaling = getattr(config, "rope_scaling", None) + + self.self_attn = Step3TextAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + norm_eps=config.rms_norm_eps, + max_position_embedding=config.max_position_embedding, + head_dim=config.head_dim, + share_q_dim=config.share_q_dim, + rope_theta=config.rope_theta, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn") + + layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [ + int(i) for i in moe_layers_enum.strip().split(',') + ] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + if layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.moe") + self.share_expert = Step3TextMLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.share_expert") + self.use_moe = True + else: + self.mlp = Step3TextMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.use_moe = False + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + hidden_states = share_output + moe_output + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Step3TextModel(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = config.vocab_size + self.config = config + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Step3TextDecoderLayer(config=vllm_config. + model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Step3TextForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + self.config = config + self.vllm_config = vllm_config + + self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None): + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + qkv_params_mapping = [ + # (param_name, shard_name, relative_start_idx, relative_end_idx) + (".qkv_proj", ".q_proj", 0, self.config.share_q_dim / + (self.config.share_q_dim + self.config.head_dim * 2)), + (".qkv_proj", ".k_proj", self.config.share_q_dim / + (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim) / + (self.config.share_q_dim + self.config.head_dim * 2)), + (".qkv_proj", ".v_proj", + (self.config.share_q_dim + self.config.head_dim) / + (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim * 2) / + (self.config.share_q_dim + self.config.head_dim * 2)), + ] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + ] + + disable_moe_stacked_params = [ + data[1] for data in expert_params_mapping + ] + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if any(disable_moe_stacked_param in name + for disable_moe_stacked_param in + disable_moe_stacked_params): + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + for expert_id in range(loaded_weight.shape[0]): + loaded_weight_expert = loaded_weight[expert_id] + weight_loader(param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id) + loaded_params.add(name) + break + else: + for (param_name, weight_name, start_idx, + end_idx) in qkv_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + dim = param.shape[param.output_dim] + begin_idx = int(start_idx * dim) + end_idx = int(end_idx * dim) + param_slice = param.narrow(param.output_dim, begin_idx, + end_idx - begin_idx) + param_slice.copy_(loaded_weight) + loaded_params.add(name) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py new file mode 100644 index 0000000000000..363c12a4bf2b8 --- /dev/null +++ b/vllm/model_executor/models/step3_vl.py @@ -0,0 +1,1052 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from itertools import product +from math import ceil, sqrt +from typing import Any, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from transformers import BatchFeature, PretrainedConfig, TensorType + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Step3VisionEncoderConfig +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + + +class Step3VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + patch_pixel_values: Optional[torch.Tensor] + num_patches: list[int] + + +class Step3VLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + image_embeds: torch.Tensor + + +Step3VLImageInputs = Union[Step3VLImagePixelInputs, + Step3VLImageEmbeddingInputs] + +ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] + +MAX_IMAGE_SIZE: int = 3024 + + +class Step3VisionProcessor: + + def __init__(self, size, interpolation_mode="bicubic", patch_size=None): + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + patch_size = patch_size if patch_size is not None else size + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (size, size), + interpolation=InterpolationMode.BICUBIC if interpolation_mode + == "bicubic" else InterpolationMode.BILINEAR, + antialias=True), + ]) + + self.patch_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (patch_size, patch_size), + interpolation=InterpolationMode.BICUBIC if interpolation_mode + == "bicubic" else InterpolationMode.BILINEAR, + antialias=True), + ]) if patch_size is not None else None + + def __call__(self, image, is_patch=False): + if is_patch: + return {"pixel_values": self.patch_transform(image).unsqueeze(0)} + else: + return {"pixel_values": self.transform(image).unsqueeze(0)} + + +class ImagePatcher: + + def determine_window_size(self, long: int, short: int) -> int: + if long <= 728: + return short if long / short > 1.5 else 0 + return min(short, 504) if long / short > 4 else 504 + + def slide_window( + self, + width: int, + height: int, + sizes: list[tuple[int, int]], + steps: list[tuple[int, int]], + img_rate_thr: float = 0.6, + ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]: + assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1" + windows = [] + # Sliding windows. + for size, step in zip(sizes, steps): + size_w, size_h = size + step_w, step_h = step + + x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + + 1) + x_start = [step_w * i for i in range(x_num)] + if len(x_start) > 1 and x_start[-1] + size_w > width: + x_start[-1] = width - size_w + + y_num = 1 if height <= size_h else ceil((height - size_h) / + step_h + 1) + y_start = [step_h * i for i in range(y_num)] + if len(y_start) > 1 and y_start[-1] + size_h > height: + y_start[-1] = height - size_h + + start = np.array(list(product(y_start, x_start)), dtype=int) + start[:, [0, 1]] = start[:, [1, 0]] + windows.append(np.concatenate([start, start + size], axis=1)) + windows = np.concatenate(windows, axis=0) + + return [(int(box[0]), int(box[1]), int(box[2] - box[0]), + int(box[3] - box[1])) for box in windows], (x_num, y_num) + + def square_pad(self, img: Image.Image) -> Image.Image: + w, h = img.size + if w == h: + return img + size = max(w, h) + padded = Image.new(img.mode, (size, size), 0) + padded.paste(img, (0, 0)) + return padded + + def get_image_size_for_padding(self, img_width: int, + img_height: int) -> tuple[int, int]: + ratio = img_width / img_height + if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4): + new_size = max(img_height, img_width) + return new_size, new_size + return img_width, img_height + + def get_image_size_for_preprocess(self, img_width: int, + img_height: int) -> tuple[int, int]: + + if max(img_height, img_width) > MAX_IMAGE_SIZE: + scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width) + img_width = int(img_width * scale_factor) + img_height = int(img_height * scale_factor) + return img_width, img_height + + def get_image_size_for_crop(self, img_width: int, img_height: int, + window_size: int): + w_ratio = img_width / window_size + h_ratio = img_height / window_size + + if w_ratio < 1: + width_new = img_width + else: + decimal_w = w_ratio - img_width // window_size + w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio) + width_new = window_size * w_ratio + if h_ratio < 1: + height_new = img_height + else: + decimal_h = h_ratio - img_height // window_size + h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio) + height_new = window_size * h_ratio + return int(width_new), int(height_new) + + def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int): + target = img.crop((j, i, j + tw, i + th)) + return target + + def get_num_patches(self, img_width: int, + img_height: int) -> tuple[int, int]: + img_width, img_height = self.get_image_size_for_padding( + img_width, img_height) + img_width, img_height = self.get_image_size_for_preprocess( + img_width, img_height) + window_size = self.determine_window_size(max(img_height, img_width), + min(img_height, img_width)) + if window_size == 0: + return 0, 0 + else: + img_width, img_height = self.get_image_size_for_crop( + img_width, img_height, window_size) + center_list, (x_num, y_num) = self.slide_window( + img_width, img_height, [(window_size, window_size)], + [(window_size, window_size)]) + full_rows = (len(center_list) - 1) // x_num + 1 + if len(center_list) > 0 and len(center_list) % x_num == 0: + full_rows -= 1 + return len(center_list), full_rows + + def __call__( + self, img: Image.Image + ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: + img_width, img_height = img.size + new_img_width, new_img_height = self.get_image_size_for_padding( + img_width, img_height) + if new_img_width != img_width or new_img_height != img_height: + img = self.square_pad(img) + img_width, img_height = img.size + + new_img_width, new_img_height = self.get_image_size_for_preprocess( + img_width, img_height) + img = img.resize((new_img_width, new_img_height), + Image.Resampling.BILINEAR) + window_size = self.determine_window_size( + max(new_img_height, new_img_width), + min(new_img_height, new_img_width)) + + if window_size == 0: + return img, [], None + else: + new_img_width, new_img_height = self.get_image_size_for_crop( + new_img_width, new_img_height, window_size) + if (new_img_width, new_img_height) != (img_width, img_height): + img_for_crop = img.resize((new_img_width, new_img_height), + Image.Resampling.BILINEAR) + else: + img_for_crop = img + + patches = [] + newlines = [] + center_list, (x_num, y_num) = self.slide_window( + new_img_width, new_img_height, [(window_size, window_size)], + [(window_size, window_size)]) + for patch_id, center_lf_point in enumerate(center_list): + x, y, patch_w, patch_h = center_lf_point + big_patch = self.patch_crop(img_for_crop, y, x, patch_h, + patch_w) + patches.append(big_patch) + if (patch_id + 1) % x_num == 0: + newlines.append(patch_id) + + if newlines and newlines[-1] == len(patches) - 1: + newlines.pop() + + return img, patches, [i in newlines for i in range(len(patches)) + ] if len(patches) > 0 else None + + +class Step3VLProcessor: + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + self.image_size = 728 + self.patch_size = 504 + self.image_preprocessor = Step3VisionProcessor(self.image_size, + "bilinear", + self.patch_size) + + self.num_image_feature_size = 169 + self.num_patch_feature_size = 81 + self.image_token = "" + self.image_feature_placeholder = (self.image_token * + self.num_image_feature_size) + self.patch_feature_placeholder = (self.image_token * + self.num_patch_feature_size) + + self.patcher = ImagePatcher() + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[self.image_token] + + def get_num_image_tokens(self, img_width: int, img_height: int) -> int: + num_patches, num_newlines = self.patcher.get_num_patches( + img_width, img_height) + + return num_patches * ( + self.num_patch_feature_size + + 2) + self.num_image_feature_size + 2 + num_newlines + + def _split_images(self, + images: list[Image.Image]) -> list[ImageWithPatches]: + result = [] + for img in images: + result.append(self.patcher(img)) + return result + + def _convert_images_to_pixel_values( + self, + images: list[Image.Image], + is_patch: bool = False, + ) -> list[torch.Tensor]: + return [ + self.image_preprocessor(img, is_patch=is_patch)["pixel_values"] + for img in images + ] + + def _get_patch_repl( + self, + num_patches: int, + patch_newline_mask: list[bool] | None, + ) -> tuple[str, list[int]]: + text = "" + token_ids = [] + for i in range(num_patches): + assert len(patch_newline_mask) == num_patches + text += f"{self.patch_feature_placeholder}" + token_ids.extend( + [self.tokenizer.convert_tokens_to_ids("")] + + [self.image_token_id] * self.num_patch_feature_size + + [self.tokenizer.convert_tokens_to_ids("")]) + if patch_newline_mask and patch_newline_mask[i]: + text += "" + token_ids.append( + self.tokenizer.convert_tokens_to_ids("")) + return text, token_ids + + def _get_image_repl( + self, + num_images: int, + ) -> tuple[str, list[int]]: + text = f"{self.image_feature_placeholder}" + token_ids = [ + self.tokenizer.convert_tokens_to_ids("") + ] + [self.image_token_id] * self.num_image_feature_size + [ + self.tokenizer.convert_tokens_to_ids("") + ] + return text * num_images, token_ids * num_images + + def _get_image_repl_features( + self, + num_images: int, + num_patches: int, + patch_new_line_idx: Optional[list[bool]], + ) -> tuple[str, list[int]]: + if num_patches > 0: + patch_repl, patch_repl_ids = self._get_patch_repl( + num_patches, patch_new_line_idx) + else: + patch_repl = "" + patch_repl_ids = [] + image_repl, image_repl_ids = self._get_image_repl(num_images) + return patch_repl + image_repl, patch_repl_ids + image_repl_ids + + def replace_placeholder(self, text: str, placeholder: str, + repls: list[str]) -> str: + parts = text.split(placeholder) + + if len(parts) - 1 != len(repls): + raise ValueError( + "The number of placeholders does not match the number of replacements." # noqa: E501 + ) + + result = [parts[0]] + for i, repl in enumerate(repls): + result.append(repl) + result.append(parts[i + 1]) + + return "".join(result) + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if len(images) == 0: + image_inputs = {} + text_inputs = self.tokenizer(text) + else: + splitted_images_data = self._split_images(images) + pixel_values_lst = [] + patch_pixel_values_lst = [] + patch_newline_mask_lst = [] + image_repl_str_lst = [] + image_repl_ids_lst = [] + num_patches = [] + for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501 + pixel_values_lst.extend( + self._convert_images_to_pixel_values([raw_img])) + + if len(img_patches) > 0: + patch_pixel_values_lst.extend( + self._convert_images_to_pixel_values(img_patches, + is_patch=True)) + num_patches.append(len(img_patches)) + + image_repl_str, image_repl_ids = self._get_image_repl_features( + 1, len(img_patches), patch_newline_mask) + image_repl_str_lst.append(image_repl_str) + image_repl_ids_lst.extend(image_repl_ids) + + if patch_newline_mask is not None: + patch_newline_mask_lst.extend(patch_newline_mask) + + image_inputs = { + "pixel_values": torch.cat(pixel_values_lst), + "num_patches": num_patches, + } + if patch_pixel_values_lst: + image_inputs["patch_pixel_values"] = torch.cat( + patch_pixel_values_lst) + if patch_newline_mask_lst: + image_inputs["patch_newline_mask"] = torch.tensor( + patch_newline_mask_lst, dtype=torch.bool) + + text = [ + self.replace_placeholder(t, self.image_token, + image_repl_str_lst) for t in text + ] + text_inputs = self.tokenizer(text) + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +class Step3VLProcessingInfo(BaseProcessingInfo): + + def get_hf_processor(self) -> Step3VLProcessor: + return Step3VLProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_max_image_tokens(self) -> int: + hf_processor = self.get_hf_processor() + return hf_processor.get_num_image_tokens( + self.get_image_size_with_most_features().width, + self.get_image_size_with_most_features().height) + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_image_size_with_most_features(self) -> ImageSize: + return ImageSize(3024, 3024) + + def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int: + if len(mm_data) != 1 or "image" not in mm_data: + raise ValueError( + "mm_data could only contain one key 'image' for steo1o") + + image_data = mm_data["image"] + if not isinstance(image_data, (list, tuple)): + image_data = [image_data] + + return sum(self.get_hf_processor().get_num_image_tokens( + img.width, img.height) for img in image_data) + + +class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + return "" * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] + ): + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_placeholder_token_id = hf_processor.image_token_id + batch_num_patches = out_mm_kwargs["num_patches"].tolist() + + def get_replacement_step1o(item_idx: int): + img_out = out_mm_kwargs.get_item("image", item_idx) + num_patches = batch_num_patches[item_idx] + if num_patches > 0: + patch_newline_mask = img_out["patch_newline_mask"].data.tolist( + ) + image_repl_ids = hf_processor._get_image_repl_features( + 1, num_patches, patch_newline_mask)[1] + else: + image_repl_ids = hf_processor._get_image_repl_features( + 1, 0, None)[1] + return PromptUpdateDetails.select_token_id( + seq=image_repl_ids, + embed_token_id=image_placeholder_token_id, + ) + + return [ + PromptReplacement( + modality="image", + target=[image_placeholder_token_id], + replacement=get_replacement_step1o, + ) + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + patch_pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), + patch_newline_mask=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + ) + + +def get_abs_pos(abs_pos, tgt_size): + dim = abs_pos.size(-1) + abs_pos_new = abs_pos.squeeze(0) + cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] + + src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + old_pos_embed = old_pos_embed.view(1, src_size, src_size, + dim).permute(0, 3, 1, + 2).contiguous() + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode='bicubic', + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) + vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, + dim) + return vision_pos_embed + else: + return abs_pos + + +class Step3VisionEmbeddings(nn.Module): + + def __init__(self, config: Step3VisionEncoderConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=True, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.pad_tp_size = 4 # hard code for padding + # To load the pretrained weights, we still use P+1 as the seqlen + self.position_embedding = torch.nn.Embedding(self.num_patches + 1, + self.embed_dim) + self.register_buffer("position_ids", + torch.arange(self.num_patches + 1).expand( + (1, -1)), + persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + # pad + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + get_abs_pos( + self.position_embedding(self.position_ids), patch_embeds.size(1)) + embeddings = torch.cat([ + embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, + 1), embeddings + ], + dim=1) + return embeddings + + +class Step3VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.total_num_heads + + self.scale = self.head_dim**-0.5 + + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.qkv_proj = QKVParallelLinear(self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=prefix) + self.out_proj = RowParallelLinear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=prefix) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) + k = k.view(bsz, tgt_len, self.num_heads, self.head_dim) + v = v.view(bsz, tgt_len, self.num_heads, self.head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + scale=self.scale, + is_causal=False) + attn_output = attn_output.transpose(1, 2).reshape( + bsz, tgt_len, self.num_heads * self.head_dim) + + attn_output, _ = self.out_proj(attn_output) + + return attn_output + + +class Step3VisionMLP(nn.Module): + + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=prefix) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=prefix) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Step3VisionEncoderLayer(nn.Module): + + def __init__(self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Step3VisionAttention(config, + quant_config, + prefix=f"{prefix}.self_attn") + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.FloatTensor: + hidden_states = hidden_states + self.layer_norm1( + self.self_attn(hidden_states)) + hidden_states = hidden_states + self.layer_norm2( + self.mlp(hidden_states)) + return hidden_states + + +class Step3VisionEncoder(nn.Module): + + def __init__(self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Step3VisionEncoderLayer(config, + quant_config, + prefix=f"{prefix}.layers.{i}") + for i in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + return hidden_states + + +class Step3VisionTransformer(nn.Module): + + def __init__(self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.image_size = config.image_size + self.embeddings = Step3VisionEmbeddings(config) + self.transformer = Step3VisionEncoder(config, + quant_config, + prefix=f"{prefix}.transformer") + + def forward( + self, + pixel_values: torch.Tensor, + ): + hidden_states = self.embeddings(pixel_values) + hidden_states = self.transformer(inputs_embeds=hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor, + info=Step3VLProcessingInfo, + dummy_inputs=Step3VLDummyInputsBuilder) +class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_model = Step3VisionTransformer(config.vision_config, + None, + prefix=maybe_prefix( + prefix, "vision_model")) + self.vit_downsampler = nn.Conv2d( + config.vision_config.hidden_size, + config.vision_config.output_hidden_size, + kernel_size=2, + stride=config.understand_projector_stride) + self.vit_downsampler2 = nn.Conv2d( + config.vision_config.output_hidden_size, + config.vision_config.output_hidden_size * 2, + kernel_size=3, + stride=2, + padding=1, + ) + self.vit_large_projector = nn.Linear( + config.vision_config.output_hidden_size * 2, + config.hidden_size, + bias=config.projector_bias, + ) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Step3VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + patch_pixel_values = kwargs.pop("patch_pixel_values", None) + num_patches = kwargs.pop("num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = flatten_bn(pixel_values, concat=True) + if pixel_values.dim() >= 3: + pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) + if patch_pixel_values is not None: + patch_pixel_values = flatten_bn(patch_pixel_values, + concat=True) + patch_pixel_values = patch_pixel_values.view( + -1, *patch_pixel_values.shape[-3:]) + # Handle empty patch_pixel_values by setting to None + if patch_pixel_values.shape[0] == 0: + patch_pixel_values = None + num_patches = flatten_bn(num_patches, concat=True).tolist() + + return Step3VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values.to(self.dtype).to(self.device), + patch_pixel_values=patch_pixel_values.to(self.dtype).to( + self.device) if patch_pixel_values is not None else None, + num_patches=num_patches, + ) + + if image_embeds is not None: + if image_embeds.dim() == 2 or image_embeds.dim() >= 3: + image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) + else: + raise ValueError( + f"Unexpected shape for image_embeds: {image_embeds.shape}") + + return Step3VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds.to(self.dtype).to(self.device), + ) + return None + + def _process_image_features(self, + image_features: torch.Tensor) -> torch.Tensor: + B, P = image_features.shape[:2] + HW = int(sqrt(P)) + image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) + image_features = self.vit_downsampler(image_features) + image_features = self.vit_downsampler2(image_features) + n_dim = image_features.size(1) + image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1) + image_features = self.vit_large_projector(image_features) + return image_features + + def _get_vision_model_output(self, + input_tensor: torch.Tensor) -> torch.Tensor: + return self.vision_model(input_tensor)[:, 4:] + + def _process_image_input( + self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]: + + if image_input["type"] == "image_embeds": + image_features = image_input["image_embeds"] + else: + image_features = self._get_vision_model_output( + image_input["pixel_values"]) + patch_image_features = self._get_vision_model_output( + image_input["patch_pixel_values"] + ) if image_input["patch_pixel_values"] is not None else None + num_patches = image_input["num_patches"] + + image_features = self._process_image_features(image_features) + patch_image_features = self._process_image_features( + patch_image_features) if patch_image_features is not None else None + + merged_image_features = [] + cur_patch_idx = 0 + for i, num_patch in enumerate(num_patches): + cur_feature = [] + if num_patch > 0: + patch_slice = patch_image_features[ + cur_patch_idx:cur_patch_idx + num_patch] + cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) + cur_feature.append(image_features[i].view( + -1, image_features.shape[-1])) + cur_patch_idx += num_patch + merged_image_features.append( + torch.cat(cur_feature) if len(cur_feature) > + 1 else cur_feature[0]) + return merged_image_features + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + if multimodal_embeddings is None: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + else: + is_text = input_ids != self.config.image_token_id + text_ids = input_ids[is_text] + text_embeds = self.language_model.model.get_input_embeddings( + text_ids) + inputs_embeds = torch.empty(input_ids.shape[0], + text_embeds.shape[-1], + dtype=text_embeds.dtype, + device=text_embeds.device) + inputs_embeds[is_text] = text_embeds + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + return loaded_weights diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index d61e4f11dfa29..1c3f78f2edbfb 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser +from .step3_reasoning_parser import Step3ReasoningParser __all__ = [ "ReasoningParser", @@ -18,4 +19,5 @@ __all__ = [ "Qwen3ReasoningParser", "Glm4MoeModelReasoningParser", "MistralReasoningParser", + "Step3ReasoningParser", ] diff --git a/vllm/reasoning/step3_reasoning_parser.py b/vllm/reasoning/step3_reasoning_parser.py new file mode 100644 index 0000000000000..f642ea977c580 --- /dev/null +++ b/vllm/reasoning/step3_reasoning_parser.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +import regex as re +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("step3") +class Step3ReasoningParser(ReasoningParser): + """ + Reasoning parser for Step3 model. + + The Step3 model uses token to denote the end of reasoning + text. This parser extracts all content before as reasoning content. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_end_token = "" + + self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", + re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_end_token_id = self.vocab.get(self.think_end_token) + if self.think_end_token_id is None: + raise RuntimeError( + "Step3 reasoning parser could not locate think end " + "token in the tokenizer!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text "abcxyz": + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special token + if len(delta_token_ids + ) == 1 and delta_token_ids[0] == self.think_end_token_id: + return None + + if self.think_end_token_id in delta_token_ids: + # in delta, extract reasoning content and remaining content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # already seen in previous text, everything is content + return DeltaMessage(content=delta_text) + else: + # No seen yet, everything is reasoning + return DeltaMessage(reasoning_content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + + # Check if the model output contains the token + if self.think_end_token not in model_output: + # If no token, everything is reasoning content + return model_output, None + else: + # Find the first occurrence of + end_index = model_output.find(self.think_end_token) + reasoning_content = model_output[:end_index] + + # Content after token + content = model_output[end_index + len(self.think_end_token):] + + if len(content) == 0: + content = None + + return reasoning_content, content + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.think_end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.think_end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.think_end_token_id) + 1:] diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4ce56cb3a6aac..fcaa48c1392a3 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, MllamaConfig, MLPSpeculatorConfig, Nemotron_Nano_VL_Config, NemotronConfig, NVLM_D_Config, - RWConfig, UltravoxConfig) + RWConfig, Step3TextConfig, + Step3VLConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file @@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, "ultravox": UltravoxConfig, + "step3_vl": Step3VLConfig, + "step3_text": Step3TextConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 7c7d859e4a325..96733da726181 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, + Step3VisionEncoderConfig, + Step3VLConfig) from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -42,4 +45,7 @@ __all__ = [ "Nemotron_Nano_VL_Config", "NVLM_D_Config", "UltravoxConfig", + "Step3VLConfig", + "Step3VisionEncoderConfig", + "Step3TextConfig", ] diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py new file mode 100644 index 0000000000000..fe3c72de69d28 --- /dev/null +++ b/vllm/transformers_utils/configs/step3_vl.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional, Union + +from transformers.configuration_utils import PretrainedConfig + + +class Step3VisionEncoderConfig(PretrainedConfig): + model_type = "step3_vision_encoder" + + def __init__( + self, + hidden_size=1792, + intermediate_size=3072, + output_hidden_size=4096, + num_hidden_layers=63, + num_attention_heads=16, + num_channels=3, + image_size=728, + patch_size=14, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + **kwargs, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.output_hidden_size = output_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + super().__init__(**kwargs) + + +class Step3TextConfig(PretrainedConfig): + model_type = "step3_text" + architectures = ["Step3TextForCausalLM"] + + def __init__( + self, + hidden_size: int = 7168, + intermediate_size: int = 18432, + num_attention_heads: int = 64, + num_attention_groups: int = 1, + num_hidden_layers: int = 61, + max_seq_len: int = 65536, + vocab_size: int = 128815, + rms_norm_eps: float = 1e-5, + moe_intermediate_size: int = 5120, + moe_num_experts: int = 48, + moe_top_k: int = 3, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embedding: int = 65536, + share_expert_dim: int = 5120, + share_q_dim: int = 2048, + head_dim: int = 256, + norm_expert_weight: bool = False, + moe_layers_enum: tuple[int, + ...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59), + **kwargs, + ) -> None: + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_attention_groups = num_attention_groups + self.num_hidden_layers = num_hidden_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.rms_norm_eps = rms_norm_eps + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.max_position_embedding = max_position_embedding + self.share_expert_dim = share_expert_dim + self.share_q_dim = share_q_dim + self.head_dim = head_dim + self.norm_expert_weight = norm_expert_weight + self.moe_layers_enum = moe_layers_enum + + super().__init__(**kwargs) + + +class Step3VLConfig(PretrainedConfig): + model_type = "step3_vl" + + def __init__( + self, + vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None, + text_config: Optional[Union[dict, Step3TextConfig]] = None, + understand_projector_stride: int = 1, + projector_bias: bool = True, + image_token_id: int = 128001, + **kwargs, + ) -> None: + if vision_config is None: + vision_config = Step3VisionEncoderConfig() + elif isinstance(vision_config, dict): + vision_config = Step3VisionEncoderConfig(**vision_config) + self.vision_config = vision_config + + if text_config is None: + text_config = Step3TextConfig() + elif isinstance(text_config, dict): + text_config = Step3TextConfig(**text_config) + self.text_config = text_config + + self.understand_projector_stride = understand_projector_stride + self.projector_bias = projector_bias + self.hidden_size = text_config.hidden_size + self.image_token_id = image_token_id + + super().__init__(**kwargs)