mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:25:48 +08:00
[V1] Add flag to disable cascade attention (#15243)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
d8e82bc06d
commit
2b22290ce0
@ -246,6 +246,7 @@ class ModelConfig:
|
|||||||
max_seq_len_to_capture: Optional[int] = None,
|
max_seq_len_to_capture: Optional[int] = None,
|
||||||
max_logprobs: int = 20,
|
max_logprobs: int = 20,
|
||||||
disable_sliding_window: bool = False,
|
disable_sliding_window: bool = False,
|
||||||
|
disable_cascade_attn: bool = False,
|
||||||
skip_tokenizer_init: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
served_model_name: Optional[Union[str, list[str]]] = None,
|
served_model_name: Optional[Union[str, list[str]]] = None,
|
||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||||
@ -322,6 +323,7 @@ class ModelConfig:
|
|||||||
self.max_seq_len_to_capture = max_seq_len_to_capture
|
self.max_seq_len_to_capture = max_seq_len_to_capture
|
||||||
self.max_logprobs = max_logprobs
|
self.max_logprobs = max_logprobs
|
||||||
self.disable_sliding_window = disable_sliding_window
|
self.disable_sliding_window = disable_sliding_window
|
||||||
|
self.disable_cascade_attn = disable_cascade_attn
|
||||||
self.skip_tokenizer_init = skip_tokenizer_init
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
self.enable_sleep_mode = enable_sleep_mode
|
self.enable_sleep_mode = enable_sleep_mode
|
||||||
|
|
||||||
|
|||||||
@ -120,6 +120,7 @@ class EngineArgs:
|
|||||||
block_size: Optional[int] = None
|
block_size: Optional[int] = None
|
||||||
enable_prefix_caching: Optional[bool] = None
|
enable_prefix_caching: Optional[bool] = None
|
||||||
disable_sliding_window: bool = False
|
disable_sliding_window: bool = False
|
||||||
|
disable_cascade_attn: bool = False
|
||||||
use_v2_block_manager: bool = True
|
use_v2_block_manager: bool = True
|
||||||
swap_space: float = 4 # GiB
|
swap_space: float = 4 # GiB
|
||||||
cpu_offload_gb: float = 0 # GiB
|
cpu_offload_gb: float = 0 # GiB
|
||||||
@ -1096,6 +1097,16 @@ class EngineArgs:
|
|||||||
"using. This is used to parse the reasoning content into OpenAI "
|
"using. This is used to parse the reasoning content into OpenAI "
|
||||||
"API format. Required for ``--enable-reasoning``.")
|
"API format. Required for ``--enable-reasoning``.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-cascade-attn",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Disable cascade attention for V1. While cascade attention "
|
||||||
|
"does not change the mathematical correctness, disabling it "
|
||||||
|
"could be useful for preventing potential numerical issues. "
|
||||||
|
"Note that even if this is set to False, cascade attention will be "
|
||||||
|
"only used when the heuristic tells that it's beneficial.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1141,6 +1152,7 @@ class EngineArgs:
|
|||||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
||||||
max_logprobs=self.max_logprobs,
|
max_logprobs=self.max_logprobs,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
|
disable_cascade_attn=self.disable_cascade_attn,
|
||||||
skip_tokenizer_init=self.skip_tokenizer_init,
|
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||||
served_model_name=self.served_model_name,
|
served_model_name=self.served_model_name,
|
||||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||||
|
|||||||
@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||||
weakref.proxy(self))
|
weakref.proxy(self))
|
||||||
|
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.input_registry = INPUT_REGISTRY
|
self.input_registry = INPUT_REGISTRY
|
||||||
@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.positions_cpu[:total_num_scheduled_tokens],
|
self.positions_cpu[:total_num_scheduled_tokens],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
# Prepare for cascade attention if needed.
|
# Prepare for cascade attention if enabled & beneficial.
|
||||||
|
common_prefix_len = 0
|
||||||
|
if self.cascade_attn_enabled:
|
||||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
scheduler_output.num_common_prefix_blocks,
|
scheduler_output.num_common_prefix_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata = self.attn_metadata_builder.build(
|
attn_metadata = self.attn_metadata_builder.build(
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user