mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[V1][BugFix] Detect interleaved sliding window attention (#14896)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
fc1f67715d
commit
31060b2757
@ -82,8 +82,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
# NOTE(woosuk): sliding_window is None for models with interleaved
|
||||
# attention. Use interleaved_sliding_window instead.
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.interleaved_sliding_window = getattr(
|
||||
model_config.hf_text_config, "interleaved_sliding_window", None)
|
||||
self.window_size = (self.sliding_window
|
||||
or self.interleaved_sliding_window)
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
@ -674,7 +681,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_query_heads=self.num_query_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_alibi=False, # FIXME
|
||||
use_sliding_window=self.sliding_window is not None,
|
||||
use_sliding_window=self.window_size is not None,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user