[Bugfix] Remove tile_size=64 for mm_prefix triton attention (#30973)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-12-19 03:42:32 +08:00 committed by GitHub
parent b8c477c115
commit d2dc5dfc6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -800,7 +800,6 @@ def _get_tile_size(
head_size: int, head_size: int,
sliding_window: int, sliding_window: int,
element_size: int, element_size: int,
is_mm_prefix: bool,
is_prefill: bool, is_prefill: bool,
) -> int: ) -> int:
"""Select tile size with Gemma3-specific optimization. """Select tile size with Gemma3-specific optimization.
@ -809,10 +808,6 @@ def _get_tile_size(
the larger head dimension (128/256). For other models, use the larger head dimension (128/256). For other models, use
the default vLLM behavior. the default vLLM behavior.
""" """
if is_mm_prefix:
# Multimodal bidirectional attention needs a larger tile size
return 64
if _is_gemma3_attention(head_size, sliding_window): if _is_gemma3_attention(head_size, sliding_window):
# Gemma3: use 32 for decode (default is 16) # Gemma3: use 32 for decode (default is 16)
return 32 return 32
@ -903,14 +898,12 @@ def unified_attention(
head_size, head_size,
sliding_window_val, sliding_window_val,
q.element_size(), q.element_size(),
is_mm_prefix=use_mm_prefix,
is_prefill=True, is_prefill=True,
) )
TILE_SIZE_DECODE = _get_tile_size( TILE_SIZE_DECODE = _get_tile_size(
head_size, head_size,
sliding_window_val, sliding_window_val,
q.element_size(), q.element_size(),
is_mm_prefix=use_mm_prefix,
is_prefill=False, is_prefill=False,
) )