mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[CI] Fix flaky test v1/worker/test_gpu_model_runner.py::test_kv_cache_stride_order (#24640)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
f592b3174b
commit
561a0baee0
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
@ -409,29 +407,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
model_runner.model_config.get_head_size()
|
||||
]
|
||||
# TODO mla test
|
||||
default_stride = list(range(5))
|
||||
default_stride = tuple(range(5))
|
||||
# Permutation that gets you back to expected kv shape
|
||||
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
|
||||
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
|
||||
|
||||
def rnd_stride_order():
|
||||
return rnd_stride
|
||||
def rnd_stride_order(test_stride=test_stride):
|
||||
return test_stride
|
||||
|
||||
# Patch the attention backend class and re-trigger the KV cache creation.
|
||||
for attn_group in model_runner._attn_group_iterator():
|
||||
attn_backend = attn_group.backend
|
||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||
rnd_stride_order)
|
||||
# Patch the attention backend class and re-trigger the KV cache creation
|
||||
for attn_group in model_runner._attn_group_iterator():
|
||||
attn_backend = attn_group.backend
|
||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||
rnd_stride_order)
|
||||
|
||||
model_runner.attn_groups = []
|
||||
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||
model_runner.attn_groups = []
|
||||
model_runner.kv_caches = []
|
||||
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||
|
||||
# Shape is unchanged, but layout may differ
|
||||
kv_cache_shape = model_runner.kv_caches[0].shape
|
||||
assert list(kv_cache_shape) == expected_kv_cache_shape
|
||||
if default_stride == rnd_stride:
|
||||
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
else:
|
||||
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
# Shape is unchanged, but layout may differ
|
||||
kv_cache_shape = model_runner.kv_caches[0].shape
|
||||
assert list(kv_cache_shape) == expected_kv_cache_shape
|
||||
if default_stride == test_stride:
|
||||
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
else:
|
||||
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
|
||||
|
||||
def test_update_config(model_runner):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user