[CI] Fix flaky test v1/worker/test_gpu_model_runner.py::test_kv_cache_stride_order (#24640)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-09-12 00:49:09 -07:00 committed by GitHub
parent f592b3174b
commit 561a0baee0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import numpy as np
import pytest
import torch
@ -409,29 +407,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
model_runner.model_config.get_head_size()
]
# TODO mla test
default_stride = list(range(5))
default_stride = tuple(range(5))
# Permutation that gets you back to expected kv shape
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
def rnd_stride_order():
return rnd_stride
def rnd_stride_order(test_stride=test_stride):
return test_stride
# Patch the attention backend class and re-trigger the KV cache creation.
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order)
# Patch the attention backend class and re-trigger the KV cache creation
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order)
model_runner.attn_groups = []
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
model_runner.attn_groups = []
model_runner.kv_caches = []
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
# Shape is unchanged, but layout may differ
kv_cache_shape = model_runner.kv_caches[0].shape
assert list(kv_cache_shape) == expected_kv_cache_shape
if default_stride == rnd_stride:
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
else:
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
# Shape is unchanged, but layout may differ
kv_cache_shape = model_runner.kv_caches[0].shape
assert list(kv_cache_shape) == expected_kv_cache_shape
if default_stride == test_stride:
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
else:
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
def test_update_config(model_runner):