mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 10:09:08 +08:00
[New Model] Add Seed-Oss model (#23241)
Signed-off-by: jiabin.00 <jiabin.00@bytedance.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
de9c085e17
commit
5964069367
@ -401,6 +401,7 @@ th {
|
|||||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
|
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
|
||||||
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
|
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -292,6 +292,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||||
|
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
||||||
|
trust_remote_code=True,
|
||||||
|
is_available_online=False),
|
||||||
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
|
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
|
||||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||||
|
|||||||
459
tests/tool_use/test_seed_oss_tool_parser.py
Normal file
459
tests/tool_use/test_seed_oss_tool_parser.py
Normal file
@ -0,0 +1,459 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionToolsParam,
|
||||||
|
DeltaMessage, FunctionCall,
|
||||||
|
ToolCall)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
|
||||||
|
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
|
|
||||||
|
# Use a common model that is likely to be available
|
||||||
|
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def seed_oss_tokenizer():
|
||||||
|
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def seed_oss_tool_parser(seed_oss_tokenizer):
|
||||||
|
return SeedOssToolParser(seed_oss_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tools():
|
||||||
|
return [
|
||||||
|
ChatCompletionToolsParam(
|
||||||
|
type="function",
|
||||||
|
function={
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current temperature for a given location.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description":
|
||||||
|
"City and country e.g. Bogotá, Colombia"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "this is the unit of temperature"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
"additionalProperties": False
|
||||||
|
},
|
||||||
|
"returns": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "temperature in celsius"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["temperature"],
|
||||||
|
"additionalProperties": False
|
||||||
|
},
|
||||||
|
"strict": True
|
||||||
|
}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||||
|
expected_tool_calls: list[ToolCall]):
|
||||||
|
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||||
|
|
||||||
|
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||||
|
expected_tool_calls):
|
||||||
|
# Seed-OSS tool call will not generate id
|
||||||
|
assert actual_tool_call.type == "function"
|
||||||
|
assert actual_tool_call.function == expected_tool_call.function
|
||||||
|
|
||||||
|
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||||
|
assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||||
|
model_output = "This is a test response without any tool calls"
|
||||||
|
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert not extracted_tool_calls.tools_called
|
||||||
|
assert extracted_tool_calls.tool_calls == []
|
||||||
|
assert extracted_tool_calls.content == model_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
ids=[
|
||||||
|
"tool_call_0_thinking_budget",
|
||||||
|
"tool_call_512_thinkg_budget",
|
||||||
|
"tool_call_unlimited_thinking_budget",
|
||||||
|
],
|
||||||
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
|
argvalues=[
|
||||||
|
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||||
|
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||||
|
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||||
|
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"location": "Barcelona, Spain",
|
||||||
|
}, ),
|
||||||
|
),
|
||||||
|
type='function')
|
||||||
|
],
|
||||||
|
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||||
|
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||||
|
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||||
|
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||||
|
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||||
|
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||||
|
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||||
|
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||||
|
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||||
|
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||||
|
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||||
|
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||||
|
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||||
|
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||||
|
"""\n</seed:tool_call>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"location": "Barcelona, Spain",
|
||||||
|
}, ),
|
||||||
|
),
|
||||||
|
type='function')
|
||||||
|
],
|
||||||
|
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||||
|
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||||
|
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||||
|
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||||
|
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||||
|
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||||
|
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||||
|
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||||
|
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||||
|
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||||
|
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||||
|
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||||
|
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||||
|
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||||
|
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||||
|
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||||
|
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||||
|
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||||
|
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||||
|
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||||
|
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||||
|
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||||
|
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||||
|
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||||
|
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||||
|
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_weather",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Barcelona, Spain",
|
||||||
|
"unit": "celsius",
|
||||||
|
}, ),
|
||||||
|
),
|
||||||
|
type='function')
|
||||||
|
],
|
||||||
|
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||||
|
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||||
|
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||||
|
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||||
|
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||||
|
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||||
|
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||||
|
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||||
|
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||||
|
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||||
|
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||||
|
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||||
|
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||||
|
"""temperature in Celsius.</seed:think>""",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
|
||||||
|
expected_tool_calls, expected_content):
|
||||||
|
request = ChatCompletionRequest(model=MODEL,
|
||||||
|
messages=[],
|
||||||
|
tools=sample_tools)
|
||||||
|
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||||
|
model_output, request=request) # type: ignore[arg-type]
|
||||||
|
assert extracted_tool_calls.tools_called
|
||||||
|
|
||||||
|
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||||
|
|
||||||
|
assert extracted_tool_calls.content == expected_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
|
||||||
|
model_output = "This is a test response without any tool calls"
|
||||||
|
|
||||||
|
result = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text="his is a test response",
|
||||||
|
current_text=model_output,
|
||||||
|
delta_text=" without any tool calls.",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return the delta text as content
|
||||||
|
assert result is not None
|
||||||
|
assert hasattr(result, 'content')
|
||||||
|
assert result.content == " without any tool calls."
|
||||||
|
|
||||||
|
|
||||||
|
def stream_delta_message_generator(
|
||||||
|
seed_oss_tool_parser: SeedOssToolParser,
|
||||||
|
seed_oss_tokenizer: AnyTokenizer,
|
||||||
|
model_output: str,
|
||||||
|
request: Optional[ChatCompletionRequest] = None
|
||||||
|
) -> Generator[DeltaMessage, None, None]:
|
||||||
|
all_token_ids = seed_oss_tokenizer.encode(model_output,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
|
previous_text = ""
|
||||||
|
previous_tokens = None
|
||||||
|
prefix_offset = 0
|
||||||
|
read_offset = 0
|
||||||
|
for i, delta_token in enumerate(all_token_ids):
|
||||||
|
delta_token_ids = [delta_token]
|
||||||
|
previous_token_ids = all_token_ids[:i]
|
||||||
|
current_token_ids = all_token_ids[:i + 1]
|
||||||
|
|
||||||
|
(new_tokens, delta_text, new_prefix_offset,
|
||||||
|
new_read_offset) = detokenize_incrementally(
|
||||||
|
tokenizer=seed_oss_tokenizer,
|
||||||
|
all_input_ids=current_token_ids,
|
||||||
|
prev_tokens=previous_tokens,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
spaces_between_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
|
||||||
|
delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta_text,
|
||||||
|
previous_token_ids,
|
||||||
|
current_token_ids,
|
||||||
|
delta_token_ids,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
if delta_message:
|
||||||
|
yield delta_message
|
||||||
|
|
||||||
|
previous_text = current_text
|
||||||
|
previous_tokens = (previous_tokens +
|
||||||
|
new_tokens if previous_tokens else new_tokens)
|
||||||
|
prefix_offset = new_prefix_offset
|
||||||
|
read_offset = new_read_offset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
ids=[
|
||||||
|
"tool_call_0_thinking_budget",
|
||||||
|
"tool_call_512_thinkg_budget",
|
||||||
|
"tool_call_unlimited_thinking_budget",
|
||||||
|
],
|
||||||
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
|
argvalues=[
|
||||||
|
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||||
|
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||||
|
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||||
|
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"location": "Barcelona, Spain",
|
||||||
|
}, ),
|
||||||
|
),
|
||||||
|
type='function')
|
||||||
|
],
|
||||||
|
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||||
|
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||||
|
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||||
|
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||||
|
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||||
|
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||||
|
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||||
|
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||||
|
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||||
|
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||||
|
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||||
|
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||||
|
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||||
|
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||||
|
"""\n</seed:tool_call>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"location": "Barcelona, Spain",
|
||||||
|
}, ),
|
||||||
|
),
|
||||||
|
type='function')
|
||||||
|
],
|
||||||
|
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||||
|
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||||
|
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||||
|
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||||
|
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||||
|
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||||
|
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||||
|
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||||
|
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||||
|
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||||
|
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||||
|
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||||
|
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||||
|
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||||
|
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||||
|
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||||
|
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||||
|
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||||
|
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||||
|
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||||
|
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||||
|
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||||
|
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||||
|
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||||
|
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||||
|
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_weather",
|
||||||
|
arguments=json.dumps(
|
||||||
|
{
|
||||||
|
"location": "Barcelona, Spain",
|
||||||
|
"unit": "celsius",
|
||||||
|
}, ),
|
||||||
|
),
|
||||||
|
type='function')
|
||||||
|
],
|
||||||
|
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||||
|
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||||
|
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||||
|
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||||
|
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||||
|
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||||
|
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||||
|
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||||
|
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||||
|
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||||
|
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||||
|
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||||
|
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||||
|
"""temperature in Celsius.</seed:think>""",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
|
||||||
|
sample_tools, model_output, expected_tool_calls,
|
||||||
|
expected_content):
|
||||||
|
"""Test incremental streaming behavior"""
|
||||||
|
request = ChatCompletionRequest(model=MODEL,
|
||||||
|
messages=[],
|
||||||
|
tools=sample_tools)
|
||||||
|
|
||||||
|
other_content = ''
|
||||||
|
tool_states = {} # Track state per tool index
|
||||||
|
|
||||||
|
for delta_message in stream_delta_message_generator(
|
||||||
|
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
|
||||||
|
# role should never be streamed from tool parser
|
||||||
|
assert not delta_message.role
|
||||||
|
|
||||||
|
if delta_message.content:
|
||||||
|
other_content += delta_message.content
|
||||||
|
|
||||||
|
if delta_message.tool_calls:
|
||||||
|
for tool_call in delta_message.tool_calls:
|
||||||
|
idx = tool_call.index
|
||||||
|
|
||||||
|
# Initialize state for new tool
|
||||||
|
if idx not in tool_states:
|
||||||
|
tool_states[idx] = {
|
||||||
|
"id": None,
|
||||||
|
"name": None,
|
||||||
|
"arguments": "",
|
||||||
|
"type": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# First chunk should have id, name, and type
|
||||||
|
if tool_call.id:
|
||||||
|
tool_states[idx]["id"] = tool_call.id
|
||||||
|
|
||||||
|
if tool_call.type:
|
||||||
|
assert tool_call.type == "function"
|
||||||
|
tool_states[idx]["type"] = tool_call.type
|
||||||
|
|
||||||
|
if tool_call.function:
|
||||||
|
if tool_call.function.name:
|
||||||
|
# Should only be set once
|
||||||
|
assert tool_states[idx]["name"] is None
|
||||||
|
tool_states[idx]["name"] = tool_call.function.name
|
||||||
|
|
||||||
|
if tool_call.function.arguments is not None:
|
||||||
|
# Accumulate arguments incrementally
|
||||||
|
tool_states[idx][
|
||||||
|
"arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
# Verify final content
|
||||||
|
assert other_content == expected_content
|
||||||
|
|
||||||
|
# Verify we got all expected tool calls
|
||||||
|
assert len(tool_states) == len(expected_tool_calls)
|
||||||
|
|
||||||
|
# Verify each tool call
|
||||||
|
for idx, expected_tool in enumerate(expected_tool_calls):
|
||||||
|
state = tool_states[idx]
|
||||||
|
assert state["id"] is not None
|
||||||
|
assert state["type"] == "function"
|
||||||
|
assert state["name"] == expected_tool.function.name
|
||||||
|
|
||||||
|
# Parse accumulated arguments
|
||||||
|
arguments_str = state["arguments"]
|
||||||
|
assert arguments_str is not None
|
||||||
|
actual_args = json.loads(arguments_str)
|
||||||
|
expected_args = json.loads(expected_tool.function.arguments)
|
||||||
|
assert actual_args == expected_args
|
||||||
@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
|
|||||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||||
from .pythonic_tool_parser import PythonicToolParser
|
from .pythonic_tool_parser import PythonicToolParser
|
||||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||||
|
from .seed_oss_tool_parser import SeedOssToolParser
|
||||||
from .step3_tool_parser import Step3ToolParser
|
from .step3_tool_parser import Step3ToolParser
|
||||||
from .xlam_tool_parser import xLAMToolParser
|
from .xlam_tool_parser import xLAMToolParser
|
||||||
|
|
||||||
@ -41,5 +42,6 @@ __all__ = [
|
|||||||
"HunyuanA13BToolParser",
|
"HunyuanA13BToolParser",
|
||||||
"Glm4MoeModelToolParser",
|
"Glm4MoeModelToolParser",
|
||||||
"Qwen3CoderToolParser",
|
"Qwen3CoderToolParser",
|
||||||
|
"SeedOssToolParser",
|
||||||
"Step3ToolParser",
|
"Step3ToolParser",
|
||||||
]
|
]
|
||||||
|
|||||||
676
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
Normal file
676
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
Normal file
@ -0,0 +1,676 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Adapted from qwen3coder xml parser, All rights reserved.
|
||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionToolsParam,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
|
DeltaToolCall,
|
||||||
|
ExtractedToolCallInformation,
|
||||||
|
FunctionCall, ToolCall)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("seed_oss")
|
||||||
|
class SeedOssToolParser(ToolParser):
|
||||||
|
TOOL_CALL_START = "<seed:tool_call>"
|
||||||
|
TOOL_CALL_END = "</seed:tool_call>"
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
|
# --- streaming state ---
|
||||||
|
self._reset_streaming_state()
|
||||||
|
self.prev_tool_call_arr: list[dict] = []
|
||||||
|
|
||||||
|
self.tool_call_start_token: str = self.TOOL_CALL_START
|
||||||
|
self.tool_call_end_token: str = self.TOOL_CALL_END
|
||||||
|
# Sentinel tokens for streaming mode
|
||||||
|
self.tool_call_prefix: str = "<function="
|
||||||
|
self.function_end_token: str = "</function>"
|
||||||
|
self.parameter_prefix: str = "<parameter="
|
||||||
|
self.parameter_end_token: str = "</parameter>"
|
||||||
|
self.think_start_token: str = "<seed:think>"
|
||||||
|
self.think_end_token: str = "</seed:think>"
|
||||||
|
self.is_tool_call_started: bool = False
|
||||||
|
self.is_thinking_end: bool = False
|
||||||
|
self.failed_count: int = 0
|
||||||
|
self._reset_streaming_state()
|
||||||
|
|
||||||
|
self.tool_call_start_token_id = self.vocab.get(
|
||||||
|
self.tool_call_start_token)
|
||||||
|
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||||
|
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||||
|
|
||||||
|
if (self.tool_call_start_token_id is None
|
||||||
|
or self.tool_call_end_token_id is None):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Seed_Oss XML parser: tokenizer did not include "
|
||||||
|
"<seed:tool_call> or its closing tag.")
|
||||||
|
|
||||||
|
tool_start_re = re.escape(self.tool_call_start_token)
|
||||||
|
tool_end_re = re.escape(self.tool_call_end_token)
|
||||||
|
|
||||||
|
self.tool_call_complete_regex = re.compile(
|
||||||
|
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL)
|
||||||
|
self.tool_call_regex = re.compile(
|
||||||
|
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$",
|
||||||
|
re.DOTALL)
|
||||||
|
|
||||||
|
self.tool_call_function_regex = re.compile(
|
||||||
|
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
|
||||||
|
self.tool_call_parameter_regex = re.compile(
|
||||||
|
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
|
||||||
|
|
||||||
|
logger.info("vLLM Seed-Oss XML tool parser loaded (%s).",
|
||||||
|
self.__class__.__name__)
|
||||||
|
|
||||||
|
def _generate_tool_call_id(self) -> str:
|
||||||
|
"""Generate a unique tool call ID."""
|
||||||
|
return f"call_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
def _reset_streaming_state(self):
|
||||||
|
"""Reset all streaming state."""
|
||||||
|
self.current_tool_index = 0
|
||||||
|
self.is_tool_call_started = False
|
||||||
|
self.header_sent = False
|
||||||
|
self.current_tool_id = -1
|
||||||
|
self.current_function_name = None
|
||||||
|
self.current_param_name = None
|
||||||
|
self.current_param_value = ""
|
||||||
|
self.param_count = 0
|
||||||
|
self.in_param = False
|
||||||
|
self.in_function = False
|
||||||
|
self.accumulated_text = ""
|
||||||
|
self.json_started = False
|
||||||
|
self.json_closed = False
|
||||||
|
|
||||||
|
def _parse_xml_function_call(
|
||||||
|
self, function_call_str: str,
|
||||||
|
tools: Optional[list[ChatCompletionToolsParam]]
|
||||||
|
) -> Optional[ToolCall]:
|
||||||
|
|
||||||
|
def get_arguments_config(func_name: str) -> dict:
|
||||||
|
if tools is None:
|
||||||
|
return {}
|
||||||
|
for config in tools:
|
||||||
|
if not hasattr(config, "type") or not (
|
||||||
|
hasattr(config, "function")
|
||||||
|
and hasattr(config.function, "name")):
|
||||||
|
continue
|
||||||
|
if (config.type == "function"
|
||||||
|
and config.function.name == func_name):
|
||||||
|
if not hasattr(config.function, "parameters"):
|
||||||
|
return {}
|
||||||
|
params = config.function.parameters
|
||||||
|
if isinstance(params, dict) and "properties" in params:
|
||||||
|
return params["properties"]
|
||||||
|
elif isinstance(params, dict):
|
||||||
|
return params
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
logger.warning("Tool '%s' is not defined in the tools list.",
|
||||||
|
func_name)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def convert_param_value(param_value: str, param_name: str,
|
||||||
|
param_config: dict, func_name: str) -> Any:
|
||||||
|
# Handle null value for any type
|
||||||
|
if param_value.lower() == "null":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if param_name not in param_config:
|
||||||
|
if param_config != {}:
|
||||||
|
logger.warning(
|
||||||
|
"Parsed parameter '%s' is not defined in "
|
||||||
|
"the tool parameters for tool '%s', "
|
||||||
|
"directly returning the string value.", param_name,
|
||||||
|
func_name)
|
||||||
|
return param_value
|
||||||
|
|
||||||
|
if (isinstance(param_config[param_name], dict)
|
||||||
|
and "type" in param_config[param_name]):
|
||||||
|
param_type = str(
|
||||||
|
param_config[param_name]["type"]).strip().lower()
|
||||||
|
else:
|
||||||
|
param_type = "string"
|
||||||
|
if param_type in [
|
||||||
|
"string", "str", "text", "varchar", "char", "enum"
|
||||||
|
]:
|
||||||
|
return param_value
|
||||||
|
elif (param_type.startswith("int") or param_type.startswith("uint")
|
||||||
|
or param_type.startswith("long")
|
||||||
|
or param_type.startswith("short")
|
||||||
|
or param_type.startswith("unsigned")):
|
||||||
|
try:
|
||||||
|
param_value = int(param_value) # type: ignore
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
"Parsed value '%s' of parameter '%s' is not an integer in tool "
|
||||||
|
"'%s', degenerating to string.", param_value,
|
||||||
|
param_name, func_name)
|
||||||
|
return param_value
|
||||||
|
elif param_type.startswith("num") or param_type.startswith(
|
||||||
|
"float"):
|
||||||
|
try:
|
||||||
|
float_param_value = float(param_value)
|
||||||
|
param_value = float_param_value if float_param_value - int(
|
||||||
|
float_param_value) != 0 else int(
|
||||||
|
float_param_value) # type: ignore
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
"Parsed value '%s' of parameter '%s' is not a float in tool "
|
||||||
|
"'%s', degenerating to string.", param_value,
|
||||||
|
param_name, func_name)
|
||||||
|
return param_value
|
||||||
|
elif param_type in ["boolean", "bool", "binary"]:
|
||||||
|
param_value = param_value.lower()
|
||||||
|
if param_value not in ["true", "false"]:
|
||||||
|
logger.warning(
|
||||||
|
"Parsed value '%s' of parameter '%s' is not a boolean "
|
||||||
|
"(`true` of `false`) in tool '%s', degenerating to false.",
|
||||||
|
param_value, param_name, func_name)
|
||||||
|
return param_value == "true"
|
||||||
|
else:
|
||||||
|
if param_type == "object" or param_type.startswith("dict"):
|
||||||
|
try:
|
||||||
|
param_value = json.loads(param_value)
|
||||||
|
return param_value
|
||||||
|
except (ValueError, TypeError, json.JSONDecodeError):
|
||||||
|
logger.warning(
|
||||||
|
"Parsed value '%s' of parameter '%s' is not a valid JSON "
|
||||||
|
"object in tool '%s', will try other methods to parse it.",
|
||||||
|
param_value, param_name, func_name)
|
||||||
|
try:
|
||||||
|
param_value = ast.literal_eval(param_value)
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
logger.warning(
|
||||||
|
"Parsed value '%s' of parameter '%s' cannot be converted via "
|
||||||
|
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
|
||||||
|
param_value, param_name, func_name)
|
||||||
|
return param_value
|
||||||
|
|
||||||
|
# Extract function name
|
||||||
|
end_index = function_call_str.index(">")
|
||||||
|
function_name = function_call_str[:end_index]
|
||||||
|
param_config = get_arguments_config(function_name)
|
||||||
|
parameters = function_call_str[end_index + 1:]
|
||||||
|
param_dict = {}
|
||||||
|
for match in self.tool_call_parameter_regex.findall(parameters):
|
||||||
|
match_text = match[0] if match[0] else match[1]
|
||||||
|
idx = match_text.index(">")
|
||||||
|
param_name = match_text[:idx]
|
||||||
|
param_value = str(match_text[idx + 1:])
|
||||||
|
# Remove prefix and trailing \n
|
||||||
|
if param_value.startswith("\n"):
|
||||||
|
param_value = param_value[1:]
|
||||||
|
if param_value.endswith("\n"):
|
||||||
|
param_value = param_value[:-1]
|
||||||
|
|
||||||
|
param_dict[param_name] = convert_param_value(
|
||||||
|
param_value, param_name, param_config, function_name)
|
||||||
|
return ToolCall(
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(name=function_name,
|
||||||
|
arguments=json.dumps(param_dict,
|
||||||
|
ensure_ascii=False)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_function_calls(self, model_output: str) -> list[str]:
|
||||||
|
# Find all tool calls
|
||||||
|
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||||
|
raw_tool_calls = [
|
||||||
|
match[0] if match[0] else match[1] for match in matched_ranges
|
||||||
|
]
|
||||||
|
|
||||||
|
# Back-off strategy if no tool_call tags found
|
||||||
|
if len(raw_tool_calls) == 0:
|
||||||
|
raw_tool_calls = [model_output]
|
||||||
|
|
||||||
|
raw_function_calls = []
|
||||||
|
for tool_call in raw_tool_calls:
|
||||||
|
raw_function_calls.extend(
|
||||||
|
self.tool_call_function_regex.findall(tool_call))
|
||||||
|
|
||||||
|
function_calls = [
|
||||||
|
match[0] if match[0] else match[1] for match in raw_function_calls
|
||||||
|
]
|
||||||
|
return function_calls
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
# Quick check to avoid unnecessary processing
|
||||||
|
if self.tool_call_prefix not in model_output:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
# Check if both think start and end tokens are present
|
||||||
|
if (self.think_start_token in model_output
|
||||||
|
and self.think_end_token in model_output):
|
||||||
|
# Find the position of think end token
|
||||||
|
think_end_index = model_output.find(self.think_end_token) + len(
|
||||||
|
self.think_end_token)
|
||||||
|
# Extract content after think end token
|
||||||
|
result_content = model_output[think_end_index:]
|
||||||
|
thinking_content = model_output[:think_end_index]
|
||||||
|
|
||||||
|
try:
|
||||||
|
function_calls = self._get_function_calls(result_content)
|
||||||
|
if len(function_calls) == 0:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
self._parse_xml_function_call(function_call_str, request.tools)
|
||||||
|
for function_call_str in function_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
# Populate prev_tool_call_arr for serving layer to set finish_reason
|
||||||
|
self.prev_tool_call_arr.clear() # Clear previous calls
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if tool_call:
|
||||||
|
self.prev_tool_call_arr.append({
|
||||||
|
"name":
|
||||||
|
tool_call.function.name,
|
||||||
|
"arguments":
|
||||||
|
tool_call.function.arguments,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Extract content before tool calls
|
||||||
|
tool_call_start_index = result_content.find(
|
||||||
|
self.tool_call_start_token)
|
||||||
|
tool_call_start_index = (
|
||||||
|
tool_call_start_index if tool_call_start_index >= 0 else
|
||||||
|
result_content.find(self.tool_call_prefix))
|
||||||
|
content = thinking_content + result_content[:tool_call_start_index]
|
||||||
|
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=(len(tool_calls) > 0),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=content if content else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in extracting tool call from response.")
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
# If no delta text, return None unless
|
||||||
|
# it's an EOS token after tool calls
|
||||||
|
if not delta_text:
|
||||||
|
# Check if this is an EOS token after all tool calls are complete
|
||||||
|
# We check for tool calls in the text even if is_tool_call_started
|
||||||
|
# is False because it might have been reset after processing all tools
|
||||||
|
if (delta_token_ids
|
||||||
|
and self.tool_call_end_token_id not in delta_token_ids):
|
||||||
|
# Count complete tool calls
|
||||||
|
complete_calls = len(
|
||||||
|
self.tool_call_complete_regex.findall(current_text))
|
||||||
|
|
||||||
|
# If we have completed tool calls and populated prev_tool_call_arr
|
||||||
|
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||||
|
# Check if all tool calls are closed
|
||||||
|
open_calls = current_text.count(
|
||||||
|
self.tool_call_start_token) - current_text.count(
|
||||||
|
self.tool_call_end_token)
|
||||||
|
if open_calls == 0:
|
||||||
|
# Return empty delta message to allow finish_reason processing
|
||||||
|
return DeltaMessage(content="")
|
||||||
|
elif not self.is_tool_call_started and current_text:
|
||||||
|
# This is a regular content response that's now complete
|
||||||
|
return DeltaMessage(content="")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if this is the first call (reset state if needed)
|
||||||
|
if not previous_text:
|
||||||
|
self._reset_streaming_state()
|
||||||
|
|
||||||
|
# Update accumulated text
|
||||||
|
self.accumulated_text = current_text
|
||||||
|
|
||||||
|
# Check if we need to advance to next tool
|
||||||
|
if self.json_closed and not self.in_function:
|
||||||
|
# Check if this tool call has ended
|
||||||
|
tool_ends = current_text.count(self.tool_call_end_token)
|
||||||
|
if tool_ends > self.current_tool_index:
|
||||||
|
# This tool has ended, advance to next
|
||||||
|
self.current_tool_index += 1
|
||||||
|
self.header_sent = False
|
||||||
|
self.param_count = 0
|
||||||
|
self.json_started = False
|
||||||
|
self.json_closed = False
|
||||||
|
|
||||||
|
# Check if there are more tool calls
|
||||||
|
if self.current_tool_index >= current_text.count(
|
||||||
|
self.tool_call_start_token):
|
||||||
|
# No more tool calls
|
||||||
|
self.is_tool_call_started = False
|
||||||
|
# Continue processing next tool
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if end thinking
|
||||||
|
if (not self.is_thinking_end
|
||||||
|
and (self.think_end_token_id in delta_token_ids
|
||||||
|
or self.think_end_token in delta_text)):
|
||||||
|
self.is_thinking_end = True
|
||||||
|
|
||||||
|
# If thinking hasn't ended yet, don't process any tool calls
|
||||||
|
if not self.is_thinking_end:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
# Handle normal content before tool calls
|
||||||
|
if not self.is_tool_call_started:
|
||||||
|
# Check if tool call is starting
|
||||||
|
if (self.tool_call_start_token_id in delta_token_ids
|
||||||
|
or self.tool_call_start_token in delta_text):
|
||||||
|
self.is_tool_call_started = True
|
||||||
|
# Return any content before the tool call
|
||||||
|
if self.tool_call_start_token in delta_text:
|
||||||
|
content_before = delta_text[:delta_text.index(
|
||||||
|
self.tool_call_start_token)]
|
||||||
|
if content_before:
|
||||||
|
return DeltaMessage(content=content_before)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# Check if we're between tool calls - skip whitespace
|
||||||
|
if (current_text.rstrip().endswith(self.tool_call_end_token)
|
||||||
|
and delta_text.strip() == ""):
|
||||||
|
# We just ended a tool call, skip whitespace
|
||||||
|
return None
|
||||||
|
# Normal content, no tool call
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
# Check if we're between tool calls (waiting for next one)
|
||||||
|
# Count tool calls we've seen vs processed
|
||||||
|
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||||
|
if self.current_tool_index >= tool_starts_count:
|
||||||
|
# We're past all tool calls, shouldn't be here
|
||||||
|
return None
|
||||||
|
|
||||||
|
# We're in a tool call, find the current tool call portion
|
||||||
|
# Need to find the correct tool call based on current_tool_index
|
||||||
|
# Only process tool calls after think_end_token
|
||||||
|
think_end_index = current_text.find(self.think_end_token) + len(
|
||||||
|
self.think_end_token
|
||||||
|
) if self.think_end_token in current_text else 0
|
||||||
|
tool_starts: list[int] = []
|
||||||
|
idx = think_end_index
|
||||||
|
while True:
|
||||||
|
idx = current_text.find(self.tool_call_start_token, idx)
|
||||||
|
if idx == -1:
|
||||||
|
break
|
||||||
|
tool_starts.append(idx)
|
||||||
|
idx += len(self.tool_call_start_token)
|
||||||
|
|
||||||
|
if self.current_tool_index >= len(tool_starts):
|
||||||
|
# No more tool calls to process yet
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_start_idx = tool_starts[self.current_tool_index]
|
||||||
|
# Find where this tool call ends (or current position if not ended yet)
|
||||||
|
tool_end_idx = current_text.find(self.tool_call_end_token,
|
||||||
|
tool_start_idx)
|
||||||
|
if tool_end_idx == -1:
|
||||||
|
tool_text = current_text[tool_start_idx:]
|
||||||
|
else:
|
||||||
|
tool_text = current_text[tool_start_idx:tool_end_idx +
|
||||||
|
len(self.tool_call_end_token)]
|
||||||
|
|
||||||
|
# Looking for function header
|
||||||
|
if not self.header_sent:
|
||||||
|
if self.tool_call_prefix in tool_text:
|
||||||
|
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||||
|
self.tool_call_prefix)
|
||||||
|
func_end = tool_text.find(">", func_start)
|
||||||
|
|
||||||
|
if func_end != -1:
|
||||||
|
# Found complete function name
|
||||||
|
self.current_function_name = tool_text[func_start:func_end]
|
||||||
|
self.current_tool_id = self._generate_tool_call_id(
|
||||||
|
) # type: ignore
|
||||||
|
self.header_sent = True
|
||||||
|
self.in_function = True
|
||||||
|
|
||||||
|
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
|
||||||
|
# This ensures finish_reason="tool_calls" even if parsing isn't complete
|
||||||
|
already_added = any(
|
||||||
|
tool.get("name") == self.current_function_name
|
||||||
|
for tool in self.prev_tool_call_arr)
|
||||||
|
if not already_added:
|
||||||
|
self.prev_tool_call_arr.append({
|
||||||
|
"name": self.current_function_name,
|
||||||
|
"arguments":
|
||||||
|
"{}", # Placeholder, will be updated later
|
||||||
|
})
|
||||||
|
|
||||||
|
# Send header with function info
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
id=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=self.current_function_name, arguments=""),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
return None
|
||||||
|
|
||||||
|
# We've sent header, now handle function body
|
||||||
|
if self.in_function:
|
||||||
|
# Send opening brace if not sent yet
|
||||||
|
if (not self.json_started
|
||||||
|
and self.parameter_prefix not in delta_text):
|
||||||
|
self.json_started = True
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(arguments="{"),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Make sure json_started is set if we're processing parameters
|
||||||
|
if not self.json_started:
|
||||||
|
self.json_started = True
|
||||||
|
|
||||||
|
# Check for function end in accumulated text
|
||||||
|
if not self.json_closed and self.function_end_token in tool_text:
|
||||||
|
# Close JSON
|
||||||
|
self.json_closed = True
|
||||||
|
|
||||||
|
# Extract the complete tool call to update prev_tool_call_arr with final arguments
|
||||||
|
# Find the function content
|
||||||
|
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||||
|
self.tool_call_prefix)
|
||||||
|
func_content_end = tool_text.find(self.function_end_token,
|
||||||
|
func_start)
|
||||||
|
if func_content_end != -1:
|
||||||
|
func_content = tool_text[func_start:func_content_end]
|
||||||
|
# Parse to get the complete arguments
|
||||||
|
try:
|
||||||
|
parsed_tool = self._parse_xml_function_call(
|
||||||
|
func_content, request.tools if request else None)
|
||||||
|
if parsed_tool:
|
||||||
|
# Update existing entry in prev_tool_call_arr with complete arguments
|
||||||
|
for i, tool in enumerate(self.prev_tool_call_arr):
|
||||||
|
if tool.get(
|
||||||
|
"name") == parsed_tool.function.name:
|
||||||
|
self.prev_tool_call_arr[i]["arguments"] = (
|
||||||
|
parsed_tool.function.arguments)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse tool arguments during streaming.",
|
||||||
|
exc_info=True)
|
||||||
|
|
||||||
|
result = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(arguments="}"),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Reset state for next tool
|
||||||
|
self.in_function = False
|
||||||
|
self.json_closed = True
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Look for parameters
|
||||||
|
# Count how many complete parameters we have processed
|
||||||
|
complete_params = tool_text.count(self.parameter_end_token)
|
||||||
|
|
||||||
|
# Check if we should start a new parameter
|
||||||
|
if not self.in_param and self.param_count < complete_params:
|
||||||
|
# Find the unprocessed parameter
|
||||||
|
# Count parameter starts
|
||||||
|
param_starts = []
|
||||||
|
idx = 0
|
||||||
|
while True:
|
||||||
|
idx = tool_text.find(self.parameter_prefix, idx)
|
||||||
|
if idx == -1:
|
||||||
|
break
|
||||||
|
param_starts.append(idx)
|
||||||
|
idx += len(self.parameter_prefix)
|
||||||
|
|
||||||
|
if len(param_starts) > self.param_count:
|
||||||
|
# Process the next parameter
|
||||||
|
param_idx = param_starts[self.param_count]
|
||||||
|
param_start = param_idx + len(self.parameter_prefix)
|
||||||
|
remaining = tool_text[param_start:]
|
||||||
|
|
||||||
|
if ">" in remaining:
|
||||||
|
# We have the complete parameter name
|
||||||
|
name_end = remaining.find(">")
|
||||||
|
self.current_param_name = remaining[:name_end]
|
||||||
|
|
||||||
|
# Find the parameter value
|
||||||
|
value_start = param_start + name_end + 1
|
||||||
|
value_text = tool_text[value_start:]
|
||||||
|
if value_text.startswith("\n"):
|
||||||
|
value_text = value_text[1:]
|
||||||
|
|
||||||
|
# Find where this parameter ends
|
||||||
|
param_end_idx = value_text.find(
|
||||||
|
self.parameter_end_token)
|
||||||
|
if param_end_idx != -1:
|
||||||
|
# Complete parameter found
|
||||||
|
param_value = value_text[:param_end_idx]
|
||||||
|
if param_value.endswith("\n"):
|
||||||
|
param_value = param_value[:-1]
|
||||||
|
|
||||||
|
# Build complete JSON fragment for this parameter
|
||||||
|
if self.param_count == 0:
|
||||||
|
json_fragment = (
|
||||||
|
'"' + self.current_param_name + '": "' +
|
||||||
|
json.dumps(param_value)[1:-1] + '"')
|
||||||
|
else:
|
||||||
|
json_fragment = (
|
||||||
|
', "' + self.current_param_name + '": "' +
|
||||||
|
json.dumps(param_value)[1:-1] + '"')
|
||||||
|
|
||||||
|
self.param_count += 1
|
||||||
|
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=json_fragment),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Continue parameter value
|
||||||
|
if self.in_param:
|
||||||
|
if self.parameter_end_token in delta_text:
|
||||||
|
# End of parameter
|
||||||
|
end_idx = delta_text.find(self.parameter_end_token)
|
||||||
|
value_chunk = delta_text[:end_idx]
|
||||||
|
|
||||||
|
# Skip past > if at start
|
||||||
|
if not self.current_param_value and ">" in value_chunk:
|
||||||
|
gt_idx = value_chunk.find(">")
|
||||||
|
value_chunk = value_chunk[gt_idx + 1:]
|
||||||
|
|
||||||
|
if not self.current_param_value and value_chunk.startswith(
|
||||||
|
"\n"):
|
||||||
|
value_chunk = value_chunk[1:]
|
||||||
|
|
||||||
|
# Calculate incremental JSON
|
||||||
|
full_value = self.current_param_value + value_chunk
|
||||||
|
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
|
||||||
|
if self.current_param_value else "")
|
||||||
|
full_escaped = json.dumps(full_value)[1:-1]
|
||||||
|
delta_escaped = full_escaped[len(prev_escaped):]
|
||||||
|
|
||||||
|
self.in_param = False
|
||||||
|
self.current_param_value = ""
|
||||||
|
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=delta_escaped + '"'),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Continue accumulating value
|
||||||
|
value_chunk = delta_text
|
||||||
|
|
||||||
|
# Handle first chunk after param name
|
||||||
|
if not self.current_param_value and ">" in value_chunk:
|
||||||
|
gt_idx = value_chunk.find(">")
|
||||||
|
value_chunk = value_chunk[gt_idx + 1:]
|
||||||
|
|
||||||
|
if not self.current_param_value and value_chunk.startswith(
|
||||||
|
"\n"):
|
||||||
|
value_chunk = value_chunk[1:]
|
||||||
|
|
||||||
|
if value_chunk:
|
||||||
|
# Stream the escaped delta
|
||||||
|
prev_escaped = (json.dumps(
|
||||||
|
self.current_param_value)[1:-1]
|
||||||
|
if self.current_param_value else "")
|
||||||
|
self.current_param_value += value_chunk
|
||||||
|
full_escaped = json.dumps(
|
||||||
|
self.current_param_value)[1:-1]
|
||||||
|
delta_escaped = full_escaped[len(prev_escaped):]
|
||||||
|
|
||||||
|
if delta_escaped:
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=delta_escaped),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
return None
|
||||||
@ -130,6 +130,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
||||||
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
||||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
|
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
|
||||||
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
|
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
|
||||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
|
|||||||
487
vllm/model_executor/models/seed_oss.py
Normal file
487
vllm/model_executor/models/seed_oss.py
Normal file
@ -0,0 +1,487 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2025 The Seed team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only SeedOss model compatible with HuggingFace weights."""
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig as SeedOssConfig
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionType
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
|
maybe_prefix)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SeedOssMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SeedOssAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position: int = 4096 * 32,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
rope_scaling: Optional[tuple] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=self.rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
attn_type=attn_type,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class SeedOssDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SeedOssConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
# Requires transformers > 4.32.0
|
||||||
|
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
# By default, SeedOss uses causal attention as it is a
|
||||||
|
# decoder-only model.
|
||||||
|
# You can override the HF config with `is_causal=False` to enable
|
||||||
|
# bidirectional attention, which is used in some embedding models
|
||||||
|
if getattr(config, "is_causal", True):
|
||||||
|
attn_type = AttentionType.DECODER
|
||||||
|
else:
|
||||||
|
attn_type = AttentionType.ENCODER_ONLY
|
||||||
|
|
||||||
|
self.self_attn = SeedOssAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
attn_type=attn_type,
|
||||||
|
)
|
||||||
|
self.mlp = SeedOssMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims={
|
||||||
|
"input_ids": 0,
|
||||||
|
"positions": -1,
|
||||||
|
"intermediate_tensors": 0,
|
||||||
|
"inputs_embeds": 0,
|
||||||
|
})
|
||||||
|
class SeedOssModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
# TODO (@robertgshaw2): see if this can be moved out
|
||||||
|
if (cache_config.sliding_window is not None
|
||||||
|
and hasattr(config, "max_window_layers")):
|
||||||
|
assert config.max_window_layers == config.num_hidden_layers, (
|
||||||
|
"Sliding window for some but all layers is not supported. "
|
||||||
|
"This model uses sliding window but `max_window_layers` = {} "
|
||||||
|
"is less than `num_hidden_layers` = {}. Please open an issue "
|
||||||
|
"to discuss this feature.".format(
|
||||||
|
config.max_window_layers,
|
||||||
|
config.num_hidden_layers,
|
||||||
|
))
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||||
|
and get_pp_group().is_last_rank):
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.embed_tokens",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
|
# Use the provided decoder layer type or default to SeedDecoderLayer
|
||||||
|
decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: decoder_layer_type(config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = SeedOssModel(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "lm_head"))
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
Loading…
x
Reference in New Issue
Block a user