mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 22:35:47 +08:00
[CI] Replace large models with tiny alternatives in tests (#24057)
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
02d709a6f1
commit
43721bc67f
@ -20,7 +20,7 @@ from ..models.utils import check_outputs_equal
|
|||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"google/gemma-2-2b-it",
|
"hmellor/tiny-random-Gemma2ForCausalLM",
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
|
|||||||
|
|
||||||
def test_vllm_gc_ed():
|
def test_vllm_gc_ed():
|
||||||
"""Verify vllm instance is GC'ed when it is deleted"""
|
"""Verify vllm instance is GC'ed when it is deleted"""
|
||||||
llm = LLM("distilbert/distilgpt2")
|
llm = LLM("hmellor/tiny-random-LlamaForCausalLM")
|
||||||
weak_llm = weakref.ref(llm)
|
weak_llm = weakref.ref(llm)
|
||||||
del llm
|
del llm
|
||||||
# If there's any circular reference to vllm, this fails
|
# If there's any circular reference to vllm, this fails
|
||||||
@ -125,14 +125,14 @@ def test_models(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, distributed_executor_backend, attention_backend, test_suite, extra_env",
|
"model, distributed_executor_backend, attention_backend, test_suite, extra_env",
|
||||||
[
|
[
|
||||||
("distilbert/distilgpt2", "ray", "", "L4", {}),
|
("facebook/opt-125m", "ray", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "L4", {}),
|
("facebook/opt-125m", "mp", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||||
("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
("facebook/opt-125m", "ray", "", "A100", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
("facebook/opt-125m", "mp", "", "A100", {}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
|
|||||||
@ -6,5 +6,5 @@ from ..utils import compare_two_settings
|
|||||||
|
|
||||||
def test_cpu_offload():
|
def test_cpu_offload():
|
||||||
compare_two_settings(
|
compare_two_settings(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"]
|
"hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -120,7 +120,7 @@ def test_cumem_with_cudagraph():
|
|||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
# sleep mode with safetensors
|
# sleep mode with safetensors
|
||||||
"meta-llama/Llama-3.2-1B",
|
"hmellor/tiny-random-LlamaForCausalLM",
|
||||||
# sleep mode with pytorch checkpoint
|
# sleep mode with pytorch checkpoint
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
],
|
],
|
||||||
@ -174,7 +174,7 @@ def test_end_to_end(model: str):
|
|||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_deep_sleep():
|
def test_deep_sleep():
|
||||||
model = "Qwen/Qwen3-0.6B"
|
model = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
free, total = torch.cuda.mem_get_info()
|
free, total = torch.cuda.mem_get_info()
|
||||||
used_bytes_baseline = total - free # in case other process is running
|
used_bytes_baseline = total - free # in case other process is running
|
||||||
llm = LLM(model, enable_sleep_mode=True)
|
llm = LLM(model, enable_sleep_mode=True)
|
||||||
|
|||||||
@ -273,14 +273,14 @@ def _compare_sp(
|
|||||||
|
|
||||||
SP_TEXT_GENERATION_MODELS = {
|
SP_TEXT_GENERATION_MODELS = {
|
||||||
# [Decoder-only]
|
# [Decoder-only]
|
||||||
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
|
"hmellor/tiny-random-LlamaForCausalLM": SPTestSettings.fast(),
|
||||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
|
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
|
||||||
}
|
}
|
||||||
|
|
||||||
SP_TEST_MODELS = [
|
SP_TEST_MODELS = [
|
||||||
# TODO support other models
|
# TODO support other models
|
||||||
# [LANGUAGE GENERATION]
|
# [LANGUAGE GENERATION]
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"hmellor/tiny-random-LlamaForCausalLM",
|
||||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
|
||||||
@ -12,6 +13,8 @@ from ...utils import create_new_process_for_each_test
|
|||||||
@pytest.mark.parametrize("backend", ["mp", "ray"])
|
@pytest.mark.parametrize("backend", ["mp", "ray"])
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_collective_rpc(tp_size, backend, monkeypatch):
|
def test_collective_rpc(tp_size, backend, monkeypatch):
|
||||||
|
if torch.cuda.device_count() < tp_size:
|
||||||
|
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||||
if tp_size == 1 and backend == "ray":
|
if tp_size == 1 and backend == "ray":
|
||||||
pytest.skip("Skip duplicate test case")
|
pytest.skip("Skip duplicate test case")
|
||||||
if tp_size == 1:
|
if tp_size == 1:
|
||||||
@ -24,7 +27,7 @@ def test_collective_rpc(tp_size, backend, monkeypatch):
|
|||||||
|
|
||||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
model="hmellor/tiny-random-LlamaForCausalLM",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
load_format="dummy",
|
load_format="dummy",
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
||||||
|
|
||||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
|
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
INPUT_BATCH = (
|
INPUT_BATCH = (
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||||
LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully."
|
LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully."
|
||||||
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||||
|
|||||||
@ -1,37 +1,93 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import get_open_port
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_shutdown_on_engine_failure():
|
async def test_shutdown_on_engine_failure():
|
||||||
# dtype, max-len etc set so that this can run in CI
|
"""Verify that API returns connection error when server process is killed.
|
||||||
args = [
|
|
||||||
"--dtype",
|
|
||||||
"bfloat16",
|
|
||||||
"--max-model-len",
|
|
||||||
"8192",
|
|
||||||
"--enforce-eager",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"128",
|
|
||||||
]
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
Starts a vLLM server, kills it to simulate a crash, then verifies that
|
||||||
async with remote_server.get_async_client() as client:
|
subsequent API calls fail appropriately.
|
||||||
with pytest.raises((openai.APIConnectionError, openai.InternalServerError)):
|
"""
|
||||||
# Asking for lots of prompt logprobs will currently crash the
|
|
||||||
# engine. This may change in the future when that bug is fixed
|
port = get_open_port()
|
||||||
prompt = "Hello " * 4000
|
|
||||||
await client.completions.create(
|
proc = subprocess.Popen(
|
||||||
model=MODEL_NAME, prompt=prompt, extra_body={"prompt_logprobs": 10}
|
[
|
||||||
|
# dtype, max-len etc set so that this can run in CI
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"vllm.entrypoints.openai.api_server",
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"128",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.05",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"2",
|
||||||
|
"--disable-frontend-multiprocessing",
|
||||||
|
],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for server startup
|
||||||
|
start_time = time.time()
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
base_url=f"http://localhost:{port}/v1",
|
||||||
|
api_key="dummy",
|
||||||
|
max_retries=0,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Poll until server is ready
|
||||||
|
while time.time() - start_time < 30:
|
||||||
|
try:
|
||||||
|
await client.completions.create(
|
||||||
|
model=MODEL_NAME, prompt="Hello", max_tokens=1
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
time.sleep(0.5)
|
||||||
|
if proc.poll() is not None:
|
||||||
|
stdout, stderr = proc.communicate(timeout=1)
|
||||||
|
pytest.fail(
|
||||||
|
f"Server died during startup. stdout: {stdout}, stderr: {stderr}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
proc.terminate()
|
||||||
|
proc.wait(timeout=5)
|
||||||
|
pytest.fail("Server failed to start in 30 seconds")
|
||||||
|
|
||||||
# Now the server should shut down
|
# Kill server to simulate crash
|
||||||
return_code = remote_server.proc.wait(timeout=8)
|
proc.terminate()
|
||||||
assert return_code is not None
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Verify API calls now fail
|
||||||
|
with pytest.raises((openai.APIConnectionError, openai.APIStatusError)):
|
||||||
|
await client.completions.create(
|
||||||
|
model=MODEL_NAME, prompt="This should fail", max_tokens=1
|
||||||
|
)
|
||||||
|
|
||||||
|
return_code = proc.wait(timeout=5)
|
||||||
|
assert return_code is not None
|
||||||
|
|||||||
@ -330,6 +330,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"guard": "meta-llama/Llama-Guard-3-1B",
|
"guard": "meta-llama/Llama-Guard-3-1B",
|
||||||
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B",
|
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B",
|
||||||
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||||
|
"tiny": "hmellor/tiny-random-LlamaForCausalLM",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
"LLaMAForCausalLM": _HfExamplesInfo(
|
"LLaMAForCausalLM": _HfExamplesInfo(
|
||||||
|
|||||||
@ -35,15 +35,13 @@ def _generate(
|
|||||||
|
|
||||||
|
|
||||||
class TestOneTokenBadWord:
|
class TestOneTokenBadWord:
|
||||||
MODEL = "TheBloke/Llama-2-7B-fp16"
|
MODEL = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
|
|
||||||
PROMPT = "Hi! How are"
|
PROMPT = "How old are "
|
||||||
TARGET_TOKEN = "you"
|
TARGET_TOKEN = "mn"
|
||||||
|
|
||||||
def setup_method(self, method):
|
def setup_method(self, method):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL)
|
||||||
self.MODEL, add_prefix_space=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||||
self.target_token_id = self._encode(
|
self.target_token_id = self._encode(
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
|
||||||
MODEL = "meta-llama/Llama-3.2-1B"
|
MODEL = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
PROMPT = "Hello my name is Robert and I"
|
PROMPT = "Hello my name is Robert and I"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,9 +24,11 @@ from ...utils import create_new_process_for_each_test, multi_gpu_test
|
|||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
# test_engine_core_concurrent_batches assumes exactly 12 tokens per prompt.
|
||||||
|
# Adjust prompt if changing model to maintain 12-token length.
|
||||||
|
PROMPT = "I am Gyoubu Masataka Oniwa"
|
||||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import pytest_asyncio
|
|||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from tests.v1.utils import check_request_balancing
|
from tests.v1.utils import check_request_balancing
|
||||||
|
|
||||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
|
|
||||||
DP_SIZE = os.getenv("DP_SIZE", "1")
|
DP_SIZE = os.getenv("DP_SIZE", "1")
|
||||||
|
|
||||||
|
|||||||
@ -5,16 +5,13 @@ import pytest
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
MODEL = "meta-llama/Llama-3.2-1B"
|
MODEL = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
PROMPT = "Hello my name is Robert and I"
|
PROMPT = "Hello my name is Robert and I"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def llm() -> LLM:
|
def llm() -> LLM:
|
||||||
# Disable prefix caching so that we can test prompt logprobs.
|
return LLM(MODEL, enforce_eager=True)
|
||||||
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
|
|
||||||
# is merged
|
|
||||||
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_n_gt_1(llm):
|
def test_n_gt_1(llm):
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from vllm.sampling_params import RequestOutputKind
|
|||||||
from vllm.utils import cuda_device_count_stateless
|
from vllm.utils import cuda_device_count_stateless
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
MODELS = ["meta-llama/Llama-3.2-1B"]
|
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.utils import cuda_device_count_stateless
|
|||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
from vllm.v1.engine.exceptions import EngineDeadError
|
from vllm.v1.engine.exceptions import EngineDeadError
|
||||||
|
|
||||||
MODELS = ["meta-llama/Llama-3.2-1B"]
|
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
|
||||||
|
|
||||||
|
|
||||||
def evil_forward(self, *args, **kwargs):
|
def evil_forward(self, *args, **kwargs):
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
|||||||
from vllm.utils import cuda_device_count_stateless
|
from vllm.utils import cuda_device_count_stateless
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
MODELS = ["meta-llama/Llama-3.2-1B"]
|
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
|
||||||
|
|
||||||
|
|
||||||
def evil_method(self, *args, **kwargs):
|
def evil_method(self, *args, **kwargs):
|
||||||
@ -76,8 +76,10 @@ def test_llm_startup_error(
|
|||||||
Test profiling (forward()) and load weights failures.
|
Test profiling (forward()) and load weights failures.
|
||||||
TODO(andy) - LLM without multiprocessing.
|
TODO(andy) - LLM without multiprocessing.
|
||||||
"""
|
"""
|
||||||
if model != "meta-llama/Llama-3.2-1B":
|
# Skip non-Llama models since we monkeypatch LlamaForCausalLM specifically.
|
||||||
pytest.skip(reason="Only test meta-llama/Llama-3.2-1B")
|
# If MODELS list grows, each architecture needs its own test variant.
|
||||||
|
if model != "JackFram/llama-68m":
|
||||||
|
pytest.skip(reason="Only test JackFram/llama-68m")
|
||||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||||
pytest.skip(reason="Not enough CUDA devices")
|
pytest.skip(reason="Not enough CUDA devices")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user