Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-01 01:16:26 -07:00
parent af7b6c5dd4
commit 01bf16ede4
2 changed files with 5 additions and 5 deletions

View File

@ -305,4 +305,5 @@ def _compute_slot_mappings_kernel(
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
return tl.cast(ptr, tl.pointer_type(elem_dtype))
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@ -687,6 +687,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_computed_tokens_np,
num_common_prefix_blocks,
kv_cache_group_spec.kv_cache_spec,
builder,
@ -722,6 +723,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _compute_cascade_attn_prefix_len(
self,
num_scheduled_tokens: np.ndarray,
num_computed_tokens: np.ndarray,
num_common_prefix_blocks: int,
kv_cache_spec: KVCacheSpec,
attn_metadata_builder: AttentionMetadataBuilder,
@ -787,10 +789,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
num_reqs = len(num_scheduled_tokens)
common_prefix_len = min(
common_prefix_len,
self.requests.num_computed_tokens.np[:num_reqs].min())
common_prefix_len = min(common_prefix_len, num_computed_tokens.min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
kv_cache_spec.block_size)