[New Model] Add Seed-Oss model (#23241)

Signed-off-by: jiabin.00 <jiabin.00@bytedance.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Bin Jia 2025-08-22 12:58:10 +08:00 committed by GitHub
parent de9c085e17
commit 5964069367
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1629 additions and 0 deletions

View File

@ -401,6 +401,7 @@ th {
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -292,6 +292,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
is_available_online=False),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),

View File

@ -0,0 +1,459 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
@pytest.fixture(scope="module")
def seed_oss_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def seed_oss_tool_parser(seed_oss_tokenizer):
return SeedOssToolParser(seed_oss_tokenizer)
@pytest.fixture
def sample_tools():
return [
ChatCompletionToolsParam(
type="function",
function={
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description":
"City and country e.g. Bogotá, Colombia"
},
"unit": {
"type": "string",
"description": "this is the unit of temperature"
}
},
"required": ["location"],
"additionalProperties": False
},
"returns": {
"type": "object",
"properties": {
"temperature": {
"type": "number",
"description": "temperature in celsius"
}
},
"required": ["temperature"],
"additionalProperties": False
},
"strict": True
}),
]
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
# Seed-OSS tool call will not generate id
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
assert actual_tool_call.function.name == expected_tool_call.function.name
assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"tool_call_0_thinking_budget",
"tool_call_512_thinkg_budget",
"tool_call_unlimited_thinking_budget",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
),
(
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
}, ),
),
type='function')
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>""",
),
],
)
def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
expected_tool_calls, expected_content):
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=request) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
model_output = "This is a test response without any tool calls"
result = seed_oss_tool_parser.extract_tool_calls_streaming(
previous_text="his is a test response",
current_text=model_output,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert result.content == " without any tool calls."
def stream_delta_message_generator(
seed_oss_tool_parser: SeedOssToolParser,
seed_oss_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None
) -> Generator[DeltaMessage, None, None]:
all_token_ids = seed_oss_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=seed_oss_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
current_text = previous_text + delta_text
delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@pytest.mark.parametrize(
ids=[
"tool_call_0_thinking_budget",
"tool_call_512_thinkg_budget",
"tool_call_unlimited_thinking_budget",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
),
(
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
}, ),
),
type='function')
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>""",
),
],
)
def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
sample_tools, model_output, expected_tool_calls,
expected_content):
"""Test incremental streaming behavior"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
other_content = ''
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
# Initialize state for new tool
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
# Should only be set once
assert tool_states[idx]["name"] is None
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx][
"arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == expected_content
# Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls)
# Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls):
state = tool_states[idx]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
expected_args = json.loads(expected_tool.function.arguments)
assert actual_args == expected_args

View File

@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
@ -41,5 +42,6 @@ __all__ = [
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
"SeedOssToolParser",
"Step3ToolParser",
]

View File

@ -0,0 +1,676 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from qwen3coder xml parser, All rights reserved.
# ruff: noqa: E501
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("seed_oss")
class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# --- streaming state ---
self._reset_streaming_state()
self.prev_tool_call_arr: list[dict] = []
self.tool_call_start_token: str = self.TOOL_CALL_START
self.tool_call_end_token: str = self.TOOL_CALL_END
# Sentinel tokens for streaming mode
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.think_start_token: str = "<seed:think>"
self.think_end_token: str = "</seed:think>"
self.is_tool_call_started: bool = False
self.is_thinking_end: bool = False
self.failed_count: int = 0
self._reset_streaming_state()
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Seed_Oss XML parser: tokenizer did not include "
"<seed:tool_call> or its closing tag.")
tool_start_re = re.escape(self.tool_call_start_token)
tool_end_re = re.escape(self.tool_call_end_token)
self.tool_call_complete_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL)
self.tool_call_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$",
re.DOTALL)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
logger.info("vLLM Seed-Oss XML tool parser loaded (%s).",
self.__class__.__name__)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = -1
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
def _parse_xml_function_call(
self, function_call_str: str,
tools: Optional[list[ChatCompletionToolsParam]]
) -> Optional[ToolCall]:
def get_arguments_config(func_name: str) -> dict:
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function")
and hasattr(config.function, "name")):
continue
if (config.type == "function"
and config.function.name == func_name):
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.",
func_name)
return {}
def convert_param_value(param_value: str, param_name: str,
param_config: dict, func_name: str) -> Any:
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in "
"the tool parameters for tool '%s', "
"directly returning the string value.", param_name,
func_name)
return param_value
if (isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]):
param_type = str(
param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in [
"string", "str", "text", "varchar", "char", "enum"
]:
return param_value
elif (param_type.startswith("int") or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")):
try:
param_value = int(param_value) # type: ignore
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer in tool "
"'%s', degenerating to string.", param_value,
param_name, func_name)
return param_value
elif param_type.startswith("num") or param_type.startswith(
"float"):
try:
float_param_value = float(param_value)
param_value = float_param_value if float_param_value - int(
float_param_value) != 0 else int(
float_param_value) # type: ignore
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float in tool "
"'%s', degenerating to string.", param_value,
param_name, func_name)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` of `false`) in tool '%s', degenerating to false.",
param_value, param_name, func_name)
return param_value == "true"
else:
if param_type == "object" or param_type.startswith("dict"):
try:
param_value = json.loads(param_value)
return param_value
except (ValueError, TypeError, json.JSONDecodeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a valid JSON "
"object in tool '%s', will try other methods to parse it.",
param_value, param_name, func_name)
try:
param_value = ast.literal_eval(param_value)
except (ValueError, SyntaxError):
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
param_value, param_name, func_name)
return param_value
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = get_arguments_config(function_name)
parameters = function_call_str[end_index + 1:]
param_dict = {}
for match in self.tool_call_parameter_regex.findall(parameters):
match_text = match[0] if match[0] else match[1]
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1:])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = convert_param_value(
param_value, param_name, param_config, function_name)
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(param_dict,
ensure_ascii=False)),
)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
matched_ranges = self.tool_call_regex.findall(model_output)
raw_tool_calls = [
match[0] if match[0] else match[1] for match in matched_ranges
]
# Back-off strategy if no tool_call tags found
if len(raw_tool_calls) == 0:
raw_tool_calls = [model_output]
raw_function_calls = []
for tool_call in raw_tool_calls:
raw_function_calls.extend(
self.tool_call_function_regex.findall(tool_call))
function_calls = [
match[0] if match[0] else match[1] for match in raw_function_calls
]
return function_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Quick check to avoid unnecessary processing
if self.tool_call_prefix not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
# Check if both think start and end tokens are present
if (self.think_start_token in model_output
and self.think_end_token in model_output):
# Find the position of think end token
think_end_index = model_output.find(self.think_end_token) + len(
self.think_end_token)
# Extract content after think end token
result_content = model_output[think_end_index:]
thinking_content = model_output[:think_end_index]
try:
function_calls = self._get_function_calls(result_content)
if len(function_calls) == 0:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools)
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
self.prev_tool_call_arr.append({
"name":
tool_call.function.name,
"arguments":
tool_call.function.arguments,
})
# Extract content before tool calls
tool_call_start_index = result_content.find(
self.tool_call_start_token)
tool_call_start_index = (
tool_call_start_index if tool_call_start_index >= 0 else
result_content.find(self.tool_call_prefix))
content = thinking_content + result_content[:tool_call_start_index]
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# If no delta text, return None unless
# it's an EOS token after tool calls
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all tools
if (delta_token_ids
and self.tool_call_end_token_id not in delta_token_ids):
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text))
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token) - current_text.count(
self.tool_call_end_token)
if open_calls == 0:
# Return empty delta message to allow finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if this is the first call (reset state if needed)
if not previous_text:
self._reset_streaming_state()
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
# Check if there are more tool calls
if self.current_tool_index >= current_text.count(
self.tool_call_start_token):
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
return None
# Check if end thinking
if (not self.is_thinking_end
and (self.think_end_token_id in delta_token_ids
or self.think_end_token in delta_text)):
self.is_thinking_end = True
# If thinking hasn't ended yet, don't process any tool calls
if not self.is_thinking_end:
return DeltaMessage(content=delta_text)
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if (self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[:delta_text.index(
self.tool_call_start_token)]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
# Only process tool calls after think_end_token
think_end_index = current_text.find(self.think_end_token) + len(
self.think_end_token
) if self.think_end_token in current_text else 0
tool_starts: list[int] = []
idx = think_end_index
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_starts.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_starts):
# No more tool calls to process yet
return None
tool_start_idx = tool_starts[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token,
tool_start_idx)
if tool_end_idx == -1:
tool_text = current_text[tool_start_idx:]
else:
tool_text = current_text[tool_start_idx:tool_end_idx +
len(self.tool_call_end_token)]
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_id = self._generate_tool_call_id(
) # type: ignore
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
# This ensures finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr)
if not already_added:
self.prev_tool_call_arr.append({
"name": self.current_function_name,
"arguments":
"{}", # Placeholder, will be updated later
})
# Send header with function info
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""),
type="function",
)
])
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if (not self.json_started
and self.parameter_prefix not in delta_text):
self.json_started = True
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
])
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract the complete tool call to update prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_content_end = tool_text.find(self.function_end_token,
func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content, request.tools if request else None)
if parsed_tool:
# Update existing entry in prev_tool_call_arr with complete arguments
for i, tool in enumerate(self.prev_tool_call_arr):
if tool.get(
"name") == parsed_tool.function.name:
self.prev_tool_call_arr[i]["arguments"] = (
parsed_tool.function.arguments)
break
except Exception:
logger.warning(
"Failed to parse tool arguments during streaming.",
exc_info=True)
result = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
])
# Reset state for next tool
self.in_function = False
self.json_closed = True
return result
# Look for parameters
# Count how many complete parameters we have processed
complete_params = tool_text.count(self.parameter_end_token)
# Check if we should start a new parameter
if not self.in_param and self.param_count < complete_params:
# Find the unprocessed parameter
# Count parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
if len(param_starts) > self.param_count:
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(
self.parameter_end_token)
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Build complete JSON fragment for this parameter
if self.param_count == 0:
json_fragment = (
'"' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
else:
json_fragment = (
', "' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
self.param_count += 1
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=json_fragment),
)
])
# Continue parameter value
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
# Calculate incremental JSON
full_value = self.current_param_value + value_chunk
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
if self.current_param_value else "")
full_escaped = json.dumps(full_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
self.in_param = False
self.current_param_value = ""
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped + '"'),
)
])
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (json.dumps(
self.current_param_value)[1:-1]
if self.current_param_value else "")
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
if delta_escaped:
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped),
)
])
return None

View File

@ -130,6 +130,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),

View File

@ -0,0 +1,487 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Seed team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only SeedOss model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig as SeedOssConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class SeedOssMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class SeedOssAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
self.head_dim = head_dim
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class SeedOssDecoderLayer(nn.Module):
def __init__(
self,
config: SeedOssConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
# By default, SeedOss uses causal attention as it is a
# decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = SeedOssAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = SeedOssMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class SeedOssModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
"is less than `num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to SeedDecoderLayer
decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = SeedOssModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)