From 5964069367a7d54c3816ce3faba79e02110cde17 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 22 Aug 2025 12:58:10 +0800 Subject: [PATCH 01/38] [New Model] Add Seed-Oss model (#23241) Signed-off-by: jiabin.00 Signed-off-by: Jee Jee Li Co-authored-by: Jee Jee Li --- docs/models/supported_models.md | 1 + tests/models/registry.py | 3 + tests/tool_use/test_seed_oss_tool_parser.py | 459 ++++++++++++ .../openai/tool_parsers/__init__.py | 2 + .../tool_parsers/seed_oss_tool_parser.py | 676 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/seed_oss.py | 487 +++++++++++++ 7 files changed, 1629 insertions(+) create mode 100644 tests/tool_use/test_seed_oss_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py create mode 100644 vllm/model_executor/models/seed_oss.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index ad3db1cf2100f..297d98142b5f2 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -401,6 +401,7 @@ th { | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | | `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 4871ade231044..4035319b45ce4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -292,6 +292,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), + "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 + trust_remote_code=True, + is_available_online=False), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py new file mode 100644 index 0000000000000..d85bc9bbf1b30 --- /dev/null +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from collections.abc import Generator +from typing import Optional + +import pytest + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +# Use a common model that is likely to be available +MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" + + +@pytest.fixture(scope="module") +def seed_oss_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def seed_oss_tool_parser(seed_oss_tokenizer): + return SeedOssToolParser(seed_oss_tokenizer) + + +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_weather", + "description": "Get current temperature for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": + "City and country e.g. Bogotá, Colombia" + }, + "unit": { + "type": "string", + "description": "this is the unit of temperature" + } + }, + "required": ["location"], + "additionalProperties": False + }, + "returns": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "description": "temperature in celsius" + } + }, + "required": ["temperature"], + "additionalProperties": False + }, + "strict": True + }), + ] + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + # Seed-OSS tool call will not generate id + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + assert actual_tool_call.function.name == expected_tool_call.function.name + assert actual_tool_call.function.arguments == expected_tool_call.function.arguments + + +def test_extract_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""\n\n\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n\n""" + """\n\n""" + """Barcelona, Spain\n\n""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """\n\n\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n\n""" + ), + ( + """The user\'s current thinking budget is 512.\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" + """\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ + """use.\n The unit parameter can be omitted since it\'s optional.\n""" + """\n\nBarcelona, Spain\n""" + """\n""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """The user\'s current thinking budget is 512.\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" + """\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ + """use.\n The unit parameter can be omitted since it\'s optional.\n""", + ), + ( + """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is \n\nBarcelona, """ + """Spain\ncelsius\n\n. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.\n\n""" + """Barcelona, Spain\ncelsius\n\n""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is \n\nBarcelona, """ + """Spain\ncelsius\n\n. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.""", + ), + ], +) +def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output, + expected_tool_calls, expected_content): + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=request) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_streaming_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + + result = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text="his is a test response", + current_text=model_output, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." + + +def stream_delta_message_generator( + seed_oss_tool_parser: SeedOssToolParser, + seed_oss_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None +) -> Generator[DeltaMessage, None, None]: + all_token_ids = seed_oss_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=seed_oss_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""\n\n\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n\n""" + """\n\n""" + """Barcelona, Spain\n\n""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """\n\n\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n\n""" + ), + ( + """The user\'s current thinking budget is 512.\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" + """\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ + """use.\n The unit parameter can be omitted since it\'s optional.\n""" + """\n\nBarcelona, Spain\n""" + """\n""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """The user\'s current thinking budget is 512.\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" + """\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ + """use.\n The unit parameter can be omitted since it\'s optional.\n""", + ), + ( + """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is \n\nBarcelona, """ + """Spain\ncelsius\n\n. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.\n\n""" + """Barcelona, Spain\ncelsius\n\n""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is \n\nBarcelona, """ + """Spain\ncelsius\n\n. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.""", + ), + ], +) +def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, + sample_tools, model_output, expected_tool_calls, + expected_content): + """Test incremental streaming behavior""" + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} # Track state per tool index + + for delta_message in stream_delta_message_generator( + seed_oss_tool_parser, seed_oss_tokenizer, model_output, request): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + # First chunk should have id, name, and type + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + # Should only be set once + assert tool_states[idx]["name"] is None + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + # Accumulate arguments incrementally + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify final content + assert other_content == expected_content + + # Verify we got all expected tool calls + assert len(tool_states) == len(expected_tool_calls) + + # Verify each tool call + for idx, expected_tool in enumerate(expected_tool_calls): + state = tool_states[idx] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == expected_tool.function.name + + # Parse accumulated arguments + arguments_str = state["arguments"] + assert arguments_str is not None + actual_args = json.loads(arguments_str) + expected_args = json.loads(expected_tool.function.arguments) + assert actual_args == expected_args diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 099e456aa486f..468c3799bd1f8 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .seed_oss_tool_parser import SeedOssToolParser from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser @@ -41,5 +42,6 @@ __all__ = [ "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "SeedOssToolParser", "Step3ToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py new file mode 100644 index 0000000000000..69cf2e68f7c41 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -0,0 +1,676 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from qwen3coder xml parser, All rights reserved. +# ruff: noqa: E501 + +import ast +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("seed_oss") +class SeedOssToolParser(ToolParser): + TOOL_CALL_START = "" + TOOL_CALL_END = "" + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # --- streaming state --- + self._reset_streaming_state() + self.prev_tool_call_arr: list[dict] = [] + + self.tool_call_start_token: str = self.TOOL_CALL_START + self.tool_call_end_token: str = self.TOOL_CALL_END + # Sentinel tokens for streaming mode + self.tool_call_prefix: str = " or its closing tag.") + + tool_start_re = re.escape(self.tool_call_start_token) + tool_end_re = re.escape(self.tool_call_end_token) + + self.tool_call_complete_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL) + self.tool_call_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", + re.DOTALL) + + self.tool_call_function_regex = re.compile( + r"|| str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_id = -1 + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: + + def get_arguments_config(func_name: str) -> dict: + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not ( + hasattr(config, "function") + and hasattr(config.function, "name")): + continue + if (config.type == "function" + and config.function.name == func_name): + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def convert_param_value(param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in " + "the tool parameters for tool '%s', " + "directly returning the string value.", param_name, + func_name) + return param_value + + if (isinstance(param_config[param_name], dict) + and "type" in param_config[param_name]): + param_type = str( + param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in [ + "string", "str", "text", "varchar", "char", "enum" + ]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned")): + try: + param_value = int(param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an integer in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type.startswith("num") or param_type.startswith( + "float"): + try: + float_param_value = float(param_value) + param_value = float_param_value if float_param_value - int( + float_param_value) != 0 else int( + float_param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` of `false`) in tool '%s', degenerating to false.", + param_value, param_name, func_name) + return param_value == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + param_value = json.loads(param_value) + return param_value + except (ValueError, TypeError, json.JSONDecodeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a valid JSON " + "object in tool '%s', will try other methods to parse it.", + param_value, param_name, func_name) + try: + param_value = ast.literal_eval(param_value) + except (ValueError, SyntaxError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be converted via " + "Python `ast.literal_eval()` in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = get_arguments_config(function_name) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = convert_param_value( + param_value, param_name, param_config, function_name) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(param_dict, + ensure_ascii=False)), + ) + + def _get_function_calls(self, model_output: str) -> list[str]: + # Find all tool calls + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + match[0] if match[0] else match[1] for match in matched_ranges + ] + + # Back-off strategy if no tool_call tags found + if len(raw_tool_calls) == 0: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tool_call in raw_tool_calls: + raw_function_calls.extend( + self.tool_call_function_regex.findall(tool_call)) + + function_calls = [ + match[0] if match[0] else match[1] for match in raw_function_calls + ] + return function_calls + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # Quick check to avoid unnecessary processing + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # Check if both think start and end tokens are present + if (self.think_start_token in model_output + and self.think_end_token in model_output): + # Find the position of think end token + think_end_index = model_output.find(self.think_end_token) + len( + self.think_end_token) + # Extract content after think end token + result_content = model_output[think_end_index:] + thinking_content = model_output[:think_end_index] + + try: + function_calls = self._get_function_calls(result_content) + if len(function_calls) == 0: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_calls = [ + self._parse_xml_function_call(function_call_str, request.tools) + for function_call_str in function_calls + ] + + # Populate prev_tool_call_arr for serving layer to set finish_reason + self.prev_tool_call_arr.clear() # Clear previous calls + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) + + # Extract content before tool calls + tool_call_start_index = result_content.find( + self.tool_call_start_token) + tool_call_start_index = ( + tool_call_start_index if tool_call_start_index >= 0 else + result_content.find(self.tool_call_prefix)) + content = thinking_content + result_content[:tool_call_start_index] + + return ExtractedToolCallInformation( + tools_called=(len(tool_calls) > 0), + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + # If no delta text, return None unless + # it's an EOS token after tool calls + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + # We check for tool calls in the text even if is_tool_call_started + # is False because it might have been reset after processing all tools + if (delta_token_ids + and self.tool_call_end_token_id not in delta_token_ids): + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text)) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) + if open_calls == 0: + # Return empty delta message to allow finish_reason processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if this is the first call (reset state if needed) + if not previous_text: + self._reset_streaming_state() + + # Update accumulated text + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + + # Check if there are more tool calls + if self.current_tool_index >= current_text.count( + self.tool_call_start_token): + # No more tool calls + self.is_tool_call_started = False + # Continue processing next tool + return None + + # Check if end thinking + if (not self.is_thinking_end + and (self.think_end_token_id in delta_token_ids + or self.think_end_token in delta_text)): + self.is_thinking_end = True + + # If thinking hasn't ended yet, don't process any tool calls + if not self.is_thinking_end: + return DeltaMessage(content=delta_text) + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if (self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text): + self.is_tool_call_started = True + # Return any content before the tool call + if self.tool_call_start_token in delta_text: + content_before = delta_text[:delta_text.index( + self.tool_call_start_token)] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if (current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == ""): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + # Count tool calls we've seen vs processed + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # We're in a tool call, find the current tool call portion + # Need to find the correct tool call based on current_tool_index + # Only process tool calls after think_end_token + think_end_index = current_text.find(self.think_end_token) + len( + self.think_end_token + ) if self.think_end_token in current_text else 0 + tool_starts: list[int] = [] + idx = think_end_index + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_starts.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_starts): + # No more tool calls to process yet + return None + + tool_start_idx = tool_starts[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + tool_end_idx = current_text.find(self.tool_call_end_token, + tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[tool_start_idx:tool_end_idx + + len(self.tool_call_end_token)] + + # Looking for function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + self.current_function_name = tool_text[func_start:func_end] + self.current_tool_id = self._generate_tool_call_id( + ) # type: ignore + self.header_sent = True + self.in_function = True + + # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call + # This ensures finish_reason="tool_calls" even if parsing isn't complete + already_added = any( + tool.get("name") == self.current_function_name + for tool in self.prev_tool_call_arr) + if not already_added: + self.prev_tool_call_arr.append({ + "name": self.current_function_name, + "arguments": + "{}", # Placeholder, will be updated later + }) + + # Send header with function info + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments=""), + type="function", + ) + ]) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if (not self.json_started + and self.parameter_prefix not in delta_text): + self.json_started = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ]) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.function_end_token in tool_text: + # Close JSON + self.json_closed = True + + # Extract the complete tool call to update prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_content_end = tool_text.find(self.function_end_token, + func_start) + if func_content_end != -1: + func_content = tool_text[func_start:func_content_end] + # Parse to get the complete arguments + try: + parsed_tool = self._parse_xml_function_call( + func_content, request.tools if request else None) + if parsed_tool: + # Update existing entry in prev_tool_call_arr with complete arguments + for i, tool in enumerate(self.prev_tool_call_arr): + if tool.get( + "name") == parsed_tool.function.name: + self.prev_tool_call_arr[i]["arguments"] = ( + parsed_tool.function.arguments) + break + except Exception: + logger.warning( + "Failed to parse tool arguments during streaming.", + exc_info=True) + + result = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ]) + + # Reset state for next tool + self.in_function = False + self.json_closed = True + + return result + + # Look for parameters + # Count how many complete parameters we have processed + complete_params = tool_text.count(self.parameter_end_token) + + # Check if we should start a new parameter + if not self.in_param and self.param_count < complete_params: + # Find the unprocessed parameter + # Count parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + if len(param_starts) > self.param_count: + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find( + self.parameter_end_token) + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Build complete JSON fragment for this parameter + if self.param_count == 0: + json_fragment = ( + '"' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + else: + json_fragment = ( + ', "' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value + if self.in_param: + if self.parameter_end_token in delta_text: + # End of parameter + end_idx = delta_text.find(self.parameter_end_token) + value_chunk = delta_text[:end_idx] + + # Skip past > if at start + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + # Calculate incremental JSON + full_value = self.current_param_value + value_chunk + prev_escaped = (json.dumps(self.current_param_value)[1:-1] + if self.current_param_value else "") + full_escaped = json.dumps(full_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + self.in_param = False + self.current_param_value = "" + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"'), + ) + ]) + else: + # Continue accumulating value + value_chunk = delta_text + + # Handle first chunk after param name + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + if value_chunk: + # Stream the escaped delta + prev_escaped = (json.dumps( + self.current_param_value)[1:-1] + if self.current_param_value else "") + self.current_param_value += value_chunk + full_escaped = json.dumps( + self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + if delta_escaped: + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped), + ) + ]) + + return None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 28d7e93af91a9..465c25f094806 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -130,6 +130,7 @@ _TEXT_GENERATION_MODELS = { "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py new file mode 100644 index 0000000000000..34a87a6a69a39 --- /dev/null +++ b/vllm/model_executor/models/seed_oss.py @@ -0,0 +1,487 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Seed team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only SeedOss model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig as SeedOssConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class SeedOssMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SeedOssAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + self.head_dim = head_dim + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class SeedOssDecoderLayer(nn.Module): + + def __init__( + self, + config: SeedOssConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, SeedOss uses causal attention as it is a + # decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = SeedOssAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = SeedOssMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class SeedOssModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + assert config.max_window_layers == config.num_hidden_layers, ( + "Sliding window for some but all layers is not supported. " + "This model uses sliding window but `max_window_layers` = {} " + "is less than `num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to SeedDecoderLayer + decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = SeedOssModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) From 17373dcd93ca60554d72cef4e159e70abbfd15af Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 21 Aug 2025 22:05:59 -0700 Subject: [PATCH 02/38] [Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154) Signed-off-by: Chen Zhang --- tests/v1/worker/test_gpu_model_runner.py | 11 +- .../layers/chunked_local_attention.py | 29 +-- .../layers/encoder_only_attention.py | 86 +++++++ vllm/model_executor/models/bert.py | 17 +- vllm/model_executor/models/bert_with_rope.py | 17 +- vllm/model_executor/models/llama.py | 6 +- vllm/model_executor/models/modernbert.py | 14 +- vllm/model_executor/models/qwen2.py | 5 +- vllm/v1/attention/backends/utils.py | 32 +-- vllm/v1/kv_cache_interface.py | 8 + vllm/v1/worker/gpu_model_runner.py | 211 ++++++------------ vllm/v1/worker/utils.py | 4 + 12 files changed, 226 insertions(+), 214 deletions(-) create mode 100644 vllm/attention/layers/encoder_only_attention.py diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 4bcc63f293e03..b9b2314ce573f 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) + kv_cache_config_after_init = runner.kv_cache_config layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache[0] @@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert id(layer_1_kv) == id(layer_0_kv) # check layer 1 added to kv cache group's layer names - assert len(kv_cache_config.kv_cache_groups) == 1 - assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 - assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + assert len(kv_cache_config_after_init.kv_cache_groups) == 1 + assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 0] == layer_0 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 1] == layer_1 def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 892077ba91e07..087c5004bde06 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -6,12 +6,13 @@ from typing import List, Optional import torch from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, QuantizationConfig from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, make_local_attention_virtual_batches, - subclass_attention_backend, subclass_attention_metadata_builder) + subclass_attention_backend) from ..layer import Attention @@ -24,21 +25,23 @@ def create_chunked_local_attention_backend( ) -> type[AttentionBackend]: prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" - def build_preprocess_fn(cm: CommonAttentionMetadata): - return make_local_attention_virtual_batches(attention_chunk_size, cm, - block_size) + underlying_builder = underlying_attn_backend.get_builder_cls() + + class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + common_attn_metadata = make_local_attention_virtual_batches( + attention_chunk_size, common_attn_metadata, block_size) + return super().build(common_prefix_len, common_attn_metadata, + fast_build) - # Dynamically create a new attention backend that wraps the - # underlying attention backend but applies - # `make_local_attention_virtual_batches` before calling `build(...)` - builder_cls = subclass_attention_metadata_builder( - name_prefix=prefix, - builder_cls=underlying_attn_backend.get_builder_cls(), - build_preprocess_fn=build_preprocess_fn) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=builder_cls) + builder_cls=ChunkedLocalAttentionBuilder) return attn_backend diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py new file mode 100644 index 0000000000000..7b3dcbd823c06 --- /dev/null +++ b/vllm/attention/layers/encoder_only_attention.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import torch +from transformers import CacheConfig + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) + + +@functools.lru_cache +def create_encoder_only_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "EncoderOnlyAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata = copy(common_attn_metadata) + new_common_attn_metadata.causal = False + return super().build(common_prefix_len, new_common_attn_metadata, + fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=EncoderOnlyAttentionBuilder) + + return attn_backend + + +class EncoderOnlyAttention(Attention): + """ + Encoder attention is a special case that doesn't need a KV Cache. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_encoder_only_attention_backend( + underlying_attn_backend) + else: + # in v0 encoder only attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_ONLY, \ + "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_ONLY, + **kwargs) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 2bd5eb5bb7aa8..22b6c4401213c 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import BertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.qkv_proj") - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index e18b7b7ffabab..129450927e564 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -119,14 +119,13 @@ class BertWithRopeAttention(nn.Module): self.rotary_emb = get_rope(**rotary_kwargs) - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") self.out_proj = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 24cd448d8361f..f99f1c3643fd4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -31,6 +31,7 @@ from torch import nn from transformers import LlamaConfig from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -173,7 +174,10 @@ class LlamaAttention(nn.Module): if is_sliding: sliding_window = config.sliding_window - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index c6e84e2d4e040..72290bf2ee29f 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import ModernBertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -104,12 +104,12 @@ class ModernBertAttention(nn.Module): head_size=self.head_dim, dim=self.head_dim, base=rope_theta) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - prefix=f"{layer_id}.attn", - attn_type=AttentionType.ENCODER_ONLY, - per_layer_sliding_window=sliding_window) + self.attn = EncoderOnlyAttention( + self.num_heads, + self.head_dim, + self.scaling, + prefix=f"{layer_id}.attn", + per_layer_sliding_window=sliding_window) self.Wo = RowParallelLinear(config.hidden_size, config.hidden_size, bias=config.attention_bias) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7304fbf120ccd..b6a1d2db303c7 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -32,6 +32,7 @@ from torch import nn from transformers import Qwen2Config from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module): rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 57c4d436c5b6b..39bdbe125635b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,7 @@ import enum import functools from abc import abstractmethod from dataclasses import dataclass, make_dataclass -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, - TypeVar) +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar import numpy as np import torch @@ -543,35 +542,6 @@ def make_local_attention_virtual_batches( ) -def subclass_attention_metadata_builder( - name_prefix: str, - builder_cls: type[AttentionMetadataBuilder[M]], - build_preprocess_fn: Callable[[CommonAttentionMetadata], - CommonAttentionMetadata], -) -> type[AttentionMetadataBuilder[M]]: - """ - Return a new subclass of `builder_cls` whose .build(...) method - first calls build_preprocess_fn(common_attn_metadata) on the metadata. - """ - name: str = name_prefix + builder_cls.__name__ # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False): - return builder_cls.build(self, common_prefix_len, - build_preprocess_fn(common_attn_metadata), - fast_build) - - Wrapped = type( - name, - (builder_cls, ), # inherit from the original - { - "build": build, - }) - return Wrapped # type: ignore - - def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], builder_cls: type[AttentionMetadataBuilder[M]] diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 429416afa2483..ed8e0bf798988 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -203,6 +203,14 @@ class MambaSpec(KVCacheSpec): return self.page_size_bytes +@dataclass(frozen=True) +class EncoderOnlyAttentionSpec(AttentionSpec): + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # Encoder-only layers do not need KV cache + return 0 + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 870aca41ec2ab..d520b71de3ff9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,6 +8,7 @@ import time from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager +from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np @@ -62,9 +63,10 @@ from vllm.v1.attention.backends.utils import ( from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, - SlidingWindowSpec) + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata @@ -136,7 +138,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self.is_pooling_model = model_config.pooler_config is not None - self.is_encoder_only_model = False self.is_multimodal_raw_input_supported = ( model_config.is_multimodal_raw_input_supported) self.max_model_len = model_config.max_model_len @@ -345,6 +346,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.reorder_batch_threshold: Optional[int] = None + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + # Cached outputs. self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None @@ -834,23 +840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata: dict[str, Any] = {} - # Prepare encoder attention metadata separately - # (encoder layers are not in KV cache groups) - if self.is_encoder_only_model: - - per_layer_metadata = \ - self._build_encoder_only_attn_metadata( - scheduler_output) - - # Add encoder attention metadata for all encoder layers - attention_layers = get_layers_from_vllm_config( - self.vllm_config, Attention) - for layer_name, attn_module in attention_layers.items(): - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - common_attn_metadata, encoder_attn_metadata =\ - per_layer_metadata[layer_name] - attn_metadata[layer_name] = encoder_attn_metadata - # Used in the below loop. query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1] seq_lens_cpu = self.seq_lens_cpu[:num_reqs] @@ -863,13 +852,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + pin_memory=self.pin_memory, + device="cpu").to(self.device, non_blocking=True) + slot_mapping = torch.zeros((total_num_scheduled_tokens, ), + dtype=torch.int32, + pin_memory=self.pin_memory, + device="cpu").to(self.device, + non_blocking=True) + num_common_prefix_blocks = 0 + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] + slot_mapping = blk_table.slot_mapping[: + total_num_scheduled_tokens] - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = ( + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id]) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -897,8 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], + num_common_prefix_blocks, kv_cache_group_spec.kv_cache_spec, builder, ) @@ -2812,49 +2820,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Calculate reorder batch threshold (if neeeded) self.calculate_reorder_batch_threshold() - if len(self.attn_groups) > 0: - return - - # Check if model is encoder-only - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) - for layer_name, attn_module in attn_layers.items(): - - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - if attn_module.sliding_window is None: - attn_spec: AttentionSpec = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - else: - attn_spec = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) - attn_specs[attn_spec].append(layer_name) - - else: - raise ValueError("Expected only encoder-only layers") - - if len(attn_specs) > 0: - total_layers = 0 - for attn_spec, layer_names in attn_specs.items(): - - attn_backends = get_attn_backends_for_layers(layer_names) - total_layers += len(layer_names) - - self.attn_groups.append( - create_attn_groups(attn_backends, attn_spec)) - assert total_layers == len(attn_layers), \ - "All or none of the layers are expected to be encoder-only" - self.is_encoder_only_model = True - def initialize_cudagraph_capture(self) -> None: min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None @@ -3002,7 +2967,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): layer_names = set() for group in kv_cache_config.kv_cache_groups: - layer_names.update(group.layer_names) + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) assert layer_names == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors @@ -3040,6 +3008,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): attn_backend = group.backend for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // @@ -3161,6 +3131,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config.kv_cache_groups, kv_caches, self.attn_groups, + self.runner_only_attn_layers, ) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) @@ -3185,8 +3156,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -3199,6 +3172,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -3287,70 +3287,3 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mamba_type=mamba_module.mamba_type) return kv_cache_spec - - def _build_encoder_only_attn_metadata( - self, scheduler_output: "SchedulerOutput") -> \ - dict[str, tuple[CommonAttentionMetadata, Any]]: - """Prepare encoder attention metadata for encoder-only models. - - Args: - scheduler_output: Scheduler output - - Returns: - dict[str, Any]: Encoder attention metadata - """ - num_reqs = self.input_batch.num_reqs - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - max_num_scheduled_tokens = max(tokens) - - dummy_block_table = torch.zeros((num_reqs, 1), - dtype=torch.int32, - pin_memory=self.pin_memory, - device="cpu").to(self.device, - non_blocking=True) - dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), - dtype=torch.int32, - pin_memory=self.pin_memory, - device="cpu").to(self.device, - non_blocking=True) - - group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]() - - for attn_group_list in self.attn_groups: - - assert len(attn_group_list) == 1 - attn_group = attn_group_list[0] - - # Use the first attention metadata builder - # to create encoder attention metadata - builder = attn_group.metadata_builder - - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(), - block_table_tensor=dummy_block_table, - slot_mapping=dummy_slot_mapping, - causal=False, - ) - - metadata = builder.build( - common_prefix_len=0, # No cascade for encoder - common_attn_metadata=common_metadata, - ) - - for layer_name in attn_group.layer_names: - group_metadata[layer_name] = (common_metadata, metadata) - - return group_metadata diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index c7ccd2e254976..ffc1a11bc3ba1 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -204,6 +204,7 @@ def initialize_kv_cache_for_kv_sharing( kv_caches: dict[str, torch.Tensor], # Optional for now to avoid breaking TPU attn_groups: Optional[list[list[AttentionGroup]]] = None, + runner_only_attn_layers: Optional[set[str]] = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -250,6 +251,9 @@ def initialize_kv_cache_for_kv_sharing( attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append( layer_name) + if runner_only_attn_layers is not None: + runner_only_attn_layers.add(layer_name) + def bind_kv_cache( kv_caches: dict[str, torch.Tensor], From 53415653ff24be03e7c90f5b42ef9cb3f72aad71 Mon Sep 17 00:00:00 2001 From: Flora Feng <4florafeng@gmail.com> Date: Thu, 21 Aug 2025 22:30:48 -0700 Subject: [PATCH 03/38] [P/D][Nixl] Make kv cache register compatible with hybrid memory allocator (#23079) Signed-off-by: sfeng33 <4florafeng@gmail.com> --- .../kv_connector/unit/test_nixl_connector.py | 86 +++++++++- .../kv_transfer/kv_connector/v1/base.py | 4 +- .../kv_connector/v1/nixl_connector.py | 155 +++++++----------- 3 files changed, 150 insertions(+), 95 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e6859ea738277..040b44dc5d2ca 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -14,6 +14,7 @@ from unittest.mock import patch import pytest import ray +import torch from vllm import LLM from vllm.config import KVTransferConfig @@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorWorker) from vllm.forward_context import ForwardContext from vllm.sampling_params import SamplingParams +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from .utils import create_request, create_scheduler, create_vllm_config @@ -98,7 +100,6 @@ class FakeNixlWrapper: def set_cycles_before_xfer_done(self, cycles: int): """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles @contextlib.contextmanager @@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): sampling_params) # Request-0 times out and is cleared! assert '0' not in req_to_blocks + + +def test_register_kv_caches(dist_init): + """ + Test that register_kv_caches() properly calls nixl_wrapper methods with + correct data. + + This test verifies: + 1. nixl_wrapper.get_reg_descs() is called with caches_data containing + tensor metadata + 2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing + block layout info + """ + + vllm_config = create_vllm_config() + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2, + block_size=16, + num_kv_heads=4, + head_size=64) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + expected_tensor_size = shared_tensor[0].element_size( + ) * shared_tensor[0].numel() + expected_base_addrs = [ + shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr() + ] + + with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501 + + # Create connector + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + + # Get the mock instance + mock_wrapper_instance = mock_nixl_wrapper.return_value + connector.connector_worker.nixl_wrapper = mock_wrapper_instance + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify get_reg_descs was called with caches_data + assert mock_wrapper_instance.get_reg_descs.called + caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] + assert len(caches_data) == 4 + + for i, cache_entry in enumerate(caches_data): + base_addr, size, _tp_rank, _ = cache_entry + assert size == expected_tensor_size, \ + f"Entry {i}: Expected tensor size {expected_tensor_size}, " \ + f"got {size}" + assert base_addr == expected_base_addrs[i], \ + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \ + f"got {base_addr}" + + # Verify get_xfer_descs was called with blocks_data + assert mock_wrapper_instance.get_xfer_descs.called + blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] + + # Validate blocks_data structure and size + expected_blocks_count = 8 + assert len(blocks_data) == expected_blocks_count, \ + f"Expected {expected_blocks_count} blocks, " \ + f"got {len(blocks_data)}" + + expected_block_len = expected_tensor_size // 2 + for i, block_entry in enumerate(blocks_data): + block_start_addr, block_len, tp_rank = block_entry + assert block_len == expected_block_len, \ + f"Block entry {i}: Expected block len {expected_block_len}, " \ + f"got {block_len}" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 07fcdecac6276..5601ee74be110 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC): Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). - Args: kv_caches: - dictionary of layer names, kv cache + Args: + kv_caches: dictionary of layer names, kv cache """ return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4f51229ffbd26..6608d2a4a9e09 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -686,9 +686,6 @@ class NixlConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - _, first_kv_cache = next(iter(kv_caches.items())) - kv_elem_size = first_kv_cache.element_size() - if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -701,66 +698,16 @@ class NixlConnectorWorker: "host_xfer_buffer should not be initialized when " f"kv_buffer_device is {self.kv_buffer_device}") - # TODO(tms): Find a more robust way to detect and handle MLA - # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected - # KV memory layout is HND, as opposed to the default NHD. Note that it - # will only affects the strides. For MLA instead, we make require no - # such thing and resort to the standard layout. - use_mla = len(first_kv_cache.shape) == 3 - if self.device_type == "tpu": - assert not use_mla, f"{self.kv_buffer_device} does not support MLA." - assert self._use_pallas_v1, f"attn backend: {self.backend_name}" - # tpu (v1) kv shape per layer: - # (num_blocks, block_size, num_kv_heads * 2, head_size) - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads_x_2, head_dim = block_shape - self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim - elif self.device_type == "cuda": - assert use_mla == self.use_mla - # TODO (NickLucche) not compatible with hybrid allocator. - # Enforce check once it goes live, as a single kv layout - # is expected for xfers. - if use_mla: - # MLA case. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 2 # [block_size, latent_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim - else: - # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] - else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads, head_dim = block_shape[-3:] - # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim - assert block_size == self.block_size - else: - raise RuntimeError( - f"{self.device_type} ({self.backend_name}) is not supported.") - - # TODO(tms): self.block_len needs to be per-layer for sliding window, - # hybrid attn, etc - # block size in bytes - self.block_len = kv_elem_size * math.prod(block_shape) logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " - "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, - self.use_host_buffer, self.num_blocks, block_shape, - first_kv_cache.shape) - self.dst_num_blocks[self.engine_id] = self.num_blocks - self.device_kv_caches = kv_caches - kv_caches_base_addr = [] + "use_host_buffer: %s", self.use_mla, self.kv_buffer_device, + self.use_host_buffer) + caches_data = [] + # With hybrid allocator, layers can share a kv cache tensor + seen_base_addresses = [] + xfer_buffers = (self.host_xfer_buffers + if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -770,42 +717,35 @@ class NixlConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in xfer_buffers.values(): - # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla \ - or self._use_pallas_v1 or self._use_flashinfer \ - else cache_or_caches + split_k_and_v = not (self.use_mla or self._use_pallas_v1 + or self._use_flashinfer) + tensor_size_bytes = None + for layer_name, cache_or_caches in xfer_buffers.items(): + cache_list = cache_or_caches if split_k_and_v else [ + cache_or_caches + ] + for cache in cache_list: base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len - # NOTE: use tp_rank for device_id since multi-node TP - # is rarely used. - caches_data.append((base_addr, region_len, self.tp_rank, "")) - kv_caches_base_addr.append(base_addr) - self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.numel() * cache.element_size() + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, \ + "All kv cache tensors must have the same size" + caches_data.append( + (base_addr, tensor_size_bytes, self.tp_rank, "")) + + self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - if self.vllm_config.model_config.hf_config.model_type == "llama4": - from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) - llama4_config = self.vllm_config.model_config.hf_text_config - no_rope_layers = llama4_config.no_rope_layers - chunk_size = llama4_config.attention_chunk_size - chunk_block_size = math.ceil(chunk_size / self.block_size) - for layer_idx in range(self.num_layers): - # no_rope_layers[layer_idx] == 0 means NoPE (global) - # Any other value means RoPE (local chunked) - is_local_attention = no_rope_layers[layer_idx] != 0 - block_window = chunk_block_size if is_local_attention else None - self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) - assert len(self.block_window_per_layer) == self.num_layers - descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) @@ -813,9 +753,20 @@ class NixlConnectorWorker: logger.debug("Done registering descs") self._registered_descs.append(descs) + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.slot_size_bytes = self.block_len // self.block_size + if self._use_flashinfer: + assert self.slot_size_bytes % 2 == 0 + self.slot_size_bytes /= 2 + self.device_kv_caches = kv_caches + self.dst_num_blocks[self.engine_id] = self.num_blocks + # Register local/src descr for NIXL xfer. blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: + for base_addr in seen_base_addresses: # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # could create fewer, but then _get_block_descs_ids needs to @@ -836,6 +787,26 @@ class NixlConnectorWorker: self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) + # TODO(mgoin): Hybrid memory allocator is currently diabled for + # models with local attention (Llama 4). Can remove this once enabled. + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id, From 0ba1b54ac6958de7a02dfd39be7b59dd430be9ca Mon Sep 17 00:00:00 2001 From: Guillaume Calmettes Date: Fri, 22 Aug 2025 10:32:24 +0200 Subject: [PATCH 04/38] [gpt-oss] add input/output usage in responses api when harmony context is leveraged (#22667) Signed-off-by: Guillaume Calmettes --- vllm/entrypoints/context.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index e817f07ef5947..f70e1fc207f86 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -3,6 +3,7 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import TYPE_CHECKING, Union from openai_harmony import Author, Message, Role, StreamState, TextContent @@ -67,15 +68,27 @@ class HarmonyContext(ConversationContext): self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) - # TODO(woosuk): Implement the following fields. self.num_prompt_tokens = 0 - self.num_cached_tokens = 0 self.num_output_tokens = 0 + # TODO(woosuk): Implement the following fields. + self.num_cached_tokens = 0 self.num_reasoning_tokens = 0 + def _update_num_prompt_tokens(self, output: RequestOutput): + if output.prompt_token_ids and len(output.prompt_token_ids) > 0: + # NOTE: with built-in tools, there might be multiple rounds in + # the conversation, with the full conversation being resent + # as new prompt each time. Hence the sum. + self.num_prompt_tokens += len(output.prompt_token_ids) + + def _update_num_output_tokens(self, token_ids: Sequence[int]): + self.num_output_tokens += len(token_ids) + def append_output(self, output) -> None: if isinstance(output, RequestOutput): + self._update_num_prompt_tokens(output) output_token_ids = output.outputs[0].token_ids + self._update_num_output_tokens(output_token_ids) self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) @@ -158,6 +171,7 @@ class StreamingHarmonyContext(HarmonyContext): self.parser = get_streamable_parser_for_assistant() self.encoding = get_encoding() self.last_tok = None + self.first_tok_of_message = True @property def messages(self) -> list: @@ -165,8 +179,18 @@ class StreamingHarmonyContext(HarmonyContext): def append_output(self, output) -> None: if isinstance(output, RequestOutput): + # append_output is called for each output token in streaming case, + # so we only want to add the prompt tokens once for each message. + if self.first_tok_of_message: + self._update_num_prompt_tokens(output) + # Reset self.first_tok_of_message if needed: + # if the current token is the last one of the current message + # (finished=True), then the next token processed will mark the + # beginning of a new message + self.first_tok_of_message = output.finished tok = output.outputs[0].token_ids[0] self.parser.process(tok) + self._update_num_output_tokens(output.outputs[0].token_ids) self.last_tok = tok else: # Handle the case of tool output in direct message format From 998720859caadd8a8d2a3e2af8b3e6e34a42e8da Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Fri, 22 Aug 2025 01:43:29 -0700 Subject: [PATCH 05/38] Migrate MiniCPMOAudioInputs to TensorSchema (#21847) Signed-off-by: Benji Beck Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/model_executor/models/minicpmo.py | 52 +++++++++++++++++--------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 98ea366d3a6e4..225668d87facb 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch from torch import nn @@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, MultiModalDataParser) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, PromptUpdateDetails) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, MiniCPMVDummyInputsBuilder, @@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, CPU_DEVICE = torch.device("cpu") -class MiniCPMOAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - audio_features: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bns: Batch size * number of audios * number of slices + - bn: Batch size * number of audios + - c: Number of channels + - l: Length + - s: Number of slices + """ + type: Literal["audio_features"] = "audio_features" + + audio_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bns", "c", "l", dynamic_dims={"l"}), + ] """ - Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Slice here means chunk. Audio that is too long will be split into slices, - which is the same as image. - Padding is used therefore `audio_features` is `torch.Tensor`. + which is the same as image. Padding is used therefore `audio_features` is + `torch.Tensor`. """ - audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]] + audio_feature_lens: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s"), + ] """ - Shape: `(batch_size * num_audios, num_slices)` - This should be feature length of each audio slice, which equals to `audio_features.shape[-1]` """ -class MiniCPMOAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - audio_embeds: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_audios, num_slices, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - instead of a batched tensor. + Dimensions: + - bn: Batch size * number of audios + - s: Number of slices + - h: Hidden size (must match language model backbone) + Length of each slice may vary, so pass it as a list. """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s", "h", dynamic_dims={"s"}), + ] MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, From 88016c372a5962eb98f4dfc71243ccd64433710e Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Fri, 22 Aug 2025 17:47:17 +0800 Subject: [PATCH 06/38] [Bugfix] Fix pooling models on CPU backend (#23392) Signed-off-by: jiang1.li --- vllm/utils/__init__.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 1eefb32eaa90b..7079bfb8dbcee 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1440,6 +1440,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None: torch.cuda.set_stream = _patched_set_stream +class _StreamPlaceholder: + + def __init__(self): + self.synchronize = lambda: None + + def current_stream() -> torch.cuda.Stream: """ replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. @@ -1459,8 +1465,18 @@ def current_stream() -> torch.cuda.Stream: # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process - _current_stream_tls.value = torch.cuda.Stream( - ) if current_platform.is_rocm() else torch.cuda.current_stream() + if current_platform.is_rocm(): + _current_stream_tls.value = torch.cuda.Stream() + elif current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API") return _current_stream_tls.value From 285178b3b824d70b46b351daa8f8942d23da264a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 22 Aug 2025 17:56:51 +0800 Subject: [PATCH 07/38] [V0 Deprecation] Remove V0 LoRA test (#23418) Signed-off-by: Jee Jee Li --- tests/lora/conftest.py | 31 +------- tests/lora/test_add_lora.py | 11 +-- tests/lora/test_llama_tp.py | 5 +- tests/lora/test_lora_manager.py | 130 +++++++++++++++++--------------- tests/lora/test_mixtral.py | 1 - tests/lora/test_worker.py | 20 ++--- tests/lora/utils.py | 76 +++++++++++++++++++ 7 files changed, 158 insertions(+), 116 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 909b73933139d..cba573b63c045 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -3,15 +3,13 @@ import tempfile from collections import OrderedDict -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest import torch import torch.nn as nn from huggingface_hub import snapshot_download -import vllm -from vllm.config import LoRAConfig from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) @@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform @@ -104,6 +101,7 @@ def dummy_model() -> nn.Module: ])) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 return model @@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module: ], } model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 + return model @@ -221,29 +221,6 @@ def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") -@pytest.fixture -def llama_2_7b_engine_extra_embeddings(): - cleanup_dist_env_and_memory(shutdown_ray=True) - get_model_old = get_model - - def get_model_patched(**kwargs): - kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, - max_lora_rank=8) - return get_model_old(**kwargs) - - with patch("vllm.worker.model_runner.get_model", get_model_patched): - engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) - yield engine.llm_engine - del engine - cleanup_dist_env_and_memory(shutdown_ray=True) - - -@pytest.fixture -def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): - yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. - model_runner.model) - - @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index d7b019509fa3e..44755c603f281 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -5,7 +5,6 @@ import time import pytest -import vllm.envs as env from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) @@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files): # Run with warmup add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests] add_lora_results = await asyncio.gather(*add_lora_tasks) - if env.VLLM_USE_V1: - # Test that all all_lora calls are successful. - assert all(add_lora_results) - else: - # No way to check V0 engine results as the calls just return None. - pass + + # Test that all all_lora calls are successful. + assert all(add_lora_results) + time_with_add_lora = await requests_processing_time( llm, warmup_run_requests) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index b1ad1fdd06064..06196cc697cec 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files): enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, - max_loras=4, - enable_chunked_prefill=True) + max_loras=4) generate_and_test(llm, sql_lora_files) @@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files): max_num_seqs=16, max_loras=4, tensor_parallel_size=4, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) @@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): max_loras=4, tensor_parallel_size=4, fully_sharded_loras=True, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8f8a27006cf67..c9ab32edc7f32 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.platforms import current_platform +from .utils import create_peft_lora + EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", @@ -35,17 +37,6 @@ DEVICES = ([ DEFAULT_DTYPE = torch.get_default_dtype() -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Some tests depend on V0 internals. Since both V0 and V1 use the same - LoRAModelManager it is okay to just test V0. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( @@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): max_loras=2, lora_dtype=DEFAULT_DTYPE), device=device) - assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity @@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) -def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, + tmp_path): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) worker_adapter_manager = LRUCacheWorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + 4, 2, + dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size, + lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device @@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, @pytest.mark.parametrize("device", DEVICES) -def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, + tmp_path): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - + 4, 2, dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager(dummy_model_gate_up) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model_gate_up, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 @@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 0ea07793311cb..03e5d8d5d6728 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): max_loras=4, distributed_executor_backend="ray", tensor_parallel_size=tp_size, - enable_chunked_prefill=True, ) expected_lora_output = [ diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index bd0aea67b9702..a836ff94ba3ed 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -4,17 +4,14 @@ import os import random import tempfile -from typing import Union from unittest.mock import patch -import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.v1.worker.gpu_worker import Worker as V1Worker -from vllm.worker.worker import Worker +from vllm.v1.worker.gpu_worker import Worker NUM_LORAS = 16 @@ -22,18 +19,11 @@ NUM_LORAS = 16 @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - def set_active_loras(worker: Union[Worker, V1Worker], - lora_requests: list[LoRARequest]): + def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) - if isinstance(worker, Worker): - # v0 case - worker.model_runner.set_active_loras(lora_requests, lora_mapping) - else: - # v1 case - worker.model_runner.lora_manager.set_active_adapters( - lora_requests, lora_mapping) - worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker + worker.model_runner.lora_manager.set_active_adapters( + lora_requests, lora_mapping) vllm_config = VllmConfig( model_config=ModelConfig( @@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files): max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS), ) - worker = worker_cls( + worker = Worker( vllm_config=vllm_config, local_rank=0, rank=0, diff --git a/tests/lora/utils.py b/tests/lora/utils.py index cc1b0d81955bc..7cda90787b6f1 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os from dataclasses import dataclass from typing import Optional, Union import torch +from safetensors.torch import save_file from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights @@ -340,3 +343,76 @@ def generate_data_for_nslices( seq_len_tensor, indices, ) + + +def create_peft_lora( + model: torch.nn.Module, + save_dir: str, + target_modules: list[str], + rank: int = 8, + alpha: int = 16, + dropout: float = 0.1, + lora_dtype: torch.dtype = torch.float16, +) -> dict[str, torch.Tensor]: + lora_weights = {} + adapter_config = { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": "dummy_model", + "revision": None, + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": rank, + "lora_alpha": alpha, + "lora_dropout": dropout, + "fan_in_fan_out": False, + "bias": "none", + "modules_to_save": None, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + "target_modules": target_modules, + "exclude_modules": None, + "use_rslora": False, + "use_dora": False, + "loftq_config": None, + } + + for module_name in target_modules: + + module = model + for attr in module_name.split("."): + module = getattr(module, attr) + + if hasattr(module, "input_size") and hasattr(module, "output_size"): + + in_features = module.input_size + out_features = module.output_size + + elif hasattr(module, "embedding_dim") and hasattr( + module, "num_embeddings"): + # ParallelLMHead + in_features = module.embedding_dim + out_features = module.num_embeddings + else: + raise ValueError( + f"Unable to determine dimensions for module {module_name}") + + lora_A = torch.randn(rank, in_features, dtype=lora_dtype) + + torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5) + + lora_B = torch.zeros(out_features, rank, dtype=lora_dtype) + + # PEFT style + lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A + lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B + + config_path = os.path.join(save_dir, "adapter_config.json") + with open(config_path, "w", encoding="utf-8") as f: + json.dump(adapter_config, f, indent=2, ensure_ascii=False) + + weights_path = os.path.join(save_dir, "adapter_model.safetensors") + save_file(lora_weights, weights_path) + + return lora_weights From 808d2e9aa0f302bf9667b09b9dcf297f86927dac Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 22 Aug 2025 03:07:22 -0700 Subject: [PATCH 08/38] [Misc] Move M-RoPE init logic to _init_mrope_positions (#23422) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 63 +++++++++++++++--------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d520b71de3ff9..7160894b4acda 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -507,42 +507,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): output_token_ids=[], lora_request=new_req_data.lora_request, ) - self.requests[req_id] = req_state # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_item in req_state.mm_kwargs: - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := - mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - hf_config = self.model_config.hf_config - - req_state.mrope_positions, req_state.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + self._init_mrope_positions(req_state) reqs_to_add.append(req_state) @@ -639,6 +608,36 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_item in req_state.mm_kwargs: + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", From 281710ef9a2a795d57bce997d89a3ed69287464e Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 22 Aug 2025 08:10:16 -0400 Subject: [PATCH 09/38] [Attention] Allow V1 flash_attn to support cross-attention (#23297) Signed-off-by: Russell Bryant --- vllm/v1/attention/backends/flash_attn.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index eed3cba9a2ca7..eca83b6d2ee45 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -405,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl): FlashAttentionBackend.validate_head_size(head_size) - if attn_type not in [ - AttentionType.DECODER, AttentionType.ENCODER_ONLY - ]: - raise NotImplementedError("Encoder/decoder cross-attention " - "is not implemented for " - "FlashAttentionImpl") - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -477,7 +470,7 @@ class FlashAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, ): + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention(query[:num_actual_tokens], @@ -489,7 +482,11 @@ class FlashAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if (self.kv_sharing_target_layer_name is None and key is not None + and value is not None): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -528,7 +525,7 @@ class FlashAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) flash_attn_varlen_func( q=query[:num_actual_tokens], From 695e7adcd22c25b859a6d4b3af99617aaf425708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E6=9C=B1=20=C2=B7=20Kiki?= Date: Fri, 22 Aug 2025 21:08:53 +0800 Subject: [PATCH 10/38] [misc] Remove outdate comment about runai_model_streamer (#23421) Signed-off-by: carlory --- vllm/model_executor/model_loader/weight_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 78b186265dd04..7053c5bc515cf 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -31,9 +31,7 @@ from vllm.utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer -except (ImportError, OSError): - # see https://github.com/run-ai/runai-model-streamer/issues/26 - # OSError will be raised on arm64 platform +except ImportError: runai_model_streamer = PlaceholderModule( "runai_model_streamer") # type: ignore[assignment] SafetensorsStreamer = runai_model_streamer.placeholder_attr( From a073be6d87c6480ecd725bd475cc4f30fd747aa4 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 22 Aug 2025 06:20:39 -0700 Subject: [PATCH 11/38] [Doc] Update the doc for log probs + prefix caching (#23399) Signed-off-by: Chen Zhang Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/usage/v1_guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index b89768913681e..7fc615d4c042f 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -166,7 +166,7 @@ Processed means the values after applying all processors, including temperature ##### Prompt Logprobs with Prefix Caching -Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414). +Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs. #### Deprecated Features From 325aa3dee922b344a26b9e74d9ae3c769828e70e Mon Sep 17 00:00:00 2001 From: Ning Xie Date: Fri, 22 Aug 2025 22:01:35 +0800 Subject: [PATCH 12/38] [Misc] local import code clean (#23420) Signed-off-by: Andy Xie --- vllm/v1/worker/gpu_worker.py | 1 - vllm/worker/worker.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d61177d4245dd..f83a4f4faeb5e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -292,7 +292,6 @@ class Worker(WorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.initialize_kv_cache(kv_cache_config) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7a01e585ba6d0..fc24d95b80f2c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -3,6 +3,7 @@ """A GPU worker class.""" import gc import os +from contextlib import nullcontext from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -206,7 +207,6 @@ class Worker(LocalOrDistributedWorkerBase): "used for one instance per process.") context = allocator.use_memory_pool(tag="weights") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.load_model() @@ -330,7 +330,6 @@ class Worker(LocalOrDistributedWorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self._init_cache_engine() From ebe14621e353217ff16da329c2e76b80ca233b1b Mon Sep 17 00:00:00 2001 From: Naman Lalit Date: Fri, 22 Aug 2025 08:12:28 -0700 Subject: [PATCH 13/38] [Bug fix] Dynamically setting the backend variable for genai_perf_tests in the run-nightly-benchmark script (#23375) Signed-off-by: Naman Lalit --- .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh index 06d7b5ed484da..a00de940cbbb8 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh @@ -382,7 +382,7 @@ run_genai_perf_tests() { client_command="genai-perf profile \ -m $model \ --service-kind openai \ - --backend vllm \ + --backend "$backend" \ --endpoint-type chat \ --streaming \ --url localhost:$port \ From 51a215300bb9df3b5730ef7dedeb46eb5f5a0138 Mon Sep 17 00:00:00 2001 From: Burkhard Ringlein Date: Fri, 22 Aug 2025 17:13:39 +0200 Subject: [PATCH 14/38] [Fix] Bump triton version in rocm-build requirements (#21630) Signed-off-by: Burkhard Ringlein --- requirements/rocm-build.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 94201543cd4f3..cbae9bbb8a9b3 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -6,7 +6,7 @@ torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 -triton==3.2 +triton==3.3.0 cmake>=3.26.1,<4 packaging>=24.2 setuptools>=77.0.3,<80.0.0 From 613a23b57f02cad9138e69399bea2d2413bb6802 Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Fri, 22 Aug 2025 17:22:29 +0100 Subject: [PATCH 15/38] [Bugfix]: Installing dev environment due to pydantic incompatible version (#23353) Signed-off-by: Martin Hickey --- requirements/common.txt | 2 +- requirements/test.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 365457436faa8..8acf634526ff1 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -13,7 +13,7 @@ protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.99.1 # For Responses API with reasoning content -pydantic >= 2.10 +pydantic >= 2.11.7 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 diff --git a/requirements/test.txt b/requirements/test.txt index 85b677c00b1d3..8b872752d875c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -742,7 +742,7 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.11.5 +pydantic==2.11.7 # via # -r requirements/test.in # albumentations From 88491c1b6bcac8fa6adfa22489c92419c5e89055 Mon Sep 17 00:00:00 2001 From: PapaGoose <56637198+PapaGoose@users.noreply.github.com> Date: Fri, 22 Aug 2025 19:39:19 +0300 Subject: [PATCH 16/38] [Speculators][Speculative Decoding] Fix Qwen 2 Eagle3 Support (#23337) --- vllm/model_executor/models/qwen2.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b6a1d2db303c7..801741ecaf3b8 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -52,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -442,7 +442,7 @@ class Qwen2Model(nn.Module): return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -488,6 +488,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor, From 424fb7a5d22c013ec0ef6244c62cd75ed076375a Mon Sep 17 00:00:00 2001 From: bppps <44322223+bppps@users.noreply.github.com> Date: Sat, 23 Aug 2025 00:56:46 +0800 Subject: [PATCH 17/38] =?UTF-8?q?[BugFix]=20Fix=20the=20issue=20where=20im?= =?UTF-8?q?age=20embeddings=20were=20incorrectly=20split.=E2=80=A6=20(#233?= =?UTF-8?q?66)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: bppps Co-authored-by: zouyu.zzx Co-authored-by: bppps --- vllm/model_executor/models/glm4_1v.py | 7 +- .../models/qwen2_5_omni_thinker.py | 80 ++++++++++++------- vllm/model_executor/models/qwen2_vl.py | 64 ++++++++++----- 3 files changed, 99 insertions(+), 52 deletions(-) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 08252c51310be..662728e6b1393 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -74,7 +74,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision +from .qwen2_vl import (_create_qwen2vl_field_factory, + apply_rotary_pos_emb_vision) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -1153,7 +1154,9 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _get_prompt_updates( self, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 5aadebc33324c..664e3f2985a59 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,7 +25,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -79,40 +79,57 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) -def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_feature_lengths = hf_inputs.get("audio_feature_lengths", - torch.empty((0, ))) +def create_qwen2_5_omni_thinker_field_factory( + spatial_merge_size: int +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, + MultiModalFieldConfig]]: - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, + torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) - num_videos = len(video_grid_sizes) + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) - return dict( - input_audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_feature_lengths, dim=1), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - audio_feature_lengths=MultiModalFieldConfig.batched("audio"), - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - second_per_grid_ts=MultiModalFieldConfig.batched("video"), - use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos), - ) + num_videos = len(video_grid_sizes) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + use_audio_in_video=MultiModalFieldConfig.shared( + "video", num_videos), + ) + + return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(self._spatial_merge_size, *args, **kwargs) + def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -124,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): required_fields={ "input_audio_features", "audio_feature_lengths" }, - fields_factory=_qwen2_5_omni_thinker_field_config, + fields_factory=create_qwen2_5_omni_thinker_field_factory( + self._spatial_merge_size), ) return super()._parse_audio_data(data) @@ -214,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( + spatial_merge_size=self.info.get_hf_config( + ).vision_config.spatial_merge_size, target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -265,7 +285,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2_5_omni_thinker_field_config(hf_inputs) + return create_qwen2_5_omni_thinker_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _maybe_apply_prompt_updates( self, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2315fe2ab92b5..ae7a8d8d7a5b9 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -699,29 +699,46 @@ class Qwen2VisionTransformer(nn.Module): return loaded_params -def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) +def _create_qwen2vl_field_factory( + spatial_merge_size: int +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) + def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + return _qwen2vl_field_config class Qwen2VLMultiModalDataParser(MultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(*args, **kwargs) + def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -731,7 +748,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_image_data(data) @@ -745,7 +763,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_video_data(data) @@ -967,7 +986,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2VLMultiModalDataParser() + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size) def _get_prompt_updates( self, @@ -1010,7 +1030,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, From 341923b9820ea1dc437445e2e81644e6ba47e5b6 Mon Sep 17 00:00:00 2001 From: Aziz Date: Fri, 22 Aug 2025 19:20:59 +0200 Subject: [PATCH 18/38] fix(tests): Ensure reliable CUDA cache clearing in MoE test (#23416) Signed-off-by: AzizCode92 Signed-off-by: Michael Goin Co-authored-by: Michael Goin Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/kernels/moe/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 1951eb0c61802..0ea9667914fd5 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -429,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) - torch.cuda.empty_cache() vllm_moe.experts.w2_weight = Parameter(F.pad( vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) + torch.cuda.synchronize() torch.cuda.empty_cache() # Run forward passes for both MoE blocks From b6d7d34fc62947eadf9adcfbc0264da388cb830c Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Fri, 22 Aug 2025 10:31:24 -0700 Subject: [PATCH 19/38] Add unit tests for batched guided and non-guided requests (#23389) Signed-off-by: Yong Hoon Shin --- .../llm/test_struct_output_generate.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 58b6297762d3c..572af0175d114 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any import jsonschema import pytest import regex as re +import torch from pydantic import BaseModel from tests.reasoning.utils import run_reasoning_extraction +from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -727,3 +729,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): assert "a4" not in generated assert "a5" not in generated assert "a6" not in generated + + +@pytest.mark.parametrize("guided_decoding_backend", + ["guidance", "xgrammar", "outlines"]) +def test_structured_output_batched_with_non_guided_requests( + monkeypatch: pytest.MonkeyPatch, + sample_json_schema: dict[str, Any], + guided_decoding_backend: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager = bool(not current_platform.is_tpu()) + + llm = LLM( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + enforce_eager=enforce_eager, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=(guided_decoding_backend + in {"xgrammar", "guidance"}), + ) + + guided_prompt = ( + "Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + + non_guided_prompt = "The diameter of the Earth in kilometers is " + + prompts = [guided_prompt, non_guided_prompt] + sampling_params = [ + SamplingParams( + temperature=1.0, + max_tokens=400, + guided_decoding=GuidedDecodingParams(json=sample_json_schema)), + # No max tokens, temp=0 to assert on contents + SamplingParams( + seed=42, + temperature=0, + top_p=1.0, + ), + ] + + outputs = llm.generate(prompts=prompts, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + # Free memory as soon as possible as failed assertions + # will short circuit and not free up memory + del llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + for index, output in enumerate(outputs): + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}") + + if index == 0: + # First prompt is guided, expect valid JSON + assert "\n" not in generated_text + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_json_schema) + else: + # Second prompt is not guided, expect valid output + # Cannot assert on exact output, but we can expect it to be factual + assert "12,742" in generated_text + + # non-guided requests should not return a valid JSON here + with pytest.raises(ValueError): + output_json = json.loads(generated_text) From 22cf679aadca99311cfb5a9f894039e464e366aa Mon Sep 17 00:00:00 2001 From: Didier Durand <2927957+didier-durand@users.noreply.github.com> Date: Fri, 22 Aug 2025 19:38:46 +0200 Subject: [PATCH 20/38] [Doc]: fix various typos in multiple files (#23179) Signed-off-by: Didier Durand --- vllm/beam_search.py | 2 +- vllm/compilation/backends.py | 2 +- vllm/engine/arg_utils.py | 6 +++--- vllm/engine/multiprocessing/client.py | 4 ++-- vllm/entrypoints/chat_utils.py | 2 +- vllm/utils/__init__.py | 4 ++-- vllm/v1/structured_output/__init__.py | 4 ++-- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index f3bc4218323d8..5a2e79e1b5c74 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -18,7 +18,7 @@ class BeamSearchSequence: The text field is optional and will only be filled when the sequence is about to be returned to the user. """ - # The tokens includes the prompt. + # The tokens include the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] lora_request: Optional[LoRARequest] = None diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 059e7a3b29761..56494dffc96b3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -484,7 +484,7 @@ class VllmBackend: factors = [] # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affects the computation graph. + # VLLM_PP_LAYER_PARTITION will affect the computation graph. env_hash = envs.compute_hash() factors.append(env_hash) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4700a93dd6da3..965264ee3097a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -605,7 +605,7 @@ class EngineArgs: **guided_decoding_kwargs["disable_additional_properties"]) guided_decoding_group.add_argument( "--reasoning-parser", - # This choices is a special case because it's not static + # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), **guided_decoding_kwargs["reasoning_backend"]) @@ -1047,7 +1047,7 @@ class EngineArgs: # details from the config directly # no user input required / expected if isinstance(hf_config, SpeculatorsConfig): - # We create one since we dont create one + # We create one since we don't create one self.speculative_config = {} self.speculative_config[ "num_speculative_tokens"] = hf_config.num_lookahead_tokens @@ -1775,7 +1775,7 @@ class AsyncEngineArgs(EngineArgs): def add_cli_args(parser: FlexibleArgumentParser, async_args_only: bool = False) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may - # adding a new kind of quantization method to --quantization argument or + # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index eca29af50055f..0bb11328b1db5 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -539,7 +539,7 @@ class MQLLMEngineClient(EngineClient): if request_id in self.output_queues: raise ValueError(f"Request {request_id} already exists") - # 1) Create output queue for this requests. + # 1) Create output queue for this request. queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue @@ -651,7 +651,7 @@ class MQLLMEngineClient(EngineClient): # Uses the same I/O as generate requests request = RPCLoadAdapterRequest(lora_request) - # Create output queue for this requests. + # Create output queue for this request. queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() self.output_queues[request.request_id] = queue diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 87772a499f423..7b11a50642de9 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1330,7 +1330,7 @@ def apply_mistral_chat_template( # mistral-common uses assert statements to stop processing of input # if input does not comply with the expected format. # We convert those assertion errors to ValueErrors so they can be - # are properly caught in the preprocessing_input step + # properly caught in the preprocessing_input step except (AssertionError, MistralCommonException) as e: raise ValueError(str(e)) from e diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 7079bfb8dbcee..7c34a858c0a21 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2482,7 +2482,7 @@ class PlaceholderModule(_PlaceholderBase): A placeholder object to use when a module does not exist. This enables more informative errors when trying to access attributes - of a module that does not exists. + of a module that does not exist. """ def __init__(self, name: str) -> None: @@ -3109,7 +3109,7 @@ class LazyLoader(types.ModuleType): """ LazyLoader module borrowed from Tensorflow https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py - with a addition of "module caching". + with an addition of "module caching". Lazily import a module, mainly to avoid pulling in large dependencies. Modules such as `xgrammar` might do additional side effects, so we diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 63604a335d9f0..3bafa61044abc 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -267,7 +267,7 @@ class StructuredOutputManager: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None # by default, we should always advance - # for cases that doesn't uses thinking mode. + # for cases that don't use thinking mode. if self.reasoner is not None: structured_req = request.structured_output_request @@ -276,7 +276,7 @@ class StructuredOutputManager: # Check if reasoning ends in *this* step if self.reasoner.is_reasoning_end(request.all_token_ids): - # Reasoning just ended, so we shouldn't advanced til + # Reasoning just ended, so we shouldn't advance til # next pass structured_req.reasoning_ended = True From 32d2b4064feea38802489b71e47703d1f901a17e Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 23 Aug 2025 01:46:34 +0800 Subject: [PATCH 21/38] [Model] Add Ovis2.5 PP support (#23405) Signed-off-by: Isotr0py --- tests/distributed/test_pipeline_parallel.py | 1 + .../multimodal/generation/test_common.py | 6 +- tests/models/registry.py | 4 +- vllm/model_executor/models/ovis2_5.py | 36 +-- vllm/model_executor/models/siglip2navit.py | 243 ++++++++++++------ 5 files changed, 185 insertions(+), 105 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 12dd7c4222630..28150d7682378 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -233,6 +233,7 @@ MULTIMODAL_MODELS = { "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(), "AIDC-AI/Ovis2-1B": PPTestSettings.fast(), + "AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(), diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index ea5de9d9f5c5b..96208f8eda628 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -11,7 +11,6 @@ from pathlib import PosixPath import pytest from transformers import (AutoModel, AutoModelForImageTextToText, AutoModelForTextToWaveform, AutoModelForVision2Seq) -from transformers.utils import is_flash_attn_2_available from vllm.platforms import current_platform from vllm.utils import identity @@ -637,10 +636,7 @@ VLM_TEST_SETTINGS = { dtype="half", num_logprobs=10, patch_hf_runner=model_utils.ovis2_5_patch_hf_runner, - marks=[pytest.mark.skipif( - not is_flash_attn_2_available(), - reason="HF model needs `flash_attn` installed" - )], + hf_model_kwargs={"revision": "refs/pr/5"}, ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], diff --git a/tests/models/registry.py b/tests/models/registry.py index 4035319b45ce4..25dbbd7fa9832 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -468,9 +468,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", - trust_remote_code=True, - max_transformers_version="4.53", - transformers_version_reason="HF model is not compatible"), # noqa: E501 + trust_remote_code=True), "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index aa4ea3dd48f6e..58a14072443cb 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -30,7 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor -from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP IMAGE_TOKEN = "" VIDEO_TOKEN = "