Merge branch 'main' into v1-sched-interface-2

This commit is contained in:
Woosuk Kwon 2025-03-20 17:56:13 -07:00
commit 8db54c7912
5 changed files with 116 additions and 21 deletions

View File

@ -246,6 +246,7 @@ class ModelConfig:
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
@ -322,6 +323,7 @@ class ModelConfig:
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.disable_cascade_attn = disable_cascade_attn
self.skip_tokenizer_init = skip_tokenizer_init
self.enable_sleep_mode = enable_sleep_mode

View File

@ -120,6 +120,7 @@ class EngineArgs:
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False
disable_cascade_attn: bool = False
use_v2_block_manager: bool = True
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
@ -1096,6 +1097,16 @@ class EngineArgs:
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.")
parser.add_argument(
"--disable-cascade-attn",
action="store_true",
default=False,
help="Disable cascade attention for V1. While cascade attention "
"does not change the mathematical correctness, disabling it "
"could be useful for preventing potential numerical issues. "
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.")
return parser
@classmethod
@ -1141,6 +1152,7 @@ class EngineArgs:
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,

View File

@ -251,6 +251,9 @@ class MambaMixer2(CustomOp):
"then num_groups must equal 1."
)
assert self.tp_size == 1 or quant_config is None, \
"Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size
self.activation = activation
@ -331,22 +334,24 @@ class MambaMixer2(CustomOp):
], self.tp_size, tp_rank)
})
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_setings, # for dt
],
self.tp_size,
tp_rank)
})
if quant_config is None:
# - quant layers do not have a weight loader
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_setings, # for dt
],
self.tp_size,
tp_rank)
})
# - these are TPed by heads to reduce the size of the
# temporal shape

View File

@ -15,6 +15,28 @@ class SchedulerInterface(ABC):
@abstractmethod
def schedule(self) -> "SchedulerOutput":
"""Schedule the requests to process in this scheduling step.
The scheduling decision is made at the iteration level. Each scheduling
step corresponds to a single forward pass of the model. Therefore, this
method is called repeatedly by a busy loop in the engine.
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
that specifies how many tokens to process for each request in this
scheduling step. For example, num_tokens can be as large as the number
of prompt tokens for new requests, or it can be 1 for the requests that
are auto-regressively generating new tokens one by one. Otherwise, it
can be somewhere in between in case of chunked prefills, prefix caching,
speculative decoding, etc.
Additionally, the scheduler also returns useful data about each request
or the batch as a whole. The model runner will use this information in
preparing inputs to the model.
Returns:
A SchedulerOutput object containing information about the scheduled
requests.
"""
raise NotImplementedError
@abstractmethod
@ -23,10 +45,26 @@ class SchedulerInterface(ABC):
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> "EngineCoreOutputs":
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
requests. The model runner output includes generated token ids, draft
token ids for next step, etc. The scheduler uses this information to
update its states, checks the finished requests, and returns the output
for each request.
Returns:
A EngineCoreOutputs object containing the outputs for each request.
"""
raise NotImplementedError
@abstractmethod
def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue.
Args:
request: The new request being added.
"""
raise NotImplementedError
@abstractmethod
@ -35,17 +73,43 @@ class SchedulerInterface(ABC):
request_ids: Union[str, Iterable[str]],
finished_status: "RequestStatus",
) -> None:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
This method is called in two cases:
1. When the request is aborted by the client.
2. When the frontend process detects a stop string of the request after
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
finished_status: The finished status of the given requests.
"""
raise NotImplementedError
@abstractmethod
def get_num_unfinished_requests(self) -> int:
"""Number of unfinished requests in the scheduler's internal queue."""
raise NotImplementedError
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests in the scheduler's
internal queue."""
return self.get_num_unfinished_requests() > 0
@abstractmethod
def has_finished_requests(self) -> bool:
"""Returns True if there are finished requests that need to be cleared.
NOTE: This is different from `not self.has_unfinished_requests()`.
The scheduler maintains an internal list of the requests finished in the
previous step. This list is returned from the next call to schedule(),
to be sent to the model runner in the next step to clear cached states
for these finished requests.
This method checks if this internal list of finished requests is
non-empty. This information is useful for DP attention.
"""
raise NotImplementedError
def has_requests(self) -> bool:
@ -60,8 +124,16 @@ class SchedulerInterface(ABC):
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
"""
raise NotImplementedError
@abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step.
"""
raise NotImplementedError

View File

@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Prepare for cascade attention if needed.
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,