diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 5a9823bb6bae7..f5d9e3b22f2a6 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
+| `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 8fcff5a8c5113..b9e7de4e9fd11 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
+ "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
+ trust_remote_code=True,
+ is_available_online=False),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
trust_remote_code=True),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
@@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
+ "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
+ trust_remote_code=True,
+ is_available_online=False),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index 88c8aa929b78d..099e456aa486f 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
+from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
__all__ = [
@@ -40,4 +41,5 @@ __all__ = [
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
+ "Step3ToolParser",
]
diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
new file mode 100644
index 0000000000000..a20d18eb52544
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
@@ -0,0 +1,296 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import contextlib
+import json
+from collections.abc import Sequence
+from typing import Any, Optional, Union
+
+import regex as re
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module(["step3"])
+class Step3ToolParser(ToolParser):
+ """
+ Tool parser for a model that uses a specific XML-like format for tool calls.
+ This version uses a robust, stateful, cursor-based streaming parser and
+ consolidates tool arguments into a single message.
+ """
+
+ TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
+ TOOL_CALLS_END = "<|tool_calls_end|>"
+ TOOL_CALL_BEGIN = "<|tool_call_begin|>"
+ TOOL_CALL_END = "<|tool_call_end|>"
+ TOOL_SEP = "<|tool_sep|>"
+ SPECIAL_TOKENS = [
+ TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
+ ]
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+ self.position = 0
+ # Explicit state flags for robust streaming
+ self.tool_block_started = False
+ self.tool_block_finished = False
+
+ def adjust_request(
+ self, request: ChatCompletionRequest) -> ChatCompletionRequest:
+ if request.tools and request.tool_choice != 'none':
+ request.skip_special_tokens = False
+ return request
+
+ @staticmethod
+ def _parse_steptml_invoke(
+ action_text: str
+ ) -> tuple[Optional[str], Optional[dict[str, str]]]:
+ func_name_match = re.search(r'',
+ action_text)
+ if not func_name_match:
+ return None, None
+ func_name = func_name_match.group(1)
+
+ params: dict[str, str] = {}
+ param_matches = re.findall(
+ r'([^<]*)',
+ action_text)
+ for name, value in param_matches:
+ params[name] = value.strip()
+ return func_name, params
+
+ def _cast_arguments(
+ self,
+ func_name: str,
+ params: dict[str, Any],
+ request: ChatCompletionRequest,
+ ) -> dict[str, Any]:
+ for tool in request.tools or []:
+ if tool.function.name == func_name:
+ schema = tool.function.parameters or {}
+ properties = schema.get("properties", {})
+ for key, value in params.items():
+ if not isinstance(value, str):
+ continue
+ prop = properties.get(key, {})
+ typ = prop.get("type")
+ if typ == "string":
+ params[key] = value.strip()
+ elif typ == "integer":
+ with contextlib.suppress(ValueError):
+ params[key] = int(value)
+ elif typ == "number":
+ with contextlib.suppress(ValueError):
+ params[key] = float(value)
+ elif typ == "boolean":
+ lower_val = value.lower()
+ params[key] = lower_val == "true" if lower_val in (
+ "true", "false") else value
+ elif typ == "null":
+ params[key] = None if value.lower(
+ ) == "null" else value
+ break
+ return params
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ # The main loop processes the stream from the last known position.
+ while True:
+ if self.position >= len(current_text):
+ return None # We've processed the entire stream.
+
+ unprocessed_text = current_text[self.position:]
+
+ # STATE: After all tools are done, all subsequent text is content.
+ if self.tool_block_finished:
+ self.position = len(current_text)
+ return DeltaMessage(content=unprocessed_text)
+
+ # STATE: Before the tool block has started.
+ if not self.tool_block_started:
+ if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
+ self.position += len(self.TOOL_CALLS_BEGIN)
+ self.tool_block_started = True
+ continue # Token consumed, re-loop.
+
+ start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
+ if start_pos == -1:
+ if self.TOOL_CALLS_BEGIN.startswith(
+ unprocessed_text.strip()) and unprocessed_text:
+ return None # It's a prefix, wait.
+ self.position = len(current_text)
+ return DeltaMessage(content=unprocessed_text)
+ else:
+ content = unprocessed_text[:start_pos]
+ self.position += len(content)
+ return DeltaMessage(content=content)
+
+ # STATE: Inside the main tool block.
+ offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
+ unprocessed_text = unprocessed_text.lstrip()
+ self.position += offset
+
+ if unprocessed_text.startswith(self.TOOL_CALLS_END):
+ self.position += len(self.TOOL_CALLS_END)
+ self.tool_block_finished = True
+ self.current_tool_id = -1
+ continue
+
+ # Check if we are between tool calls.
+ tool_finished = (
+ self.current_tool_id != -1 and
+ self.prev_tool_call_arr[self.current_tool_id].get("finished"))
+ if self.current_tool_id == -1 or tool_finished:
+ if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
+ self.position += len(self.TOOL_CALL_BEGIN)
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ else:
+ self.current_tool_id += 1
+ self.current_tool_name_sent = False
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ self.prev_tool_call_arr[
+ self.current_tool_id]["finished"] = False
+ continue
+
+ if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
+ return None
+
+ # STATE: Parsing an active tool call.
+ if self.current_tool_id != -1 and not self.prev_tool_call_arr[
+ self.current_tool_id].get("finished", False):
+ end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
+ if end_tool_pos == -1:
+ tool_body = unprocessed_text
+ else:
+ tool_body = unprocessed_text[:end_tool_pos]
+
+ if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
+ tool_body):
+ return None
+
+ function_name, arguments = self._parse_steptml_invoke(
+ tool_body)
+ if not function_name:
+ return None
+
+ tool_call_arr = {
+ "name": function_name,
+ "parameters": arguments or {}
+ }
+
+ # Send the function name as soon as it's parsed.
+ if not self.current_tool_name_sent:
+ self.current_tool_name_sent = True
+ self.prev_tool_call_arr[self.current_tool_id].update(
+ tool_call_arr)
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name))
+ ])
+
+ # Update our internal state with the latest parsed arguments.
+ self.prev_tool_call_arr[
+ self.current_tool_id].update( # noqa: E501
+ tool_call_arr)
+
+ # Only send arguments when the tool call is complete.
+ if end_tool_pos != -1:
+ self.position += end_tool_pos + len(self.TOOL_CALL_END)
+ self.prev_tool_call_arr[
+ self.current_tool_id]["finished"] = True
+
+ final_args = self._cast_arguments(
+ function_name,
+ tool_call_arr.get("parameters", {}), # type: ignore
+ request)
+ if final_args:
+ final_args_json = json.dumps(final_args,
+ ensure_ascii=False)
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=final_args_json))
+ ])
+
+ # If tool is not finished, return None to wait for more tokens.
+ return None
+
+ return None
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ if self.TOOL_CALLS_BEGIN not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
+ if self.TOOL_CALLS_END not in rest:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
+ content = (pre_text + post_text).strip()
+
+ tool_calls: list[ToolCall] = []
+ call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
+
+ for part in call_parts:
+ if not part or self.TOOL_CALL_END not in part:
+ continue
+
+ call_content = part.split(self.TOOL_CALL_END, 1)[0]
+ if self.TOOL_SEP not in call_content:
+ continue
+
+ type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
+ if type_part.strip() != "function":
+ continue
+
+ function_name, params_dict = self._parse_steptml_invoke(
+ invoke_part)
+
+ if function_name and params_dict is not None:
+ params_dict = self._cast_arguments(function_name, params_dict,
+ request)
+ params_str = json.dumps(params_dict, ensure_ascii=False)
+ tool_calls.append(
+ ToolCall(function=FunctionCall(name=function_name,
+ arguments=params_str)))
+ if tool_calls:
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content if content else None)
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 51831a770347a..848c04b9b32f7 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
+ "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
@@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
+ "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py
new file mode 100644
index 0000000000000..47d2af5c2a140
--- /dev/null
+++ b/vllm/model_executor/models/step3_text.py
@@ -0,0 +1,521 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only Jurassic model."""
+from collections.abc import Iterable
+from typing import Any, Optional
+
+import torch
+from torch import nn
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, ModelConfig, VllmConfig
+from vllm.distributed import (get_pp_group,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsPP
+from .utils import (PPMissingLayer, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers)
+
+logger = init_logger(__name__)
+
+
+class FusedMoEBlock(nn.Module):
+
+ def __init__(self,
+ config: ModelConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+
+ if self.tp_size > config.moe_num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {config.moe_num_experts}.")
+
+ self.experts = FusedMoE(num_experts=config.moe_num_experts,
+ top_k=config.moe_top_k,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_expert_weight,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts")
+ self.gate = ReplicatedLinear(config.hidden_size,
+ config.moe_num_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ orig_shape = hidden_states.shape
+ hidden_dim = hidden_states.shape[-1]
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ router_logits, _ = self.gate(hidden_states)
+
+ final_hidden_states = self.experts(hidden_states=hidden_states,
+ router_logits=router_logits)
+ if self.tp_size > 1:
+ final_hidden_states = tensor_model_parallel_all_reduce(
+ final_hidden_states)
+
+ return final_hidden_states.view(orig_shape)
+
+
+class Step3TextMLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size, [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj")
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+ self.hidden_size = hidden_size
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_up, _ = self.gate_up_proj(hidden_states)
+ intermediate_act = self.act_fn(gate_up)
+ output, _ = self.down_proj(intermediate_act)
+ return output
+
+
+class Step3TextAttention(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ norm_eps: float,
+ rope_theta: int,
+ share_q_dim: Optional[int] = None,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embedding: int = 8192,
+ head_dim: int = 256,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+
+ if num_kv_heads != 1:
+ raise ValueError(f"Step3TextAttention num_kv_heads must be 1, "
+ f"but got {num_kv_heads}.")
+ self.num_kv_heads = num_kv_heads
+
+ self.head_dim = head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.q_size = share_q_dim if share_q_dim else self.head_dim
+
+ self.qkv_proj = ReplicatedLinear(
+ hidden_size,
+ self.q_size + self.kv_size * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+ self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
+ self.wq = ColumnParallelLinear(
+ self.q_size,
+ self.head_dim * self.total_num_heads,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.wq",
+ )
+ self.rotary_emb = get_rope(self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embedding,
+ base=rope_theta,
+ rope_scaling=rope_scaling)
+ scaling = self.head_dim**-0.5
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ scaling,
+ self.num_kv_heads,
+ cache_config=cache_config,
+ prefix=f"{prefix}.attn")
+
+ def forward(self, positions: torch.Tensor,
+ hidden_states: torch.Tensor) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.inter_norm(q)
+ q = self.wq(q)[0]
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ residual, _ = self.o_proj(attn_output)
+ return residual
+
+
+class Step3TextDecoderLayer(nn.Module):
+
+ def __init__(self,
+ config: ModelConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "") -> None:
+ super().__init__()
+ config = config.hf_config
+ self.hidden_size = config.hidden_size
+ rope_scaling = getattr(config, "rope_scaling", None)
+
+ self.self_attn = Step3TextAttention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=1,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ norm_eps=config.rms_norm_eps,
+ max_position_embedding=config.max_position_embedding,
+ head_dim=config.head_dim,
+ share_q_dim=config.share_q_dim,
+ rope_theta=config.rope_theta,
+ rope_scaling=rope_scaling,
+ prefix=f"{prefix}.self_attn")
+
+ layer_idx = int(prefix.split("layers.")[1].split(".")[0])
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
+ if moe_layers_enum is not None:
+ moe_layers_idx = [
+ int(i) for i in moe_layers_enum.strip().split(',')
+ ]
+ else:
+ # Default to 1dense.
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
+
+ if layer_idx in moe_layers_idx:
+ self.moe = FusedMoEBlock(config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.moe")
+ self.share_expert = Step3TextMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.share_expert_dim,
+ hidden_act="silu",
+ quant_config=quant_config,
+ prefix=f"{prefix}.share_expert")
+ self.use_moe = True
+ else:
+ self.mlp = Step3TextMLP(hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act="silu",
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+ self.use_moe = False
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+
+ def forward(
+ self, positions: torch.Tensor, hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+
+ if self.use_moe:
+ share_output = self.share_expert(hidden_states)
+ moe_output = self.moe(hidden_states)
+ hidden_states = share_output + moe_output
+ else:
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states, residual
+
+
+@support_torch_compile
+class Step3TextModel(nn.Module):
+
+ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ self.vocab_size = config.vocab_size
+ self.config = config
+
+ if get_pp_group().is_first_rank or (config.tie_word_embeddings
+ and get_pp_group().is_last_rank):
+ self.embed_tokens = VocabParallelEmbedding(
+ self.vocab_size,
+ config.hidden_size,
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Step3TextDecoderLayer(config=vllm_config.
+ model_config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix),
+ prefix=f"{prefix}.layers",
+ )
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(["hidden_states"],
+ config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual,
+ })
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Step3TextForCausalLM(nn.Module, SupportsPP):
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ lora_config = vllm_config.lora_config
+ self.config = config
+ self.vllm_config = vllm_config
+
+ self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
+
+ if get_pp_group().is_last_rank:
+ self.unpadded_vocab_size = config.vocab_size
+ if lora_config:
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+ self.lm_head = ParallelLMHead(
+ self.unpadded_vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE
+ if not lora_config else lora_config.lora_vocab_padding_size,
+ )
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+ config.vocab_size)
+ self.sampler = get_sampler()
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+ def forward(self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None):
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: Optional[torch.Tensor],
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ qkv_params_mapping = [
+ # (param_name, shard_name, relative_start_idx, relative_end_idx)
+ (".qkv_proj", ".q_proj", 0, self.config.share_q_dim /
+ (self.config.share_q_dim + self.config.head_dim * 2)),
+ (".qkv_proj", ".k_proj", self.config.share_q_dim /
+ (self.config.share_q_dim + self.config.head_dim * 2),
+ (self.config.share_q_dim + self.config.head_dim) /
+ (self.config.share_q_dim + self.config.head_dim * 2)),
+ (".qkv_proj", ".v_proj",
+ (self.config.share_q_dim + self.config.head_dim) /
+ (self.config.share_q_dim + self.config.head_dim * 2),
+ (self.config.share_q_dim + self.config.head_dim * 2) /
+ (self.config.share_q_dim + self.config.head_dim * 2)),
+ ]
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ expert_params_mapping = [
+ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
+ (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
+ (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
+ ]
+
+ disable_moe_stacked_params = [
+ data[1] for data in expert_params_mapping
+ ]
+
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ if any(disable_moe_stacked_param in name
+ for disable_moe_stacked_param in
+ disable_moe_stacked_params):
+ continue
+ name = name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(name)
+ break
+ else:
+ for mapping in expert_params_mapping:
+ param_name, weight_name, shard_id = mapping
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ # Skip loading extra bias for GPTQ models.
+ if ((name.endswith(".bias") or name.endswith("_bias"))
+ and name not in params_dict):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ for expert_id in range(loaded_weight.shape[0]):
+ loaded_weight_expert = loaded_weight[expert_id]
+ weight_loader(param,
+ loaded_weight_expert,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id)
+ loaded_params.add(name)
+ break
+ else:
+ for (param_name, weight_name, start_idx,
+ end_idx) in qkv_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ dim = param.shape[param.output_dim]
+ begin_idx = int(start_idx * dim)
+ end_idx = int(end_idx * dim)
+ param_slice = param.narrow(param.output_dim, begin_idx,
+ end_idx - begin_idx)
+ param_slice.copy_(loaded_weight)
+ loaded_params.add(name)
+ break
+ else:
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py
new file mode 100644
index 0000000000000..363c12a4bf2b8
--- /dev/null
+++ b/vllm/model_executor/models/step3_vl.py
@@ -0,0 +1,1052 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from functools import cached_property
+from itertools import product
+from math import ceil, sqrt
+from typing import Any, Literal, Optional, TypedDict, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from transformers import BatchFeature, PretrainedConfig, TensorType
+
+from vllm.config import VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs, NestedTensors)
+from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate, PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs import Step3VisionEncoderConfig
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
+
+
+class Step3VLImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values: torch.Tensor
+ patch_pixel_values: Optional[torch.Tensor]
+ num_patches: list[int]
+
+
+class Step3VLImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ image_embeds: torch.Tensor
+
+
+Step3VLImageInputs = Union[Step3VLImagePixelInputs,
+ Step3VLImageEmbeddingInputs]
+
+ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
+
+MAX_IMAGE_SIZE: int = 3024
+
+
+class Step3VisionProcessor:
+
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+ patch_size = patch_size if patch_size is not None else size
+
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ transforms.Resize(
+ (size, size),
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
+ == "bicubic" else InterpolationMode.BILINEAR,
+ antialias=True),
+ ])
+
+ self.patch_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ transforms.Resize(
+ (patch_size, patch_size),
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
+ == "bicubic" else InterpolationMode.BILINEAR,
+ antialias=True),
+ ]) if patch_size is not None else None
+
+ def __call__(self, image, is_patch=False):
+ if is_patch:
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
+ else:
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
+
+
+class ImagePatcher:
+
+ def determine_window_size(self, long: int, short: int) -> int:
+ if long <= 728:
+ return short if long / short > 1.5 else 0
+ return min(short, 504) if long / short > 4 else 504
+
+ def slide_window(
+ self,
+ width: int,
+ height: int,
+ sizes: list[tuple[int, int]],
+ steps: list[tuple[int, int]],
+ img_rate_thr: float = 0.6,
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
+ windows = []
+ # Sliding windows.
+ for size, step in zip(sizes, steps):
+ size_w, size_h = size
+ step_w, step_h = step
+
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
+ 1)
+ x_start = [step_w * i for i in range(x_num)]
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
+ x_start[-1] = width - size_w
+
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
+ step_h + 1)
+ y_start = [step_h * i for i in range(y_num)]
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
+ y_start[-1] = height - size_h
+
+ start = np.array(list(product(y_start, x_start)), dtype=int)
+ start[:, [0, 1]] = start[:, [1, 0]]
+ windows.append(np.concatenate([start, start + size], axis=1))
+ windows = np.concatenate(windows, axis=0)
+
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
+
+ def square_pad(self, img: Image.Image) -> Image.Image:
+ w, h = img.size
+ if w == h:
+ return img
+ size = max(w, h)
+ padded = Image.new(img.mode, (size, size), 0)
+ padded.paste(img, (0, 0))
+ return padded
+
+ def get_image_size_for_padding(self, img_width: int,
+ img_height: int) -> tuple[int, int]:
+ ratio = img_width / img_height
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
+ new_size = max(img_height, img_width)
+ return new_size, new_size
+ return img_width, img_height
+
+ def get_image_size_for_preprocess(self, img_width: int,
+ img_height: int) -> tuple[int, int]:
+
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
+ img_width = int(img_width * scale_factor)
+ img_height = int(img_height * scale_factor)
+ return img_width, img_height
+
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
+ window_size: int):
+ w_ratio = img_width / window_size
+ h_ratio = img_height / window_size
+
+ if w_ratio < 1:
+ width_new = img_width
+ else:
+ decimal_w = w_ratio - img_width // window_size
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
+ width_new = window_size * w_ratio
+ if h_ratio < 1:
+ height_new = img_height
+ else:
+ decimal_h = h_ratio - img_height // window_size
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
+ height_new = window_size * h_ratio
+ return int(width_new), int(height_new)
+
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
+ target = img.crop((j, i, j + tw, i + th))
+ return target
+
+ def get_num_patches(self, img_width: int,
+ img_height: int) -> tuple[int, int]:
+ img_width, img_height = self.get_image_size_for_padding(
+ img_width, img_height)
+ img_width, img_height = self.get_image_size_for_preprocess(
+ img_width, img_height)
+ window_size = self.determine_window_size(max(img_height, img_width),
+ min(img_height, img_width))
+ if window_size == 0:
+ return 0, 0
+ else:
+ img_width, img_height = self.get_image_size_for_crop(
+ img_width, img_height, window_size)
+ center_list, (x_num, y_num) = self.slide_window(
+ img_width, img_height, [(window_size, window_size)],
+ [(window_size, window_size)])
+ full_rows = (len(center_list) - 1) // x_num + 1
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
+ full_rows -= 1
+ return len(center_list), full_rows
+
+ def __call__(
+ self, img: Image.Image
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
+ img_width, img_height = img.size
+ new_img_width, new_img_height = self.get_image_size_for_padding(
+ img_width, img_height)
+ if new_img_width != img_width or new_img_height != img_height:
+ img = self.square_pad(img)
+ img_width, img_height = img.size
+
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
+ img_width, img_height)
+ img = img.resize((new_img_width, new_img_height),
+ Image.Resampling.BILINEAR)
+ window_size = self.determine_window_size(
+ max(new_img_height, new_img_width),
+ min(new_img_height, new_img_width))
+
+ if window_size == 0:
+ return img, [], None
+ else:
+ new_img_width, new_img_height = self.get_image_size_for_crop(
+ new_img_width, new_img_height, window_size)
+ if (new_img_width, new_img_height) != (img_width, img_height):
+ img_for_crop = img.resize((new_img_width, new_img_height),
+ Image.Resampling.BILINEAR)
+ else:
+ img_for_crop = img
+
+ patches = []
+ newlines = []
+ center_list, (x_num, y_num) = self.slide_window(
+ new_img_width, new_img_height, [(window_size, window_size)],
+ [(window_size, window_size)])
+ for patch_id, center_lf_point in enumerate(center_list):
+ x, y, patch_w, patch_h = center_lf_point
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
+ patch_w)
+ patches.append(big_patch)
+ if (patch_id + 1) % x_num == 0:
+ newlines.append(patch_id)
+
+ if newlines and newlines[-1] == len(patches) - 1:
+ newlines.pop()
+
+ return img, patches, [i in newlines for i in range(len(patches))
+ ] if len(patches) > 0 else None
+
+
+class Step3VLProcessor:
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: AnyTokenizer,
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.image_size = 728
+ self.patch_size = 504
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
+ "bilinear",
+ self.patch_size)
+
+ self.num_image_feature_size = 169
+ self.num_patch_feature_size = 81
+ self.image_token = ""
+ self.image_feature_placeholder = (self.image_token *
+ self.num_image_feature_size)
+ self.patch_feature_placeholder = (self.image_token *
+ self.num_patch_feature_size)
+
+ self.patcher = ImagePatcher()
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[self.image_token]
+
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
+ num_patches, num_newlines = self.patcher.get_num_patches(
+ img_width, img_height)
+
+ return num_patches * (
+ self.num_patch_feature_size +
+ 2) + self.num_image_feature_size + 2 + num_newlines
+
+ def _split_images(self,
+ images: list[Image.Image]) -> list[ImageWithPatches]:
+ result = []
+ for img in images:
+ result.append(self.patcher(img))
+ return result
+
+ def _convert_images_to_pixel_values(
+ self,
+ images: list[Image.Image],
+ is_patch: bool = False,
+ ) -> list[torch.Tensor]:
+ return [
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
+ for img in images
+ ]
+
+ def _get_patch_repl(
+ self,
+ num_patches: int,
+ patch_newline_mask: list[bool] | None,
+ ) -> tuple[str, list[int]]:
+ text = ""
+ token_ids = []
+ for i in range(num_patches):
+ assert len(patch_newline_mask) == num_patches
+ text += f"{self.patch_feature_placeholder}"
+ token_ids.extend(
+ [self.tokenizer.convert_tokens_to_ids("")] +
+ [self.image_token_id] * self.num_patch_feature_size +
+ [self.tokenizer.convert_tokens_to_ids("")])
+ if patch_newline_mask and patch_newline_mask[i]:
+ text += ""
+ token_ids.append(
+ self.tokenizer.convert_tokens_to_ids(""))
+ return text, token_ids
+
+ def _get_image_repl(
+ self,
+ num_images: int,
+ ) -> tuple[str, list[int]]:
+ text = f"{self.image_feature_placeholder}"
+ token_ids = [
+ self.tokenizer.convert_tokens_to_ids("")
+ ] + [self.image_token_id] * self.num_image_feature_size + [
+ self.tokenizer.convert_tokens_to_ids("")
+ ]
+ return text * num_images, token_ids * num_images
+
+ def _get_image_repl_features(
+ self,
+ num_images: int,
+ num_patches: int,
+ patch_new_line_idx: Optional[list[bool]],
+ ) -> tuple[str, list[int]]:
+ if num_patches > 0:
+ patch_repl, patch_repl_ids = self._get_patch_repl(
+ num_patches, patch_new_line_idx)
+ else:
+ patch_repl = ""
+ patch_repl_ids = []
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
+
+ def replace_placeholder(self, text: str, placeholder: str,
+ repls: list[str]) -> str:
+ parts = text.split(placeholder)
+
+ if len(parts) - 1 != len(repls):
+ raise ValueError(
+ "The number of placeholders does not match the number of replacements." # noqa: E501
+ )
+
+ result = [parts[0]]
+ for i, repl in enumerate(repls):
+ result.append(repl)
+ result.append(parts[i + 1])
+
+ return "".join(result)
+
+ def __call__(
+ self,
+ text: Optional[Union[str, list[str]]] = None,
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ ) -> BatchFeature:
+ if text is None:
+ text = []
+ if not isinstance(text, list):
+ text = [text]
+ if images is None:
+ images = []
+ if not isinstance(images, list):
+ images = [images]
+
+ if len(images) == 0:
+ image_inputs = {}
+ text_inputs = self.tokenizer(text)
+ else:
+ splitted_images_data = self._split_images(images)
+ pixel_values_lst = []
+ patch_pixel_values_lst = []
+ patch_newline_mask_lst = []
+ image_repl_str_lst = []
+ image_repl_ids_lst = []
+ num_patches = []
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
+ pixel_values_lst.extend(
+ self._convert_images_to_pixel_values([raw_img]))
+
+ if len(img_patches) > 0:
+ patch_pixel_values_lst.extend(
+ self._convert_images_to_pixel_values(img_patches,
+ is_patch=True))
+ num_patches.append(len(img_patches))
+
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
+ 1, len(img_patches), patch_newline_mask)
+ image_repl_str_lst.append(image_repl_str)
+ image_repl_ids_lst.extend(image_repl_ids)
+
+ if patch_newline_mask is not None:
+ patch_newline_mask_lst.extend(patch_newline_mask)
+
+ image_inputs = {
+ "pixel_values": torch.cat(pixel_values_lst),
+ "num_patches": num_patches,
+ }
+ if patch_pixel_values_lst:
+ image_inputs["patch_pixel_values"] = torch.cat(
+ patch_pixel_values_lst)
+ if patch_newline_mask_lst:
+ image_inputs["patch_newline_mask"] = torch.tensor(
+ patch_newline_mask_lst, dtype=torch.bool)
+
+ text = [
+ self.replace_placeholder(t, self.image_token,
+ image_repl_str_lst) for t in text
+ ]
+ text_inputs = self.tokenizer(text)
+
+ return BatchFeature(
+ {
+ **text_inputs,
+ **image_inputs,
+ },
+ tensor_type=return_tensors,
+ )
+
+
+class Step3VLProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_processor(self) -> Step3VLProcessor:
+ return Step3VLProcessor(
+ self.get_hf_config(),
+ self.get_tokenizer(),
+ )
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_max_image_tokens(self) -> int:
+ hf_processor = self.get_hf_processor()
+ return hf_processor.get_num_image_tokens(
+ self.get_image_size_with_most_features().width,
+ self.get_image_size_with_most_features().height)
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {"image": self.get_max_image_tokens()}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ return ImageSize(3024, 3024)
+
+ def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
+ if len(mm_data) != 1 or "image" not in mm_data:
+ raise ValueError(
+ "mm_data could only contain one key 'image' for steo1o")
+
+ image_data = mm_data["image"]
+ if not isinstance(image_data, (list, tuple)):
+ image_data = [image_data]
+
+ return sum(self.get_hf_processor().get_num_image_tokens(
+ img.width, img.height) for img in image_data)
+
+
+class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ return "" * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+ num_images = mm_counts.get("image", 0)
+
+ return {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+
+
+class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]
+ ):
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_placeholder_token_id = hf_processor.image_token_id
+ batch_num_patches = out_mm_kwargs["num_patches"].tolist()
+
+ def get_replacement_step1o(item_idx: int):
+ img_out = out_mm_kwargs.get_item("image", item_idx)
+ num_patches = batch_num_patches[item_idx]
+ if num_patches > 0:
+ patch_newline_mask = img_out["patch_newline_mask"].data.tolist(
+ )
+ image_repl_ids = hf_processor._get_image_repl_features(
+ 1, num_patches, patch_newline_mask)[1]
+ else:
+ image_repl_ids = hf_processor._get_image_repl_features(
+ 1, 0, None)[1]
+ return PromptUpdateDetails.select_token_id(
+ seq=image_repl_ids,
+ embed_token_id=image_placeholder_token_id,
+ )
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[image_placeholder_token_id],
+ replacement=get_replacement_step1o,
+ )
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ num_patches = hf_inputs.get("num_patches", torch.empty(0))
+
+ return dict(
+ pixel_values=MultiModalFieldConfig.batched("image"),
+ patch_pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_patches),
+ num_patches=MultiModalFieldConfig.batched("image"),
+ patch_newline_mask=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_patches),
+ )
+
+
+def get_abs_pos(abs_pos, tgt_size):
+ dim = abs_pos.size(-1)
+ abs_pos_new = abs_pos.squeeze(0)
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
+
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
+ tgt_size = int(math.sqrt(tgt_size))
+ dtype = abs_pos.dtype
+
+ if src_size != tgt_size:
+ old_pos_embed = old_pos_embed.view(1, src_size, src_size,
+ dim).permute(0, 3, 1,
+ 2).contiguous()
+ old_pos_embed = old_pos_embed.to(torch.float32)
+ new_pos_embed = F.interpolate(
+ old_pos_embed,
+ size=(tgt_size, tgt_size),
+ mode='bicubic',
+ antialias=True,
+ align_corners=False,
+ ).to(dtype)
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
+ new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
+ vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
+ vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1,
+ dim)
+ return vision_pos_embed
+ else:
+ return abs_pos
+
+
+class Step3VisionEmbeddings(nn.Module):
+
+ def __init__(self, config: Step3VisionEncoderConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=True,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.pad_tp_size = 4 # hard code for padding
+ # To load the pretrained weights, we still use P+1 as the seqlen
+ self.position_embedding = torch.nn.Embedding(self.num_patches + 1,
+ self.embed_dim)
+ self.register_buffer("position_ids",
+ torch.arange(self.num_patches + 1).expand(
+ (1, -1)),
+ persistent=False)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(
+ pixel_values) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ # pad
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + get_abs_pos(
+ self.position_embedding(self.position_ids), patch_embeds.size(1))
+ embeddings = torch.cat([
+ embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1,
+ 1), embeddings
+ ],
+ dim=1)
+ return embeddings
+
+
+class Step3VisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self,
+ config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.total_num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.total_num_heads
+
+ self.scale = self.head_dim**-0.5
+
+ tp_size = get_tensor_model_parallel_world_size()
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.qkv_proj = QKVParallelLinear(self.embed_dim,
+ self.head_dim,
+ self.total_num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+ self.out_proj = RowParallelLinear(self.embed_dim,
+ self.embed_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads,
+ self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ):
+ """Input shape: Batch x Time x Channel"""
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
+ q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ attn_output = F.scaled_dot_product_attention(q,
+ k,
+ v,
+ scale=self.scale,
+ is_causal=False)
+ attn_output = attn_output.transpose(1, 2).reshape(
+ bsz, tgt_len, self.num_heads * self.head_dim)
+
+ attn_output, _ = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Step3VisionMLP(nn.Module):
+
+ def __init__(self,
+ config,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.activation_fn = get_act_fn(config.hidden_act)
+ self.fc1 = ColumnParallelLinear(config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+ self.fc2 = RowParallelLinear(config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states, _ = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Step3VisionEncoderLayer(nn.Module):
+
+ def __init__(self,
+ config: Step3VisionEncoderConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Step3VisionAttention(config,
+ quant_config,
+ prefix=f"{prefix}.self_attn")
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+ self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.FloatTensor:
+ hidden_states = hidden_states + self.layer_norm1(
+ self.self_attn(hidden_states))
+ hidden_states = hidden_states + self.layer_norm2(
+ self.mlp(hidden_states))
+ return hidden_states
+
+
+class Step3VisionEncoder(nn.Module):
+
+ def __init__(self,
+ config: Step3VisionEncoderConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ Step3VisionEncoderLayer(config,
+ quant_config,
+ prefix=f"{prefix}.layers.{i}")
+ for i in range(config.num_hidden_layers)
+ ])
+
+ def forward(
+ self,
+ inputs_embeds,
+ ):
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(hidden_states)
+ return hidden_states
+
+
+class Step3VisionTransformer(nn.Module):
+
+ def __init__(self,
+ config: Step3VisionEncoderConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.config = config
+ self.image_size = config.image_size
+ self.embeddings = Step3VisionEmbeddings(config)
+ self.transformer = Step3VisionEncoder(config,
+ quant_config,
+ prefix=f"{prefix}.transformer")
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ ):
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.transformer(inputs_embeds=hidden_states)
+ return hidden_states
+
+
+@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor,
+ info=Step3VLProcessingInfo,
+ dummy_inputs=Step3VLDummyInputsBuilder)
+class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsPP):
+
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
+ "model.": "language_model.model.",
+ "lm_head.": "language_model.lm_head.",
+ })
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return ""
+
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ self.vision_model = Step3VisionTransformer(config.vision_config,
+ None,
+ prefix=maybe_prefix(
+ prefix, "vision_model"))
+ self.vit_downsampler = nn.Conv2d(
+ config.vision_config.hidden_size,
+ config.vision_config.output_hidden_size,
+ kernel_size=2,
+ stride=config.understand_projector_stride)
+ self.vit_downsampler2 = nn.Conv2d(
+ config.vision_config.output_hidden_size,
+ config.vision_config.output_hidden_size * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ self.vit_large_projector = nn.Linear(
+ config.vision_config.output_hidden_size * 2,
+ config.hidden_size,
+ bias=config.projector_bias,
+ )
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"))
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+
+ return get_sampler()
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Step3VLImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
+ num_patches = kwargs.pop("num_patches", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ if pixel_values is not None:
+ pixel_values = flatten_bn(pixel_values, concat=True)
+ if pixel_values.dim() >= 3:
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
+ if patch_pixel_values is not None:
+ patch_pixel_values = flatten_bn(patch_pixel_values,
+ concat=True)
+ patch_pixel_values = patch_pixel_values.view(
+ -1, *patch_pixel_values.shape[-3:])
+ # Handle empty patch_pixel_values by setting to None
+ if patch_pixel_values.shape[0] == 0:
+ patch_pixel_values = None
+ num_patches = flatten_bn(num_patches, concat=True).tolist()
+
+ return Step3VLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
+ self.device) if patch_pixel_values is not None else None,
+ num_patches=num_patches,
+ )
+
+ if image_embeds is not None:
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
+ else:
+ raise ValueError(
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
+
+ return Step3VLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
+ )
+ return None
+
+ def _process_image_features(self,
+ image_features: torch.Tensor) -> torch.Tensor:
+ B, P = image_features.shape[:2]
+ HW = int(sqrt(P))
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
+ image_features = self.vit_downsampler(image_features)
+ image_features = self.vit_downsampler2(image_features)
+ n_dim = image_features.size(1)
+ image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
+ image_features = self.vit_large_projector(image_features)
+ return image_features
+
+ def _get_vision_model_output(self,
+ input_tensor: torch.Tensor) -> torch.Tensor:
+ return self.vision_model(input_tensor)[:, 4:]
+
+ def _process_image_input(
+ self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]:
+
+ if image_input["type"] == "image_embeds":
+ image_features = image_input["image_embeds"]
+ else:
+ image_features = self._get_vision_model_output(
+ image_input["pixel_values"])
+ patch_image_features = self._get_vision_model_output(
+ image_input["patch_pixel_values"]
+ ) if image_input["patch_pixel_values"] is not None else None
+ num_patches = image_input["num_patches"]
+
+ image_features = self._process_image_features(image_features)
+ patch_image_features = self._process_image_features(
+ patch_image_features) if patch_image_features is not None else None
+
+ merged_image_features = []
+ cur_patch_idx = 0
+ for i, num_patch in enumerate(num_patches):
+ cur_feature = []
+ if num_patch > 0:
+ patch_slice = patch_image_features[
+ cur_patch_idx:cur_patch_idx + num_patch]
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
+ cur_feature.append(image_features[i].view(
+ -1, image_features.shape[-1]))
+ cur_patch_idx += num_patch
+ merged_image_features.append(
+ torch.cat(cur_feature) if len(cur_feature) >
+ 1 else cur_feature[0])
+ return merged_image_features
+
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+ vision_embeddings = self._process_image_input(image_input)
+ return vision_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ if multimodal_embeddings is None:
+ inputs_embeds = self.language_model.model.get_input_embeddings(
+ input_ids)
+ else:
+ is_text = input_ids != self.config.image_token_id
+ text_ids = input_ids[is_text]
+ text_embeds = self.language_model.model.get_input_embeddings(
+ text_ids)
+ inputs_embeds = torch.empty(input_ids.shape[0],
+ text_embeds.shape[-1],
+ dtype=text_embeds.dtype,
+ device=text_embeds.device)
+ inputs_embeds[is_text] = text_embeds
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ self.config.image_token_id)
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ # always pass the input via `inputs_embeds`
+ # to make sure the computation graph is consistent
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ hidden_states = self.language_model(input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds=inputs_embeds)
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
+ loader = AutoWeightsLoader(self)
+ loaded_weights = loader.load_weights(weights,
+ mapper=self.hf_to_vllm_mapper)
+ return loaded_weights
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index d61e4f11dfa29..1c3f78f2edbfb 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
+from .step3_reasoning_parser import Step3ReasoningParser
__all__ = [
"ReasoningParser",
@@ -18,4 +19,5 @@ __all__ = [
"Qwen3ReasoningParser",
"Glm4MoeModelReasoningParser",
"MistralReasoningParser",
+ "Step3ReasoningParser",
]
diff --git a/vllm/reasoning/step3_reasoning_parser.py b/vllm/reasoning/step3_reasoning_parser.py
new file mode 100644
index 0000000000000..f642ea977c580
--- /dev/null
+++ b/vllm/reasoning/step3_reasoning_parser.py
@@ -0,0 +1,109 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Sequence
+from typing import Optional, Union
+
+import regex as re
+from transformers import PreTrainedTokenizerBase
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaMessage)
+from vllm.logger import init_logger
+from vllm.reasoning import ReasoningParser, ReasoningParserManager
+
+logger = init_logger(__name__)
+
+
+@ReasoningParserManager.register_module("step3")
+class Step3ReasoningParser(ReasoningParser):
+ """
+ Reasoning parser for Step3 model.
+
+ The Step3 model uses token to denote the end of reasoning
+ text. This parser extracts all content before as reasoning content.
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
+ super().__init__(tokenizer)
+ self.think_end_token = ""
+
+ self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
+ re.DOTALL)
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ReasoningParser "
+ "constructor during construction.")
+
+ self.think_end_token_id = self.vocab.get(self.think_end_token)
+ if self.think_end_token_id is None:
+ raise RuntimeError(
+ "Step3 reasoning parser could not locate think end "
+ "token in the tokenizer!")
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> Union[DeltaMessage, None]:
+ """
+ Extract reasoning content from a delta message.
+ Handles streaming output where previous + delta = current.
+ Uses token IDs for faster processing.
+ For text "abcxyz":
+ - 'abc' goes to reasoning_content
+ - 'xyz' goes to content
+ """
+ # Skip single special token
+ if len(delta_token_ids
+ ) == 1 and delta_token_ids[0] == self.think_end_token_id:
+ return None
+
+ if self.think_end_token_id in delta_token_ids:
+ # in delta, extract reasoning content and remaining content
+ end_index = delta_text.find(self.think_end_token)
+ reasoning_content = delta_text[:end_index]
+ content = delta_text[end_index + len(self.think_end_token):]
+ return DeltaMessage(reasoning_content=reasoning_content,
+ content=content if content else None)
+ elif self.think_end_token_id in previous_token_ids:
+ # already seen in previous text, everything is content
+ return DeltaMessage(content=delta_text)
+ else:
+ # No seen yet, everything is reasoning
+ return DeltaMessage(reasoning_content=delta_text)
+
+ def extract_reasoning_content(
+ self, model_output: str, request: ChatCompletionRequest
+ ) -> tuple[Optional[str], Optional[str]]:
+
+ # Check if the model output contains the token
+ if self.think_end_token not in model_output:
+ # If no token, everything is reasoning content
+ return model_output, None
+ else:
+ # Find the first occurrence of
+ end_index = model_output.find(self.think_end_token)
+ reasoning_content = model_output[:end_index]
+
+ # Content after token
+ content = model_output[end_index + len(self.think_end_token):]
+
+ if len(content) == 0:
+ content = None
+
+ return reasoning_content, content
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.think_end_token_id in input_ids
+
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ if self.think_end_token_id not in input_ids[:-1]:
+ return []
+ else:
+ return input_ids[input_ids.index(self.think_end_token_id) + 1:]
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 4ce56cb3a6aac..fcaa48c1392a3 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
- RWConfig, UltravoxConfig)
+ RWConfig, Step3TextConfig,
+ Step3VLConfig, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
@@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ultravox": UltravoxConfig,
+ "step3_vl": Step3VLConfig,
+ "step3_text": Step3TextConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 7c7d859e4a325..96733da726181 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
+from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
+ Step3VisionEncoderConfig,
+ Step3VLConfig)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
@@ -42,4 +45,7 @@ __all__ = [
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
"UltravoxConfig",
+ "Step3VLConfig",
+ "Step3VisionEncoderConfig",
+ "Step3TextConfig",
]
diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py
new file mode 100644
index 0000000000000..fe3c72de69d28
--- /dev/null
+++ b/vllm/transformers_utils/configs/step3_vl.py
@@ -0,0 +1,123 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any, Optional, Union
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class Step3VisionEncoderConfig(PretrainedConfig):
+ model_type = "step3_vision_encoder"
+
+ def __init__(
+ self,
+ hidden_size=1792,
+ intermediate_size=3072,
+ output_hidden_size=4096,
+ num_hidden_layers=63,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=728,
+ patch_size=14,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.output_hidden_size = output_hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ super().__init__(**kwargs)
+
+
+class Step3TextConfig(PretrainedConfig):
+ model_type = "step3_text"
+ architectures = ["Step3TextForCausalLM"]
+
+ def __init__(
+ self,
+ hidden_size: int = 7168,
+ intermediate_size: int = 18432,
+ num_attention_heads: int = 64,
+ num_attention_groups: int = 1,
+ num_hidden_layers: int = 61,
+ max_seq_len: int = 65536,
+ vocab_size: int = 128815,
+ rms_norm_eps: float = 1e-5,
+ moe_intermediate_size: int = 5120,
+ moe_num_experts: int = 48,
+ moe_top_k: int = 3,
+ rope_theta: float = 500000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embedding: int = 65536,
+ share_expert_dim: int = 5120,
+ share_q_dim: int = 2048,
+ head_dim: int = 256,
+ norm_expert_weight: bool = False,
+ moe_layers_enum: tuple[int,
+ ...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
+ 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59),
+ **kwargs,
+ ) -> None:
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_attention_groups = num_attention_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.max_seq_len = max_seq_len
+ self.vocab_size = vocab_size
+ self.rms_norm_eps = rms_norm_eps
+ self.moe_intermediate_size = moe_intermediate_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_top_k = moe_top_k
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.max_position_embedding = max_position_embedding
+ self.share_expert_dim = share_expert_dim
+ self.share_q_dim = share_q_dim
+ self.head_dim = head_dim
+ self.norm_expert_weight = norm_expert_weight
+ self.moe_layers_enum = moe_layers_enum
+
+ super().__init__(**kwargs)
+
+
+class Step3VLConfig(PretrainedConfig):
+ model_type = "step3_vl"
+
+ def __init__(
+ self,
+ vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
+ text_config: Optional[Union[dict, Step3TextConfig]] = None,
+ understand_projector_stride: int = 1,
+ projector_bias: bool = True,
+ image_token_id: int = 128001,
+ **kwargs,
+ ) -> None:
+ if vision_config is None:
+ vision_config = Step3VisionEncoderConfig()
+ elif isinstance(vision_config, dict):
+ vision_config = Step3VisionEncoderConfig(**vision_config)
+ self.vision_config = vision_config
+
+ if text_config is None:
+ text_config = Step3TextConfig()
+ elif isinstance(text_config, dict):
+ text_config = Step3TextConfig(**text_config)
+ self.text_config = text_config
+
+ self.understand_projector_stride = understand_projector_stride
+ self.projector_bias = projector_bias
+ self.hidden_size = text_config.hidden_size
+ self.image_token_id = image_token_id
+
+ super().__init__(**kwargs)