diff --git a/vllm/config.py b/vllm/config.py index 74d7d9b17ce1b..1f7147f7cfd41 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -246,6 +246,7 @@ class ModelConfig: max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, disable_sliding_window: bool = False, + disable_cascade_attn: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, list[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, @@ -322,6 +323,7 @@ class ModelConfig: self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window + self.disable_cascade_attn = disable_cascade_attn self.skip_tokenizer_init = skip_tokenizer_init self.enable_sleep_mode = enable_sleep_mode diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5d22aa51e948f..bbe780a0ec118 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,6 +120,7 @@ class EngineArgs: block_size: Optional[int] = None enable_prefix_caching: Optional[bool] = None disable_sliding_window: bool = False + disable_cascade_attn: bool = False use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB @@ -1096,6 +1097,16 @@ class EngineArgs: "using. This is used to parse the reasoning content into OpenAI " "API format. Required for ``--enable-reasoning``.") + parser.add_argument( + "--disable-cascade-attn", + action="store_true", + default=False, + help="Disable cascade attention for V1. While cascade attention " + "does not change the mathematical correctness, disabling it " + "could be useful for preventing potential numerical issues. " + "Note that even if this is set to False, cascade attention will be " + "only used when the heuristic tells that it's beneficial.") + return parser @classmethod @@ -1141,6 +1152,7 @@ class EngineArgs: max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, disable_sliding_window=self.disable_sliding_window, + disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 53d68b60f2fde..fec6d6112d665 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -251,6 +251,9 @@ class MambaMixer2(CustomOp): "then num_groups must equal 1." ) + assert self.tp_size == 1 or quant_config is None, \ + "Tensor parallel currently not supported for quantized models." + self.ssm_state_size = ssm_state_size self.activation = activation @@ -331,22 +334,24 @@ class MambaMixer2(CustomOp): ], self.tp_size, tp_rank) }) - delattr(self.in_proj.weight, "weight_loader") - set_weight_attrs( - self.in_proj.weight, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, # for gate - intermediate_settings, - group_shard_settings, - group_shard_settings, - head_setings, # for dt - ], - self.tp_size, - tp_rank) - }) + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) # - these are TPed by heads to reduce the size of the # temporal shape diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index f42d546ba29b0..bfed44f9d58c8 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -15,6 +15,28 @@ class SchedulerInterface(ABC): @abstractmethod def schedule(self) -> "SchedulerOutput": + """Schedule the requests to process in this scheduling step. + + The scheduling decision is made at the iteration level. Each scheduling + step corresponds to a single forward pass of the model. Therefore, this + method is called repeatedly by a busy loop in the engine. + + Essentially, the scheduler produces a dictionary of {req_id: num_tokens} + that specifies how many tokens to process for each request in this + scheduling step. For example, num_tokens can be as large as the number + of prompt tokens for new requests, or it can be 1 for the requests that + are auto-regressively generating new tokens one by one. Otherwise, it + can be somewhere in between in case of chunked prefills, prefix caching, + speculative decoding, etc. + + Additionally, the scheduler also returns useful data about each request + or the batch as a whole. The model runner will use this information in + preparing inputs to the model. + + Returns: + A SchedulerOutput object containing information about the scheduled + requests. + """ raise NotImplementedError @abstractmethod @@ -23,10 +45,26 @@ class SchedulerInterface(ABC): scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", ) -> "EngineCoreOutputs": + """Update the scheduler state based on the model runner output. + + This method is called after the model runner has processed the scheduled + requests. The model runner output includes generated token ids, draft + token ids for next step, etc. The scheduler uses this information to + update its states, checks the finished requests, and returns the output + for each request. + + Returns: + A EngineCoreOutputs object containing the outputs for each request. + """ raise NotImplementedError @abstractmethod def add_request(self, request: "Request") -> None: + """Add a new request to the scheduler's internal queue. + + Args: + request: The new request being added. + """ raise NotImplementedError @abstractmethod @@ -35,17 +73,43 @@ class SchedulerInterface(ABC): request_ids: Union[str, Iterable[str]], finished_status: "RequestStatus", ) -> None: + """Finish the requests in the scheduler's internal queue. If the request + is not in the queue, this method will do nothing. + + This method is called in two cases: + 1. When the request is aborted by the client. + 2. When the frontend process detects a stop string of the request after + de-tokenizing its generated tokens. + + Args: + request_ids: A single or a list of request IDs. + finished_status: The finished status of the given requests. + """ raise NotImplementedError @abstractmethod def get_num_unfinished_requests(self) -> int: + """Number of unfinished requests in the scheduler's internal queue.""" raise NotImplementedError def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests in the scheduler's + internal queue.""" return self.get_num_unfinished_requests() > 0 @abstractmethod def has_finished_requests(self) -> bool: + """Returns True if there are finished requests that need to be cleared. + NOTE: This is different from `not self.has_unfinished_requests()`. + + The scheduler maintains an internal list of the requests finished in the + previous step. This list is returned from the next call to schedule(), + to be sent to the model runner in the next step to clear cached states + for these finished requests. + + This method checks if this internal list of finished requests is + non-empty. This information is useful for DP attention. + """ raise NotImplementedError def has_requests(self) -> bool: @@ -60,8 +124,16 @@ class SchedulerInterface(ABC): @abstractmethod def reset_prefix_cache(self) -> bool: + """Reset the prefix cache for KV cache. + + This is particularly required when the model weights are live-updated. + """ raise NotImplementedError @abstractmethod def make_stats(self) -> Optional["SchedulerStats"]: + """Make a SchedulerStats object for logging. + + The SchedulerStats object is created for every scheduling step. + """ raise NotImplementedError diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 37a20186ad2de..b186300a00330 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support self.input_registry = INPUT_REGISTRY @@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if needed. - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + attn_metadata = self.attn_metadata_builder.build( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens,