mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 17:27:03 +08:00
Merge remote-tracking branch 'origin/main' into fp8_ep_dp
This commit is contained in:
commit
1236aebf0e
@ -113,7 +113,7 @@ WARNING: The benchmarking script will save json results by itself, so please do
|
||||
|
||||
### Visualizing the results
|
||||
|
||||
The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results.
|
||||
The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](performance-benchmarks-descriptions.md) with real benchmarking results.
|
||||
You can find the result presented as a table inside the `buildkite/performance-benchmark` job page.
|
||||
If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
|
||||
@ -2,14 +2,45 @@
|
||||
|
||||
set -xu
|
||||
|
||||
|
||||
remove_docker_container() {
|
||||
docker rm -f tpu-test || true;
|
||||
docker rm -f vllm-tpu || true;
|
||||
}
|
||||
|
||||
trap remove_docker_container EXIT
|
||||
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# Build the docker image.
|
||||
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
|
||||
|
||||
# Set up cleanup.
|
||||
remove_docker_container() { docker rm -f tpu-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
cleanup_docker() {
|
||||
# Get Docker's root directory
|
||||
docker_root=$(docker info -f '{{.DockerRootDir}}')
|
||||
if [ -z "$docker_root" ]; then
|
||||
echo "Failed to determine Docker root directory."
|
||||
exit 1
|
||||
fi
|
||||
echo "Docker root directory: $docker_root"
|
||||
# Check disk usage of the filesystem where Docker's root directory is located
|
||||
disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
|
||||
# Define the threshold
|
||||
threshold=70
|
||||
if [ "$disk_usage" -gt "$threshold" ]; then
|
||||
echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune --force --filter "until=72h" --all
|
||||
echo "Docker images and volumes cleanup completed."
|
||||
else
|
||||
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||
fi
|
||||
}
|
||||
cleanup_docker
|
||||
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
|
||||
@ -199,8 +199,9 @@ steps:
|
||||
- tests/test_sequence
|
||||
- tests/test_config
|
||||
- tests/test_logger
|
||||
- tests/test_vllm_port
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
|
||||
@ -617,9 +618,11 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
@ -8,4 +8,6 @@ Please report security issues privately using [the vulnerability submission form
|
||||
|
||||
---
|
||||
|
||||
Please see the [Security Guide in the vLLM documentation](https://docs.vllm.ai/en/latest/usage/security.html) for more information on vLLM's security assumptions and recommendations.
|
||||
|
||||
Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models.
|
||||
|
||||
@ -64,6 +64,12 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Custom</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>data.jsonl</code></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@ -124,6 +130,38 @@ P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Custom Dataset
|
||||
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
|
||||
|
||||
```
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
|
||||
```bash
|
||||
# start server
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
```bash
|
||||
# run benchmarking script
|
||||
python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \
|
||||
--backend vllm \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name custom \
|
||||
--dataset-path <path-to-your-data-jsonl> \
|
||||
--custom-skip-chat-template \
|
||||
--num-prompts 80 \
|
||||
--max-concurrency 1 \
|
||||
--temperature=0.3 \
|
||||
--top-p=0.75 \
|
||||
--result-dir "./log/"
|
||||
```
|
||||
|
||||
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
@ -146,9 +184,9 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--ngram_prompt_lookup_min 2 \
|
||||
--ngram-prompt-lookup-max 5 \
|
||||
--speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5}
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
```
|
||||
|
||||
``` bash
|
||||
@ -203,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
**`philschmid/mt-bench`**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path philschmid/mt-bench \
|
||||
--num-prompts 80
|
||||
```
|
||||
|
||||
### Running With Sampling Parameters
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
@ -273,9 +321,9 @@ python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--output-len=100 \
|
||||
--num-prompts=2048 \
|
||||
--async-engine \
|
||||
--ngram_prompt_lookup_min=2 \
|
||||
--ngram-prompt-lookup-max=5 \
|
||||
--speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5}
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@ -324,7 +324,7 @@ async def async_request_openai_completions(
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
if usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
@ -611,6 +611,7 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
"scalellm": async_request_openai_completions,
|
||||
"sglang": async_request_openai_completions,
|
||||
"llama.cpp": async_request_openai_completions,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
|
||||
@ -9,9 +9,6 @@ generation. Supported dataset types include:
|
||||
- BurstGPT
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
|
||||
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
|
||||
SampleRequest instances, similar to the approach used in ShareGPT.
|
||||
"""
|
||||
|
||||
import base64
|
||||
@ -442,6 +439,97 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Custom Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CustomDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the Custom dataset. Loads data from a JSONL file and generates
|
||||
sample requests based on conversation turns. E.g.,
|
||||
```
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
# self.data will be a list of dictionaries
|
||||
# e.g., [{"prompt": "What is the capital of India?"}, ...]
|
||||
# This will be the standardized format which load_data()
|
||||
# has to convert into depending on the filetype of dataset_path.
|
||||
# sample() will assume this standardized format of self.data
|
||||
self.data = []
|
||||
|
||||
# Load the JSONL file
|
||||
if self.dataset_path.endswith(".jsonl"):
|
||||
jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)
|
||||
|
||||
# check if the JSONL file has a 'prompt' column
|
||||
if "prompt" not in jsonl_data.columns:
|
||||
raise ValueError("JSONL file must contain a 'prompt' column.")
|
||||
|
||||
# Convert each row to a dictionary and append to self.data
|
||||
# This will convert the DataFrame to a list of dictionaries
|
||||
# where each dictionary corresponds to a row in the DataFrame.
|
||||
# This is the standardized format we want for self.data
|
||||
for _, row in jsonl_data.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only JSONL format is supported for CustomDataset."
|
||||
)
|
||||
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
lora_path: Optional[str] = None,
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["prompt"]
|
||||
|
||||
# apply template
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@ -60,6 +60,7 @@ from benchmark_dataset import (
|
||||
ASRDataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
CustomDataset,
|
||||
HuggingFaceDataset,
|
||||
InstructCoderDataset,
|
||||
MTBenchDataset,
|
||||
@ -627,7 +628,16 @@ def main(args: argparse.Namespace):
|
||||
"'--dataset-path' if required."
|
||||
)
|
||||
|
||||
if args.dataset_name == "sonnet":
|
||||
if args.dataset_name == "custom":
|
||||
dataset = CustomDataset(dataset_path=args.dataset_path)
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.custom_output_len,
|
||||
skip_chat_template=args.custom_skip_chat_template,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
@ -762,6 +772,10 @@ def main(args: argparse.Namespace):
|
||||
if "temperature" not in sampling_params:
|
||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||
|
||||
if args.backend == "llama.cpp":
|
||||
# Disable prompt caching in llama.cpp backend
|
||||
sampling_params["cache_prompt"] = False
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
@ -834,6 +848,8 @@ def main(args: argparse.Namespace):
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
if field in benchmark_result:
|
||||
del benchmark_result[field]
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
@ -846,6 +862,7 @@ def main(args: argparse.Namespace):
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(
|
||||
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
|
||||
@ -886,7 +903,7 @@ if __name__ == "__main__":
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -1056,6 +1073,19 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
custom_group.add_argument(
|
||||
"--custom-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of output tokens per request, used only for custom dataset.",
|
||||
)
|
||||
custom_group.add_argument(
|
||||
"--custom-skip-chat-template",
|
||||
action="store_true",
|
||||
help="Skip applying chat template to prompt, used only for custom dataset.",
|
||||
)
|
||||
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-input-len",
|
||||
|
||||
222
benchmarks/kernels/bench_fp8_gemm.py
Normal file
222
benchmarks/kernels/bench_fp8_gemm.py
Normal file
@ -0,0 +1,222 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"torch-bf16",
|
||||
# "fp8-tensor-w-token-a",
|
||||
"fp8-tensor-w-tensor-a",
|
||||
"fp8-channel-w-token-a",
|
||||
# "fp8-channel-w-tensor-a",
|
||||
# "fp8-tensor-w-token-a-noquant",
|
||||
"fp8-tensor-w-tensor-a-noquant",
|
||||
"fp8-channel-w-token-a-noquant",
|
||||
# "fp8-channel-w-tensor-a-noquant",
|
||||
],
|
||||
line_names=[
|
||||
"torch-bf16",
|
||||
# "fp8-tensor-w-token-a",
|
||||
"fp8-tensor-w-tensor-a",
|
||||
"fp8-channel-w-token-a",
|
||||
# "fp8-channel-w-tensor-a",
|
||||
# "fp8-tensor-w-token-a-noquant",
|
||||
"fp8-tensor-w-tensor-a-noquant",
|
||||
"fp8-channel-w-token-a-noquant",
|
||||
# "fp8-channel-w-tensor-a-noquant",
|
||||
],
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs FP8 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Create input tensors
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if "torch-bf16" in provider:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
|
||||
elif "fp8" in provider:
|
||||
# Weights are always quantized ahead of time
|
||||
if "noquant" in provider:
|
||||
# For no quantization, we just measure the GEMM
|
||||
if "tensor-w-token-a" in provider:
|
||||
# Dynamic per-token quant for A, per-tensor quant for B
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "tensor-w-tensor-a" in provider:
|
||||
# Static per-tensor quantization with fixed scales
|
||||
# for both A and B
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-token-a" in provider:
|
||||
# Static per-channel quantization for weights, per-token
|
||||
# quant for A
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-tensor-a" in provider:
|
||||
# Static per-channel quantization for weights, per-tensor
|
||||
# quant for A
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
else:
|
||||
# In these cases, we quantize the activations during the GEMM call
|
||||
if "tensor-w-token-a" in provider:
|
||||
# Dynamic per-token quant for A, per-tensor quant for B
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "tensor-w-tensor-a" in provider:
|
||||
# Static per-tensor quantization with fixed scales
|
||||
# for both A and B
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-token-a" in provider:
|
||||
# Static per-channel quantization for weights, per-token
|
||||
# quant for A
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-tensor-a" in provider:
|
||||
# Static per-channel quantization for weights, per-tensor
|
||||
# quant for A
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
b_fp8 = b_fp8.t()
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
# Calculate TFLOP/s, two flops per multiply-add
|
||||
tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=[*WEIGHT_SHAPES.keys()],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_fp8_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
@ -22,7 +22,7 @@ def benchmark_rope_kernels_multi_lora(
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -48,4 +48,50 @@ WEIGHT_SHAPES = {
|
||||
([16384, 106496], 1),
|
||||
([53248, 16384], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
@ -13,14 +13,34 @@
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
|
||||
#define __HIP__MI300_MI250__
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__GFX9__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && defined(__gfx942__)
|
||||
#define __HIP__MI300__
|
||||
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__MI3XX__
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define LDS_SIZE 160 * 1024
|
||||
#else
|
||||
#define LDS_SIZE 64 * 1024
|
||||
#endif
|
||||
|
||||
int get_lds_size() {
|
||||
static bool is_cached = false;
|
||||
static int result;
|
||||
if (is_cached == false) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx95");
|
||||
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
|
||||
is_cached = true;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
#if defined(NDEBUG)
|
||||
#undef NDEBUG
|
||||
#include <assert.h>
|
||||
@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
V0 += (s.x + s.y); \
|
||||
}
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
// This version targets cases where A[] fits LDS capacity
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
};
|
||||
|
||||
//----------------------------------------------------
|
||||
// Reserving 64 KB of LDS to have 1 WG / CU
|
||||
// Reserving 64/160 KB of LDS to have 1 WG / CU
|
||||
// Goal is to bring the activation matrix A to the LDS
|
||||
// and use it across the lifetime of the work group
|
||||
// TODO: When activation matrix is larger than 64 KB
|
||||
// then this is not goint to work!
|
||||
//----------------------------------------------------
|
||||
__shared__ scalar_t s[1024 * 32];
|
||||
__shared__ scalar_t s[max_lds_len];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Fetch the activation matrix to LDS
|
||||
@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// - Then the WG will move to another 8 K elements
|
||||
// TODO: Logic below will only work when K is multiple of 8
|
||||
//----------------------------------------------------
|
||||
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
||||
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
||||
|
||||
if (k_in >= min(K * N, 32 * 1024)) break;
|
||||
if (k_in >= min(K * N, max_lds_len)) break;
|
||||
|
||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
||||
}
|
||||
@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
// This version targets cases where A[] marginally exceeds LDS capacity
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// TODO: When activation matrix is larger than 64 KB
|
||||
// then this is not goint to work!
|
||||
//----------------------------------------------------
|
||||
__shared__ scalar_t s[1024 * 32];
|
||||
__shared__ scalar_t s[max_lds_len];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Computation of columns that need to be committed to memory!
|
||||
@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// - Then the WG will move to another 8 K elements
|
||||
// TODO: Logic below will only work when K is multiple of 8
|
||||
//----------------------------------------------------
|
||||
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
||||
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
||||
|
||||
if (k_in >= min(K * N, 32 * 1024)) break;
|
||||
if (k_in >= min(K * N, max_lds_len)) break;
|
||||
|
||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
||||
}
|
||||
@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
// Fetch A activation matrix in interleaved fashion from LDS or memory
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
if (k_ + K * n < 32 * 1024)
|
||||
if (k_ + K * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
else
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
||||
@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
}
|
||||
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
// This version targets big A[] cases, where it is much larger than LDS capacity
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
#if defined(__HIP__MI300__)
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
|
||||
#else
|
||||
constexpr bool use_mfma = false;
|
||||
@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
};
|
||||
|
||||
//----------------------------------------------------
|
||||
// Reserving 64 KB of LDS to have 1 WG / CU
|
||||
// Reserving 64/160 KB of LDS to have 1 WG / CU
|
||||
// Goal is to bring the activation matrix A to the LDS
|
||||
// and use it across the lifetime of the work group
|
||||
// TODO: When activation matrix is larger than 64 KB
|
||||
// then this is not goint to work!
|
||||
//----------------------------------------------------
|
||||
__shared__ scalar_t s[1024 * 32];
|
||||
__shared__ scalar_t s[max_lds_len];
|
||||
|
||||
//----------------------------------------------------
|
||||
// Computation of columns that need to be committed to memory!
|
||||
@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
//----------------------------------------------------
|
||||
#define PCML
|
||||
#ifndef PCML
|
||||
for (uint32_t k = 0; k < min(K * N, 32 * 1024);
|
||||
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
||||
|
||||
if (k_in >= min(K * N, 32 * 1024)) break;
|
||||
if (k_in >= min(K * N, max_lds_len)) break;
|
||||
|
||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
||||
}
|
||||
@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#define TUC (THRDS * UNRL * A_CHUNK)
|
||||
uint32_t kBase = 0;
|
||||
// find biggest k size that fits in LDS
|
||||
uint32_t kFit = (32 * 1024) / N;
|
||||
uint32_t kFit = (max_lds_len) / N;
|
||||
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
|
||||
// of TUC
|
||||
kFit = (kFit % TUC == 0)
|
||||
@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
int mindiv(int N, int div1, int div2) {
|
||||
int nPrRnd = div1 * div2;
|
||||
@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int max_lds_len = get_lds_size() / 2;
|
||||
|
||||
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
||||
_N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
} else if (K_in * N_in <= 32 * 1024 * 1.2) { \
|
||||
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
return out_c;
|
||||
}
|
||||
|
||||
#if defined(__HIP__MI300__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE;
|
||||
using scalar8 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
|
||||
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
scalar8 h8;
|
||||
};
|
||||
|
||||
__shared__ fp8_t s[1024 * 64];
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
}
|
||||
__syncthreads();
|
||||
@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
#if defined(__HIP__MI300__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
const fp8_t* __restrict__ A, scalar_t* C,
|
||||
const float* __restrict__ s_A, const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE;
|
||||
using scalar8 =
|
||||
__attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float;
|
||||
using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
scalar8 h8;
|
||||
};
|
||||
|
||||
__shared__ fp8_t s[1024 * 64];
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
}
|
||||
__syncthreads();
|
||||
@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
for (int n = 0; n < N; n++) {
|
||||
if (k_ + K * n < 64 * 1024)
|
||||
if (k_ + K * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
else
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
||||
@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
m += CuCount * _WvPrGrp * YTILE;
|
||||
}
|
||||
}
|
||||
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b,
|
||||
@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
dim3 grid(CuCount);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int max_lds_len = get_lds_size();
|
||||
|
||||
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
||||
_N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
|
||||
|
||||
@ -12,6 +12,7 @@ nav:
|
||||
- User Guide: usage/README.md
|
||||
- Developer Guide: contributing/README.md
|
||||
- API Reference: api/README.md
|
||||
- CLI Reference: cli/README.md
|
||||
- Timeline:
|
||||
- Roadmap: https://roadmap.vllm.ai
|
||||
- Releases: https://github.com/vllm-project/vllm/releases
|
||||
@ -56,6 +57,8 @@ nav:
|
||||
- Contents:
|
||||
- glob: api/vllm/*
|
||||
preserve_directory_names: true
|
||||
- CLI Reference:
|
||||
- Summary: cli/README.md
|
||||
- Community:
|
||||
- community/*
|
||||
- Blog: https://blog.vllm.ai
|
||||
|
||||
@ -12,8 +12,8 @@
|
||||
<p style="text-align:center">
|
||||
<script async defer src="https://buttons.github.io/buttons.js"></script>
|
||||
<a class="github-button" href="https://github.com/vllm-project/vllm" data-show-count="true" data-size="large" aria-label="Star">Star</a>
|
||||
<a class="github-button" href="https://github.com/vllm-project/vllm/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
|
||||
<a class="github-button" href="https://github.com/vllm-project/vllm/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
|
||||
<a class="github-button" href="https://github.com/vllm-project/vllm/subscription" data-show-count="true" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
|
||||
<a class="github-button" href="https://github.com/vllm-project/vllm/fork" data-show-count="true" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
|
||||
</p>
|
||||
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
179
docs/cli/README.md
Normal file
179
docs/cli/README.md
Normal file
@ -0,0 +1,179 @@
|
||||
# vLLM CLI Guide
|
||||
|
||||
The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with:
|
||||
|
||||
```
|
||||
vllm --help
|
||||
```
|
||||
|
||||
Available Commands:
|
||||
|
||||
```
|
||||
vllm {chat,complete,serve,bench,collect-env,run-batch}
|
||||
```
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [serve](#serve)
|
||||
- [chat](#chat)
|
||||
- [complete](#complete)
|
||||
- [bench](#bench)
|
||||
- [latency](#latency)
|
||||
- [serve](#serve-1)
|
||||
- [throughput](#throughput)
|
||||
- [collect-env](#collect-env)
|
||||
- [run-batch](#run-batch)
|
||||
- [More Help](#more-help)
|
||||
|
||||
## serve
|
||||
|
||||
Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Start with a model
|
||||
vllm serve meta-llama/Llama-2-7b-hf
|
||||
|
||||
# Specify the port
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
|
||||
# Check with --help for more options
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
|
||||
# To view a argument group
|
||||
vllm serve --help=ModelConfig
|
||||
|
||||
# To view a single argument
|
||||
vllm serve --help=max-num-seqs
|
||||
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
```
|
||||
|
||||
## chat
|
||||
|
||||
Generate chat completions via the running API server.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Directly connect to localhost API without arguments
|
||||
vllm chat
|
||||
|
||||
# Specify API url
|
||||
vllm chat --url http://{vllm-serve-host}:{vllm-serve-port}/v1
|
||||
|
||||
# Quick chat with a single prompt
|
||||
vllm chat --quick "hi"
|
||||
```
|
||||
|
||||
## complete
|
||||
|
||||
Generate text completions based on the given prompt via the running API server.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Directly connect to localhost API without arguments
|
||||
vllm complete
|
||||
|
||||
# Specify API url
|
||||
vllm complete --url http://{vllm-serve-host}:{vllm-serve-port}/v1
|
||||
|
||||
# Quick complete with a single prompt
|
||||
vllm complete --quick "The future of AI is"
|
||||
```
|
||||
|
||||
## bench
|
||||
|
||||
Run benchmark tests for latency online serving throughput and offline inference throughput.
|
||||
|
||||
Available Commands:
|
||||
|
||||
```bash
|
||||
vllm bench {latency, serve, throughput}
|
||||
```
|
||||
|
||||
### latency
|
||||
|
||||
Benchmark the latency of a single batch of requests.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench latency \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
--input-len 32 \
|
||||
--output-len 1 \
|
||||
--enforce-eager \
|
||||
--load-format dummy
|
||||
```
|
||||
|
||||
### serve
|
||||
|
||||
Benchmark the online serving throughput.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
--host server-host \
|
||||
--port server-port \
|
||||
--random-input-len 32 \
|
||||
--random-output-len 4 \
|
||||
--num-prompts 5
|
||||
```
|
||||
|
||||
### throughput
|
||||
|
||||
Benchmark offline inference throughput.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
--input-len 32 \
|
||||
--output-len 1 \
|
||||
--enforce-eager \
|
||||
--load-format dummy
|
||||
```
|
||||
|
||||
## collect-env
|
||||
|
||||
Start collecting environment information.
|
||||
|
||||
```bash
|
||||
vllm collect-env
|
||||
```
|
||||
|
||||
## run-batch
|
||||
|
||||
Run batch prompts and write results to file.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Running with a local file
|
||||
vllm run-batch \
|
||||
-i offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# Using remote file
|
||||
vllm run-batch \
|
||||
-i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
## More Help
|
||||
|
||||
For detailed options of any subcommand, use:
|
||||
|
||||
```bash
|
||||
vllm <subcommand> --help
|
||||
```
|
||||
@ -29,20 +29,68 @@ See <gh-file:LICENSE>.
|
||||
Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation.
|
||||
Check out the [building from source][build-from-source] documentation for details.
|
||||
|
||||
### Building the docs
|
||||
### Building the docs with MkDocs
|
||||
|
||||
Install the dependencies:
|
||||
#### Introduction to MkDocs
|
||||
|
||||
[MkDocs](https://github.com/mkdocs/mkdocs) is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file.
|
||||
|
||||
#### Install MkDocs and Plugins
|
||||
|
||||
Install MkDocs along with the [plugins](https://github.com/vllm-project/vllm/blob/main/mkdocs.yaml) used in the vLLM documentation, as well as required dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements/docs.txt
|
||||
```
|
||||
|
||||
Start the autoreloading MkDocs server:
|
||||
!!! note
|
||||
Ensure that your Python version is compatible with the plugins (e.g., `mkdocs-awesome-nav` requires Python 3.10+)
|
||||
|
||||
#### Verify Installation
|
||||
|
||||
Confirm that MkDocs is correctly installed:
|
||||
|
||||
```bash
|
||||
mkdocs --version
|
||||
```
|
||||
|
||||
Example output:
|
||||
|
||||
```console
|
||||
mkdocs, version 1.6.1 from /opt/miniconda3/envs/mkdoc/lib/python3.10/site-packages/mkdocs (Python 3.10)
|
||||
```
|
||||
|
||||
#### Clone the `vLLM` repository
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
```
|
||||
|
||||
#### Start the Development Server
|
||||
|
||||
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the `mkdocs.yml` configuration file, and then start the server by running the `mkdocs serve` command:
|
||||
|
||||
```bash
|
||||
mkdocs serve
|
||||
```
|
||||
|
||||
Example output:
|
||||
|
||||
```console
|
||||
INFO - Documentation built in 106.83 seconds
|
||||
INFO - [22:02:02] Watching paths for changes: 'docs', 'mkdocs.yaml'
|
||||
INFO - [22:02:02] Serving on http://127.0.0.1:8000/
|
||||
```
|
||||
|
||||
#### View in Your Browser
|
||||
|
||||
Open up [http://127.0.0.1:8000/](http://127.0.0.1:8000/) in your browser to see a live preview:.
|
||||
|
||||
#### Learn More
|
||||
|
||||
For additional features and advanced configurations, refer to the official [MkDocs Documentation](https://www.mkdocs.org/).
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
@ -60,6 +108,9 @@ pre-commit run mypy-3.9 --hook-stage manual --all-files
|
||||
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
|
||||
!!! tip
|
||||
|
||||
@ -48,8 +48,7 @@ for output in outputs:
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
More API details can be found in the [Offline Inference]
|
||||
(#offline-inference-api) section of the API docs.
|
||||
More API details can be found in the [Offline Inference](#offline-inference-api) section of the API docs.
|
||||
|
||||
The code for the `LLM` class can be found in <gh-file:vllm/entrypoints/llm.py>.
|
||||
|
||||
|
||||
@ -22,13 +22,13 @@ This document describes how vLLM deals with these challenges.
|
||||
|
||||
[Python multiprocessing methods](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) include:
|
||||
|
||||
- `spawn` - spawn a new Python process. This will be the default as of Python
|
||||
3.14. In macOS, this is already the default.
|
||||
- `spawn` - spawn a new Python process. The default on Windows and macOS.
|
||||
|
||||
- `fork` - Use `os.fork()` to fork the Python interpreter. This is the default
|
||||
in Python versions prior to 3.14.
|
||||
- `fork` - Use `os.fork()` to fork the Python interpreter. The default on
|
||||
Linux for Python versions prior to 3.14.
|
||||
|
||||
- `forkserver` - Spawn a server process that will fork a new process on request.
|
||||
The default on Linux for Python version 3.14 and newer.
|
||||
|
||||
### Tradeoffs
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ The symbols used have the following meanings:
|
||||
- ✅ = Full compatibility
|
||||
- 🟠 = Partial compatibility
|
||||
- ❌ = No compatibility
|
||||
- ❔ = Unknown or TBD
|
||||
|
||||
!!! note
|
||||
Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/hardware combination.
|
||||
@ -36,23 +37,23 @@ th:not(:first-child) {
|
||||
}
|
||||
</style>
|
||||
|
||||
| Feature | [CP][chunked-prefill] | [APC][automatic-prefix-caching] | [LoRA][lora-adapter] | <abbr title="Prompt Adapter">prmpt adptr</abbr> | [SD][spec-decode] | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|
||||
|-----------------------------------------------------------|-------------------------|-----------------------------------|------------------------|---------------------------------------------------|---------------------|--------------|-----------------------------------------------|-------------------------------------------------------|--------------------------------------|---------------------------------------------------|-------------------------------------------------------------|--------------------|---------------------------------------------|-----------|---------------|
|
||||
| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | |
|
||||
| [APC][automatic-prefix-caching] | ✅ | ✅ | | | | | | | | | | | | | |
|
||||
| [LoRA][lora-adapter] | ✅ | ✅ | ✅ | | | | | | | | | | | | |
|
||||
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | |
|
||||
| [SD][spec-decode] | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | |
|
||||
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | |
|
||||
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | |
|
||||
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [❌](gh-issue:7366) | ❌ | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
|
||||
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | |
|
||||
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
|
||||
| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
|
||||
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
|
||||
| best-of | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | |
|
||||
| beam-search | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ |
|
||||
| Feature | [CP][chunked-prefill] | [APC][automatic-prefix-caching] | [LoRA][lora-adapter] | <abbr title="Prompt Adapter">prmpt adptr</abbr> | [SD][spec-decode] | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|
||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
||||
| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | |
|
||||
| [APC][automatic-prefix-caching] | ✅ | ✅ | | | | | | | | | | | | | |
|
||||
| [LoRA][lora-adapter] | ✅ | ✅ | ✅ | | | | | | | | | | | | |
|
||||
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | |
|
||||
| [SD][spec-decode] | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | |
|
||||
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | |
|
||||
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | |
|
||||
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [❌](gh-issue:7366) | ❌ | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
|
||||
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | |
|
||||
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
|
||||
| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | |
|
||||
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
|
||||
| best-of | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | |
|
||||
| beam-search | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ |
|
||||
|
||||
[](){ #feature-x-hardware }
|
||||
|
||||
@ -75,3 +76,6 @@ th:not(:first-child) {
|
||||
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ |
|
||||
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
!!! note
|
||||
Please refer to [Feature support through NxD Inference backend][feature-support-through-nxd-inference-backend] for features supported on AWS Neuron hardware
|
||||
|
||||
@ -165,6 +165,7 @@ it will first look in the local directory for a directory `foobar`, and attempt
|
||||
that adapter will then be available for normal use on the server.
|
||||
|
||||
Alternatively, follow these example steps to implement your own plugin:
|
||||
|
||||
1. Implement the LoRAResolver interface.
|
||||
|
||||
Example of a simple S3 LoRAResolver implementation:
|
||||
@ -198,9 +199,9 @@ Alternatively, follow these example steps to implement your own plugin:
|
||||
return lora_request
|
||||
```
|
||||
|
||||
2. Register LoRAResolver plugin.
|
||||
2. Register `LoRAResolver` plugin.
|
||||
|
||||
```python
|
||||
```python
|
||||
from vllm.lora.resolver import LoRAResolverRegistry
|
||||
|
||||
s3_resolver = S3LoRAResolver()
|
||||
|
||||
@ -5,13 +5,13 @@ title: Supported Hardware
|
||||
|
||||
The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM:
|
||||
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Inferentia | Google TPU |
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU |
|
||||
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------|
|
||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ |
|
||||
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# --8<-- [start:installation]
|
||||
|
||||
vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK with continuous batching.
|
||||
Paged Attention and Chunked Prefill are currently in development and will be available soon.
|
||||
Data types currently supported in Neuron SDK are FP16 and BF16.
|
||||
[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) is the software development kit (SDK) used to run deep learning and
|
||||
generative AI workloads on AWS Inferentia and AWS Trainium powered Amazon EC2 instances and UltraServers (Inf1, Inf2, Trn1, Trn2,
|
||||
and Trn2 UltraServer). Both Trainium and Inferentia are powered by fully-independent heterogeneous compute-units called NeuronCores.
|
||||
This tab describes how to set up your environment to run vLLM on Neuron.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
@ -11,59 +12,31 @@ Data types currently supported in Neuron SDK are FP16 and BF16.
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
- OS: Linux
|
||||
- Python: 3.9 -- 3.11
|
||||
- Accelerator: NeuronCore_v2 (in trn1/inf2 instances)
|
||||
- Pytorch 2.0.1/2.1.1
|
||||
- AWS Neuron SDK 2.16/2.17 (Verified on python 3.8)
|
||||
- Python: 3.9 or newer
|
||||
- Pytorch 2.5/2.6
|
||||
- Accelerator: NeuronCore-v2 (in trn1/inf2 chips) or NeuronCore-v3 (in trn2 chips)
|
||||
- AWS Neuron SDK 2.23
|
||||
|
||||
## Configure a new environment
|
||||
|
||||
### Launch Trn1/Inf2 instances
|
||||
### Launch a Trn1/Trn2/Inf2 instance and verify Neuron dependencies
|
||||
|
||||
Here are the steps to launch trn1/inf2 instances, in order to install [PyTorch Neuron ("torch-neuronx") Setup on Ubuntu 22.04 LTS](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/pytorch/neuronx/ubuntu/torch-neuronx-ubuntu22.html).
|
||||
The easiest way to launch a Trainium or Inferentia instance with pre-installed Neuron dependencies is to follow this
|
||||
[quick start guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/multiframework/multi-framework-ubuntu22-neuron-dlami.html#setup-ubuntu22-multi-framework-dlami) using the Neuron Deep Learning AMI (Amazon machine image).
|
||||
|
||||
- Please follow the instructions at [launch an Amazon EC2 Instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EC2_GetStarted.html#ec2-launch-instance) to launch an instance. When choosing the instance type at the EC2 console, please make sure to select the correct instance type.
|
||||
- To get more information about instances sizes and pricing see: [Trn1 web page](https://aws.amazon.com/ec2/instance-types/trn1/), [Inf2 web page](https://aws.amazon.com/ec2/instance-types/inf2/)
|
||||
- Select Ubuntu Server 22.04 TLS AMI
|
||||
- When launching a Trn1/Inf2, please adjust your primary EBS volume size to a minimum of 512GB.
|
||||
- After launching the instance, follow the instructions in [Connect to your instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstancesLinux.html) to connect to the instance
|
||||
|
||||
### Install drivers and tools
|
||||
|
||||
The installation of drivers and tools wouldn't be necessary, if [Deep Learning AMI Neuron](https://docs.aws.amazon.com/dlami/latest/devguide/appendix-ami-release-notes.html) is installed. In case the drivers and tools are not installed on the operating system, follow the steps below:
|
||||
|
||||
- Once inside your instance, activate the pre-installed virtual environment for inference by running
|
||||
```console
|
||||
# Configure Linux for Neuron repository updates
|
||||
. /etc/os-release
|
||||
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
|
||||
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
|
||||
EOF
|
||||
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB \
|
||||
| sudo apt-key add -
|
||||
|
||||
# Update OS packages
|
||||
sudo apt-get update -y
|
||||
|
||||
# Install OS headers
|
||||
sudo apt-get install linux-headers-$(uname -r) -y
|
||||
|
||||
# Install git
|
||||
sudo apt-get install git -y
|
||||
|
||||
# install Neuron Driver
|
||||
sudo apt-get install aws-neuronx-dkms=2.* -y
|
||||
|
||||
# Install Neuron Runtime
|
||||
sudo apt-get install aws-neuronx-collectives=2.* -y
|
||||
sudo apt-get install aws-neuronx-runtime-lib=2.* -y
|
||||
|
||||
# Install Neuron Tools
|
||||
sudo apt-get install aws-neuronx-tools=2.* -y
|
||||
|
||||
# Add PATH
|
||||
export PATH=/opt/aws/neuron/bin:$PATH
|
||||
source /opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/activate
|
||||
```
|
||||
|
||||
Refer to the [NxD Inference Setup Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/nxdi-setup.html)
|
||||
for alternative setup instructions including using Docker and manually installing dependencies.
|
||||
|
||||
!!! note
|
||||
NxD Inference is the default recommended backend to run inference on Neuron. If you are looking to use the legacy [transformers-neuronx](https://github.com/aws-neuron/transformers-neuronx)
|
||||
library, refer to [Transformers NeuronX Setup](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/setup/index.html).
|
||||
|
||||
# --8<-- [end:requirements]
|
||||
# --8<-- [start:set-up-using-python]
|
||||
|
||||
@ -75,60 +48,37 @@ Currently, there are no pre-built Neuron wheels.
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
|
||||
!!! note
|
||||
The currently supported version of Pytorch for Neuron installs `triton` version `2.1.0`. This is incompatible with `vllm >= 0.5.3`. You may see an error `cannot import name 'default_dump_dir...`. To work around this, run a `pip install --upgrade triton==3.0.0` after installing the vLLM wheel.
|
||||
|
||||
Following instructions are applicable to Neuron SDK 2.16 and beyond.
|
||||
|
||||
#### Install transformers-neuronx and its dependencies
|
||||
|
||||
[transformers-neuronx](https://github.com/aws-neuron/transformers-neuronx) will be the backend to support inference on trn1/inf2 instances.
|
||||
Follow the steps below to install transformer-neuronx package and its dependencies.
|
||||
|
||||
```console
|
||||
# Install Python venv
|
||||
sudo apt-get install -y python3.10-venv g++
|
||||
|
||||
# Create Python venv
|
||||
python3.10 -m venv aws_neuron_venv_pytorch
|
||||
|
||||
# Activate Python venv
|
||||
source aws_neuron_venv_pytorch/bin/activate
|
||||
|
||||
# Install Jupyter notebook kernel
|
||||
pip install ipykernel
|
||||
python3.10 -m ipykernel install \
|
||||
--user \
|
||||
--name aws_neuron_venv_pytorch \
|
||||
--display-name "Python (torch-neuronx)"
|
||||
pip install jupyter notebook
|
||||
pip install environment_kernels
|
||||
|
||||
# Set pip repository pointing to the Neuron repository
|
||||
python -m pip config set \
|
||||
global.extra-index-url \
|
||||
https://pip.repos.neuron.amazonaws.com
|
||||
|
||||
# Install wget, awscli
|
||||
python -m pip install wget
|
||||
python -m pip install awscli
|
||||
|
||||
# Update Neuron Compiler and Framework
|
||||
python -m pip install --upgrade neuronx-cc==2.* --pre torch-neuronx==2.1.* torchvision transformers-neuronx
|
||||
```
|
||||
|
||||
#### Install vLLM from source
|
||||
|
||||
Once neuronx-cc and transformers-neuronx packages are installed, we will be able to install vllm as follows:
|
||||
Install vllm as follows:
|
||||
|
||||
```console
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
pip install -U -r requirements/neuron.txt
|
||||
VLLM_TARGET_DEVICE="neuron" pip install .
|
||||
VLLM_TARGET_DEVICE="neuron" pip install -e .
|
||||
```
|
||||
|
||||
If neuron packages are detected correctly in the installation process, `vllm-0.3.0+neuron212` will be installed.
|
||||
AWS Neuron maintains a [Github fork of vLLM](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2) at
|
||||
[https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2), which contains several features in addition to what's
|
||||
available on vLLM V0. Please utilize the AWS Fork for the following features:
|
||||
|
||||
- Llama-3.2 multi-modal support
|
||||
- Multi-node distributed inference
|
||||
|
||||
Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html)
|
||||
for more details and usage examples.
|
||||
|
||||
To install the AWS Neuron fork, run the following:
|
||||
|
||||
```console
|
||||
git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git
|
||||
cd upstreaming-to-vllm
|
||||
pip install -r requirements/neuron.txt
|
||||
VLLM_TARGET_DEVICE="neuron" pip install -e .
|
||||
```
|
||||
|
||||
Note that the AWS Neuron fork is only intended to support Neuron hardware; compatibility with other hardwares is not tested.
|
||||
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:set-up-using-docker]
|
||||
@ -148,5 +98,57 @@ Make sure to use <gh-file:docker/Dockerfile.neuron> in place of the default Dock
|
||||
# --8<-- [end:build-image-from-source]
|
||||
# --8<-- [start:extra-information]
|
||||
|
||||
There is no extra information for this device.
|
||||
[](){ #feature-support-through-nxd-inference-backend }
|
||||
### Feature support through NxD Inference backend
|
||||
|
||||
The current vLLM and Neuron integration relies on either the `neuronx-distributed-inference` (preferred) or `transformers-neuronx` backend
|
||||
to perform most of the heavy lifting which includes PyTorch model initialization, compilation, and runtime execution. Therefore, most
|
||||
[features supported on Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html) are also available via the vLLM integration.
|
||||
|
||||
To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override
|
||||
as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include
|
||||
```console
|
||||
override_neuron_config={
|
||||
"enable_bucketing":False,
|
||||
}
|
||||
```
|
||||
or when launching vLLM from the CLI, pass
|
||||
```console
|
||||
--override-neuron-config "{\"enable_bucketing\":false}"
|
||||
```
|
||||
|
||||
Alternatively, users can directly call the NxDI library to trace and compile your model, then load the pre-compiled artifacts
|
||||
(via `NEURON_COMPILED_ARTIFACTS` environment variable) in vLLM to run inference workloads.
|
||||
|
||||
### Known limitations
|
||||
|
||||
- EAGLE speculative decoding: NxD Inference requires the EAGLE draft checkpoint to include the LM head weights from the target model. Refer to this
|
||||
[guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html#eagle-checkpoint-compatibility)
|
||||
for how to convert pretrained EAGLE model checkpoints to be compatible for NxDI.
|
||||
- Quantization: the native quantization flow in vLLM is not well supported on NxD Inference. It is recommended to follow this
|
||||
[Neuron quantization guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/custom-quantization.html)
|
||||
to quantize and compile your model using NxD Inference, and then load the compiled artifacts into vLLM.
|
||||
- Multi-LoRA serving: NxD Inference only supports loading of LoRA adapters at server startup. Dynamic loading of LoRA adapters at
|
||||
runtime is not currently supported. Refer to [multi-lora example](https://github.com/aws-neuron/upstreaming-to-vllm/blob/neuron-2.23-vllm-v0.7.2/examples/offline_inference/neuron_multi_lora.py)
|
||||
- Multi-modal support: multi-modal support is only available through the AWS Neuron fork. This feature has not been upstreamed
|
||||
to vLLM main because NxD Inference currently relies on certain adaptations to the core vLLM logic to support this feature.
|
||||
- Multi-node support: distributed inference across multiple Trainium/Inferentia instances is only supported on the AWS Neuron fork. Refer
|
||||
to this [multi-node example](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2/examples/neuron/multi_node)
|
||||
to run. Note that tensor parallelism (distributed inference across NeuronCores) is available in vLLM main.
|
||||
- Known edge case bug in speculative decoding: An edge case failure may occur in speculative decoding when sequence length approaches
|
||||
max model length (e.g. when requesting max tokens up to the max model length and ignoring eos). In this scenario, vLLM may attempt
|
||||
to allocate an additional block to ensure there is enough memory for number of lookahead slots, but since we do not have good support
|
||||
for paged attention, there isn't another Neuron block for vLLM to allocate. A workaround fix (to terminate 1 iteration early) is
|
||||
implemented in the AWS Neuron fork but is not upstreamed to vLLM main as it modifies core vLLM logic.
|
||||
|
||||
|
||||
### Environment variables
|
||||
- `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid
|
||||
compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the
|
||||
artifacts under `neuron-compiled-artifacts/{unique_hash}/` sub-directory in the model path. If this environment variable is set,
|
||||
but the directory does not exist, or the contents are invalid, Neuron will also fallback to a new compilation and store the artifacts
|
||||
under this specified path.
|
||||
- `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend).
|
||||
- `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend).
|
||||
|
||||
# --8<-- [end:extra-information]
|
||||
|
||||
@ -302,31 +302,31 @@ Specified using `--task generate`.
|
||||
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|
||||
|---------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|
|
||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ |
|
||||
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | ✅︎ | |
|
||||
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ |
|
||||
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ |
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | | |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | ✅︎ | |
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | ✅︎ | |
|
||||
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | |
|
||||
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | ✅︎ | |
|
||||
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | ✅︎ | |
|
||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | ✅︎ | |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ |
|
||||
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ |
|
||||
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | | ✅︎ |
|
||||
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | | ✅︎ |
|
||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | | ✅︎ |
|
||||
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | ✅︎ | |
|
||||
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ |
|
||||
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ |
|
||||
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ |
|
||||
| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ |
|
||||
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | ✅︎ | |
|
||||
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ |
|
||||
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ |
|
||||
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | ✅︎ | |
|
||||
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | ✅︎ | |
|
||||
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ |
|
||||
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ |
|
||||
| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ |
|
||||
@ -336,39 +336,39 @@ Specified using `--task generate`.
|
||||
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ |
|
||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | ✅︎ | |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | ✅︎ | |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
|
||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
|
||||
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | ✅︎ | |
|
||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ |
|
||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | |
|
||||
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | |
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | ✅︎ | |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ |
|
||||
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ |
|
||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ |
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
|
||||
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | ✅︎ | |
|
||||
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ |
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | ✅︎ | |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ |
|
||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | |
|
||||
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | |
|
||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ |
|
||||
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | |
|
||||
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | ✅︎ | |
|
||||
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | ✅︎ | |
|
||||
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | | ✅︎ |
|
||||
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | |
|
||||
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ |
|
||||
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
|
||||
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ |
|
||||
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | ✅︎ | |
|
||||
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | |
|
||||
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | |
|
||||
|
||||
!!! note
|
||||
@ -401,7 +401,7 @@ Specified using `--task embed`.
|
||||
|
||||
!!! note
|
||||
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||
You should manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
||||
You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
||||
|
||||
!!! note
|
||||
For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded.
|
||||
@ -512,44 +512,44 @@ Specified using `--task generate`.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|----------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | ✅︎ | ✅︎ | |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | ✅︎ | ✅︎ | |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | ✅︎ | ✅︎ | |
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | ✅︎ | ✅︎ | |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | ✅︎ | ✅︎ | |
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | | ✅︎ | ✅︎ |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | | ✅︎ | ✅︎ |
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | ✅︎ | ✅︎ | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎\* | |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | ✅︎ | |
|
||||
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | ✅︎ | | |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | ✅︎ | ✅︎ | |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | ✅︎ | ✅︎ | |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ |
|
||||
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | ✅︎ | ✅︎ | |
|
||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ |
|
||||
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | |
|
||||
| `Mistral3ForConditionalGeneration` | Mistral3 | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
||||
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | ✅︎ | ✅︎ | |
|
||||
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | ✅︎ | | |
|
||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ⚠️ | |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | ✅︎ | ✅︎ | |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | |
|
||||
| `PixtralForConditionalGeneration` | Pixtral | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | ✅︎ | ✅︎ | |
|
||||
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `PixtralForConditionalGeneration` | Pixtral | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | ✅︎ | ✅︎ | |
|
||||
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ |
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎\* | |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | ✅︎ | ✅︎ | |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | ✅︎ | |
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
|
||||
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
|
||||
• For example, to use DeepSeek-VL2 series models:
|
||||
@ -647,7 +647,7 @@ The following table lists those that are tested in vLLM.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|
||||
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | ✅︎ | |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ |
|
||||
|
||||
#### Transcription
|
||||
|
||||
@ -12,14 +12,14 @@ All communications between nodes in a multi-node vLLM deployment are **insecure
|
||||
|
||||
The following options control inter-node communications in vLLM:
|
||||
|
||||
1. **Environment Variables:**
|
||||
#### 1. **Environment Variables:**
|
||||
- `VLLM_HOST_IP`: Sets the IP address for vLLM processes to communicate on
|
||||
|
||||
2. **KV Cache Transfer Configuration:**
|
||||
#### 2. **KV Cache Transfer Configuration:**
|
||||
- `--kv-ip`: The IP address for KV cache transfer communications (default: 127.0.0.1)
|
||||
- `--kv-port`: The port for KV cache transfer communications (default: 14579)
|
||||
|
||||
3. **Data Parallel Configuration:**
|
||||
#### 3. **Data Parallel Configuration:**
|
||||
- `data_parallel_master_ip`: IP of the data parallel master (default: 127.0.0.1)
|
||||
- `data_parallel_master_port`: Port of the data parallel master (default: 29500)
|
||||
|
||||
@ -39,16 +39,16 @@ Key points from the PyTorch security guide:
|
||||
|
||||
### Security Recommendations
|
||||
|
||||
1. **Network Isolation:**
|
||||
#### 1. **Network Isolation:**
|
||||
- Deploy vLLM nodes on a dedicated, isolated network
|
||||
- Use network segmentation to prevent unauthorized access
|
||||
- Implement appropriate firewall rules
|
||||
|
||||
2. **Configuration Best Practices:**
|
||||
#### 2. **Configuration Best Practices:**
|
||||
- Always set `VLLM_HOST_IP` to a specific IP address rather than using defaults
|
||||
- Configure firewalls to only allow necessary ports between nodes
|
||||
|
||||
3. **Access Control:**
|
||||
#### 3. **Access Control:**
|
||||
- Restrict physical and network access to the deployment environment
|
||||
- Implement proper authentication and authorization for management interfaces
|
||||
- Follow the principle of least privilege for all system components
|
||||
|
||||
@ -97,10 +97,14 @@ def main(
|
||||
# with DP, each rank should process different prompts.
|
||||
# usually all the DP ranks process a full dataset,
|
||||
# and each rank processes a different part of the dataset.
|
||||
promts_per_rank = len(prompts) // dp_size
|
||||
start = global_dp_rank * promts_per_rank
|
||||
end = start + promts_per_rank
|
||||
prompts = prompts[start:end]
|
||||
floor = len(prompts) // dp_size
|
||||
remainder = len(prompts) % dp_size
|
||||
|
||||
# Distribute prompts into even groups.
|
||||
def start(rank):
|
||||
return rank * floor + min(rank, remainder)
|
||||
|
||||
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
|
||||
if len(prompts) == 0:
|
||||
# if any rank has no prompts to process,
|
||||
# we need to set a placeholder prompt
|
||||
|
||||
105
examples/offline_inference/neuron_multimodal.py
Normal file
105
examples/offline_inference/neuron_multimodal.py
Normal file
@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import requests
|
||||
import torch
|
||||
from neuronx_distributed_inference.models.mllama.utils import add_instruct
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM, SamplingParams, TextPrompt
|
||||
|
||||
|
||||
def get_image(image_url):
|
||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
# Model Inputs
|
||||
PROMPTS = [
|
||||
"What is in this image? Tell me a story",
|
||||
"What is the recipe of mayonnaise in two sentences?",
|
||||
"Describe this image",
|
||||
"What is the capital of Italy famous for?",
|
||||
]
|
||||
IMAGES = [
|
||||
get_image(
|
||||
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
|
||||
),
|
||||
None,
|
||||
get_image(
|
||||
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
|
||||
),
|
||||
None,
|
||||
]
|
||||
SAMPLING_PARAMS = [
|
||||
dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16)
|
||||
for _ in range(len(PROMPTS))
|
||||
]
|
||||
|
||||
|
||||
def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params):
|
||||
# Prepare all inputs for mllama generation, including:
|
||||
# 1. put text prompt into instruct chat template
|
||||
# 2. compose single text and single image prompt into Vllm's prompt class
|
||||
# 3. prepare sampling parameters
|
||||
input_image = single_image
|
||||
has_image = torch.tensor([1])
|
||||
if isinstance(single_image, torch.Tensor) and single_image.numel() == 0:
|
||||
has_image = torch.tensor([0])
|
||||
|
||||
instruct_prompt = add_instruct(prompt, has_image)
|
||||
inputs = TextPrompt(prompt=instruct_prompt)
|
||||
|
||||
if input_image is not None:
|
||||
inputs["multi_modal_data"] = {"image": input_image}
|
||||
|
||||
sampling_params = SamplingParams(**sampling_params)
|
||||
return inputs, sampling_params
|
||||
|
||||
|
||||
def print_outputs(outputs):
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS)
|
||||
), f"""Text, image prompts and sampling parameters should have the
|
||||
same batch size; but got {len(PROMPTS)}, {len(IMAGES)},
|
||||
and {len(SAMPLING_PARAMS)}"""
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
max_num_seqs=1,
|
||||
max_model_len=4096,
|
||||
block_size=4096,
|
||||
device="neuron",
|
||||
tensor_parallel_size=32,
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled": False,
|
||||
"skip_warmup": True,
|
||||
"save_sharded_checkpoint": True,
|
||||
"on_device_sampling_config": {
|
||||
"global_topk": 1,
|
||||
"dynamic": False,
|
||||
"deterministic": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
batched_inputs = []
|
||||
batched_sample_params = []
|
||||
for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS):
|
||||
inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params)
|
||||
# test batch-size = 1
|
||||
outputs = llm.generate(inputs, sampling_params)
|
||||
print_outputs(outputs)
|
||||
batched_inputs.append(inputs)
|
||||
batched_sample_params.append(sampling_params)
|
||||
|
||||
# test batch-size = 4
|
||||
outputs = llm.generate(batched_inputs, batched_sample_params)
|
||||
print_outputs(outputs)
|
||||
@ -8,7 +8,6 @@ requires = [
|
||||
"setuptools-scm>=8.0",
|
||||
"torch == 2.7.0",
|
||||
"wheel",
|
||||
"regex",
|
||||
"jinja2",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@ -14,7 +14,7 @@ protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
aiohttp
|
||||
openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
|
||||
pydantic >= 2.9
|
||||
pydantic >= 2.10
|
||||
prometheus_client >= 0.18.0
|
||||
pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
|
||||
@ -51,3 +51,4 @@ numpy
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
@ -480,12 +480,13 @@ pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
# via blobfile
|
||||
pydantic==2.9.2
|
||||
pydantic==2.11.5
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# datamodel-code-generator
|
||||
# mistral-common
|
||||
# mteb
|
||||
pydantic-core==2.23.4
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
@ -784,6 +785,9 @@ typing-extensions==4.12.2
|
||||
# pydantic-core
|
||||
# torch
|
||||
# typer
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2024.2
|
||||
# via pandas
|
||||
uri-template==1.3.0
|
||||
|
||||
@ -18,9 +18,9 @@ setuptools==78.1.0
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.8.0.dev20250518
|
||||
torchvision==0.22.0.dev20250518
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch==2.8.0.dev20250529
|
||||
torchvision==0.22.0.dev20250529
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@ -5,12 +5,12 @@ import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from packaging.version import Version, parse
|
||||
from setuptools import Extension, setup
|
||||
|
||||
@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs(
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
@ -69,7 +68,6 @@ def test_models(
|
||||
hf_runner,
|
||||
model: str,
|
||||
backend: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
enable_prompt_embeds: bool,
|
||||
@ -97,7 +95,7 @@ def test_models(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
example_prompts = [prompt]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
@ -106,7 +104,6 @@ def test_models(
|
||||
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
|
||||
@ -74,11 +74,12 @@ class SillyModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def test_simple_piecewise_compile():
|
||||
def _test_simple_piecewise_compile(*, use_inductor):
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
@ -108,3 +109,11 @@ def test_simple_piecewise_compile():
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=True)
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_no_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=False)
|
||||
|
||||
@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor,
|
||||
@torch.inference_mode
|
||||
def run_model(llama_config,
|
||||
use_compile: bool,
|
||||
use_inductor: bool,
|
||||
split_attn: bool = False) -> torch.Tensor:
|
||||
|
||||
if use_compile:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
if split_attn:
|
||||
@ -304,7 +306,7 @@ def run_model(llama_config,
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_toy_llama():
|
||||
def _test_toy_llama(*, use_inductor):
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(hidden_size=128,
|
||||
@ -326,8 +328,14 @@ def test_toy_llama():
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_caputured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=False))
|
||||
run_model(tractable_config, use_compile=False)
|
||||
outputs.append(
|
||||
run_model(llama_config, use_inductor=False, use_compile=False))
|
||||
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||
|
||||
if use_inductor:
|
||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||
else:
|
||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@ -336,9 +344,13 @@ def test_toy_llama():
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=
|
||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=True))
|
||||
run_model(tractable_config, use_compile=True)
|
||||
outputs.append(
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True))
|
||||
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@ -353,13 +365,27 @@ def test_toy_llama():
|
||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config, use_compile=True, split_attn=True))
|
||||
run_model(tractable_config, use_compile=True, split_attn=True)
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True))
|
||||
run_model(tractable_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
||||
|
||||
def test_toy_llama_inductor():
|
||||
_test_toy_llama(use_inductor=True)
|
||||
|
||||
|
||||
def test_toy_no_inductor():
|
||||
_test_toy_llama(use_inductor=False)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def benchmark():
|
||||
from triton.testing import do_bench
|
||||
|
||||
@ -311,6 +311,7 @@ class HfRunner:
|
||||
dtype: str = "auto",
|
||||
*,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
trust_remote_code: bool = True,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
@ -320,10 +321,15 @@ class HfRunner:
|
||||
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
self.device = self.get_default_device()
|
||||
self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)
|
||||
self.dtype = torch_dtype = _get_and_verify_dtype(
|
||||
self.model_name,
|
||||
self.config,
|
||||
dtype=dtype,
|
||||
is_pooling_model=is_sentence_transformer or is_cross_encoder,
|
||||
)
|
||||
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
model_kwargs.setdefault("torch_dtype", torch_dtype)
|
||||
@ -336,7 +342,7 @@ class HfRunner:
|
||||
model_name,
|
||||
device=self.device,
|
||||
model_kwargs=model_kwargs,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif is_cross_encoder:
|
||||
# Lazy init required for AMD CI
|
||||
@ -346,12 +352,12 @@ class HfRunner:
|
||||
model_name,
|
||||
device=self.device,
|
||||
automodel_args=model_kwargs,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
model = auto_cls.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -372,7 +378,7 @@ class HfRunner:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
# don't put this import at the top level
|
||||
@ -381,7 +387,7 @@ class HfRunner:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if skip_tokenizer_init:
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
|
||||
@ -227,6 +227,7 @@ MULTIMODAL_MODELS = {
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
|
||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
|
||||
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
|
||||
|
||||
@ -1,24 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
from ...utils import error_on_warning
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
|
||||
def test_pos_args_deprecated():
|
||||
with error_on_warning(DeprecationWarning):
|
||||
LLM(model=MODEL_NAME, tokenizer=MODEL_NAME)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
LLM(MODEL_NAME, tokenizer=MODEL_NAME)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'tokenizer'"):
|
||||
LLM(MODEL_NAME, MODEL_NAME)
|
||||
|
||||
with pytest.warns(DeprecationWarning,
|
||||
match="'tokenizer', 'tokenizer_mode'"):
|
||||
LLM(MODEL_NAME, MODEL_NAME, "auto")
|
||||
@ -4,6 +4,8 @@ import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
||||
|
||||
# ruff: noqa: E501
|
||||
@ -23,9 +25,13 @@ INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}
|
||||
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
|
||||
|
||||
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
||||
|
||||
INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
||||
|
||||
|
||||
def test_empty_file():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
@ -105,11 +111,13 @@ def test_embeddings():
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
|
||||
def test_score():
|
||||
@pytest.mark.parametrize("input_batch",
|
||||
[INPUT_SCORE_BATCH, INPUT_RERANK_BATCH])
|
||||
def test_score(input_batch):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
input_file.write(INPUT_SCORE_BATCH)
|
||||
input_file.write(input_batch)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
"vllm",
|
||||
|
||||
@ -76,11 +76,11 @@ async def test_tokenize_completions(
|
||||
})
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {
|
||||
"tokens": tokens,
|
||||
"count": len(tokens),
|
||||
"max_model_len": 8192
|
||||
}
|
||||
result = response.json()
|
||||
assert result["tokens"] == tokens
|
||||
assert result["count"] == len(tokens)
|
||||
assert result["max_model_len"] == 8192
|
||||
assert result["token_strs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -138,11 +138,11 @@ async def test_tokenize_chat(
|
||||
})
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {
|
||||
"tokens": tokens,
|
||||
"count": len(tokens),
|
||||
"max_model_len": 8192
|
||||
}
|
||||
result = response.json()
|
||||
assert result["tokens"] == tokens
|
||||
assert result["count"] == len(tokens)
|
||||
assert result["max_model_len"] == 8192
|
||||
assert result["token_strs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -215,11 +215,46 @@ async def test_tokenize_chat_with_tools(
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {
|
||||
"tokens": tokens,
|
||||
"count": len(tokens),
|
||||
"max_model_len": 8192,
|
||||
}
|
||||
result = response.json()
|
||||
assert result["tokens"] == tokens
|
||||
assert result["count"] == len(tokens)
|
||||
assert result["max_model_len"] == 8192
|
||||
assert result["token_strs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_with_return_token_strs(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
|
||||
prompt = "This is a token_strs test prompt! vllm1"
|
||||
response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"model": model_name,
|
||||
"return_token_strs": True
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=True)
|
||||
tokens_str = tokenizer.convert_ids_to_tokens(tokens)
|
||||
|
||||
result = response.json()
|
||||
assert result["tokens"] == tokens
|
||||
assert result["count"] == len(tokens)
|
||||
assert result["max_model_len"] == 8192
|
||||
assert result["token_strs"] == tokens_str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -191,3 +191,27 @@ def test_streaming_tool_call_with_large_steps():
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -159,3 +159,27 @@ def test_streaming_tool_call_with_large_steps():
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
|
||||
268
tests/entrypoints/test_api_server_process_manager.py
Normal file
268
tests/entrypoints/test_api_server_process_manager.py
Normal file
@ -0,0 +1,268 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import multiprocessing
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.utils import (APIServerProcessManager,
|
||||
wait_for_completion_or_failure)
|
||||
|
||||
# Global variables to control worker behavior
|
||||
WORKER_RUNTIME_SECONDS = 0.5
|
||||
|
||||
|
||||
# Mock implementation of run_api_server_worker
|
||||
def mock_run_api_server_worker(listen_address, sock, args, client_config=None):
|
||||
"""Mock run_api_server_worker that runs for a specific time."""
|
||||
print(f"Mock worker started with client_config: {client_config}")
|
||||
time.sleep(WORKER_RUNTIME_SECONDS)
|
||||
print("Mock worker completed successfully")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server_args():
|
||||
"""Fixture to provide arguments for APIServerProcessManager."""
|
||||
sock = socket.socket()
|
||||
return {
|
||||
"target_server_fn":
|
||||
mock_run_api_server_worker,
|
||||
"listen_address":
|
||||
"localhost:8000",
|
||||
"sock":
|
||||
sock,
|
||||
"args":
|
||||
"test_args", # Simple string to avoid pickling issues
|
||||
"num_servers":
|
||||
3,
|
||||
"input_addresses": [
|
||||
"tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002",
|
||||
"tcp://127.0.0.1:5003"
|
||||
],
|
||||
"output_addresses": [
|
||||
"tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002",
|
||||
"tcp://127.0.0.1:6003"
|
||||
],
|
||||
"stats_update_address":
|
||||
"tcp://127.0.0.1:7000",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("with_stats_update", [True, False])
|
||||
def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
||||
"""Test initializing the APIServerProcessManager."""
|
||||
# Set the worker runtime to ensure tests complete in reasonable time
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 0.5
|
||||
|
||||
# Copy the args to avoid mutating the
|
||||
args = api_server_args.copy()
|
||||
|
||||
if not with_stats_update:
|
||||
args.pop("stats_update_address")
|
||||
manager = APIServerProcessManager(**args)
|
||||
|
||||
try:
|
||||
# Verify the manager was initialized correctly
|
||||
assert len(manager.processes) == 3
|
||||
|
||||
# Verify all processes are running
|
||||
for proc in manager.processes:
|
||||
assert proc.is_alive()
|
||||
|
||||
print("Waiting for processes to run...")
|
||||
time.sleep(WORKER_RUNTIME_SECONDS / 2)
|
||||
|
||||
# They should still be alive at this point
|
||||
for proc in manager.processes:
|
||||
assert proc.is_alive()
|
||||
|
||||
finally:
|
||||
# Always clean up the processes
|
||||
print("Cleaning up processes...")
|
||||
manager.close()
|
||||
|
||||
# Give processes time to terminate
|
||||
time.sleep(0.2)
|
||||
|
||||
# Verify all processes were terminated
|
||||
for proc in manager.processes:
|
||||
assert not proc.is_alive()
|
||||
|
||||
|
||||
@patch("vllm.entrypoints.cli.serve.run_api_server_worker",
|
||||
mock_run_api_server_worker)
|
||||
def test_wait_for_completion_or_failure(api_server_args):
|
||||
"""Test that wait_for_completion_or_failure works with failures."""
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 1.0
|
||||
|
||||
# Create the manager
|
||||
manager = APIServerProcessManager(**api_server_args)
|
||||
|
||||
try:
|
||||
assert len(manager.processes) == 3
|
||||
|
||||
# Create a result capture for the thread
|
||||
result: dict[str, Optional[Exception]] = {"exception": None}
|
||||
|
||||
def run_with_exception_capture():
|
||||
try:
|
||||
wait_for_completion_or_failure(api_server_manager=manager)
|
||||
except Exception as e:
|
||||
result["exception"] = e
|
||||
|
||||
# Start a thread to run wait_for_completion_or_failure
|
||||
wait_thread = threading.Thread(target=run_with_exception_capture,
|
||||
daemon=True)
|
||||
wait_thread.start()
|
||||
|
||||
# Let all processes run for a short time
|
||||
time.sleep(0.2)
|
||||
|
||||
# All processes should still be running
|
||||
assert all(proc.is_alive() for proc in manager.processes)
|
||||
|
||||
# Now simulate a process failure
|
||||
print("Simulating process failure...")
|
||||
manager.processes[0].terminate()
|
||||
|
||||
# Wait for the wait_for_completion_or_failure
|
||||
# to detect and handle the failure
|
||||
# This should trigger it to terminate all other processes
|
||||
wait_thread.join(timeout=1.0)
|
||||
|
||||
# The wait thread should have exited
|
||||
assert not wait_thread.is_alive()
|
||||
|
||||
# Verify that an exception was raised with appropriate error message
|
||||
assert result["exception"] is not None
|
||||
assert "died with exit code" in str(result["exception"])
|
||||
|
||||
# All processes should now be terminated
|
||||
for i, proc in enumerate(manager.processes):
|
||||
assert not proc.is_alive(), f"Process {i} should not be alive"
|
||||
|
||||
finally:
|
||||
manager.close()
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@pytest.mark.timeout(30)
|
||||
def test_normal_completion(api_server_args):
|
||||
"""Test that wait_for_completion_or_failure works in normal completion."""
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 0.1
|
||||
|
||||
# Create the manager
|
||||
manager = APIServerProcessManager(**api_server_args)
|
||||
|
||||
try:
|
||||
# Give processes time to terminate
|
||||
# wait for processes to complete
|
||||
remaining_processes = manager.processes.copy()
|
||||
while remaining_processes:
|
||||
for proc in remaining_processes:
|
||||
if not proc.is_alive():
|
||||
remaining_processes.remove(proc)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Verify all processes have terminated
|
||||
for i, proc in enumerate(manager.processes):
|
||||
assert not proc.is_alive(
|
||||
), f"Process {i} still alive after terminate()"
|
||||
|
||||
# Now call wait_for_completion_or_failure
|
||||
# since all processes have already
|
||||
# terminated, it should return immediately
|
||||
# with no error
|
||||
wait_for_completion_or_failure(api_server_manager=manager)
|
||||
|
||||
finally:
|
||||
# Clean up just in case
|
||||
manager.close()
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@pytest.mark.timeout(30)
|
||||
def test_external_process_monitoring(api_server_args):
|
||||
"""Test that wait_for_completion_or_failure handles additional processes."""
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 100
|
||||
|
||||
# Create and start the external process
|
||||
# (simulates local_engine_manager or coordinator)
|
||||
spawn_context = multiprocessing.get_context("spawn")
|
||||
external_proc = spawn_context.Process(target=mock_run_api_server_worker,
|
||||
name="MockExternalProcess")
|
||||
external_proc.start()
|
||||
|
||||
# Create the class to simulate a coordinator
|
||||
class MockCoordinator:
|
||||
|
||||
def __init__(self, proc):
|
||||
self.proc = proc
|
||||
|
||||
def close(self):
|
||||
if self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
self.proc.join(timeout=0.5)
|
||||
|
||||
# Create a mock coordinator with the external process
|
||||
mock_coordinator = MockCoordinator(external_proc)
|
||||
|
||||
# Create the API server manager
|
||||
manager = APIServerProcessManager(**api_server_args)
|
||||
|
||||
try:
|
||||
# Verify manager initialization
|
||||
assert len(manager.processes) == 3
|
||||
|
||||
# Create a result capture for the thread
|
||||
result: dict[str, Optional[Exception]] = {"exception": None}
|
||||
|
||||
def run_with_exception_capture():
|
||||
try:
|
||||
wait_for_completion_or_failure(api_server_manager=manager,
|
||||
coordinator=mock_coordinator)
|
||||
except Exception as e:
|
||||
result["exception"] = e
|
||||
|
||||
# Start a thread to run wait_for_completion_or_failure
|
||||
wait_thread = threading.Thread(target=run_with_exception_capture,
|
||||
daemon=True)
|
||||
wait_thread.start()
|
||||
|
||||
# Terminate the external process to trigger a failure
|
||||
time.sleep(0.2)
|
||||
external_proc.terminate()
|
||||
|
||||
# Wait for the thread to detect the failure
|
||||
wait_thread.join(timeout=1.0)
|
||||
|
||||
# The wait thread should have completed
|
||||
assert not wait_thread.is_alive(
|
||||
), "wait_for_completion_or_failure thread still running"
|
||||
|
||||
# Verify that an exception was raised with appropriate error message
|
||||
assert result["exception"] is not None, "No exception was raised"
|
||||
error_message = str(result["exception"])
|
||||
assert "died with exit code" in error_message, \
|
||||
f"Unexpected error message: {error_message}"
|
||||
assert "MockExternalProcess" in error_message, \
|
||||
f"Error doesn't mention external process: {error_message}"
|
||||
|
||||
# Verify that all API server processes were terminated as a result
|
||||
for i, proc in enumerate(manager.processes):
|
||||
assert not proc.is_alive(
|
||||
), f"API server process {i} was not terminated"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
manager.close()
|
||||
mock_coordinator.close()
|
||||
time.sleep(0.2)
|
||||
@ -70,7 +70,7 @@ def test_rotary_embedding(
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
@ -135,7 +135,7 @@ def test_batched_rotary_embedding(
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -35,6 +35,15 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
|
||||
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
|
||||
|
||||
|
||||
def test_rocm_aiter_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk')
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||
"""Test that the op can be used with torch.compile."""
|
||||
# Create test tensors
|
||||
@ -120,3 +129,87 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
|
||||
|
||||
def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
|
||||
"""Test that the op can be used with torch.compile."""
|
||||
# Create test tensors
|
||||
token = 64
|
||||
expert = 256
|
||||
num_expert_group = 8
|
||||
topk = 8
|
||||
topk_group = 4
|
||||
renormalize = True
|
||||
scoring_func = "softmax"
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
|
||||
return torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
gating_output, topk_weights, topk_ids, num_expert_group,
|
||||
topk_group, renormalize, scoring_func, scale_factor)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk,
|
||||
(gating_output, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"scoring_func": scoring_func,
|
||||
"routed_scaling_factor": scale_factor
|
||||
},
|
||||
test_utils=("test_faketensor"))
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False)
|
||||
|
||||
topk_weights_original = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_original = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_compiled = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original,
|
||||
scoring_func)
|
||||
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled,
|
||||
scoring_func)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1,
|
||||
indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
|
||||
indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(topk_weights_original,
|
||||
topk_weights_compiled,
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
|
||||
@ -8,7 +8,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
ROCM_FP8_MAX = 224.0
|
||||
ROCM_FP8FNUZ_MAX = 224.0
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
|
||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||
else torch.finfo(quant_dtype)
|
||||
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
-> tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else fp8_traits.max
|
||||
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
|
||||
and current_platform.is_fp8_fnuz() \
|
||||
else fp8_traits.min
|
||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner,
|
||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||
MTEB_EMBED_TASKS)
|
||||
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
|
||||
model_dtype = getattr(
|
||||
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
|
||||
vllm_dtype)
|
||||
|
||||
with set_default_torch_dtype(model_dtype) and hf_runner(
|
||||
with set_default_torch_dtype(vllm_dtype) and hf_runner(
|
||||
model_info.name, is_sentence_transformer=True,
|
||||
dtype=model_dtype) as hf_model:
|
||||
dtype=vllm_dtype) as hf_model:
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||
|
||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||
print("SentenceTransformer:", model_dtype, st_main_score)
|
||||
print("VLLM:", vllm_main_score)
|
||||
print("SentenceTransformers:", st_main_score)
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
|
||||
|
||||
@ -43,6 +43,6 @@ def test_models(
|
||||
|
||||
# the tolerance value of 1e-2 is selected based on the
|
||||
# half datatype tests in
|
||||
# tests/models/embedding/language/test_embedding.py
|
||||
# tests/models/language/pooling/test_embedding.py
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
@ -10,29 +10,31 @@ from ...utils import check_embeddings_close
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# [Encoder-only]
|
||||
pytest.param("BAAI/bge-base-en-v1.5",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
||||
pytest.param("intfloat/multilingual-e5-small"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||
# Be careful of the order of models, decoder-only models should be
|
||||
# placed before encoder-only models, otherwise `Qwen2.5-0.5B-Instruct`
|
||||
# case won't pass because gte-Qwen2-1.5B-instruct will cache custom
|
||||
# model code with bidirectional attention.
|
||||
# [Decoder-only]
|
||||
pytest.param("BAAI/bge-multilingual-gemma2",
|
||||
marks=[pytest.mark.core_model]),
|
||||
pytest.param("intfloat/e5-mistral-7b-instruct",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
# [Encoder-only]
|
||||
pytest.param("BAAI/bge-base-en-v1.5",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
||||
pytest.param("intfloat/multilingual-e5-small"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||
# [Cross-Encoder]
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
|
||||
@ -44,7 +46,7 @@ def test_models(
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["override_pooler_config"] = \
|
||||
PoolerConfig(pooling_type="MEAN")
|
||||
PoolerConfig(pooling_type="MEAN", normalize=False)
|
||||
|
||||
# The example_prompts has ending "\n", for example:
|
||||
# "Write a short story about a robot that dreams for the first time.\n"
|
||||
@ -54,13 +56,11 @@ def test_models(
|
||||
# So we need to strip the input texts to avoid test failing.
|
||||
example_prompts = [str(s).strip() for s in example_prompts]
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
is_sentence_transformer=True) as hf_model:
|
||||
with hf_runner(model, is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(model,
|
||||
task="embed",
|
||||
dtype=dtype,
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
@ -45,6 +45,7 @@ MODELS = [
|
||||
########### Qwen2ForCausalLM
|
||||
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
architecture="Qwen2ForCausalLM",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
########## ModernBertModel
|
||||
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
||||
|
||||
@ -100,6 +100,7 @@ def run_test(
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
max_model_len=448,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
|
||||
@ -40,7 +40,7 @@ def _test_processing_correctness(
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
dtype="auto",
|
||||
revision=None,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
|
||||
@ -434,6 +434,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
|
||||
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
|
||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL")
|
||||
|
||||
@ -314,6 +314,7 @@ def check_embeddings_close(
|
||||
dim=0)
|
||||
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\nCosine similarity: \t{sim:.4f}"
|
||||
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
||||
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
||||
|
||||
|
||||
98
tests/neuron/2_core/test_multi_lora.py
Normal file
98
tests/neuron/2_core/test_multi_lora.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def test_llama_single_lora():
|
||||
sql_lora_files = snapshot_download(
|
||||
repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-hf",
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=4,
|
||||
max_model_len=512,
|
||||
use_v2_block_manager=True,
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled": False,
|
||||
"skip_warmup": True,
|
||||
"lora_modules": [{
|
||||
"name": "lora_id_1",
|
||||
"path": sql_lora_files
|
||||
}]
|
||||
},
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=256,
|
||||
device="neuron")
|
||||
"""For multi-lora requests using NxDI as the backend, only the lora_name
|
||||
needs to be specified. The lora_id and lora_path are supplied at the LLM
|
||||
class/server initialization, after which the paths are handled by NxDI"""
|
||||
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
outputs = llm.generate(prompts,
|
||||
SamplingParams(top_k=1),
|
||||
lora_request=[lora_req_1, lora_req_1])
|
||||
|
||||
expected_outputs = [
|
||||
" the head of state and head of government of the United States. "
|
||||
"The president direct",
|
||||
" a city of contrasts. The city is home to the Eiffel Tower"
|
||||
]
|
||||
|
||||
for expected_output, output in zip(expected_outputs, outputs):
|
||||
generated_text = output.outputs[0].text
|
||||
assert (expected_output == generated_text)
|
||||
|
||||
|
||||
def test_llama_multiple_lora():
|
||||
sql_lora_files = snapshot_download(
|
||||
repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-hf",
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=4,
|
||||
max_model_len=512,
|
||||
use_v2_block_manager=True,
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled":
|
||||
False,
|
||||
"skip_warmup":
|
||||
True,
|
||||
"lora_modules": [{
|
||||
"name": "lora_id_1",
|
||||
"path": sql_lora_files
|
||||
}, {
|
||||
"name": "lora_id_2",
|
||||
"path": sql_lora_files
|
||||
}]
|
||||
},
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=256,
|
||||
device="neuron")
|
||||
"""For multi-lora requests using NxDI as the backend, only the lora_name
|
||||
needs to be specified. The lora_id and lora_path are supplied at the LLM
|
||||
class/server initialization, after which the paths are handled by NxDI"""
|
||||
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
|
||||
lora_req_2 = LoRARequest("lora_id_2", 1, " ")
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
outputs = llm.generate(prompts,
|
||||
SamplingParams(top_k=1),
|
||||
lora_request=[lora_req_1, lora_req_2])
|
||||
|
||||
expected_outputs = [
|
||||
" the head of state and head of government of the United States. "
|
||||
"The president direct",
|
||||
" a city of contrasts. The city is home to the Eiffel Tower"
|
||||
]
|
||||
|
||||
for expected_output, output in zip(expected_outputs, outputs):
|
||||
generated_text = output.outputs[0].text
|
||||
assert (expected_output == generated_text)
|
||||
@ -103,7 +103,7 @@ class TestTwoTokenBadWord:
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
|
||||
@ -4,7 +4,6 @@ import gc
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -16,7 +15,6 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||
TensorSerializer,
|
||||
is_vllm_tensorized,
|
||||
load_with_tensorizer,
|
||||
open_stream,
|
||||
tensorize_vllm_model)
|
||||
# yapf: enable
|
||||
@ -61,21 +59,6 @@ def write_keyfile(keyfile_path: str):
|
||||
f.write(encryption_params.key)
|
||||
|
||||
|
||||
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
|
||||
def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||
mock_linear_method = MagicMock()
|
||||
mock_agent_instance = mock_agent.return_value
|
||||
mock_agent_instance.deserialize.return_value = MagicMock()
|
||||
|
||||
result = load_with_tensorizer(tensorizer_config,
|
||||
quant_method=mock_linear_method)
|
||||
|
||||
mock_agent.assert_called_once_with(tensorizer_config,
|
||||
quant_method=mock_linear_method)
|
||||
mock_agent_instance.deserialize.assert_called_once()
|
||||
assert result == mock_agent_instance.deserialize.return_value
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_can_deserialize_s3(vllm_runner):
|
||||
model_ref = "EleutherAI/pythia-1.4b"
|
||||
|
||||
@ -17,7 +17,8 @@ from vllm_test_utils.monitor import monitor
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||
bind_kv_cache, deprecate_kwargs, get_open_port,
|
||||
bind_kv_cache, common_broadcastable_dtype,
|
||||
deprecate_kwargs, get_open_port, is_lossless_cast,
|
||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_zmq_path,
|
||||
supports_kw, swap_dict_values)
|
||||
@ -567,12 +568,65 @@ def test_lru_cache():
|
||||
assert 6 in cache
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("src_dtype", "tgt_dtype", "expected_result"),
|
||||
[
|
||||
# Different precision_levels
|
||||
(torch.bool, torch.int8, True),
|
||||
(torch.bool, torch.float16, True),
|
||||
(torch.bool, torch.complex32, True),
|
||||
(torch.int64, torch.bool, False),
|
||||
(torch.int64, torch.float16, True),
|
||||
(torch.int64, torch.complex32, True),
|
||||
(torch.float64, torch.bool, False),
|
||||
(torch.float64, torch.int8, False),
|
||||
(torch.float64, torch.complex32, True),
|
||||
(torch.complex128, torch.bool, False),
|
||||
(torch.complex128, torch.int8, False),
|
||||
(torch.complex128, torch.float16, False),
|
||||
# precision_level=0
|
||||
(torch.bool, torch.bool, True),
|
||||
# precision_level=1
|
||||
(torch.int8, torch.int16, True),
|
||||
(torch.int16, torch.int8, False),
|
||||
(torch.uint8, torch.int8, False),
|
||||
(torch.int8, torch.uint8, False),
|
||||
# precision_level=2
|
||||
(torch.float16, torch.float32, True),
|
||||
(torch.float32, torch.float16, False),
|
||||
(torch.bfloat16, torch.float32, True),
|
||||
(torch.float32, torch.bfloat16, False),
|
||||
# precision_level=3
|
||||
(torch.complex32, torch.complex64, True),
|
||||
(torch.complex64, torch.complex32, False),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
|
||||
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("dtypes", "expected_result"),
|
||||
[
|
||||
([torch.bool], torch.bool),
|
||||
([torch.bool, torch.int8], torch.int8),
|
||||
([torch.bool, torch.int8, torch.float16], torch.float16),
|
||||
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_common_broadcastable_dtype(dtypes, expected_result):
|
||||
assert common_broadcastable_dtype(dtypes) == expected_result
|
||||
|
||||
|
||||
def test_placeholder_module_error_handling():
|
||||
placeholder = PlaceholderModule("placeholder_1234")
|
||||
|
||||
def build_ctx():
|
||||
return pytest.raises(ModuleNotFoundError,
|
||||
match="No module named")
|
||||
return pytest.raises(ModuleNotFoundError, match="No module named")
|
||||
|
||||
with build_ctx():
|
||||
int(placeholder)
|
||||
@ -608,6 +662,7 @@ def test_placeholder_module_error_handling():
|
||||
_ = placeholder_attr.module
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
[
|
||||
@ -618,6 +673,7 @@ def test_placeholder_module_error_handling():
|
||||
# Tests for both keys do not exist
|
||||
({1: "a", 2: "b"}, 3, 4),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_swap_dict_values(obj, key1, key2):
|
||||
original_obj = obj.copy()
|
||||
swap_dict_values(obj, key1, key2)
|
||||
@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2):
|
||||
assert key1 not in obj
|
||||
|
||||
|
||||
def test_model_specification(parser_with_config,
|
||||
cli_config_file,
|
||||
def test_model_specification(parser_with_config, cli_config_file,
|
||||
cli_config_file_with_model):
|
||||
# Test model in CLI takes precedence over config
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', 'cli-model', '--config', cli_config_file_with_model
|
||||
])
|
||||
args = parser_with_config.parse_args(
|
||||
['serve', 'cli-model', '--config', cli_config_file_with_model])
|
||||
assert args.model_tag == 'cli-model'
|
||||
assert args.served_model_name == 'mymodel'
|
||||
|
||||
# Test model from config file works
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', '--config', cli_config_file_with_model,
|
||||
'serve',
|
||||
'--config',
|
||||
cli_config_file_with_model,
|
||||
])
|
||||
assert args.model == 'config-model'
|
||||
assert args.served_model_name == 'mymodel'
|
||||
@ -654,17 +710,19 @@ def test_model_specification(parser_with_config,
|
||||
|
||||
# Test using --model option raises error
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"With `vllm serve`, you should provide the model as a positional "
|
||||
"argument or in a config file instead of via the `--model` option."
|
||||
),
|
||||
ValueError,
|
||||
match=
|
||||
("With `vllm serve`, you should provide the model as a positional "
|
||||
"argument or in a config file instead of via the `--model` option."),
|
||||
):
|
||||
parser_with_config.parse_args(['serve', '--model', 'my-model'])
|
||||
|
||||
# Test other config values are preserved
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', 'cli-model', '--config', cli_config_file_with_model,
|
||||
'serve',
|
||||
'cli-model',
|
||||
'--config',
|
||||
cli_config_file_with_model,
|
||||
])
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
@ -673,7 +731,7 @@ def test_model_specification(parser_with_config,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||
(None, bool, [1, 2, 3])])
|
||||
(None, bool, [1, 2, 3])])
|
||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||
def test_sha256(input: tuple, output: int):
|
||||
hash = sha256(input)
|
||||
@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int):
|
||||
assert hash != 0
|
||||
|
||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
|
||||
byteorder="big")
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert hash == sha256(input)
|
||||
@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int):
|
||||
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
|
||||
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
|
||||
("inproc://some_identifier", ("inproc", "some_identifier", "")),
|
||||
]
|
||||
)
|
||||
])
|
||||
def test_split_zmq_path(path, expected):
|
||||
assert split_zmq_path(path) == expected
|
||||
|
||||
@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected):
|
||||
"tcp://127.0.0.1", # Missing port
|
||||
"tcp://[::1]", # Missing port for IPv6
|
||||
"tcp://:5555", # Missing host
|
||||
]
|
||||
)
|
||||
])
|
||||
def test_split_zmq_path_invalid(invalid_path):
|
||||
with pytest.raises(ValueError):
|
||||
split_zmq_path(invalid_path)
|
||||
@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6():
|
||||
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
|
||||
|
||||
# Verify that the IPV6 option is set
|
||||
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
|
||||
assert zsock.getsockopt(
|
||||
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
|
||||
|
||||
# Clean up
|
||||
zsock.close()
|
||||
|
||||
@ -26,7 +26,7 @@ TOP_KS = [2, 6]
|
||||
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
||||
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
|
||||
@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
@ -99,7 +99,8 @@ class RemoteOpenAIServer:
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
parser = ServeSubcommand().subparser_init(subparsers)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
|
||||
@ -45,7 +45,6 @@ def make_request(request_id,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
@ -38,7 +38,6 @@ def make_request(request_id,
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
prompt_logprobs=prompt_logprobs),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
@ -138,7 +138,6 @@ def create_requests(num_requests: int,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=0,
|
||||
)
|
||||
requests.append(request)
|
||||
return requests
|
||||
@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
||||
|
||||
# No draft or accepted tokens counted yet
|
||||
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None
|
||||
assert not engine_core_outputs or (
|
||||
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
|
||||
|
||||
# Schedule the speculated tokens for validation
|
||||
output = scheduler.schedule()
|
||||
@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
|
||||
scheduler_stats = engine_core_outputs.scheduler_stats
|
||||
scheduler_stats = engine_core_outputs[0].scheduler_stats \
|
||||
if engine_core_outputs else None
|
||||
if expected[0] == 0:
|
||||
assert scheduler_stats.spec_decoding_stats is None
|
||||
else:
|
||||
@ -843,7 +844,7 @@ def _step_until_done(
|
||||
# We should be in the decode phase now.
|
||||
assert num_scheduled_tokens == 1
|
||||
assert len(output.kv_connector_metadata.requests) == 0
|
||||
ecos = scheduler.update_from_output(output, model_runner_output)
|
||||
ecos = scheduler.update_from_output(output, model_runner_output)[0]
|
||||
all_done = True
|
||||
for eco in ecos.outputs:
|
||||
if eco.finish_reason is None:
|
||||
|
||||
@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(engine_core.scheduler.running) == 4
|
||||
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step().outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
req0.request_id = req1.request_id = "test"
|
||||
engine_core.add_request(req0)
|
||||
|
||||
while len(engine_core.step().outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
engine_core.add_request(req1)
|
||||
while len(engine_core.step().outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step().outputs) > 0:
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
@ -296,7 +296,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
engine_core.add_request(req1)
|
||||
|
||||
# Schedule Batch 1: (10, req0)
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.step_with_batch_queue()[0] is None
|
||||
assert engine_core.batch_queue.qsize() == 1
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 10
|
||||
@ -305,7 +305,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
req0.request_id].num_computed_tokens == 10
|
||||
|
||||
# Schedule Batch 2: (2, req0), (8, req1)
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.step_with_batch_queue()[0] is None
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 2
|
||||
@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 4
|
||||
|
||||
# Batch queue is full. Finish Batch 2. Get first token of req0.
|
||||
output = engine_core.step_with_batch_queue()
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
||||
@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 1
|
||||
|
||||
# Batch queue is full. Finish Batch 3. Get first token of req1.
|
||||
output = engine_core.step_with_batch_queue()
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
||||
@ -358,11 +358,11 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
engine_core.scheduler.requests[1].num_tokens + 1,
|
||||
]
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
output = engine_core.step_with_batch_queue()
|
||||
output = engine_core.step_with_batch_queue()[0]
|
||||
if step % 2 == 0:
|
||||
# Even steps consumes an output.
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert len(output[0].outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert engine_core.scheduler.requests[
|
||||
req_id].num_tokens == expected_num_tokens[req_id]
|
||||
|
||||
171
tests/v1/entrypoints/openai/test_multi_api_servers.py
Normal file
171
tests/v1/entrypoints/openai/test_multi_api_servers.py
Normal file
@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
|
||||
DP_SIZE = os.getenv("DP_SIZE", "1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
"--api-server-count",
|
||||
"4",
|
||||
"--data_parallel_size",
|
||||
DP_SIZE,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(default_server_args):
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str) -> None:
|
||||
|
||||
async def make_request():
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=10,
|
||||
temperature=1.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
# The exact number of tokens can vary slightly with temperature=1.0,
|
||||
# so we check for a reasonable minimum length.
|
||||
assert len(choice.text) >= 1
|
||||
# Finish reason might not always be 'length' if the model finishes early
|
||||
# or due to other reasons, especially with high temperature.
|
||||
# So, we'll accept 'length' or 'stop'.
|
||||
assert choice.finish_reason in ("length", "stop")
|
||||
|
||||
# Token counts can also vary, so we check they are positive.
|
||||
assert completion.usage.completion_tokens > 0
|
||||
assert completion.usage.prompt_tokens > 0
|
||||
assert completion.usage.total_tokens > 0
|
||||
return completion
|
||||
|
||||
# Test single request
|
||||
result = await make_request()
|
||||
assert result is not None
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send two bursts of requests
|
||||
num_requests = 100
|
||||
tasks = [make_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
tasks = [make_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str) -> None:
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
async def make_streaming_request():
|
||||
# Perform a non-streaming request to get the expected full output
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
# Perform the streaming request
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
last_chunk = None
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
last_chunk = chunk # Keep track of the last chunk
|
||||
|
||||
# finish reason should only return in the last block for OpenAI API
|
||||
assert finish_reason_count == 1, (
|
||||
"Finish reason should appear exactly once.")
|
||||
assert last_chunk is not None, (
|
||||
"Stream should have yielded at least one chunk.")
|
||||
assert last_chunk.choices[
|
||||
0].finish_reason == "length", "Finish reason should be 'length'."
|
||||
# Check that the combined text matches the non-streamed version.
|
||||
assert "".join(
|
||||
chunks
|
||||
) == single_output, "Streamed output should match non-streamed output."
|
||||
return True # Indicate success for this request
|
||||
|
||||
# Test single request
|
||||
result = await make_streaming_request()
|
||||
assert result is not None
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send two bursts of requests
|
||||
num_requests = 100
|
||||
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(
|
||||
results
|
||||
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(
|
||||
results
|
||||
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
@ -43,7 +43,7 @@ def test_basic_lifecycle():
|
||||
# Ensure the request is finished after 1 tokens.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs.outputs[0]
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco.outputs[0].kv_transfer_params
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
|
||||
# Ensure we send all block ids, even if there is a cache hit.
|
||||
assert (len(
|
||||
|
||||
@ -61,7 +61,7 @@ def test_basic_lifecycle():
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert len(engine_core_outputs.outputs) == 0
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
@ -112,7 +112,7 @@ def test_basic_lifecycle():
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs.outputs
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
@ -335,7 +335,7 @@ def test_full_block_prompt():
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs.outputs
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
|
||||
@ -153,7 +153,6 @@ def create_request(
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=0,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
@ -81,7 +81,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
block_ids=[0],
|
||||
block_ids=[[0]], # block_ids should be list[list[int]]
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
))
|
||||
@ -112,14 +112,35 @@ def _is_req_added(model_runner, req_id: str) -> bool:
|
||||
|
||||
|
||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
"""Check if the request state block IDs match the block table.
|
||||
|
||||
This function handles both legacy BlockTable and new MultiGroupBlockTable
|
||||
structures for backward compatibility.
|
||||
"""
|
||||
|
||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||
block_table = model_runner.input_batch.block_table
|
||||
multi_group_block_table = model_runner.input_batch.block_table
|
||||
req_state = model_runner.requests[req_id]
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
|
||||
|
||||
# Access the first block table from MultiGroupBlockTable
|
||||
# This is safe since we currently only use single KV cache groups
|
||||
block_table = multi_group_block_table[0]
|
||||
|
||||
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
|
||||
# Extract the first group's block IDs
|
||||
if isinstance(req_state.block_ids[0], list):
|
||||
# New format: list[list[int]] - extract first group
|
||||
req_block_ids = req_state.block_ids[0]
|
||||
else:
|
||||
# Legacy format: list[int] - use directly
|
||||
req_block_ids = req_state.block_ids
|
||||
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
|
||||
return False
|
||||
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
return (block_table.block_table_np[req_index, :num_blocks] ==
|
||||
req_state.block_ids).all()
|
||||
block_table_values = block_table.block_table_np[req_index, :num_blocks]
|
||||
return (block_table_values == req_block_ids).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -13,27 +16,30 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
|
||||
|
||||
def initialize_kv_cache(runner: GPUModelRunner):
|
||||
"""
|
||||
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
||||
"""
|
||||
attn_spec = FullAttentionSpec(
|
||||
block_size=BLOCK_SIZE,
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=10,
|
||||
num_blocks=NUM_BLOCKS,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
"layer.0": KVCacheTensor(size=tensor_size),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
))
|
||||
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
|
||||
])
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
runner.input_batch = InputBatch(
|
||||
@ -65,7 +71,7 @@ def model_runner():
|
||||
seed=42,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
block_size=BLOCK_SIZE,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
@ -77,6 +83,10 @@ def model_runner():
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context[
|
||||
"layer.0"] = Attention(num_heads, head_size, 0.1)
|
||||
|
||||
device = "cuda"
|
||||
runner = GPUModelRunner(vllm_config, device)
|
||||
@ -84,6 +94,9 @@ def model_runner():
|
||||
return runner
|
||||
|
||||
|
||||
model_runner_2 = model_runner
|
||||
|
||||
|
||||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
new_reqs = []
|
||||
num_scheduled_tokens = {}
|
||||
@ -321,3 +334,53 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[1])
|
||||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
||||
|
||||
|
||||
def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
# This test checks if GPUModelRunner initializes correctly when an attention
|
||||
# backend enforces a non-default KV cache stride order.
|
||||
n_heads = model_runner.model_config.get_num_kv_heads(
|
||||
model_runner.parallel_config)
|
||||
expected_kv_cache_shape = [
|
||||
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
|
||||
model_runner.model_config.get_head_size()
|
||||
]
|
||||
# TODO mla test
|
||||
default_stride = list(range(5))
|
||||
# Permutation that gets you back to expected kv shape
|
||||
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
|
||||
|
||||
def rnd_stride_order():
|
||||
return rnd_stride
|
||||
|
||||
# Patch the attention backend class and re-trigger the KV cache creation.
|
||||
for attn_backend in model_runner.attn_backends:
|
||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||
rnd_stride_order)
|
||||
|
||||
model_runner.attn_backends = []
|
||||
model_runner.attn_metadata_builders = []
|
||||
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||
|
||||
# Shape is unchanged, but layout may differ
|
||||
kv_cache_shape = model_runner.kv_caches[0].shape
|
||||
assert list(kv_cache_shape) == expected_kv_cache_shape
|
||||
if default_stride == rnd_stride:
|
||||
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
else:
|
||||
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
|
||||
|
||||
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
||||
# In this test, model_runner loads model + weights in one go, while
|
||||
# model_runner_2 loads dummy weights first then load real weights inplace
|
||||
model_runner.load_model()
|
||||
original_load_format = model_runner_2.load_config.load_format
|
||||
model_runner_2.load_config.load_format = "dummy"
|
||||
model_runner_2.load_model() # Initial model loading with dummy weights
|
||||
assert str(model_runner.get_model().state_dict()) != str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
model_runner_2.load_config.load_format = original_load_format
|
||||
model_runner_2.load_model() # Load real weights inplace
|
||||
assert str(model_runner.get_model().state_dict()) == str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
|
||||
@ -58,6 +58,9 @@ def main() -> int:
|
||||
if not Path(filepath).exists():
|
||||
continue
|
||||
|
||||
if filepath == "setup.py":
|
||||
continue
|
||||
|
||||
violations = check_file(filepath)
|
||||
if violations:
|
||||
print(f"\n❌ {filepath}:")
|
||||
|
||||
@ -132,8 +132,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
assert self.runner.model_config.max_model_len == 32768,\
|
||||
"AITER MLA requires max model len to be set to 32768"
|
||||
assert self.block_size == 1, "AITER MLA requires only block size 1."
|
||||
|
||||
def prepare(self):
|
||||
|
||||
@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
and layer._v_scale and layer._prob_scale
|
||||
and self.kv_cache_dtype == "fp8")
|
||||
full_scales = (
|
||||
layer._q_scale, layer._k_scale, layer._v_scale,
|
||||
layer._prob_scale) if use_fp8_scales else None
|
||||
layer._q_scale.item(), layer._k_scale.item(),
|
||||
layer._v_scale.item(),
|
||||
layer._prob_scale.item()) if use_fp8_scales else None
|
||||
self.triton_attn_func(
|
||||
query,
|
||||
key,
|
||||
|
||||
@ -264,8 +264,8 @@ def chunked_prefill_paged_decode(
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert key_cache.dtype == torch.uint8
|
||||
assert value_cache.dtype == torch.uint8
|
||||
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
|
||||
@ -744,8 +744,8 @@ def context_attention_fwd(q,
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert (k_cache.dtype == torch.uint8)
|
||||
assert (v_cache.dtype == torch.uint8)
|
||||
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -9,9 +9,6 @@ generation. Supported dataset types include:
|
||||
- BurstGPT
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
|
||||
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
|
||||
SampleRequest instances, similar to the approach used in ShareGPT.
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
@ -26,6 +23,7 @@ from io import BytesIO
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@ -443,6 +441,99 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Custom Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CustomDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the Custom dataset. Loads data from a JSONL file and generates
|
||||
sample requests based on conversation turns. E.g.,
|
||||
```
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
# self.data will be a list of dictionaries
|
||||
# e.g., [{"prompt": "What is the capital of India?"}, ...]
|
||||
# This will be the standardized format which load_data()
|
||||
# has to convert into depending on the filetype of dataset_path.
|
||||
# sample() will assume this standardized format of self.data
|
||||
self.data = []
|
||||
|
||||
# Load the JSONL file
|
||||
if self.dataset_path.endswith(".jsonl"):
|
||||
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
|
||||
lines=True)
|
||||
|
||||
# check if the JSONL file has a 'prompt' column
|
||||
if "prompt" not in jsonl_data.columns:
|
||||
raise ValueError("JSONL file must contain a 'prompt' column.")
|
||||
|
||||
# Convert each row to a dictionary and append to self.data
|
||||
# This will convert the DataFrame to a list of dictionaries
|
||||
# where each dictionary corresponds to a row in the DataFrame.
|
||||
# This is the standardized format we want for self.data
|
||||
for _, row in jsonl_data.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only JSONL format is supported for CustomDataset.")
|
||||
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
lora_path: Optional[str] = None,
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["prompt"]
|
||||
|
||||
# apply template
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@ -1110,6 +1110,8 @@ def main(args: argparse.Namespace):
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
if field in benchmark_result:
|
||||
del benchmark_result[field]
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
@ -1120,6 +1122,7 @@ def main(args: argparse.Namespace):
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name,
|
||||
mode="a+" if args.append_result else "w",
|
||||
|
||||
@ -16,7 +16,7 @@ import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||
|
||||
from .compiler_interface import (CompilerInterface, EagerAdaptor,
|
||||
InductorAdaptor, InductorStandaloneAdaptor)
|
||||
@ -29,7 +29,8 @@ logger = init_logger(__name__)
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
if envs.VLLM_TEST_STANDALONE_COMPILE:
|
||||
if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer(
|
||||
"2.8.0"):
|
||||
logger.info("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor()
|
||||
else:
|
||||
|
||||
@ -12,6 +12,7 @@ import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
@ -154,7 +155,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
name = "inductor_standalone"
|
||||
|
||||
@ -175,6 +176,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
@ -262,6 +264,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
@ -412,8 +415,14 @@ class InductorAdaptor(CompilerInterface):
|
||||
# compilation cache. So turn off the checks if we disable the
|
||||
# compilation cache.
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||
assert hash_str is not None, (
|
||||
"failed to get the hash of the compiled graph")
|
||||
if hash_str is None:
|
||||
raise RuntimeError(
|
||||
"vLLM failed to compile the model. The most "
|
||||
"likely reason for this is that a previous compilation "
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. ")
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph")
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
@ -528,6 +537,7 @@ class EagerAdaptor(CompilerInterface):
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
|
||||
@ -15,6 +15,10 @@ class CompilationCounter:
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
num_cudagraph_caputured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
230
vllm/config.py
230
vllm/config.py
@ -24,6 +24,7 @@ import torch
|
||||
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
|
||||
model_validator)
|
||||
from pydantic.dataclasses import dataclass
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated, runtime_checkable
|
||||
@ -42,15 +43,16 @@ from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
try_get_generation_config, uses_mrope)
|
||||
try_get_generation_config, try_get_safetensors_metadata, uses_mrope)
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
|
||||
LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
|
||||
random_uuid, resolve_obj_by_qualname)
|
||||
LayerBlockType, common_broadcastable_dtype,
|
||||
cuda_device_count_stateless, get_cpu_memory,
|
||||
get_open_port, is_torch_equal_or_newer, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
@ -304,7 +306,7 @@ class ModelConfig:
|
||||
- 25.6k -> 25,600"""
|
||||
spec_target_max_model_len: Optional[int] = None
|
||||
"""Specify the maximum length for spec decoding draft models."""
|
||||
quantization: Optional[QuantizationMethods] = None
|
||||
quantization: SkipValidation[Optional[QuantizationMethods]] = None
|
||||
"""Method used to quantize the weights. If `None`, we first check the
|
||||
`quantization_config` attribute in the model config file. If that is
|
||||
`None`, we assume the model weights are not quantized and use `dtype` to
|
||||
@ -380,7 +382,7 @@ class ModelConfig:
|
||||
"""Initialize non-default neuron config or override default neuron config
|
||||
that are specific to Neuron devices, this argument will be used to
|
||||
configure the neuron config that can not be gathered from the vllm
|
||||
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`."""
|
||||
arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`."""
|
||||
pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||
models."""
|
||||
@ -540,7 +542,24 @@ class ModelConfig:
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
self.hf_config,
|
||||
self.dtype,
|
||||
is_pooling_model=self.runner_type == "pooling",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Workaround for Gemma 2 which uses interleaved sliding window
|
||||
# attention, but it's not specified in its config. TODO: remove this
|
||||
@ -597,16 +616,6 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
"`override_neuron_config` is only supported on Neuron.")
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
@ -692,7 +701,6 @@ class ModelConfig:
|
||||
self.model, self.revision)
|
||||
|
||||
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||
|
||||
if self.runner_type == "pooling":
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.override_pooler_config = PoolerConfig(
|
||||
@ -1360,6 +1368,16 @@ class ModelConfig:
|
||||
@property
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
"""Extract the HF encoder/decoder model flag."""
|
||||
"""
|
||||
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
|
||||
True to enable cross-attention
|
||||
Neuron needs all multimodal data to be in the decoder and does not
|
||||
need to explicitly enable cross-attention
|
||||
"""
|
||||
if (current_platform.is_neuron()
|
||||
and self.hf_config.model_type == "mllama"):
|
||||
return False
|
||||
|
||||
return is_encoder_decoder(self.hf_config)
|
||||
|
||||
@property
|
||||
@ -1772,6 +1790,10 @@ class ParallelConfig:
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = False
|
||||
""" Use data parallelism instead of tensor parallelism for vision encoder.
|
||||
Only support LLama4 for now"""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
@ -2231,7 +2253,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: Union[Device, torch.device] = "auto"
|
||||
device: SkipValidation[Union[Device, torch.device]] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
@ -2573,7 +2595,8 @@ class SpeculativeConfig:
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config,
|
||||
method=self.method)
|
||||
method=self.method,
|
||||
model_type="eagle")
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
|
||||
if (self.num_speculative_tokens is not None
|
||||
@ -3064,13 +3087,37 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] #
|
||||
# model_type -> reason
|
||||
_FLOAT16_NOT_SUPPORTED_MODELS = {
|
||||
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
}
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
def _is_valid_dtype(model_type: str, dtype: torch.dtype):
|
||||
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_valid_dtype(model_type: str, dtype: torch.dtype):
|
||||
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16:
|
||||
reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type]
|
||||
raise ValueError(f"The model type {model_type!r} "
|
||||
f"does not support float16. Reason: {reason}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _find_dtype(
|
||||
model_id: str,
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
*,
|
||||
revision: Optional[str],
|
||||
):
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
@ -3082,75 +3129,111 @@ def _get_and_verify_dtype(
|
||||
if config_dtype is None and hasattr(config, "vision_config"):
|
||||
config_dtype = getattr(config.vision_config, "torch_dtype", None)
|
||||
|
||||
# Try to read the dtype of the weights if they are in safetensors format
|
||||
if config_dtype is None:
|
||||
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
|
||||
|
||||
if repo_mt and (files_mt := repo_mt.files_metadata):
|
||||
param_dtypes: set[torch.dtype] = {
|
||||
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
|
||||
for file_mt in files_mt.values()
|
||||
for dtype_str in file_mt.parameter_count
|
||||
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
|
||||
}
|
||||
|
||||
if param_dtypes:
|
||||
return common_broadcastable_dtype(param_dtypes)
|
||||
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
return config_dtype
|
||||
|
||||
|
||||
def _resolve_auto_dtype(
|
||||
model_type: str,
|
||||
config_dtype: torch.dtype,
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
supported_dtypes = [
|
||||
dtype for dtype in current_platform.supported_dtypes
|
||||
if _is_valid_dtype(model_type, dtype)
|
||||
]
|
||||
|
||||
if is_pooling_model and torch.float16 in supported_dtypes:
|
||||
preferred_dtype = torch.float16
|
||||
else:
|
||||
preferred_dtype = supported_dtypes[0]
|
||||
|
||||
# Downcast for float32 models
|
||||
if config_dtype == torch.float32:
|
||||
config_dtype = preferred_dtype
|
||||
|
||||
if config_dtype in supported_dtypes:
|
||||
return config_dtype
|
||||
|
||||
# Ensure device compatibility
|
||||
device_name = current_platform.get_device_name()
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
if device_capability is None:
|
||||
device_str = f"{device_name!r}"
|
||||
else:
|
||||
version_str = device_capability.as_version_str()
|
||||
device_str = f"{device_name!r} (with compute capability {version_str})"
|
||||
|
||||
logger.warning(
|
||||
"Your device %s doesn't support %s. "
|
||||
"Falling back to %s for compatibility.",
|
||||
device_str,
|
||||
config_dtype,
|
||||
preferred_dtype,
|
||||
)
|
||||
|
||||
return preferred_dtype
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
model_id: str,
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
revision: Optional[str] = None,
|
||||
) -> torch.dtype:
|
||||
config_dtype = _find_dtype(model_id, config, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
# Set default dtype from model config
|
||||
if config_dtype == torch.float32:
|
||||
# Following common practice, we use float16 for float32 models
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
if config.model_type == "plamo2":
|
||||
logger.warning(
|
||||
"For PLaMo2, we cast models to bfloat16 instead of using "
|
||||
"float16 by default. This is because float16 does not work."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
# Deal with torch dtype fallback for device compatibility.
|
||||
from vllm.platforms import current_platform
|
||||
if torch_dtype not in current_platform.supported_dtypes:
|
||||
device_name = current_platform.get_device_name()
|
||||
|
||||
if ((capability := current_platform.get_device_capability())
|
||||
is None):
|
||||
compute_str = ""
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f" (with compute capability {version_str})"
|
||||
fallback_dtype = current_platform.supported_dtypes[0]
|
||||
logger.warning(
|
||||
"Your %s device%s doesn't support %s. " \
|
||||
"Falling back to %s for compatibility.",
|
||||
device_name, compute_str, torch_dtype, fallback_dtype
|
||||
)
|
||||
torch_dtype = fallback_dtype
|
||||
|
||||
if current_platform.is_hpu() and torch_dtype == torch.float16:
|
||||
logger.warning(
|
||||
"For HPU, we cast models to bfloat16 instead of "
|
||||
"using float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == "float16" and config.model_type == "plamo2":
|
||||
logger.warning(
|
||||
"For PLaMo2, using float16 is unstable and might cause "
|
||||
"unexpected behavior. Please use bfloat16 or float32 instead.")
|
||||
torch_dtype = torch.float16
|
||||
torch_dtype = _resolve_auto_dtype(
|
||||
model_type,
|
||||
config_dtype,
|
||||
is_pooling_model=is_pooling_model,
|
||||
)
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
raise ValueError(f"Unknown dtype: {dtype!r}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
_check_valid_dtype(model_type, torch_dtype)
|
||||
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
||||
@ -4315,15 +4398,10 @@ class VllmConfig:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
# NOTE(woosuk): Currently, we use inductor because the piecewise
|
||||
# CUDA graphs do not work properly with the custom CUDA kernels.
|
||||
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
||||
# and avoid any potential issues with the inductor.
|
||||
# FIXME(rob): Add function to set all of these.
|
||||
if not self.compilation_config.custom_ops:
|
||||
self.compilation_config.custom_ops = ["none"]
|
||||
self.compilation_config.use_cudagraph = True
|
||||
self.compilation_config.use_inductor = True
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
self.compilation_config.pass_config.enable_fusion = False
|
||||
self.compilation_config.pass_config.enable_noop = False
|
||||
|
||||
@ -70,7 +70,8 @@ class KVConnectorFactory:
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||
logger.info("Creating v1 connector with name: %s", connector_name)
|
||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_name, kv_transfer_config.engine_id)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
|
||||
@ -172,6 +172,11 @@ class NixlConnectorScheduler:
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||
vllm_config.parallel_config.data_parallel_rank_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv.
|
||||
@ -310,8 +315,8 @@ class NixlConnectorScheduler:
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
||||
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
|
||||
|
||||
@ -330,9 +335,18 @@ class NixlConnectorWorker:
|
||||
# Map of engine_id -> agent_name.
|
||||
self._remote_agents: dict[str, str] = {}
|
||||
|
||||
# NIXL handshake port.
|
||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
||||
# base port (which is sent in the KVTransferParams).
|
||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||
vllm_config.parallel_config.data_parallel_rank_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
|
||||
# Metadata.
|
||||
self.engine_id = engine_id
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
@ -382,15 +396,11 @@ class NixlConnectorWorker:
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, rank: int):
|
||||
ready_event: threading.Event, base_port: int,
|
||||
tp_rank: int):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach like an ETCD server in the future.
|
||||
|
||||
# NOTE(rob): to support heterogeneous TP, we will have to
|
||||
# move this into the scheduler rather than worker, since
|
||||
# each rank needs the metadata of all other ranks (whereas
|
||||
# in this setup, each rank only gets one other rank's meta.
|
||||
# to a better approach via HTTP endpoint soon.
|
||||
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = encoder.encode(metadata)
|
||||
@ -400,11 +410,7 @@ class NixlConnectorWorker:
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
# NOTE(rob): we need each rank to have a unique port. This
|
||||
# hack to keeps us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
path = make_zmq_path("tcp", host, base_port + tp_rank)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
@ -419,10 +425,10 @@ class NixlConnectorWorker:
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
# NOTE(rob): we need each rank to have a unique port. This is
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
path = make_zmq_path("tcp", host, port + self.rank)
|
||||
# NOTE(rob): we need each tp_rank to have a unique port.
|
||||
# This is a hack to keep us moving. We will switch when
|
||||
# we switch to HTTP-based NIXL metadata exchange.
|
||||
path = make_zmq_path("tcp", host, port + self.tp_rank)
|
||||
logger.debug("Querying metadata on path: %s", path)
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
# Send query for the request.
|
||||
@ -486,7 +492,8 @@ class NixlConnectorWorker:
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
caches_data.append((base_addr, region_len, self.rank, ""))
|
||||
caches_data.append(
|
||||
(base_addr, region_len, cache.device.index, ""))
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
self.num_regions = len(caches_data)
|
||||
@ -529,7 +536,7 @@ class NixlConnectorWorker:
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(metadata, ready_event, self.rank),
|
||||
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener")
|
||||
self._nixl_handshake_listener_t.start()
|
||||
@ -553,9 +560,9 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.rank)
|
||||
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
@ -570,9 +577,9 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.rank))
|
||||
logger.debug("Created %s blocks for dst engine %s and rank %s",
|
||||
len(blocks_data), engine_id, self.rank)
|
||||
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
|
||||
len(blocks_data), engine_id, self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
@ -597,14 +604,14 @@ class NixlConnectorWorker:
|
||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
"and %s requests done recving", self.rank, len(done_sending),
|
||||
len(done_recving))
|
||||
"and %s requests done recving", self.tp_rank,
|
||||
len(done_sending), len(done_recving))
|
||||
|
||||
if self.world_size == 1:
|
||||
return done_sending, done_recving
|
||||
|
||||
# Rank 0: get finished from all other ranks.
|
||||
if self.rank == 0:
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self._done_sending_count[req_id] += 1
|
||||
for req_id in done_recving:
|
||||
|
||||
@ -41,8 +41,8 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
||||
supports_custom_op)
|
||||
from vllm.utils import (direct_register_custom_op, get_distributed_init_method,
|
||||
resolve_obj_by_qualname, supports_custom_op)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -929,7 +929,7 @@ def init_distributed_environment(
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.get_next_dp_init_port()
|
||||
distributed_init_method = f"tcp://{ip}:{port}" # noqa
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
logger.info(
|
||||
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
|
||||
world_size, rank, distributed_init_method)
|
||||
|
||||
@ -224,7 +224,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
elif contains_type(type_hints, int):
|
||||
kwargs[name]["type"] = int
|
||||
# Special case for large integers
|
||||
if name in {"max_model_len"}:
|
||||
if name in {"max_model_len", "max_num_batched_tokens"}:
|
||||
kwargs[name]["type"] = human_readable_int
|
||||
elif contains_type(type_hints, float):
|
||||
kwargs[name]["type"] = float
|
||||
@ -423,6 +423,9 @@ class EngineArgs:
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = \
|
||||
ParallelConfig.enable_multimodal_encoder_data_parallel
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
@ -637,6 +640,9 @@ class EngineArgs:
|
||||
**parallel_kwargs["worker_cls"])
|
||||
parallel_group.add_argument("--worker-extension-cls",
|
||||
**parallel_kwargs["worker_extension_cls"])
|
||||
parallel_group.add_argument(
|
||||
"--enable-multimodal-encoder-data-parallel",
|
||||
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
|
||||
|
||||
# KV cache arguments
|
||||
cache_kwargs = get_kwargs(CacheConfig)
|
||||
@ -1078,6 +1084,8 @@ class EngineArgs:
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
enable_multimodal_encoder_data_parallel=self.
|
||||
enable_multimodal_encoder_data_parallel,
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
@ -1380,7 +1388,8 @@ class EngineArgs:
|
||||
|
||||
if (self.pipeline_parallel_size > 1
|
||||
and self.distributed_executor_backend
|
||||
not in ("ray", "mp", "external_launcher")):
|
||||
not in (ParallelConfig.distributed_executor_backend, "ray",
|
||||
"mp", "external_launcher")):
|
||||
name = "Pipeline Parallelism without Ray distributed executor " \
|
||||
"or multiprocessing executor or external launcher"
|
||||
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
|
||||
|
||||
@ -1,24 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.api_server import run_server
|
||||
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
|
||||
setup_server)
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
|
||||
show_filtered_argument_or_group_from_help)
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.core_client import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
||||
EngineZmqAddresses, get_engine_client_zmq_addr,
|
||||
wait_for_completion_or_failure,
|
||||
wait_for_engine_startup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
|
||||
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||
args.model = args.model_tag
|
||||
|
||||
if args.headless:
|
||||
if args.headless or args.api_server_count < 1:
|
||||
run_headless(args)
|
||||
elif args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
# Single API server (this process).
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
|
||||
type=int,
|
||||
default=0,
|
||||
help='Starting data parallel rank for secondary nodes.')
|
||||
serve_parser.add_argument('--api-server-count',
|
||||
'-asc',
|
||||
type=int,
|
||||
default=1,
|
||||
help='How many API server processes to run.')
|
||||
serve_parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
@ -91,23 +110,26 @@ def cmd_init() -> list[CLISubcommand]:
|
||||
|
||||
def run_headless(args: argparse.Namespace):
|
||||
|
||||
if args.api_server_count > 1:
|
||||
raise ValueError("api_server_count can't be set in headless mode")
|
||||
|
||||
# Create the EngineConfig.
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise RuntimeError("Headless mode is only supported for V1")
|
||||
raise ValueError("Headless mode is only supported for V1")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
input_address = get_tcp_uri(host, port)
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
if local_engine_count <= 0:
|
||||
raise RuntimeError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
raise ValueError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.", local_engine_count, input_address)
|
||||
"with head node address %s.", local_engine_count, handshake_address)
|
||||
|
||||
# Create the engines.
|
||||
engine_manager = CoreEngineProcManager(
|
||||
@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
on_head_node=False,
|
||||
input_address=input_address,
|
||||
handshake_address=handshake_address,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
|
||||
finally:
|
||||
logger.info("Shutting down.")
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
assert not args.headless
|
||||
num_api_servers = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if num_api_servers > 1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("api_server_count > 1 is only supported for V1")
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||
"with api_server_count > 1")
|
||||
|
||||
if model_config.is_multimodal_model and not (
|
||||
model_config.disable_mm_preprocessor_cache):
|
||||
logger.warning(
|
||||
"Multi-model preprocessor cache will be disabled for"
|
||||
" api_server_count > 1")
|
||||
model_config.disable_mm_preprocessor_cache = True
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
assert parallel_config.data_parallel_rank == 0
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_only = local_engine_count == dp_size
|
||||
|
||||
# Set up input and output addresses.
|
||||
input_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
output_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=input_addresses,
|
||||
outputs=output_addresses,
|
||||
)
|
||||
|
||||
# Set up coordinator for dp > 1.
|
||||
coordinator = None
|
||||
stats_update_address = None
|
||||
if dp_size > 1:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
stats_update_address = coordinator.get_stats_publish_address()
|
||||
logger.info("Started DP Coordinator process (PID: %d)",
|
||||
coordinator.proc.pid)
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||
bind=True) as handshake_socket:
|
||||
|
||||
# Start local engines.
|
||||
if not local_engine_count:
|
||||
local_engine_manager = None
|
||||
else:
|
||||
local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
handshake_address=handshake_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=0,
|
||||
local_start_index=0)
|
||||
|
||||
# Start API servers using the manager
|
||||
api_server_manager = APIServerProcessManager(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=input_addresses,
|
||||
output_addresses=output_addresses,
|
||||
stats_update_address=stats_update_address)
|
||||
|
||||
# Wait for engine handshakes to complete.
|
||||
core_engines = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(dp_size)
|
||||
]
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
core_engines,
|
||||
parallel_config,
|
||||
vllm_config.cache_config,
|
||||
local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(
|
||||
api_server_manager=api_server_manager,
|
||||
local_engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(listen_address,
|
||||
sock,
|
||||
args,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
|
||||
# Add process-specific prefix to stdout and stderr.
|
||||
from multiprocessing import current_process
|
||||
process_name = current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
uvloop.run(
|
||||
run_server_worker(listen_address, sock, args, client_config,
|
||||
**uvicorn_kwargs))
|
||||
|
||||
@ -45,8 +45,7 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
|
||||
is_list_of)
|
||||
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
@ -143,12 +142,6 @@ class LLM:
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = True
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
|
||||
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
|
||||
"""
|
||||
A flag to toggle whether to deprecate positional arguments in
|
||||
[LLM.__init__][].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def deprecate_legacy_api(cls):
|
||||
@ -158,16 +151,11 @@ class LLM:
|
||||
|
||||
cls.DEPRECATE_LEGACY = False
|
||||
|
||||
@deprecate_args(
|
||||
start_index=2, # Ignore self and model
|
||||
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
|
||||
additional_message=(
|
||||
"All positional arguments other than `model` will be "
|
||||
"replaced with keyword arguments in an upcoming version."),
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
task: TaskOption = "auto",
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: TokenizerMode = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
@ -189,8 +177,6 @@ class LLM:
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
|
||||
@ -5,6 +5,7 @@ import atexit
|
||||
import gc
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
@ -16,8 +17,7 @@ from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import prometheus_client
|
||||
import regex as re
|
||||
@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||
args: Namespace,
|
||||
client_config: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
engine_args, args.disable_frontend_multiprocessing,
|
||||
client_config) as engine:
|
||||
yield engine
|
||||
|
||||
|
||||
@ -157,6 +163,7 @@ async def build_async_engine_client(
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
client_config: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_index = client_config.pop(
|
||||
"client_index") if client_config else 0
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_requests=engine_args.disable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats)
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
client_addresses=client_config,
|
||||
client_index=client_index)
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
await async_llm.reset_mm_cache()
|
||||
@ -318,22 +329,9 @@ class PrometheusResponse(Response):
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = REGISTRY
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
@ -932,7 +930,7 @@ async def invocations(raw_request: Request):
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except JSONDecodeError as e:
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}") from e
|
||||
|
||||
@ -1005,6 +1003,18 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
|
||||
def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
||||
if not log_config_file:
|
||||
return None
|
||||
try:
|
||||
with open(log_config_file) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load log config from file %s: error %s",
|
||||
log_config_file, e)
|
||||
return None
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(openapi_url=None,
|
||||
@ -1256,16 +1266,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
||||
return sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
log_non_default_args(args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
def validate_api_server_args(args):
|
||||
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valid_tool_parses:
|
||||
and args.tool_call_parser not in valid_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
||||
|
||||
@ -1276,6 +1280,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
f"invalid reasoning parser: {args.reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
||||
|
||||
|
||||
def setup_server(args):
|
||||
"""Validate API server args, set up signal handler, create socket
|
||||
ready to serve."""
|
||||
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
log_non_default_args(args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
validate_api_server_args(args)
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
@ -1292,22 +1309,46 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
addr, port = sock_addr
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
host_part = f"[{addr}]" if is_valid_ipv6_address(
|
||||
addr) else addr or "0.0.0.0"
|
||||
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
|
||||
|
||||
return listen_address, sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
"""Run a single-worker API server."""
|
||||
listen_address, sock = setup_server(args)
|
||||
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||
|
||||
|
||||
async def run_server_worker(listen_address,
|
||||
sock,
|
||||
args,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Run a single API server worker."""
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
# Load logging config for uvicorn if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
uvicorn_kwargs['log_config'] = log_config
|
||||
|
||||
async with build_async_engine_client(args, client_config) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||
|
||||
def _listen_addr(a: str) -> str:
|
||||
if is_valid_ipv6_address(a):
|
||||
return '[' + a + ']'
|
||||
return a or "0.0.0.0"
|
||||
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
logger.info("Starting vLLM API server on http%s://%s:%d",
|
||||
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
||||
sock_addr[1])
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index,
|
||||
listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
|
||||
@ -11,6 +11,7 @@ import ssl
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union, get_args
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
@ -243,6 +244,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in ``--tool-call-parser``.")
|
||||
|
||||
parser.add_argument(
|
||||
"--log-config-file",
|
||||
type=str,
|
||||
default=envs.VLLM_LOGGING_CONFIG_PATH,
|
||||
help="Path to logging config JSON file for both vllm and uvicorn",
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
|
||||
@ -1481,6 +1481,10 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
|
||||
ScoreRequest, RerankRequest]
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
@ -1501,21 +1505,22 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
|
||||
body: BatchRequestInputBody
|
||||
|
||||
@field_validator('body', mode='plain')
|
||||
@classmethod
|
||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||
# Use url to disambiguate models
|
||||
url = info.data['url']
|
||||
url: str = info.data["url"]
|
||||
if url == "/v1/chat/completions":
|
||||
return ChatCompletionRequest.model_validate(value)
|
||||
if url == "/v1/embeddings":
|
||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||
if url == "/v1/score":
|
||||
if url.endswith("/score"):
|
||||
return ScoreRequest.model_validate(value)
|
||||
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
|
||||
ScoreRequest]).validate_python(value)
|
||||
if url.endswith("/rerank"):
|
||||
return RerankRequest.model_validate(value)
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
@ -1527,7 +1532,7 @@ class BatchResponseData(OpenAIBaseModel):
|
||||
|
||||
# The body of the response.
|
||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
|
||||
ScoreResponse]] = None
|
||||
ScoreResponse, RerankResponse]] = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
@ -1558,6 +1563,11 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."),
|
||||
)
|
||||
return_token_strs: Optional[bool] = Field(
|
||||
default=False,
|
||||
description=("If true, also return the token strings "
|
||||
"corresponding to the token ids."),
|
||||
)
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
@ -1571,6 +1581,11 @@ class TokenizeChatRequest(OpenAIBaseModel):
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."),
|
||||
)
|
||||
return_token_strs: Optional[bool] = Field(
|
||||
default=False,
|
||||
description=("If true, also return the token strings "
|
||||
"corresponding to the token ids."),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
@ -1628,6 +1643,7 @@ class TokenizeResponse(OpenAIBaseModel):
|
||||
count: int
|
||||
max_model_len: int
|
||||
tokens: list[int]
|
||||
token_strs: Optional[list[str]] = None
|
||||
|
||||
|
||||
class DetokenizeRequest(OpenAIBaseModel):
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
ScoreResponse)
|
||||
RerankResponse, ScoreResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
@ -274,8 +274,11 @@ async def run_request(serving_engine_func: Callable,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
|
||||
if isinstance(
|
||||
response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
|
||||
RerankResponse),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
@ -397,7 +400,7 @@ async def main(args):
|
||||
response_futures.append(
|
||||
run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/score":
|
||||
elif request.url.endswith("/score"):
|
||||
score_handler_fn = openai_serving_scores.create_score if \
|
||||
openai_serving_scores is not None else None
|
||||
if score_handler_fn is None:
|
||||
@ -411,13 +414,29 @@ async def main(args):
|
||||
response_futures.append(
|
||||
run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/rerank"):
|
||||
rerank_handler_fn = openai_serving_scores.do_rerank if \
|
||||
openai_serving_scores is not None else None
|
||||
if rerank_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Rerank API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(
|
||||
run_request(rerank_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=
|
||||
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
|
||||
"are supported in the batch endpoint.",
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /score, /rerank ."
|
||||
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||
"score/rerank versions.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
|
||||
@ -988,7 +988,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters)))
|
||||
arguments=json.dumps(tool_call.parameters,
|
||||
ensure_ascii=False)))
|
||||
for tool_call in tool_calls
|
||||
])
|
||||
|
||||
|
||||
@ -110,7 +110,12 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
token_strs=token_strs,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len)
|
||||
|
||||
|
||||
@ -278,7 +278,9 @@ class OpenAIServingTranscription(OpenAIServing):
|
||||
|
||||
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
|
||||
try:
|
||||
# TODO(rob): subtract len of tokenized prompt.
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram.
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, Union
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -64,7 +65,19 @@ class Llama4PythonicToolParser(ToolParser):
|
||||
if model_output.startswith("<|python_start|>"):
|
||||
model_output = model_output[len("<|python_start|>"):]
|
||||
model_output = model_output.replace("<|python_end|>", "")
|
||||
if not (self.TOOL_CALL_REGEX.match(model_output)):
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
@ -68,8 +68,8 @@ class Phi4MiniJsonToolParser(ToolParser):
|
||||
len(function_call_arr))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to parse function calls from model output: %s. "
|
||||
"Error: %s", model_output, str(e))
|
||||
"Failed to parse function calls from model output. "
|
||||
"Error: %s", str(e))
|
||||
|
||||
tool_calls: list[ToolCall] = [
|
||||
ToolCall(
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Union
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -61,8 +62,18 @@ class PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not (self.TOOL_CALL_REGEX.match(model_output)):
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
49
vllm/envs.py
49
vllm/envs.py
@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NCCL_SO_PATH: Optional[str] = None
|
||||
LD_LIBRARY_PATH: Optional[str] = None
|
||||
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
||||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
||||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
||||
LOCAL_RANK: int = 0
|
||||
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
||||
@ -118,6 +119,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
||||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -142,10 +144,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
|
||||
|
||||
def get_vllm_port() -> Optional[int]:
|
||||
"""Get the port from VLLM_PORT environment variable.
|
||||
|
||||
|
||||
Returns:
|
||||
The port number as an integer if VLLM_PORT is set, None otherwise.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If VLLM_PORT is a URI, suggest k8s service discovery issue.
|
||||
"""
|
||||
@ -158,17 +160,13 @@ def get_vllm_port() -> Optional[int]:
|
||||
return int(port)
|
||||
except ValueError as err:
|
||||
from urllib.parse import urlparse
|
||||
try:
|
||||
parsed = urlparse(port)
|
||||
if parsed.scheme:
|
||||
raise ValueError(
|
||||
f"VLLM_PORT '{port}' appears to be a URI. "
|
||||
"This may be caused by a Kubernetes service discovery issue"
|
||||
"check the warning in: https://docs.vllm.ai/en/stable/usage/env_vars.html"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parsed = urlparse(port)
|
||||
if parsed.scheme:
|
||||
raise ValueError(
|
||||
f"VLLM_PORT '{port}' appears to be a URI. "
|
||||
"This may be caused by a Kubernetes service discovery issue,"
|
||||
"check the warning in: https://docs.vllm.ai/en/stable/serving/env_vars.html"
|
||||
) from None
|
||||
raise ValueError(
|
||||
f"VLLM_PORT '{port}' must be a valid integer") from err
|
||||
|
||||
@ -290,6 +288,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Use separate prefill and decode kernels for V1 attention instead of
|
||||
# the unified triton kernel.
|
||||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION":
|
||||
lambda:
|
||||
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
||||
# when using the flash-attention backend.
|
||||
"VLLM_FLASH_ATTN_VERSION":
|
||||
@ -300,9 +305,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(
|
||||
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
||||
|
||||
# Internal flag to enable/disable Inductor standalone compile
|
||||
"VLLM_TEST_STANDALONE_COMPILE":
|
||||
lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0",
|
||||
# Feature flag to enable/disable Inductor standalone compile.
|
||||
# In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is
|
||||
# enabled by default.
|
||||
"VLLM_USE_STANDALONE_COMPILE":
|
||||
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1",
|
||||
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
@ -323,8 +330,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
|
||||
# Whether to log responses from API Server for debugging
|
||||
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
|
||||
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
|
||||
lower() == "true",
|
||||
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"
|
||||
).lower() == "true",
|
||||
|
||||
# S3 access information, used for tensorizer to load model from S3
|
||||
"S3_ACCESS_KEY_ID":
|
||||
@ -822,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# This is used to prevent the kernel from running out of memory.
|
||||
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
|
||||
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
|
||||
|
||||
# Regex timeout for use by the vLLM tool parsing plugins.
|
||||
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
|
||||
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
@ -884,7 +895,7 @@ def compute_hash() -> str:
|
||||
"VLLM_USE_TRITON_AWQ",
|
||||
"VLLM_DP_RANK",
|
||||
"VLLM_DP_SIZE",
|
||||
"VLLM_TEST_STANDALONE_COMPILE",
|
||||
"VLLM_USE_STANDALONE_COMPILE",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
if key in environment_variables:
|
||||
|
||||
@ -47,8 +47,12 @@ class DPMetadata:
|
||||
return num_tokens_tensor
|
||||
|
||||
@staticmethod
|
||||
def make(parallel_config: ParallelConfig, attn_metadata: Any,
|
||||
num_tokens: int) -> "DPMetadata":
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None
|
||||
) -> "DPMetadata":
|
||||
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
@ -62,10 +66,15 @@ class DPMetadata:
|
||||
# for v1 attention backends or no attn_metadata
|
||||
batchsize = num_tokens
|
||||
|
||||
num_tokens_tensor = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (num_tokens_across_dp is None
|
||||
or num_tokens_across_dp[dp_rank] == batchsize)
|
||||
if num_tokens_across_dp is None:
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
||||
|
||||
|
||||
@ -101,7 +110,8 @@ def get_forward_context() -> ForwardContext:
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0):
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
@ -111,9 +121,11 @@ def set_forward_context(attn_metadata: Any,
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
||||
attn_metadata is not None or num_tokens is not None):
|
||||
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
|
||||
attn_metadata, num_tokens)
|
||||
attn_metadata, num_tokens or 0,
|
||||
num_tokens_across_dp)
|
||||
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user