From 2b22290ce01b033cc692e7dce159d74a43f6f2c5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 20 Mar 2025 15:24:16 -0700 Subject: [PATCH 1/3] [V1] Add flag to disable cascade attention (#15243) Signed-off-by: Woosuk Kwon --- vllm/config.py | 2 ++ vllm/engine/arg_utils.py | 12 ++++++++++++ vllm/v1/worker/gpu_model_runner.py | 14 +++++++++----- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 74d7d9b17ce1b..1f7147f7cfd41 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -246,6 +246,7 @@ class ModelConfig: max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, disable_sliding_window: bool = False, + disable_cascade_attn: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, list[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, @@ -322,6 +323,7 @@ class ModelConfig: self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window + self.disable_cascade_attn = disable_cascade_attn self.skip_tokenizer_init = skip_tokenizer_init self.enable_sleep_mode = enable_sleep_mode diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 43bf2fe8f0932..5015f1d684b76 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,6 +120,7 @@ class EngineArgs: block_size: Optional[int] = None enable_prefix_caching: Optional[bool] = None disable_sliding_window: bool = False + disable_cascade_attn: bool = False use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB @@ -1096,6 +1097,16 @@ class EngineArgs: "using. This is used to parse the reasoning content into OpenAI " "API format. Required for ``--enable-reasoning``.") + parser.add_argument( + "--disable-cascade-attn", + action="store_true", + default=False, + help="Disable cascade attention for V1. While cascade attention " + "does not change the mathematical correctness, disabling it " + "could be useful for preventing potential numerical issues. " + "Note that even if this is set to False, cascade attention will be " + "only used when the heuristic tells that it's beneficial.") + return parser @classmethod @@ -1141,6 +1152,7 @@ class EngineArgs: max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, disable_sliding_window=self.disable_sliding_window, + disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7faf666dc61c2..c82bcec25d245 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support self.input_registry = INPUT_REGISTRY @@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if needed. - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + attn_metadata = self.attn_metadata_builder.build( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, From 06dd08256f076689945418cd61397c1759f4abfa Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 21 Mar 2025 08:44:37 +0800 Subject: [PATCH 2/3] Enforce that TP > 1 is not supported for Mamba2 if Quantization is Enabled. (#14617) Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/mamba_mixer2.py | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 53d68b60f2fde..fec6d6112d665 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -251,6 +251,9 @@ class MambaMixer2(CustomOp): "then num_groups must equal 1." ) + assert self.tp_size == 1 or quant_config is None, \ + "Tensor parallel currently not supported for quantized models." + self.ssm_state_size = ssm_state_size self.activation = activation @@ -331,22 +334,24 @@ class MambaMixer2(CustomOp): ], self.tp_size, tp_rank) }) - delattr(self.in_proj.weight, "weight_loader") - set_weight_attrs( - self.in_proj.weight, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, # for gate - intermediate_settings, - group_shard_settings, - group_shard_settings, - head_setings, # for dt - ], - self.tp_size, - tp_rank) - }) + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) # - these are TPed by heads to reduce the size of the # temporal shape From 0c6f5023c390075e842bb7c70bb8f5aa433c584c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 20 Mar 2025 17:50:43 -0700 Subject: [PATCH 3/3] [V1] Scheduler Refactoring [1/N] - Add Scheduler Interface (#15250) Signed-off-by: Woosuk Kwon Co-authored-by: Cody Yu Co-authored-by: Nick Hill --- tests/plugins_tests/test_scheduler_plugins.py | 2 +- tests/v1/core/test_scheduler.py | 3 +- tests/v1/worker/test_gpu_model_runner.py | 4 +- vllm/engine/arg_utils.py | 2 +- vllm/executor/ray_utils.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/mla/common.py | 2 +- vllm/v1/core/sched/__init__.py | 0 vllm/v1/core/sched/interface.py | 139 ++++++++++++++++++ .../{scheduler_output.py => sched/output.py} | 0 vllm/v1/core/{ => sched}/scheduler.py | 37 +---- vllm/v1/core/sched/utils.py | 22 +++ vllm/v1/engine/core.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/v1/worker/gpu_worker.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 2 +- vllm/v1/worker/tpu_worker.py | 2 +- 17 files changed, 182 insertions(+), 45 deletions(-) create mode 100644 vllm/v1/core/sched/__init__.py create mode 100644 vllm/v1/core/sched/interface.py rename vllm/v1/core/{scheduler_output.py => sched/output.py} (100%) rename vllm/v1/core/{ => sched}/scheduler.py (96%) create mode 100644 vllm/v1/core/sched/utils.py diff --git a/tests/plugins_tests/test_scheduler_plugins.py b/tests/plugins_tests/test_scheduler_plugins.py index 7abf5066a4133..4c95a52a967bd 100644 --- a/tests/plugins_tests/test_scheduler_plugins.py +++ b/tests/plugins_tests/test_scheduler_plugins.py @@ -6,7 +6,7 @@ from vllm.core.scheduler import Scheduler from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.sampling_params import SamplingParams -from vllm.v1.core.scheduler import Scheduler as V1Scheduler +from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 9413373390fe2..8916aa580000a 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -6,7 +6,8 @@ import pytest from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.v1.core.scheduler import Scheduler, SchedulerOutput +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 345519a07e411..dd95a7f53064e 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -3,8 +3,8 @@ import pytest from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams -from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, - SchedulerOutput) +from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, + SchedulerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_model_runner import GPUModelRunner diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5015f1d684b76..bbe780a0ec118 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1695,7 +1695,7 @@ class EngineArgs: # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: - self.scheduler_cls = "vllm.v1.core.scheduler.Scheduler" + self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" # When no user override, set the default values based on the usage # context. diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index c1bf2fb316d9b..a7042ca8df17c 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -17,7 +17,7 @@ from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 637c01556ac1c..27b3aabbc3504 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -17,7 +17,7 @@ from vllm.platforms import current_platform from vllm.utils import cdiv if TYPE_CHECKING: - from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f801745ab5c7d..188a425b107e4 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -212,7 +212,7 @@ except ImportError: from flash_attn import flash_attn_varlen_func if TYPE_CHECKING: - from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner diff --git a/vllm/v1/core/sched/__init__.py b/vllm/v1/core/sched/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py new file mode 100644 index 0000000000000..bfed44f9d58c8 --- /dev/null +++ b/vllm/v1/core/sched/interface.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.engine import EngineCoreOutputs + from vllm.v1.metrics.stats import SchedulerStats + from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + + +class SchedulerInterface(ABC): + + @abstractmethod + def schedule(self) -> "SchedulerOutput": + """Schedule the requests to process in this scheduling step. + + The scheduling decision is made at the iteration level. Each scheduling + step corresponds to a single forward pass of the model. Therefore, this + method is called repeatedly by a busy loop in the engine. + + Essentially, the scheduler produces a dictionary of {req_id: num_tokens} + that specifies how many tokens to process for each request in this + scheduling step. For example, num_tokens can be as large as the number + of prompt tokens for new requests, or it can be 1 for the requests that + are auto-regressively generating new tokens one by one. Otherwise, it + can be somewhere in between in case of chunked prefills, prefix caching, + speculative decoding, etc. + + Additionally, the scheduler also returns useful data about each request + or the batch as a whole. The model runner will use this information in + preparing inputs to the model. + + Returns: + A SchedulerOutput object containing information about the scheduled + requests. + """ + raise NotImplementedError + + @abstractmethod + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_runner_output: "ModelRunnerOutput", + ) -> "EngineCoreOutputs": + """Update the scheduler state based on the model runner output. + + This method is called after the model runner has processed the scheduled + requests. The model runner output includes generated token ids, draft + token ids for next step, etc. The scheduler uses this information to + update its states, checks the finished requests, and returns the output + for each request. + + Returns: + A EngineCoreOutputs object containing the outputs for each request. + """ + raise NotImplementedError + + @abstractmethod + def add_request(self, request: "Request") -> None: + """Add a new request to the scheduler's internal queue. + + Args: + request: The new request being added. + """ + raise NotImplementedError + + @abstractmethod + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: "RequestStatus", + ) -> None: + """Finish the requests in the scheduler's internal queue. If the request + is not in the queue, this method will do nothing. + + This method is called in two cases: + 1. When the request is aborted by the client. + 2. When the frontend process detects a stop string of the request after + de-tokenizing its generated tokens. + + Args: + request_ids: A single or a list of request IDs. + finished_status: The finished status of the given requests. + """ + raise NotImplementedError + + @abstractmethod + def get_num_unfinished_requests(self) -> int: + """Number of unfinished requests in the scheduler's internal queue.""" + raise NotImplementedError + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests in the scheduler's + internal queue.""" + return self.get_num_unfinished_requests() > 0 + + @abstractmethod + def has_finished_requests(self) -> bool: + """Returns True if there are finished requests that need to be cleared. + NOTE: This is different from `not self.has_unfinished_requests()`. + + The scheduler maintains an internal list of the requests finished in the + previous step. This list is returned from the next call to schedule(), + to be sent to the model runner in the next step to clear cached states + for these finished requests. + + This method checks if this internal list of finished requests is + non-empty. This information is useful for DP attention. + """ + raise NotImplementedError + + def has_requests(self) -> bool: + """Returns True if there are unfinished requests, or finished requests + not yet returned in SchedulerOutputs.""" + return self.has_unfinished_requests() or self.has_finished_requests() + + @abstractmethod + def get_num_unscheduled_requests(self) -> int: + """Number of requests that are not being processed by the executor.""" + raise NotImplementedError + + @abstractmethod + def reset_prefix_cache(self) -> bool: + """Reset the prefix cache for KV cache. + + This is particularly required when the model weights are live-updated. + """ + raise NotImplementedError + + @abstractmethod + def make_stats(self) -> Optional["SchedulerStats"]: + """Make a SchedulerStats object for logging. + + The SchedulerStats object is created for every scheduling step. + """ + raise NotImplementedError diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/sched/output.py similarity index 100% rename from vllm/v1/core/scheduler_output.py rename to vllm/v1/core/sched/output.py diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/sched/scheduler.py similarity index 96% rename from vllm/v1/core/scheduler.py rename to vllm/v1/core/sched/scheduler.py index 056458ef9dd28..d002a19b08a41 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -13,8 +13,10 @@ from vllm.logger import init_logger from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager -from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, - SchedulerOutput) +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, + SchedulerOutput) +from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) from vllm.v1.metrics.stats import SchedulerStats @@ -25,7 +27,7 @@ from vllm.v1.structured_output import StructuredOutputManager logger = init_logger(__name__) -class Scheduler: +class Scheduler(SchedulerInterface): def __init__( self, @@ -602,7 +604,7 @@ class Scheduler: # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. - stopped = self._check_stop(request) + stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) break @@ -648,25 +650,6 @@ class Scheduler: scheduler_stats=self.make_stats(), ) - def _check_stop(self, request: Request) -> bool: - if (request.num_tokens >= self.max_model_len - or request.num_output_tokens >= request.max_tokens): - request.status = RequestStatus.FINISHED_LENGTH_CAPPED - return True - - sampling_params = request.sampling_params - last_token_id = request.output_token_ids[-1] - if (not sampling_params.ignore_eos - and last_token_id == request.eos_token_id): - request.status = RequestStatus.FINISHED_STOPPED - return True - - if last_token_id in (sampling_params.stop_token_ids or ()): - request.status = RequestStatus.FINISHED_STOPPED - request.stop_reason = last_token_id - return True - return False - def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request @@ -715,17 +698,9 @@ class Scheduler: def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) - def has_unfinished_requests(self) -> bool: - return self.get_num_unfinished_requests() > 0 - def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 - def has_requests(self): - """Returns True if there are unfinished requests, or finished requests - not yet returned in SchedulerOutputs.""" - return self.has_unfinished_requests() or self.has_finished_requests() - def get_num_unscheduled_requests(self) -> int: """Number of requests that are not being processed by the executor.""" return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py new file mode 100644 index 0000000000000..3a0028a59016e --- /dev/null +++ b/vllm/v1/core/sched/utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.request import Request, RequestStatus + + +def check_stop(request: Request, max_model_len: int) -> bool: + if (request.num_tokens >= max_model_len + or request.num_output_tokens >= request.max_tokens): + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + return True + + sampling_params = request.sampling_params + last_token_id = request.output_token_ids[-1] + if (not sampling_params.ignore_eos + and last_token_id == request.eos_token_id): + request.status = RequestStatus.FINISHED_STOPPED + return True + + if last_token_id in (sampling_params.stop_token_ids or ()): + request.status = RequestStatus.FINISHED_STOPPED + request.stop_reason = last_token_id + return True + return False diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b0c18aee97c28..1598e6b8443fe 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -22,8 +22,8 @@ from vllm.transformers_utils.config import ( from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, zmq_socket_ctx) from vllm.v1.core.kv_cache_utils import get_kv_cache_configs -from vllm.v1.core.scheduler import Scheduler as V1Scheduler -from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c82bcec25d245..b186300a00330 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: import xgrammar as xgr - from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.core.sched.output import SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 241869e35c620..a63a2d022378e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.v1.core.scheduler_output import SchedulerOutput + from vllm.v1.core.sched.output import SchedulerOutput class Worker(WorkerBase): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b7924752aec8d..ec3dcbc064cba 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -37,7 +37,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9f59561192753..dbb231950d08d 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput