mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 19:47:04 +08:00
[CI/Test Fix] Fix CP tests on Blackwell (#28404)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
35d801f13f
commit
39029d5192
@ -14,6 +14,7 @@ from dataclasses import dataclass
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
@ -254,6 +255,17 @@ def test_cp_generation(
|
||||
test_options: CPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
if (
|
||||
model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
and torch.cuda.get_device_capability() < (9, 0)
|
||||
):
|
||||
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
|
||||
if (
|
||||
model_id == "bigcode/gpt_bigcode-santacoder"
|
||||
and torch.cuda.get_device_capability() != (9, 0)
|
||||
):
|
||||
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
|
||||
|
||||
_compare_cp_with_tp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
|
||||
@ -195,7 +195,6 @@ def cp_lse_ag_out_rs(
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
assert out.is_contiguous()
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
|
||||
if return_lse:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user