From 0bec63fa317e1fbd62e19b0fc31c43c81bf89077 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Wed, 3 Dec 2025 14:20:37 +0800 Subject: [PATCH 01/11] [BugFix] fix imgs_pos in hunyuan_vl (#29879) Co-authored-by: Isotr0py --- vllm/transformers_utils/processors/hunyuan_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/transformers_utils/processors/hunyuan_vl.py b/vllm/transformers_utils/processors/hunyuan_vl.py index 615a8bff85912..f32ce115c866d 100644 --- a/vllm/transformers_utils/processors/hunyuan_vl.py +++ b/vllm/transformers_utils/processors/hunyuan_vl.py @@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin): attention_mask = input_ids.ne(self.pad_id) text_inputs["attention_mask"] = attention_mask - text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)] + text_inputs["imgs_pos"] = [self.get_imgs_pos(e) for e in input_ids] # image_inputs["imgs"] = [[image_inputs["pixel_values"]]] return_tensors = kwargs.pop("return_tensors", None) From bbfb55c29e7febb91e90f261dd9adb4200ee3a09 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 3 Dec 2025 15:49:34 +0800 Subject: [PATCH 02/11] [Misc] Allow `fetch_*` utils to access local files by default (#29932) Signed-off-by: DarkLight1337 --- vllm/multimodal/utils.py | 38 ++++++++++++++++++++++++++++++-------- vllm/multimodal/video.py | 2 +- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 1020554e2e073..1840220854858 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -67,8 +67,9 @@ class MediaConnector: to set num_frames for video, set `--media-io-kwargs '{"video":{"num_frames":40}}'` connection: HTTP connection client to download media contents. - allowed_local_media_path: A local directory to load media files - from. + allowed_local_media_path: A local directory to load media files from. + allowed_media_domains: If set, only media URLs that belong to this + domain can be used for multi-modal inputs. """ super().__init__() @@ -123,16 +124,16 @@ class MediaConnector: "Cannot load local files without `--allowed-local-media-path`." ) - filepath = Path(url2pathname(url_spec.path)) + filepath = Path(url2pathname(url_spec.netloc + url_spec.path)) if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( f"The file path {filepath} must be a subpath " - f"of `--allowed-local-media-path` {allowed_local_media_path}." + f"of `--allowed-local-media-path {allowed_local_media_path}`." ) return media_io.load_file(filepath) - def _assert_url_in_allowed_media_domains(self, url_spec) -> None: + def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None: if ( self.allowed_media_domains and url_spec.hostname not in self.allowed_media_domains @@ -489,9 +490,16 @@ def fetch_audio( Args: audio_url: URL of the audio file to fetch. audio_io_kwargs: Additional kwargs passed to handle audio IO. + + Warning: + This method has direct access to local files and is only intended + to be called by user code. Never call this from the online server! """ media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs} - media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) + media_connector = MediaConnector( + media_io_kwargs=media_io_kwargs, + allowed_local_media_path="/", + ) return media_connector.fetch_audio(audio_url) @@ -503,9 +511,16 @@ def fetch_image( Args: image_url: URL of the image file to fetch. image_io_kwargs: Additional kwargs passed to handle image IO. + + Warning: + This method has direct access to local files and is only intended + to be called by user code. Never call this from the online server! """ media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs} - media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) + media_connector = MediaConnector( + media_io_kwargs=media_io_kwargs, + allowed_local_media_path="/", + ) return media_connector.fetch_image(image_url) @@ -517,7 +532,14 @@ def fetch_video( Args: video_url: URL of the video file to fetch. video_io_kwargs: Additional kwargs passed to handle video IO. + + Warning: + This method has direct access to local files and is only intended + to be called by user code. Never call this from the online server! """ media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs} - media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) + media_connector = MediaConnector( + media_io_kwargs=media_io_kwargs, + allowed_local_media_path="/", + ) return media_connector.fetch_video(video_url) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 763f90fde7b6d..abfc226a689c2 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -267,7 +267,7 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend): return frames, metadata -class VideoMediaIO(MediaIO[npt.NDArray]): +class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): def __init__( self, image_io: ImageMediaIO, From 3a7751485b71ce5ef927e4aa03b28602cb90811c Mon Sep 17 00:00:00 2001 From: Andrew Xia Date: Tue, 2 Dec 2025 23:59:23 -0800 Subject: [PATCH 03/11] [responsesAPI] support input output messages for non harmony models (#29549) Signed-off-by: Andrew Xia Co-authored-by: Andrew Xia --- .../openai/test_response_api_simple.py | 18 +++++++++++++++ vllm/entrypoints/context.py | 22 +++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 22 +++++++++++++++---- vllm/entrypoints/openai/serving_responses.py | 13 +++++------ 4 files changed, 64 insertions(+), 11 deletions(-) diff --git a/tests/entrypoints/openai/test_response_api_simple.py b/tests/entrypoints/openai/test_response_api_simple.py index 425b8199a0fd0..aee03199bc6f4 100644 --- a/tests/entrypoints/openai/test_response_api_simple.py +++ b/tests/entrypoints/openai/test_response_api_simple.py @@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str): assert response.status == "completed" +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_enable_response_messages(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="Hello?", + extra_body={"enable_response_messages": True}, + ) + assert response.status == "completed" + assert response.input_messages[0]["type"] == "raw_message_tokens" + assert type(response.input_messages[0]["message"]) is str + assert len(response.input_messages[0]["message"]) > 10 + assert type(response.input_messages[0]["tokens"][0]) is int + assert type(response.output_messages[0]["message"]) is str + assert len(response.output_messages[0]["message"]) > 10 + assert type(response.output_messages[0]["tokens"][0]) is int + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_reasoning_item(client: OpenAI, model_name: str): diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 1260f65dba59a..43783c92667af 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -23,6 +23,7 @@ from vllm.entrypoints.openai.parser.responses_parser import ( ) from vllm.entrypoints.openai.protocol import ( ResponseInputOutputItem, + ResponseRawMessageAndToken, ResponsesRequest, ) from vllm.entrypoints.responses_utils import construct_tool_dicts @@ -148,6 +149,8 @@ def _create_json_parse_error_messages( class SimpleContext(ConversationContext): + """This is a context that cannot handle MCP tool calls""" + def __init__(self): self.last_output = None self.num_prompt_tokens = 0 @@ -158,6 +161,9 @@ class SimpleContext(ConversationContext): # not implemented yet for SimpleContext self.all_turn_metrics = [] + self.input_messages: list[ResponseRawMessageAndToken] = [] + self.output_messages: list[ResponseRawMessageAndToken] = [] + def append_output(self, output) -> None: self.last_output = output if not isinstance(output, RequestOutput): @@ -166,6 +172,22 @@ class SimpleContext(ConversationContext): self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) + if len(self.input_messages) == 0: + output_prompt = output.prompt or "" + output_prompt_token_ids = output.prompt_token_ids or [] + self.input_messages.append( + ResponseRawMessageAndToken( + message=output_prompt, + tokens=output_prompt_token_ids, + ) + ) + self.output_messages.append( + ResponseRawMessageAndToken( + message=output.outputs[0].text, + tokens=output.outputs[0].token_ids, + ) + ) + def append_tool_output(self, output) -> None: raise NotImplementedError("Should not be called.") diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0f4b2b4d7aad0..2d34a6a0cd5ad 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1598,6 +1598,20 @@ def serialize_messages(msgs): return [serialize_message(msg) for msg in msgs] if msgs else None +class ResponseRawMessageAndToken(OpenAIBaseModel): + """Class to show the raw message. + If message / tokens diverge, tokens is the source of truth""" + + message: str + tokens: list[int] + type: Literal["raw_message_tokens"] = "raw_message_tokens" + + +ResponseInputOutputMessage: TypeAlias = ( + list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken] +) + + class ResponsesResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") created_at: int = Field(default_factory=lambda: int(time.time())) @@ -1631,8 +1645,8 @@ class ResponsesResponse(OpenAIBaseModel): # These are populated when enable_response_messages is set to True # NOTE: custom serialization is needed # see serialize_input_messages and serialize_output_messages - input_messages: list[ChatCompletionMessageParam] | None = None - output_messages: list[ChatCompletionMessageParam] | None = None + input_messages: ResponseInputOutputMessage | None = None + output_messages: ResponseInputOutputMessage | None = None # --8<-- [end:responses-extra-params] # NOTE: openAI harmony doesn't serialize TextContent properly, @@ -1658,8 +1672,8 @@ class ResponsesResponse(OpenAIBaseModel): output: list[ResponseOutputItem], status: ResponseStatus, usage: ResponseUsage | None = None, - input_messages: list[ChatCompletionMessageParam] | None = None, - output_messages: list[ChatCompletionMessageParam] | None = None, + input_messages: ResponseInputOutputMessage | None = None, + output_messages: ResponseInputOutputMessage | None = None, ) -> "ResponsesResponse": incomplete_details: IncompleteDetails | None = None if status == "incomplete": diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 5ad86194ce1b2..3c9ae8e8c8087 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -86,6 +86,7 @@ from vllm.entrypoints.openai.protocol import ( ResponseCompletedEvent, ResponseCreatedEvent, ResponseInProgressEvent, + ResponseInputOutputMessage, ResponseReasoningPartAddedEvent, ResponseReasoningPartDoneEvent, ResponsesRequest, @@ -629,8 +630,8 @@ class OpenAIServingResponses(OpenAIServing): # "completed" is implemented as the "catch-all" for now. status: ResponseStatus = "completed" - input_messages = None - output_messages = None + input_messages: ResponseInputOutputMessage | None = None + output_messages: ResponseInputOutputMessage | None = None if self.use_harmony: assert isinstance(context, HarmonyContext) output = self._make_response_output_items_with_harmony(context) @@ -670,12 +671,10 @@ class OpenAIServingResponses(OpenAIServing): output = self._make_response_output_items(request, final_output, tokenizer) - # TODO: context for non-gptoss models doesn't use messages - # so we can't get them out yet if request.enable_response_messages: - raise NotImplementedError( - "enable_response_messages is currently only supported for gpt-oss" - ) + input_messages = context.input_messages + output_messages = context.output_messages + # Calculate usage. assert final_res.prompt_token_ids is not None num_tool_output_tokens = 0 From 69520bc695ff8fa7fda66ef7c1a16761824ad354 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Tue, 2 Dec 2025 23:01:48 -1000 Subject: [PATCH 04/11] Add logging for cudagraph related info (#29825) Signed-off-by: Yong Hoon Shin --- vllm/compilation/cuda_graph.py | 94 ++++++++++++++++++++++++++++++ vllm/config/observability.py | 4 ++ vllm/engine/arg_utils.py | 6 ++ vllm/v1/core/sched/scheduler.py | 8 ++- vllm/v1/metrics/loggers.py | 14 +++++ vllm/v1/metrics/stats.py | 3 + vllm/v1/outputs.py | 4 ++ vllm/v1/worker/gpu_model_runner.py | 32 ++++++++-- vllm/v1/worker/gpu_worker.py | 2 +- 9 files changed, 161 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index a2e0abfebc2c9..0748643a5299f 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +from collections import Counter from collections.abc import Callable from contextlib import ExitStack from typing import Any @@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors logger = init_logger(__name__) +@dataclasses.dataclass(frozen=True) +class CUDAGraphStat: + num_unpadded_tokens: int + num_padded_tokens: int + num_paddings: int + runtime_mode: str + + +class CUDAGraphLogging: + """Aggregate and log cudagraph metrics""" + + COLUMN_HEADERS = [ + "Unpadded Tokens", + "Padded Tokens", + "Num Paddings", + "Runtime Mode", + "Count", + ] + + def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None): + self.reset() + self.cg_mode = str(cg_mode) + self.cg_capture_sizes = str(cg_capture_sizes or []) + + self.settings_header = ( + "**CUDAGraph Config Settings:**\n\n" + f"- Mode: {self.cg_mode}\n" + f"- Capture sizes: {self.cg_capture_sizes}\n\n" + "**CUDAGraph Stats:**\n\n" + ) + + def reset(self): + self.stats = [] + + def observe(self, cudagraph_stat: CUDAGraphStat): + self.stats.append(cudagraph_stat) + + def generate_metric_table(self) -> str: + stats_counts = Counter(self.stats) + + # Convert stats to rows of strings, in descending order of observed frequencies + rows = [] + for stat, count in sorted( + stats_counts.items(), key=lambda item: item[1], reverse=True + ): + rows.append( + [ + str(stat.num_unpadded_tokens), + str(stat.num_padded_tokens), + str(stat.num_paddings), + stat.runtime_mode, + str(count), + ] + ) + + # Calculate column widths (max of header and data) + col_widths = [] + for i, header_text in enumerate(self.COLUMN_HEADERS): + max_width = len(header_text) + for row in rows: + max_width = max(max_width, len(row[i])) + col_widths.append(max_width) + + table_header_list = [ + h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths) + ] + table_header = "| " + " | ".join(table_header_list) + " |\n" + + table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n" + + # Create data rows with proper alignment + data_rows = [] + for row in rows: + formatted_row = [ + str(val).ljust(width) for val, width in zip(row, col_widths) + ] + data_rows.append("| " + " | ".join(formatted_row) + " |") + + return ( + self.settings_header + + table_header + + table_separator + + "\n".join(data_rows) + + "\n" + ) + + def log(self, log_fn=logger.info): + if not self.stats: + return + log_fn(self.generate_metric_table()) + self.reset() + + @dataclasses.dataclass class CUDAGraphEntry: batch_descriptor: BatchDescriptor diff --git a/vllm/config/observability.py b/vllm/config/observability.py index 656a5f8a9068e..fdc27aee380ef 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -55,6 +55,10 @@ class ObservabilityConfig: kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1) """Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks.""" + cudagraph_metrics: bool = False + """Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph + dispatch modes, and their observed frequencies at every logging interval).""" + @cached_property def collect_model_forward_time(self) -> bool: """Whether to collect model forward time for the request.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 096217da4fe44..fd07cded7bc51 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -518,6 +518,7 @@ class EngineArgs: kv_cache_metrics_sample: float = get_field( ObservabilityConfig, "kv_cache_metrics_sample" ) + cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls @@ -1021,6 +1022,10 @@ class EngineArgs: "--kv-cache-metrics-sample", **observability_kwargs["kv_cache_metrics_sample"], ) + observability_group.add_argument( + "--cudagraph-metrics", + **observability_kwargs["cudagraph_metrics"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -1698,6 +1703,7 @@ class EngineArgs: collect_detailed_traces=self.collect_detailed_traces, kv_cache_metrics=self.kv_cache_metrics, kv_cache_metrics_sample=self.kv_cache_metrics_sample, + cudagraph_metrics=self.cudagraph_metrics, ) # Compilation config overrides diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 52b98ef654592..75a7385df38b1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,6 +7,7 @@ from collections.abc import Iterable from typing import Any from vllm import envs +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( ECConnectorMetadata, @@ -1037,6 +1038,7 @@ class Scheduler(SchedulerInterface): pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + cudagraph_stats = model_runner_output.cudagraph_stats outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None @@ -1219,7 +1221,9 @@ class Scheduler(SchedulerInterface): finished_req_ids.clear() if ( - stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + stats := self.make_stats( + spec_decoding_stats, kv_connector_stats, cudagraph_stats + ) ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: @@ -1420,6 +1424,7 @@ class Scheduler(SchedulerInterface): self, spec_decoding_stats: SpecDecodingStats | None = None, kv_connector_stats: KVConnectorStats | None = None, + cudagraph_stats: CUDAGraphStat | None = None, ) -> SchedulerStats | None: if not self.log_stats: return None @@ -1444,6 +1449,7 @@ class Scheduler(SchedulerInterface): kv_cache_eviction_events=eviction_events, spec_decoding_stats=spec_stats, kv_connector_stats=connector_stats_payload, + cudagraph_stats=cudagraph_stats, ) def make_spec_decoding_stats( diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index dec0e2d00aea8..6961e15c2d0c5 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -10,6 +10,7 @@ from typing import TypeAlias from prometheus_client import Counter, Gauge, Histogram import vllm.envs as envs +from vllm.compilation.cuda_graph import CUDAGraphLogging from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorLogging, @@ -106,6 +107,12 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging = SpecDecodingLogging() kv_transfer_config = self.vllm_config.kv_transfer_config self.kv_connector_logging = KVConnectorLogging(kv_transfer_config) + self.cudagraph_logging = None + if self.vllm_config.observability_config.cudagraph_metrics: + self.cudagraph_logging = CUDAGraphLogging( + self.vllm_config.compilation_config.cudagraph_mode, + self.vllm_config.compilation_config.cudagraph_capture_sizes, + ) self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 self.engine_is_idle = False @@ -161,6 +168,11 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: self.kv_connector_logging.observe(kv_connector_stats) + if ( + self.cudagraph_logging is not None + and scheduler_stats.cudagraph_stats is not None + ): + self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats) if not self.aggregated: self.last_scheduler_stats = scheduler_stats if mm_cache_stats: @@ -240,6 +252,8 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn) + if self.cudagraph_logging is not None: + self.cudagraph_logging.log(log_fn=log_fn) def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a3078eaa75dc5..733d3ae12e67f 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import vllm.envs as envs +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.v1.spec_decode.metrics import SpecDecodingStats if TYPE_CHECKING: @@ -183,6 +184,8 @@ class SchedulerStats: waiting_lora_adapters: dict[str, int] = field(default_factory=dict) running_lora_adapters: dict[str, int] = field(default_factory=dict) + cudagraph_stats: CUDAGraphStat | None = None + @dataclass class RequestStateStats: diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 8110deb5a610b..88ac6b4aeb4bb 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple import numpy as np import torch +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: @@ -169,6 +170,9 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: dict[str, int] | None = None + # information related to cudagraph execution + cudagraph_stats: CUDAGraphStat | None = None + # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b250a8bd009c..3f043e3b2648b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,7 +27,7 @@ from vllm.attention.backends.abstract import ( ) from vllm.attention.layer import Attention, MLAAttention from vllm.compilation.counter import compilation_counter -from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import ( CompilationMode, @@ -257,6 +257,7 @@ class ExecuteModelState(NamedTuple): sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None + cudagraph_stats: CUDAGraphStat | None class GPUModelRunner( @@ -2755,7 +2756,11 @@ class GPUModelRunner( force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, ) -> tuple[ - CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None + CUDAGraphMode, + BatchDescriptor, + UBatchSlices | None, + torch.Tensor | None, + CUDAGraphStat | None, ]: num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) uniform_decode = ( @@ -2820,7 +2825,22 @@ class GPUModelRunner( # num_tokens_across_dp will no-longer be valid assert batch_descriptor.num_tokens == num_tokens_padded - return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp + cudagraph_stats = None + if self.vllm_config.observability_config.cudagraph_metrics: + cudagraph_stats = CUDAGraphStat( + num_unpadded_tokens=num_tokens, + num_padded_tokens=batch_descriptor.num_tokens, + num_paddings=batch_descriptor.num_tokens - num_tokens, + runtime_mode=str(cudagraph_mode), + ) + + return ( + cudagraph_mode, + batch_descriptor, + ubatch_slices, + num_tokens_across_dp, + cudagraph_stats, + ) @torch.inference_mode() def execute_model( @@ -2918,6 +2938,7 @@ class GPUModelRunner( batch_desc, ubatch_slices, num_tokens_across_dp, + cudagraph_stats, ) = self._determine_batch_execution_and_padding( num_tokens=num_tokens_unpadded, num_reqs=num_reqs, @@ -3067,6 +3088,7 @@ class GPUModelRunner( sample_hidden_states, aux_hidden_states, ec_connector_output, + cudagraph_stats, ) self.kv_connector_output = kv_connector_output return None @@ -3102,6 +3124,7 @@ class GPUModelRunner( sample_hidden_states, aux_hidden_states, ec_connector_output, + cudagraph_stats, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -3217,6 +3240,7 @@ class GPUModelRunner( if self.supports_mm_inputs else None, num_nans_in_logits=num_nans_in_logits, + cudagraph_stats=cudagraph_stats, ) if not self.use_async_scheduling: @@ -3937,7 +3961,7 @@ class GPUModelRunner( num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = ( + _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = ( self._determine_batch_execution_and_padding( num_tokens=num_tokens_unpadded, num_reqs=num_reqs, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index edba07a423cda..a133575cbbced 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -564,7 +564,7 @@ class Worker(WorkerBase): # TODO(lucas): This is pretty gross; ideally we should only ever call # `_determine_batch_execution_and_padding` once (will get called again # in `execute_model`) but this requires a larger refactor of PP. - _, batch_desc, _, _ = ( + _, batch_desc, _, _, _ = ( self.model_runner._determine_batch_execution_and_padding( num_tokens=num_scheduled_tokens, num_reqs=len(num_scheduled_tokens_np), From 3f42b05fbc53e50813a1619f5fc770f17ac2a1b6 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Wed, 3 Dec 2025 17:26:39 +0800 Subject: [PATCH 05/11] [Refactor] [1/N] to simplify the vLLM serving architecture (#28040) Signed-off-by: chaunceyjiang --- tests/entrypoints/openai/test_basic.py | 2 +- vllm/entrypoints/api_server.py | 1 + vllm/entrypoints/openai/api_server.py | 455 +----------------- vllm/entrypoints/openai/serving_engine.py | 3 +- vllm/entrypoints/sagemaker/routes.py | 2 +- vllm/entrypoints/serve/__init__.py | 60 +++ vllm/entrypoints/serve/disagg/__init__.py | 0 vllm/entrypoints/serve/disagg/api_router.py | 110 +++++ vllm/entrypoints/serve/disagg/protocol.py | 90 ++++ .../disagg/serving.py} | 10 +- vllm/entrypoints/serve/elastic_ep/__init__.py | 0 .../serve/elastic_ep/api_router.py | 96 ++++ .../serve/elastic_ep/middleware.py | 49 ++ .../serve/instrumentator/__init__.py | 0 .../serve/instrumentator/health.py | 33 ++ .../serve/instrumentator/metrics.py | 46 ++ vllm/entrypoints/serve/lora/__init__.py | 0 .../lora/api_router.py} | 19 +- vllm/entrypoints/serve/profile/__init__.py | 0 vllm/entrypoints/serve/profile/api_router.py | 49 ++ vllm/entrypoints/serve/rlhf/__init__.py | 0 vllm/entrypoints/serve/rlhf/api_router.py | 102 ++++ vllm/entrypoints/serve/sleep/__init__.py | 0 vllm/entrypoints/serve/sleep/api_router.py | 60 +++ vllm/entrypoints/serve/tokenize/__init__.py | 0 vllm/entrypoints/serve/tokenize/api_router.py | 118 +++++ .../tokenize/serving.py} | 0 27 files changed, 850 insertions(+), 455 deletions(-) create mode 100644 vllm/entrypoints/serve/__init__.py create mode 100644 vllm/entrypoints/serve/disagg/__init__.py create mode 100644 vllm/entrypoints/serve/disagg/api_router.py create mode 100644 vllm/entrypoints/serve/disagg/protocol.py rename vllm/entrypoints/{openai/serving_tokens.py => serve/disagg/serving.py} (99%) create mode 100644 vllm/entrypoints/serve/elastic_ep/__init__.py create mode 100644 vllm/entrypoints/serve/elastic_ep/api_router.py create mode 100644 vllm/entrypoints/serve/elastic_ep/middleware.py create mode 100644 vllm/entrypoints/serve/instrumentator/__init__.py create mode 100644 vllm/entrypoints/serve/instrumentator/health.py create mode 100644 vllm/entrypoints/serve/instrumentator/metrics.py create mode 100644 vllm/entrypoints/serve/lora/__init__.py rename vllm/entrypoints/{dynamic_lora.py => serve/lora/api_router.py} (80%) create mode 100644 vllm/entrypoints/serve/profile/__init__.py create mode 100644 vllm/entrypoints/serve/profile/api_router.py create mode 100644 vllm/entrypoints/serve/rlhf/__init__.py create mode 100644 vllm/entrypoints/serve/rlhf/api_router.py create mode 100644 vllm/entrypoints/serve/sleep/__init__.py create mode 100644 vllm/entrypoints/serve/sleep/api_router.py create mode 100644 vllm/entrypoints/serve/tokenize/__init__.py create mode 100644 vllm/entrypoints/serve/tokenize/api_router.py rename vllm/entrypoints/{openai/serving_tokenization.py => serve/tokenize/serving.py} (100%) diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 3d581a300b6a9..1ff30de31bbe5 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer): @pytest.mark.asyncio async def test_health_check_engine_dead_error(): # Import the health function directly to test it in isolation - from vllm.entrypoints.openai.api_server import health + from vllm.entrypoints.serve.instrumentator.health import health # Create a mock request that simulates what FastAPI would provide mock_request = Mock(spec=Request) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 154cdeb42a3ea..b59f7120551e0 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -118,6 +118,7 @@ async def init_app( ) ) app.state.engine_client = engine + app.state.args = args return app diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cdc316b65ba78..2fa6afa2bacb5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -20,21 +20,15 @@ from http import HTTPStatus from typing import Annotated, Any, Literal import model_hosting_container_standards.sagemaker as sagemaker_standards -import prometheus_client import pydantic -import regex as re import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import make_asgi_app -from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders, State -from starlette.routing import Mount from starlette.types import ASGIApp, Message, Receive, Scope, Send -from typing_extensions import assert_never import vllm.envs as envs from vllm.config import VllmConfig @@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionResponse, CompletionRequest, CompletionResponse, - DetokenizeRequest, - DetokenizeResponse, ErrorInfo, ErrorResponse, - GenerateRequest, - GenerateResponse, ResponsesRequest, ResponsesResponse, StreamingResponsesResponse, - TokenizeRequest, - TokenizeResponse, TranscriptionRequest, TranscriptionResponseVariant, TranslationRequest, @@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import ( OpenAIServingModels, ) from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses -from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization -from vllm.entrypoints.openai.serving_tokens import ServingTokens from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation, @@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.score.serving import ServingScores +from vllm.entrypoints.serve.disagg.serving import ServingTokens +from vllm.entrypoints.serve.elastic_ep.middleware import ( + ScalingMiddleware, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.utils import ( cli_env_setup, @@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit -from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION prometheus_multiproc_dir: tempfile.TemporaryDirectory @@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args( router = APIRouter() -class PrometheusResponse(Response): - media_type = prometheus_client.CONTENT_TYPE_LATEST - - -def mount_metrics(app: FastAPI): - """Mount prometheus metrics to a FastAPI app.""" - - registry = get_prometheus_registry() - - # `response_class=PrometheusResponse` is needed to return an HTTP response - # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" - # instead of the default "application/json" which is incorrect. - # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 - Instrumentator( - excluded_handlers=[ - "/metrics", - "/health", - "/load", - "/ping", - "/version", - "/server_info", - ], - registry=registry, - ).add().instrument(app).expose(app, response_class=PrometheusResponse) - - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - - # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile("^/metrics(?P.*)$") - app.routes.append(metrics_route) - - def base(request: Request) -> OpenAIServing: # Reuse the existing instance return tokenization(request) @@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None: return request.app.state.serving_tokens -@router.get("/health", response_class=Response) -async def health(raw_request: Request) -> Response: - """Health check.""" - try: - await engine_client(raw_request).check_health() - return Response(status_code=200) - except EngineDeadError: - return Response(status_code=503) - - @router.get("/load") async def get_server_load_metrics(request: Request): # This endpoint returns the current server load metrics. @@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.post("/pause") -async def pause_generation( - raw_request: Request, - wait_for_inflight_requests: bool = Query(False), - clear_cache: bool = Query(True), -) -> JSONResponse: - """Pause generation requests to allow weight updates. - - Args: - wait_for_inflight_requests: When ``True`` waits for in-flight - requests to finish before pausing. When ``False`` (default), - aborts any in-flight requests immediately. - clear_cache: Whether to clear KV/prefix caches after draining. - """ - - engine = engine_client(raw_request) - - try: - await engine.pause_generation( - wait_for_inflight_requests=wait_for_inflight_requests, - clear_cache=clear_cache, - ) - return JSONResponse( - content={"status": "paused"}, - status_code=HTTPStatus.OK.value, - ) - - except ValueError as err: - return JSONResponse( - content={"error": str(err)}, - status_code=HTTPStatus.BAD_REQUEST.value, - ) - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to pause generation") - return JSONResponse( - content={"error": f"Failed to pause generation: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - -@router.post("/resume") -async def resume_generation(raw_request: Request) -> JSONResponse: - """Resume generation after a pause.""" - - engine = engine_client(raw_request) - - try: - await engine.resume_generation() - return JSONResponse( - content={"status": "resumed"}, - status_code=HTTPStatus.OK.value, - ) - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to resume generation") - return JSONResponse( - content={"error": f"Failed to resume generation: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - -@router.get("/is_paused") -async def is_paused(raw_request: Request) -> JSONResponse: - """Return the current pause status.""" - - engine = engine_client(raw_request) - - try: - paused = await engine.is_paused() - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to fetch pause status") - return JSONResponse( - content={"error": f"Failed to fetch pause status: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - return JSONResponse(content={"is_paused": paused}) - - -@router.post( - "/tokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -async def tokenize(request: TokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - try: - generator = await handler.create_tokenize(request, raw_request) - except NotImplementedError as e: - raise HTTPException( - status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) - ) from e - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, TokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -@router.post( - "/detokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -async def detokenize(request: DetokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - try: - generator = await handler.create_detokenize(request, raw_request) - except OverflowError as e: - raise RequestValidationError(errors=[str(e)]) from e - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, DetokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -def maybe_register_tokenizer_info_endpoint(args): - """Conditionally register the tokenizer info endpoint if enabled.""" - if getattr(args, "enable_tokenizer_info_endpoint", False): - - @router.get("/tokenizer_info") - async def get_tokenizer_info(raw_request: Request): - """Get comprehensive tokenizer information.""" - result = await tokenization(raw_request).get_tokenizer_info() - return JSONResponse( - content=result.model_dump(), - status_code=result.error.code - if isinstance(result, ErrorResponse) - else 200, - ) - - @router.get("/v1/models") async def show_available_models(raw_request: Request): handler = models(raw_request) @@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE: await engine_client(raw_request).reset_mm_cache() return Response(status_code=200) - @router.post("/sleep") - async def sleep(raw_request: Request): - # get POST params - level = raw_request.query_params.get("level", "1") - await engine_client(raw_request).sleep(int(level)) - # FIXME: in v0 with frontend multiprocessing, the sleep command - # is sent but does not finish yet when we return a response. - return Response(status_code=200) - - @router.post("/wake_up") - async def wake_up(raw_request: Request): - tags = raw_request.query_params.getlist("tags") - if tags == []: - # set to None to wake up all tags if no tags are provided - tags = None - logger.info("wake up the engine with tags: %s", tags) - await engine_client(raw_request).wake_up(tags) - # FIXME: in v0 with frontend multiprocessing, the wake-up command - # is sent but does not finish yet when we return a response. - return Response(status_code=200) - - @router.get("/is_sleeping") - async def is_sleeping(raw_request: Request): - logger.info("check whether the engine is sleeping") - is_sleeping = await engine_client(raw_request).is_sleeping() - return JSONResponse(content={"is_sleeping": is_sleeping}) - @router.post("/collective_rpc") async def collective_rpc(raw_request: Request): try: @@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE: return Response(status_code=200) response: list[Any] = [] for result in results: - if result is None or isinstance(result, (dict, list)): + if result is None or isinstance(result, dict | list): response.append(result) else: response.append(str(result)) return JSONResponse(content={"results": response}) -@router.post( - "/scale_elastic_ep", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"model": dict}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def scale_elastic_ep(raw_request: Request): - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 - - new_data_parallel_size = body.get("new_data_parallel_size") - drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes - - if new_data_parallel_size is None: - raise HTTPException( - status_code=400, detail="new_data_parallel_size is required" - ) - - if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: - raise HTTPException( - status_code=400, detail="new_data_parallel_size must be a positive integer" - ) - - if not isinstance(drain_timeout, int) or drain_timeout <= 0: - raise HTTPException( - status_code=400, detail="drain_timeout must be a positive integer" - ) - - # Set scaling flag to prevent new requests - global _scaling_elastic_ep - _scaling_elastic_ep = True - client = engine_client(raw_request) - try: - await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) - return JSONResponse( - { - "message": f"Scaled to {new_data_parallel_size} data parallel engines", - } - ) - except TimeoutError as e: - raise HTTPException( - status_code=408, - detail="Scale failed due to request drain timeout " - f"after {drain_timeout} seconds", - ) from e - except Exception as e: - logger.error("Scale failed: %s", e) - raise HTTPException(status_code=500, detail="Scale failed") from e - finally: - _scaling_elastic_ep = False - - -@router.post("/is_scaling_elastic_ep") -async def is_scaling_elastic_ep(raw_request: Request): - return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep}) - - -@router.post( - "/inference/v1/generate", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def generate(request: GenerateRequest, raw_request: Request): - handler = generate_tokens(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support generate tokens API" - ) - try: - generator = await handler.serve_tokens(request, raw_request) - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - - elif isinstance(generator, GenerateResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -if envs.VLLM_TORCH_PROFILER_DIR: - logger.warning_once( - "Torch Profiler is enabled in the API server. This should ONLY be " - "used for local development!" - ) -elif envs.VLLM_TORCH_CUDA_PROFILE: - logger.warning_once( - "CUDA Profiler is enabled in the API server. This should ONLY be " - "used for local development!" - ) -if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE: - - @router.post("/start_profile") - async def start_profile(raw_request: Request): - logger.info("Starting profiler...") - await engine_client(raw_request).start_profile() - logger.info("Profiler started.") - return Response(status_code=200) - - @router.post("/stop_profile") - async def stop_profile(raw_request: Request): - logger.info("Stopping profiler...") - await engine_client(raw_request).stop_profile() - logger.info("Profiler stopped.") - return Response(status_code=200) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -1176,41 +809,6 @@ class XRequestIdMiddleware: return self.app(scope, receive, send_with_request_id) -# Global variable to track scaling state -_scaling_elastic_ep = False - - -class ScalingMiddleware: - """ - Middleware that checks if the model is currently scaling and - returns a 503 Service Unavailable response if it is. - - This middleware applies to all HTTP requests and prevents - processing when the model is in a scaling state. - """ - - def __init__(self, app: ASGIApp) -> None: - self.app = app - - def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: - if scope["type"] != "http": - return self.app(scope, receive, send) - - # Check global scaling state - global _scaling_elastic_ep - if _scaling_elastic_ep: - # Return 503 Service Unavailable response - response = JSONResponse( - content={ - "error": "The model is currently scaling. Please try again later." - }, - status_code=503, - ) - return response(scope, receive, send) - - return self.app(scope, receive, send) - - def _extract_content_from_chunk(chunk_data: dict) -> str: """Extract content from a streaming response chunk.""" try: @@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI: ) else: app = FastAPI(lifespan=lifespan) + app.state.args = args + from vllm.entrypoints.serve import register_vllm_serve_api_routers - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!" - ) - from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes - - register_dynamic_lora_routes(router) + register_vllm_serve_api_routers(app) from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes @@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI: app.root_path = args.root_path - mount_metrics(app) - from vllm.entrypoints.pooling import register_pooling_api_routers register_pooling_api_routers(app) @@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI: ) app = sagemaker_standards.bootstrap(app) - # Optional endpoints - if args.tokens_only: - - @app.post("/abort_requests") - async def abort_requests(raw_request: Request): - """ - Abort one or more requests. To be used in a - Disaggregated Everything setup. - """ - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e - request_ids = body.get("request_ids") - if request_ids is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Missing 'request_ids' in request body", - ) - # Abort requests in background - asyncio.create_task(engine_client(raw_request).abort(request_ids)) - return Response(status_code=200) return app @@ -1515,7 +1081,7 @@ async def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.vllm_config = vllm_config - + state.args = args supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported tasks: %s", supported_tasks) @@ -1839,7 +1405,6 @@ async def run_server_worker( args, client_config=client_config, ) as engine_client: - maybe_register_tokenizer_info_endpoint(args) app = build_app(args) await init_app_state(engine_client, app.state, args) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1d89aa011af21..67291f45a9251 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import ( ErrorResponse, FunctionCall, FunctionDefinition, - GenerateRequest, - GenerateResponse, ResponsesRequest, TokenizeChatRequest, TokenizeCompletionRequest, @@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig +from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py index 108fdd773e321..ea88c0fc4b979 100644 --- a/vllm/entrypoints/sagemaker/routes.py +++ b/vllm/entrypoints/sagemaker/routes.py @@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import ( completion, create_chat_completion, create_completion, - health, validate_json_request, ) from vllm.entrypoints.openai.protocol import ( @@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import ( score, ) from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest +from vllm.entrypoints.serve.instrumentator.health import health # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) diff --git a/vllm/entrypoints/serve/__init__.py b/vllm/entrypoints/serve/__init__.py new file mode 100644 index 0000000000000..c4fcc92db931f --- /dev/null +++ b/vllm/entrypoints/serve/__init__.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import FastAPI + + +def register_vllm_serve_api_routers(app: FastAPI): + from vllm.entrypoints.serve.lora.api_router import ( + attach_router as attach_lora_router, + ) + + attach_lora_router(app) + from vllm.entrypoints.serve.elastic_ep.api_router import ( + attach_router as attach_elastic_ep_router, + ) + + attach_elastic_ep_router(app) + + from vllm.entrypoints.serve.profile.api_router import ( + attach_router as attach_profile_router, + ) + + attach_profile_router(app) + + from vllm.entrypoints.serve.sleep.api_router import ( + attach_router as attach_sleep_router, + ) + + attach_sleep_router(app) + + from vllm.entrypoints.serve.tokenize.api_router import ( + attach_router as attach_tokenize_router, + ) + + attach_tokenize_router(app) + + from vllm.entrypoints.serve.disagg.api_router import ( + attach_router as attach_disagg_router, + ) + + attach_disagg_router(app) + + from vllm.entrypoints.serve.rlhf.api_router import ( + attach_router as attach_rlhf_router, + ) + + attach_rlhf_router(app) + + from vllm.entrypoints.serve.instrumentator.metrics import ( + attach_router as attach_metrics_router, + ) + + attach_metrics_router(app) + + from vllm.entrypoints.serve.instrumentator.health import ( + attach_router as attach_health_router, + ) + + attach_health_router(app) diff --git a/vllm/entrypoints/serve/disagg/__init__.py b/vllm/entrypoints/serve/disagg/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/disagg/api_router.py b/vllm/entrypoints/serve/disagg/api_router.py new file mode 100644 index 0000000000000..c38ede30dad1c --- /dev/null +++ b/vllm/entrypoints/serve/disagg/api_router.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import asyncio +import json +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, +) +from vllm.entrypoints.serve.disagg.protocol import ( + GenerateRequest, + GenerateResponse, +) +from vllm.entrypoints.serve.disagg.serving import ( + ServingTokens, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.entrypoints.utils import ( + load_aware_call, + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def generate_tokens(request: Request) -> ServingTokens | None: + return request.app.state.serving_tokens + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post( + "/inference/v1/generate", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def generate(request: GenerateRequest, raw_request: Request): + handler = generate_tokens(raw_request) + if handler is None: + return tokenization(raw_request).create_error_response( + message="The model does not support generate tokens API" + ) + try: + generator = await handler.serve_tokens(request, raw_request) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + + elif isinstance(generator, GenerateResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +def attach_router(app: FastAPI): + if getattr(app.state.args, "tokens_only", False): + + @router.post("/abort_requests") + async def abort_requests(raw_request: Request): + """ + Abort one or more requests. To be used in a + Disaggregated Everything setup. + """ + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + request_ids = body.get("request_ids") + if request_ids is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'request_ids' in request body", + ) + # Abort requests in background + asyncio.create_task(engine_client(raw_request).abort(request_ids)) + return Response(status_code=200) + + app.include_router(router) diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py new file mode 100644 index 0000000000000..251fcf12ed7dd --- /dev/null +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from pydantic import BaseModel, Field + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProbs, + Logprob, + SamplingParams, + StreamOptions, +) +from vllm.utils import random_uuid + + +####### Tokens IN <> Tokens OUT ####### +class GenerateRequest(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + token_ids: list[int] + """The token ids to generate text from.""" + + # features: MultiModalFeatureSpec + # TODO (NickLucche): implement once Renderer work is completed + features: str | None = None + """The processed MM inputs for the model.""" + + sampling_params: SamplingParams + """The sampling parameters for the model.""" + + model: str | None = None + + stream: bool | None = False + stream_options: StreamOptions | None = None + cache_salt: str | None = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit)." + ), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) + + +class GenerateResponseChoice(BaseModel): + index: int + logprobs: ChatCompletionLogProbs | None = None + # per OpenAI spec this is the default + finish_reason: str | None = "stop" + token_ids: list[int] | None = None + + +class GenerateResponse(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + choices: list[GenerateResponseChoice] + + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) diff --git a/vllm/entrypoints/openai/serving_tokens.py b/vllm/entrypoints/serve/disagg/serving.py similarity index 99% rename from vllm/entrypoints/openai/serving_tokens.py rename to vllm/entrypoints/serve/disagg/serving.py index daa739e41fa07..5c1d17156a90d 100644 --- a/vllm/entrypoints/openai/serving_tokens.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + import asyncio import time from collections.abc import AsyncGenerator @@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProbs, ChatCompletionLogProbsContent, ErrorResponse, - GenerateRequest, - GenerateResponse, - GenerateResponseChoice, PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo, ) from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.serve.disagg.protocol import ( + GenerateRequest, + GenerateResponse, + GenerateResponseChoice, +) from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob diff --git a/vllm/entrypoints/serve/elastic_ep/__init__.py b/vllm/entrypoints/serve/elastic_ep/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/elastic_ep/api_router.py b/vllm/entrypoints/serve/elastic_ep/api_router.py new file mode 100644 index 0000000000000..21d5d2e60778a --- /dev/null +++ b/vllm/entrypoints/serve/elastic_ep/api_router.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import json +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, +) +from vllm.entrypoints.serve.elastic_ep.middleware import ( + get_scaling_elastic_ep, + set_scaling_elastic_ep, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post( + "/scale_elastic_ep", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def scale_elastic_ep(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + + new_data_parallel_size = body.get("new_data_parallel_size") + drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes + + if new_data_parallel_size is None: + raise HTTPException( + status_code=400, detail="new_data_parallel_size is required" + ) + + if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: + raise HTTPException( + status_code=400, + detail="new_data_parallel_size must be a positive integer", + ) + + if not isinstance(drain_timeout, int) or drain_timeout <= 0: + raise HTTPException( + status_code=400, detail="drain_timeout must be a positive integer" + ) + + # Set scaling flag to prevent new requests + set_scaling_elastic_ep(True) + client = engine_client(raw_request) + try: + await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) + return JSONResponse( + { + "message": f"Scaled to {new_data_parallel_size} data parallel engines", + } + ) + except TimeoutError as e: + raise HTTPException( + status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds", + ) from e + except Exception as e: + logger.error("Scale failed: %s", e) + raise HTTPException(status_code=500, detail="Scale failed") from e + finally: + set_scaling_elastic_ep(False) + + +@router.post("/is_scaling_elastic_ep") +async def is_scaling_elastic_ep(raw_request: Request): + return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()}) + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/serve/elastic_ep/middleware.py b/vllm/entrypoints/serve/elastic_ep/middleware.py new file mode 100644 index 0000000000000..23f45eafeaa0d --- /dev/null +++ b/vllm/entrypoints/serve/elastic_ep/middleware.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Awaitable + +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +# Global variable to track scaling state +_scaling_elastic_ep = False + + +def get_scaling_elastic_ep(): + return _scaling_elastic_ep + + +def set_scaling_elastic_ep(value): + global _scaling_elastic_ep + _scaling_elastic_ep = value + + +class ScalingMiddleware: + """ + Middleware that checks if the model is currently scaling and + returns a 503 Service Unavailable response if it is. + + This middleware applies to all HTTP requests and prevents + processing when the model is in a scaling state. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope["type"] != "http": + return self.app(scope, receive, send) + + # Check global scaling state + if get_scaling_elastic_ep(): + # Return 503 Service Unavailable response + response = JSONResponse( + content={ + "error": "The model is currently scaling. Please try again later." + }, + status_code=503, + ) + return response(scope, receive, send) + + return self.app(scope, receive, send) diff --git a/vllm/entrypoints/serve/instrumentator/__init__.py b/vllm/entrypoints/serve/instrumentator/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/instrumentator/health.py b/vllm/entrypoints/serve/instrumentator/health.py new file mode 100644 index 0000000000000..029ef677aaa25 --- /dev/null +++ b/vllm/entrypoints/serve/instrumentator/health.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, Request +from fastapi.responses import Response + +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger +from vllm.v1.engine.exceptions import EngineDeadError + +logger = init_logger(__name__) + + +router = APIRouter() + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health", response_class=Response) +async def health(raw_request: Request) -> Response: + """Health check.""" + try: + await engine_client(raw_request).check_health() + return Response(status_code=200) + except EngineDeadError: + return Response(status_code=503) + + +def attach_router(app): + app.include_router(router) diff --git a/vllm/entrypoints/serve/instrumentator/metrics.py b/vllm/entrypoints/serve/instrumentator/metrics.py new file mode 100644 index 0000000000000..efe0c63a90714 --- /dev/null +++ b/vllm/entrypoints/serve/instrumentator/metrics.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import re + +import prometheus_client +from fastapi import FastAPI, Response +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator +from starlette.routing import Mount + +from vllm.v1.metrics.prometheus import get_prometheus_registry + + +class PrometheusResponse(Response): + media_type = prometheus_client.CONTENT_TYPE_LATEST + + +def attach_router(app: FastAPI): + """Mount prometheus metrics to a FastAPI app.""" + + registry = get_prometheus_registry() + + # `response_class=PrometheusResponse` is needed to return an HTTP response + # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" + # instead of the default "application/json" which is incorrect. + # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 + Instrumentator( + excluded_handlers=[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + "/server_info", + ], + registry=registry, + ).add().instrument(app).expose(app, response_class=PrometheusResponse) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) diff --git a/vllm/entrypoints/serve/lora/__init__.py b/vllm/entrypoints/serve/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/dynamic_lora.py b/vllm/entrypoints/serve/lora/api_router.py similarity index 80% rename from vllm/entrypoints/dynamic_lora.py rename to vllm/entrypoints/serve/lora/api_router.py index cc0f437e5c77f..6a57e73f334f2 100644 --- a/vllm/entrypoints/dynamic_lora.py +++ b/vllm/entrypoints/serve/lora/api_router.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + import model_hosting_container_standards.sagemaker as sagemaker_standards -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, FastAPI, Request from fastapi.responses import JSONResponse, Response +from vllm import envs from vllm.entrypoints.openai.api_server import models, validate_json_request from vllm.entrypoints.openai.protocol import ( ErrorResponse, @@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger logger = init_logger(__name__) +router = APIRouter() -def register_dynamic_lora_routes(router: APIRouter): +def attach_router(app: FastAPI): + if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + """If LoRA dynamic loading & unloading is not enabled, do nothing.""" + return + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!" + ) + @sagemaker_standards.register_load_adapter_handler( request_shape={ "lora_name": "body.name", @@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter): return Response(status_code=200, content=response) - return router + # register the router + app.include_router(router) diff --git a/vllm/entrypoints/serve/profile/__init__.py b/vllm/entrypoints/serve/profile/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/profile/api_router.py b/vllm/entrypoints/serve/profile/api_router.py new file mode 100644 index 0000000000000..166f13764eb36 --- /dev/null +++ b/vllm/entrypoints/serve/profile/api_router.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import Response + +import vllm.envs as envs +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + +router = APIRouter() + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.post("/start_profile") +async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + +@router.post("/stop_profile") +async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +def attach_router(app: FastAPI): + if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning_once( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!" + ) + elif envs.VLLM_TORCH_CUDA_PROFILE: + logger.warning_once( + "CUDA Profiler is enabled in the API server. This should ONLY be " + "used for local development!" + ) + if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE: + app.include_router(router) diff --git a/vllm/entrypoints/serve/rlhf/__init__.py b/vllm/entrypoints/serve/rlhf/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py new file mode 100644 index 0000000000000..3b37840ae0899 --- /dev/null +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, FastAPI, Query, Request +from fastapi.responses import JSONResponse + +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post("/pause") +async def pause_generation( + raw_request: Request, + wait_for_inflight_requests: bool = Query(False), + clear_cache: bool = Query(True), +) -> JSONResponse: + """Pause generation requests to allow weight updates. + + Args: + wait_for_inflight_requests: When ``True`` waits for in-flight + requests to finish before pausing. When ``False`` (default), + aborts any in-flight requests immediately. + clear_cache: Whether to clear KV/prefix caches after draining. + """ + + engine = engine_client(raw_request) + + try: + await engine.pause_generation( + wait_for_inflight_requests=wait_for_inflight_requests, + clear_cache=clear_cache, + ) + return JSONResponse( + content={"status": "paused"}, + status_code=HTTPStatus.OK.value, + ) + + except ValueError as err: + return JSONResponse( + content={"error": str(err)}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to pause generation") + return JSONResponse( + content={"error": f"Failed to pause generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.post("/resume") +async def resume_generation(raw_request: Request) -> JSONResponse: + """Resume generation after a pause.""" + + engine = engine_client(raw_request) + + try: + await engine.resume_generation() + return JSONResponse( + content={"status": "resumed"}, + status_code=HTTPStatus.OK.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to resume generation") + return JSONResponse( + content={"error": f"Failed to resume generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.get("/is_paused") +async def is_paused(raw_request: Request) -> JSONResponse: + """Return the current pause status.""" + + engine = engine_client(raw_request) + + try: + paused = await engine.is_paused() + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to fetch pause status") + return JSONResponse( + content={"error": f"Failed to fetch pause status: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + return JSONResponse(content={"is_paused": paused}) + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/serve/sleep/__init__.py b/vllm/entrypoints/serve/sleep/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/sleep/api_router.py b/vllm/entrypoints/serve/sleep/api_router.py new file mode 100644 index 0000000000000..bc01e185315c8 --- /dev/null +++ b/vllm/entrypoints/serve/sleep/api_router.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse, Response + +import vllm.envs as envs +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post("/sleep") +async def sleep(raw_request: Request): + # get POST params + level = raw_request.query_params.get("level", "1") + await engine_client(raw_request).sleep(int(level)) + # FIXME: in v0 with frontend multiprocessing, the sleep command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + +@router.post("/wake_up") +async def wake_up(raw_request: Request): + tags = raw_request.query_params.getlist("tags") + if tags == []: + # set to None to wake up all tags if no tags are provided + tags = None + logger.info("wake up the engine with tags: %s", tags) + await engine_client(raw_request).wake_up(tags) + # FIXME: in v0 with frontend multiprocessing, the wake-up command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + +@router.get("/is_sleeping") +async def is_sleeping(raw_request: Request): + logger.info("check whether the engine is sleeping") + is_sleeping = await engine_client(raw_request).is_sleeping() + return JSONResponse(content={"is_sleeping": is_sleeping}) + + +def attach_router(app: FastAPI): + if not envs.VLLM_SERVER_DEV_MODE: + return + logger.warning( + "SECURITY WARNING: Development endpoints are enabled! " + "This should NOT be used in production!" + ) + + app.include_router(router) diff --git a/vllm/entrypoints/serve/tokenize/__init__.py b/vllm/entrypoints/serve/tokenize/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/tokenize/api_router.py b/vllm/entrypoints/serve/tokenize/api_router.py new file mode 100644 index 0000000000000..a10e78c8d28ee --- /dev/null +++ b/vllm/entrypoints/serve/tokenize/api_router.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from typing_extensions import assert_never + +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeRequest, + TokenizeResponse, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.entrypoints.utils import ( + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +router = APIRouter() + + +@router.post( + "/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +async def tokenize(request: TokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_tokenize(request, raw_request) + except NotImplementedError as e: + raise HTTPException( + status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) + ) from e + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post( + "/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +async def detokenize(request: DetokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_detokenize(request, raw_request) + except OverflowError as e: + raise RequestValidationError(errors=[str(e)]) from e + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +def attach_router(app: FastAPI): + if getattr(app.state.args, "enable_tokenizer_info_endpoint", False): + """Conditionally register the tokenizer info endpoint if enabled.""" + + @router.get("/tokenizer_info") + async def get_tokenizer_info(raw_request: Request): + """Get comprehensive tokenizer information.""" + result = await tokenization(raw_request).get_tokenizer_info() + return JSONResponse( + content=result.model_dump(), + status_code=result.error.code + if isinstance(result, ErrorResponse) + else 200, + ) + + app.include_router(router) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/serve/tokenize/serving.py similarity index 100% rename from vllm/entrypoints/openai/serving_tokenization.py rename to vllm/entrypoints/serve/tokenize/serving.py From 7fe9c1a2232275ee4cc7d65af3bc5b648543f367 Mon Sep 17 00:00:00 2001 From: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:51:08 +0800 Subject: [PATCH 06/11] [CI] Add Async Eplb nightly CI tests (#29385) Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: Cyrus Leung --- .../deepseek_v2_lite_ep_async_eplb.sh | 73 ++++++++++++++++++ .../deepseek_v2_lite_ep_eplb.sh | 1 + .../qwen3_next_mtp_async_eplb.sh | 74 +++++++++++++++++++ .buildkite/test-pipeline.yaml | 20 ++++- vllm/distributed/eplb/rebalance_execute.py | 3 - 5 files changed, 167 insertions(+), 4 deletions(-) create mode 100644 .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh create mode 100644 .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh diff --git a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh new file mode 100644 index 0000000000000..d7167161b0059 --- /dev/null +++ b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +set -euxo pipefail + +# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] +THRESHOLD=${1:-0.25} +NUM_Q=${2:-1319} +PORT=${3:-8030} +OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled} +mkdir -p "${OUT_DIR}" + +wait_for_server() { + local port=$1 + timeout 600 bash -c ' + until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do + sleep 1 + done' +} + +MODEL="deepseek-ai/DeepSeek-V2-lite" + +# Set BACKENDS based on platform +if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then + # ROCm platform + BACKENDS=("allgather_reducescatter") + # Disable MOE padding for ROCm since it is causing eplb to fail + export VLLM_ROCM_MOE_PADDING=0 +else + # Non-ROCm platform (CUDA/other) + BACKENDS=("deepep_high_throughput" "deepep_low_latency") +fi + +cleanup() { + if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then + kill "${SERVER_PID}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${SERVER_PID}" 2>/dev/null || break + sleep 0.5 + done + kill -9 "${SERVER_PID}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +for BACK in "${BACKENDS[@]}"; do + VLLM_DEEP_GEMM_WARMUP=skip \ + VLLM_ALL2ALL_BACKEND=$BACK \ + vllm serve "$MODEL" \ + --enforce-eager \ + --tensor-parallel-size 2 \ + --data-parallel-size 2 \ + --enable-expert-parallel \ + --enable-eplb \ + --eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \ + --trust-remote-code \ + --max-model-len 2048 \ + --port $PORT & + SERVER_PID=$! + wait_for_server $PORT + + TAG=$(echo "$MODEL" | tr '/: \\n' '_____') + OUT="${OUT_DIR}/${TAG}_${BACK}_async_eplb.json" + python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT} + python3 - <= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}" +PY + + cleanup + SERVER_PID= + sleep 1 + PORT=$((PORT+1)) +done diff --git a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh index 8106f50f18f66..693418da6093e 100644 --- a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh +++ b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh @@ -50,6 +50,7 @@ for BACK in "${BACKENDS[@]}"; do --data-parallel-size 2 \ --enable-expert-parallel \ --enable-eplb \ + --eplb-config '{"window_size":200,"step_interval":600}' \ --trust-remote-code \ --max-model-len 2048 \ --port $PORT & diff --git a/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh new file mode 100644 index 0000000000000..937a43d1a3221 --- /dev/null +++ b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +set -euxo pipefail + +# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] +THRESHOLD=${1:-0.25} +NUM_Q=${2:-1319} +PORT=${3:-8040} +OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled} +mkdir -p "${OUT_DIR}" + +wait_for_server() { + local port=$1 + timeout 600 bash -c ' + until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do + sleep 1 + done' +} + +MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct" + +# Set BACKENDS based on platform +if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then + # ROCm platform + BACKENDS=("allgather_reducescatter") + # Disable MOE padding for ROCm since it is causing eplb to fail + export VLLM_ROCM_MOE_PADDING=0 +else + # Non-ROCm platform (CUDA/other) + BACKENDS=("deepep_high_throughput" "deepep_low_latency") +fi + +cleanup() { + if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then + kill "${SERVER_PID}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${SERVER_PID}" 2>/dev/null || break + sleep 0.5 + done + kill -9 "${SERVER_PID}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +for BACK in "${BACKENDS[@]}"; do + VLLM_DEEP_GEMM_WARMUP=skip \ + VLLM_ALL2ALL_BACKEND=$BACK \ + vllm serve "$MODEL" \ + --enforce-eager \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --enable-eplb \ + --eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \ + --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \ + --trust-remote-code \ + --max-model-len 2048 \ + --gpu-memory-utilization 0.9 \ + --port $PORT & + SERVER_PID=$! + wait_for_server $PORT + + TAG=$(echo "$MODEL" | tr '/: \\n' '_____') + OUT="${OUT_DIR}/${TAG}_${BACK}.json" + python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT} + python3 - <= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}" +PY + + cleanup + SERVER_PID= + sleep 1 + PORT=$((PORT+1)) +done diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 52c848c784e53..f79e9266559f6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1373,4 +1373,22 @@ steps: num_gpus: 2 working_dir: "/vllm-workspace" commands: - - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 \ No newline at end of file + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 + +- label: DeepSeek V2-Lite Async EPLB Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030 + +- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 376dad8a72ef1..55856d940f001 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -322,9 +322,6 @@ async def transfer_layer( num_local_physical_experts = next(iter(expert_weights[0])).shape[0] assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert num_physical_experts == ep_size * num_local_physical_experts - # A buffer to hold the expert weights in one layer during the exchange. - # NOTE: Currently we assume the same weights across different layers - # have the same shape. is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( num_local_experts=num_local_physical_experts, From a21cd9ed239b853bd587ffe3c9140fe68cd41f59 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 3 Dec 2025 18:05:10 +0800 Subject: [PATCH 07/11] [Bugfix] Fix incorrect `image_grid_thw` rank for HunyuanOCR from missing `merge_by_field_config=True` (#29950) Signed-off-by: Isotr0py --- .../vision_language_multi_image.py | 23 +++++++++++++++++++ vllm/model_executor/models/hunyuan_vision.py | 1 + 2 files changed, 24 insertions(+) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 2193b1ca9cf48..560ca768d1a6c 100755 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -309,6 +309,28 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: ) +# HunyuanOCR +def load_hunyuan_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "tencent/HunyuanOCR" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholder = ( + "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501 + ) * len(image_urls) + prompt = f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>" + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_hyperclovax_seed_vision( question: str, image_urls: list[str] ) -> ModelRequestData: @@ -1322,6 +1344,7 @@ model_example_map = { "deepseek_ocr": load_deepseek_ocr, "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, + "hunyuan_vl": load_hunyuan_vl, "hyperclovax_seed_vision": load_hyperclovax_seed_vision, "idefics3": load_idefics3, "interns1": load_interns1, diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index 2950db571e6ee..6537b6df876a9 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -785,6 +785,7 @@ class HunYuanVLForConditionalGeneration( SupportsQuant, SupportsXDRoPE, ): + merge_by_field_config = True multimodal_cpu_fields = {"image_grid_thw"} # To ensure correct weight loading and mapping. From cc4e296ea62226632de5285621fd0cd287621ddc Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 3 Dec 2025 18:27:36 +0800 Subject: [PATCH 08/11] [CI/Build] Avoid duplicate empty inputs test for common multimodal generation tests (#29907) Signed-off-by: Isotr0py --- .../multimodal/generation/test_common.py | 14 +-- .../generation/vlm_utils/case_filtering.py | 114 +++++++++--------- .../multimodal/generation/vlm_utils/types.py | 4 +- 3 files changed, 69 insertions(+), 63 deletions(-) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index deaeea059ccaf..0eaf7198f91b7 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -137,7 +137,7 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "qwen2_5_omni": VLMTestInfo( @@ -152,7 +152,7 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForTextToWaveform, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "qwen3_vl": VLMTestInfo( @@ -173,7 +173,7 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[ pytest.mark.core_model, ], @@ -350,7 +350,7 @@ VLM_TEST_SETTINGS = { patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner, hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output, stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], - image_size_factors=[(), (1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], + image_size_factors=[(1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], ), "fuyu": VLMTestInfo( models=["adept/fuyu-8b"], @@ -707,7 +707,7 @@ VLM_TEST_SETTINGS = { max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForCausalLM, - image_size_factors=[(), (0.25,)], + image_size_factors=[(0.25,)], marks=[ pytest.mark.skipif( Version(TRANSFORMERS_VERSION) == Version("4.57.3"), @@ -760,7 +760,7 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.cpu_model], ), "skywork_r1v": VLMTestInfo( @@ -812,7 +812,7 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.skip("Model initialization hangs")], ), ### Tensor parallel / multi-gpu broadcast tests diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index d42150bcbf672..116eead7a70ad 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -62,6 +62,65 @@ def get_filtered_test_settings( return matching_tests +def get_model_type_cases( + model_type: str, + test_info: VLMTestInfo, + test_type: VLMTestType, +): + # Ensure that something is wrapped as an iterable it's not already + ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,) + + # This is essentially the same as nesting a bunch of mark.parametrize + # decorators, but we do it programmatically to allow overrides for on + # a per-model basis, while still being able to execute each of these + # as individual test cases in pytest. + iter_kwargs = OrderedDict( + [ + ("model", ensure_wrapped(test_info.models)), + ("max_tokens", ensure_wrapped(test_info.max_tokens)), + ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), + ("dtype", ensure_wrapped(test_info.dtype)), + ( + "distributed_executor_backend", + ensure_wrapped(test_info.distributed_executor_backend), + ), + ] + ) + + # num_frames is video only + if test_type == VLMTestType.VIDEO: + iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) + iter_kwargs["needs_video_metadata"] = ensure_wrapped( + test_info.needs_video_metadata + ) + + # No sizes passed for custom inputs, since inputs are directly provided + if test_type not in ( + VLMTestType.CUSTOM_INPUTS, + VLMTestType.AUDIO, + ): + wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) + if wrapped_sizes is None: + raise ValueError(f"Sizes must be set for test type {test_type}") + iter_kwargs["size_wrapper"] = wrapped_sizes + + # Otherwise expand the custom test options instead + elif test_type == VLMTestType.CUSTOM_INPUTS: + if test_info.custom_test_opts is None: + raise ValueError("Test has type CUSTOM_INPUTS, but none given") + iter_kwargs["custom_test_opts"] = test_info.custom_test_opts + + # Wrap all model cases in a pytest parameter & pass marks through + return [ + pytest.param( + model_type, + ExpandableVLMTestArgs(**{k: v for k, v in zip(iter_kwargs.keys(), case)}), + marks=test_info.marks if test_info.marks is not None else [], + ) + for case in list(itertools.product(*iter_kwargs.values())) + ] + + def get_parametrized_options( test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, @@ -76,64 +135,11 @@ def get_parametrized_options( test_settings, test_type, create_new_process_for_each_test ) - # Ensure that something is wrapped as an iterable it's not already - ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,) - - def get_model_type_cases(model_type: str, test_info: VLMTestInfo): - # This is essentially the same as nesting a bunch of mark.parametrize - # decorators, but we do it programmatically to allow overrides for on - # a per-model basis, while still being able to execute each of these - # as individual test cases in pytest. - iter_kwargs = OrderedDict( - [ - ("model", ensure_wrapped(test_info.models)), - ("max_tokens", ensure_wrapped(test_info.max_tokens)), - ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), - ("dtype", ensure_wrapped(test_info.dtype)), - ( - "distributed_executor_backend", - ensure_wrapped(test_info.distributed_executor_backend), - ), - ] - ) - - # num_frames is video only - if test_type == VLMTestType.VIDEO: - iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) - iter_kwargs["needs_video_metadata"] = ensure_wrapped( - test_info.needs_video_metadata - ) - - # No sizes passed for custom inputs, since inputs are directly provided - if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): - wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) - if wrapped_sizes is None: - raise ValueError(f"Sizes must be set for test type {test_type}") - iter_kwargs["size_wrapper"] = wrapped_sizes - - # Otherwise expand the custom test options instead - elif test_type == VLMTestType.CUSTOM_INPUTS: - if test_info.custom_test_opts is None: - raise ValueError("Test has type CUSTOM_INPUTS, but none given") - iter_kwargs["custom_test_opts"] = test_info.custom_test_opts - - # Wrap all model cases in a pytest parameter & pass marks through - return [ - pytest.param( - model_type, - ExpandableVLMTestArgs( - **{k: v for k, v in zip(iter_kwargs.keys(), case)} - ), - marks=test_info.marks if test_info.marks is not None else [], - ) - for case in list(itertools.product(*iter_kwargs.values())) - ] - # Get a list per model type, where each entry contains a tuple of all of # that model type's cases, then flatten them into the top level so that # we can consume them in one mark.parametrize call. cases_by_model_type = [ - get_model_type_cases(model_type, test_info) + get_model_type_cases(model_type, test_info, test_type) for model_type, test_info in matching_tests.items() ] return list(itertools.chain(*cases_by_model_type)) diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 0c03c84497125..ae2f754813590 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -50,8 +50,8 @@ MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PL VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?" -IMAGE_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] -EMBEDDING_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0)] +IMAGE_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] +EMBEDDING_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0)] RunnerOutput = tuple[list[int], str, SampleLogprobs | None] From 42c194964341bea9fc59e0d35db04dfafc3c473d Mon Sep 17 00:00:00 2001 From: Tsukasa OI Date: Wed, 3 Dec 2025 19:33:46 +0900 Subject: [PATCH 09/11] [Bugfix][Quantization] Support BF16 tensors on GGUF (#29948) Signed-off-by: Tsukasa OI --- tests/models/quantization/test_gguf.py | 7 +++++++ vllm/model_executor/model_loader/weight_utils.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index 3b9597507ac1b..064ca94f3cbac 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -47,6 +47,12 @@ QWEN2_CONFIG = GGUFTestConfig( gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf", ) +QWEN3_CONFIG = GGUFTestConfig( + original_model="Qwen/Qwen3-0.6B", + gguf_repo="unsloth/Qwen3-0.6B-GGUF", + gguf_filename="Qwen3-0.6B-BF16.gguf", +) + PHI3_CONFIG = GGUFTestConfig( original_model="microsoft/Phi-3.5-mini-instruct", gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF", @@ -87,6 +93,7 @@ GEMMA3_CONFIG = GGUFTestConfig( MODELS = [ # LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458 QWEN2_CONFIG, + QWEN3_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0809bdfa9d4c2..0496b7a84507b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -921,7 +921,17 @@ def gguf_quant_weights_iterator( name = gguf_to_hf_name_map[tensor.name] if weight_type.name not in ("F32", "BF16", "F16"): name = name.replace("weight", "qweight") - param = torch.tensor(weight) + if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: + # BF16 is currently the only "quantization" type that isn't + # actually quantized but is read as a raw byte tensor. + # Reinterpret as `torch.bfloat16` tensor. + weight = weight.view(np.uint16) + if reader.byte_order == "S": + # GGUF endianness != system endianness + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) yield name, param From 787b84a9fc9d1744f82addf40912e9fb84c0b4c5 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 3 Dec 2025 02:42:49 -0800 Subject: [PATCH 10/11] [Bugfix] Follow-up fix on MediaWithBytes (#29951) Signed-off-by: Roger Wang --- vllm/multimodal/base.py | 2 ++ vllm/multimodal/inputs.py | 3 ++- vllm/multimodal/parse.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 4a619fd303ca9..53eb4c591ef99 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -21,6 +21,8 @@ class MediaWithBytes(Generic[_T]): The wrapper delegates attribute access to the underlying media object, making it behave transparently like the wrapped type (e.g., PIL.Image). + + NOTE: Currently, this wrapper is used only for the image modality. """ media: _T diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index f4e38b1f3325f..397684fa2f83c 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from PIL.Image import Image from transformers.feature_extraction_utils import BatchFeature + from .base import MediaWithBytes from .processing import MultiModalHashes else: @@ -59,7 +60,7 @@ Represents a single audio item, which can be passed to a HuggingFace `AudioProcessor`. """ -ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"] +ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"] """ A `transformers.image_utils.ImageInput` representing a single image item, which can be passed to a HuggingFace `ImageProcessor`. diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 650368dcb8fcd..c3c7cc2c3da0e 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -484,7 +484,7 @@ class MultiModalDataParser: return ImageEmbeddingItems(data) if ( - isinstance(data, PILImage.Image) + isinstance(data, (PILImage.Image, MediaWithBytes)) or isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3 ): From b294e28db2c5dee61bc25157664edcada8b90b31 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Wed, 3 Dec 2025 06:00:56 -0500 Subject: [PATCH 11/11] [refactor] CTMoEMethods to use QuantizationArgs (#28871) Signed-off-by: HDCharles Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- .../compressed_tensors/compressed_tensors.py | 6 +- .../compressed_tensors_moe.py | 155 +++++++++--------- 2 files changed, 86 insertions(+), 75 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 02086c3c0052d..b91ecb59fee18 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -767,8 +767,10 @@ class CompressedTensorsConfig(QuantizationConfig): targets=self.target_scheme_map.keys(), fused_mapping=self.packed_modules_mapping, ) - - return self.target_scheme_map[matched_target] + scheme_dict = self.target_scheme_map[matched_target] + if scheme_dict.get("format") is None: + scheme_dict["format"] = self.quant_format + return scheme_dict return None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 80ee443d4dd6a..c7368bf427fe1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -7,7 +7,11 @@ from enum import Enum import torch from compressed_tensors import CompressionFormat -from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationStrategy, +) from torch.nn.parameter import Parameter import vllm.envs as envs @@ -142,10 +146,26 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): # are supported + check if the layer is being ignored. weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") + format = scheme_dict.get("format") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): # group_size=None means channelwise group_size = weight_quant.group_size or -1 + + valid_format_and_bits = ( + weight_quant.num_bits in WNA16_SUPPORTED_BITS + and format == CompressionFormat.pack_quantized.value + ) + + if not valid_format_and_bits: + raise ValueError( + "For Fused MoE layers, only format: ", + f"{CompressionFormat.pack_quantized.value} ", + f" and bits: {WNA16_SUPPORTED_BITS} is supported ", + f"but got format: {CompressionFormat.pack_quantized.value} " + f" and bits: {weight_quant.num_bits}", + ) + # Prefer to use the MarlinMoE kernel when it is supported. if ( not check_moe_marlin_supports_layer(layer, group_size) @@ -161,12 +181,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") return CompressedTensorsWNA16MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name) @@ -176,15 +196,15 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): or quant_config._is_fp8_w8a8(weight_quant, input_quant) ): return CompressedTensorsW8A8Fp8MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): return CompressedTensorsW4A8Int8MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) else: raise RuntimeError( @@ -650,17 +670,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, moe: FusedMoEConfig, layer_name: str | None = None, ): - super().__init__(moe) - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations" + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig, ) + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + per_tensor = ( self.weight_quant.strategy == QuantizationStrategy.TENSOR and self.input_quant.strategy == QuantizationStrategy.TENSOR @@ -698,11 +720,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # cutlass path - self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( + self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100( self.weight_quant, self.input_quant ) self.use_cutlass = not self.block_quant and ( - quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) + CompressedTensorsConfig._is_fp8_w8a8_sm90( + self.weight_quant, self.input_quant + ) or self.is_fp8_w8a8_sm100 ) self.disable_expert_map = False @@ -1261,16 +1285,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations" - ) + self.weight_quant = weight_quant + self.input_quant = input_quant per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL @@ -1414,36 +1436,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs | None, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) - self.quant_config = quant_config - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - config = self.quant_config.target_scheme_map["Linear"].get("weights") - self.num_bits = config.num_bits - self.packed_factor = 32 // config.num_bits - self.strategy = config.strategy - self.group_size = config.group_size - self.actorder = config.actorder - self.layer_name = layer_name - self.marlin_input_dtype = get_marlin_input_dtype(layer_name) - assert config.symmetric, "Only symmetric quantization is supported for MoE" + self.weight_quant = weight_quant + self.input_quant = input_quant + assert weight_quant.symmetric, ( + "Only symmetric quantization is supported for MoE" + ) + # Extract properties from weight_quant + self.num_bits = weight_quant.num_bits + self.packed_factor = 32 // weight_quant.num_bits + self.strategy = weight_quant.strategy + self.group_size = weight_quant.group_size + self.actorder = weight_quant.actorder - if not ( - self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS - ): - raise ValueError( - "For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}", - ) self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] self.use_marlin = True + self.marlin_input_dtype = get_marlin_input_dtype(layer_name) def create_weights( self, @@ -1812,35 +1825,26 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs | None, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) - self.quant_config = quant_config - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - config = self.quant_config.target_scheme_map["Linear"].get("weights") - self.num_bits = config.num_bits - self.packed_factor = 32 // config.num_bits - self.strategy = config.strategy + self.weight_quant = weight_quant + self.input_quant = input_quant + # Extract properties from weight_quant + self.num_bits = weight_quant.num_bits + self.packed_factor = 32 // weight_quant.num_bits + self.strategy = weight_quant.strategy # channelwise is not supported by this kernel - assert config.strategy == "group" - self.group_size = config.group_size + assert weight_quant.strategy == "group" + self.group_size = weight_quant.group_size # grouped actorder isn't supported by this kernel - assert config.actorder != "group" - assert config.symmetric, "Only symmetric quantization is supported for MoE" - - if not ( - self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS - ): - raise ValueError( - "For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}", - ) + assert weight_quant.actorder != "group" + assert weight_quant.symmetric, ( + "Only symmetric quantization is supported for MoE" + ) def create_weights( self, @@ -2065,28 +2069,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) self.has_bias = self.moe.has_bias - self.quant_config = quant_config + self.weight_quant = weight_quant + self.input_quant = input_quant # Validate scheme: weights=W4 (channel or group), # activations=dynamic TOKEN (A8) - wq = self.quant_config.target_scheme_map["Linear"].get("weights") - aq = self.quant_config.target_scheme_map["Linear"].get("input_activations") # Must be dynamic per-token activations - if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic: + if ( + input_quant.strategy != QuantizationStrategy.TOKEN + or not input_quant.dynamic + ): raise ValueError( "W4A8-int MoE needs dynamic per-token activation quantization." ) # Weight can be channel-wise (group_size=None) or group-wise - self.group_size = wq.group_size if (wq.group_size is not None) else -1 - if wq.num_bits != 4: + self.group_size = ( + weight_quant.group_size if (weight_quant.group_size is not None) else -1 + ) + if weight_quant.num_bits != 4: raise ValueError("This method only supports 4-bit weights (num_bits=4).") # CPU only