mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 10:45:54 +08:00
[CI] Fix flaky test v1/worker/test_gpu_model_runner.py::test_kv_cache_stride_order (#24640)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
f592b3174b
commit
561a0baee0
@ -1,8 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -409,29 +407,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
|||||||
model_runner.model_config.get_head_size()
|
model_runner.model_config.get_head_size()
|
||||||
]
|
]
|
||||||
# TODO mla test
|
# TODO mla test
|
||||||
default_stride = list(range(5))
|
default_stride = tuple(range(5))
|
||||||
# Permutation that gets you back to expected kv shape
|
# Permutation that gets you back to expected kv shape
|
||||||
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
|
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
|
||||||
|
|
||||||
def rnd_stride_order():
|
def rnd_stride_order(test_stride=test_stride):
|
||||||
return rnd_stride
|
return test_stride
|
||||||
|
|
||||||
# Patch the attention backend class and re-trigger the KV cache creation.
|
# Patch the attention backend class and re-trigger the KV cache creation
|
||||||
for attn_group in model_runner._attn_group_iterator():
|
for attn_group in model_runner._attn_group_iterator():
|
||||||
attn_backend = attn_group.backend
|
attn_backend = attn_group.backend
|
||||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||||
rnd_stride_order)
|
rnd_stride_order)
|
||||||
|
|
||||||
model_runner.attn_groups = []
|
model_runner.attn_groups = []
|
||||||
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
model_runner.kv_caches = []
|
||||||
|
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||||
|
|
||||||
# Shape is unchanged, but layout may differ
|
# Shape is unchanged, but layout may differ
|
||||||
kv_cache_shape = model_runner.kv_caches[0].shape
|
kv_cache_shape = model_runner.kv_caches[0].shape
|
||||||
assert list(kv_cache_shape) == expected_kv_cache_shape
|
assert list(kv_cache_shape) == expected_kv_cache_shape
|
||||||
if default_stride == rnd_stride:
|
if default_stride == test_stride:
|
||||||
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
|
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||||
else:
|
else:
|
||||||
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||||
|
|
||||||
|
|
||||||
def test_update_config(model_runner):
|
def test_update_config(model_runner):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user