diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index d3f5fc5cd4cee..72c52d5bb5e9b 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -113,7 +113,7 @@ WARNING: The benchmarking script will save json results by itself, so please do ### Visualizing the results -The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. +The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](performance-benchmarks-descriptions.md) with real benchmarking results. You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. If you do not see the table, please wait till the benchmark finish running. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 5dd53420dfdfa..6102431456210 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -2,14 +2,45 @@ set -xu + +remove_docker_container() { + docker rm -f tpu-test || true; + docker rm -f vllm-tpu || true; +} + +trap remove_docker_container EXIT + +# Remove the container that might not be cleaned up in the previous run. +remove_docker_container + # Build the docker image. docker build -f docker/Dockerfile.tpu -t vllm-tpu . # Set up cleanup. -remove_docker_container() { docker rm -f tpu-test || true; } -trap remove_docker_container EXIT -# Remove the container that might not be cleaned up in the previous run. -remove_docker_container +cleanup_docker() { + # Get Docker's root directory + docker_root=$(docker info -f '{{.DockerRootDir}}') + if [ -z "$docker_root" ]; then + echo "Failed to determine Docker root directory." + exit 1 + fi + echo "Docker root directory: $docker_root" + # Check disk usage of the filesystem where Docker's root directory is located + disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') + # Define the threshold + threshold=70 + if [ "$disk_usage" -gt "$threshold" ]; then + echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune --force --filter "until=72h" --all + echo "Docker images and volumes cleanup completed." + else + echo "Disk usage is below $threshold%. No cleanup needed." + fi +} +cleanup_docker # For HF_TOKEN. source /etc/environment diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4e7bea25e1717..bff2f69c17ba7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -199,8 +199,9 @@ steps: - tests/test_sequence - tests/test_config - tests/test_logger + - tests/test_vllm_port commands: - - pytest -v -s engine test_sequence.py test_config.py test_logger.py + - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py # OOM in the CI unless we run this separately - pytest -v -s tokenization @@ -617,9 +618,11 @@ steps: - vllm/worker/model_runner.py - entrypoints/llm/test_collective_rpc.py - tests/v1/test_async_llm_dp.py + - tests/v1/entrypoints/openai/test_multi_api_servers.py - vllm/v1/engine/ commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py diff --git a/SECURITY.md b/SECURITY.md index 47196a1f1221e..6053cfb41f35b 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -8,4 +8,6 @@ Please report security issues privately using [the vulnerability submission form --- +Please see the [Security Guide in the vLLM documentation](https://docs.vllm.ai/en/latest/usage/security.html) for more information on vLLM's security assumptions and recommendations. + Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. diff --git a/benchmarks/README.md b/benchmarks/README.md index ecab570bb31c4..6f9fbb91cbd91 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -64,6 +64,12 @@ become available. ✅ lmms-lab/LLaVA-OneVision-Data, Aeala/ShareGPT_Vicuna_unfiltered + + Custom + ✅ + ✅ + Local file: data.jsonl + @@ -124,6 +130,38 @@ P99 ITL (ms): 8.39 ================================================== ``` +### Custom Dataset +If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl + +``` +{"prompt": "What is the capital of India?"} +{"prompt": "What is the capital of Iran?"} +{"prompt": "What is the capital of China?"} +``` + +```bash +# start server +VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests +``` + +```bash +# run benchmarking script +python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \ + --backend vllm \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --endpoint /v1/completions \ + --dataset-name custom \ + --dataset-path \ + --custom-skip-chat-template \ + --num-prompts 80 \ + --max-concurrency 1 \ + --temperature=0.3 \ + --top-p=0.75 \ + --result-dir "./log/" +``` + +You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. + ### VisionArena Benchmark for Vision Language Models ```bash @@ -146,9 +184,9 @@ python3 vllm/benchmarks/benchmark_serving.py \ ``` bash VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ - --ngram_prompt_lookup_min 2 \ - --ngram-prompt-lookup-max 5 \ - --speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5} + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' ``` ``` bash @@ -203,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \ --seed 42 ``` +**`philschmid/mt-bench`** + +``` bash +python3 vllm/benchmarks/benchmark_serving.py \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path philschmid/mt-bench \ + --num-prompts 80 +``` + ### Running With Sampling Parameters When using OpenAI-compatible backends such as `vllm`, optional sampling @@ -273,9 +321,9 @@ python3 vllm/benchmarks/benchmark_throughput.py \ --output-len=100 \ --num-prompts=2048 \ --async-engine \ - --ngram_prompt_lookup_min=2 \ - --ngram-prompt-lookup-max=5 \ - --speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5} + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' ``` ``` diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 88616e1108c52..85e6eda7f36fd 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -324,7 +324,7 @@ async def async_request_openai_completions( most_recent_timestamp = timestamp generated_text += text or "" - elif usage := data.get("usage"): + if usage := data.get("usage"): output.output_tokens = usage.get("completion_tokens") if first_chunk_received: output.success = True @@ -611,6 +611,7 @@ ASYNC_REQUEST_FUNCS = { "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, "sglang": async_request_openai_completions, + "llama.cpp": async_request_openai_completions, } OPENAI_COMPATIBLE_BACKENDS = [ diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 5513a5f78f1ce..d86bf045ea47e 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -9,9 +9,6 @@ generation. Supported dataset types include: - BurstGPT - HuggingFace - VisionArena - -TODO: Implement CustomDataset to parse a JSON file and convert its contents into -SampleRequest instances, similar to the approach used in ShareGPT. """ import base64 @@ -442,6 +439,97 @@ class ShareGPTDataset(BenchmarkDataset): return samples +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset." + ) + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index a887e7150dc78..6bd9f1b49c2ec 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -60,6 +60,7 @@ from benchmark_dataset import ( ASRDataset, BurstGPTDataset, ConversationDataset, + CustomDataset, HuggingFaceDataset, InstructCoderDataset, MTBenchDataset, @@ -627,7 +628,16 @@ def main(args: argparse.Namespace): "'--dataset-path' if required." ) - if args.dataset_name == "sonnet": + if args.dataset_name == "custom": + dataset = CustomDataset(dataset_path=args.dataset_path) + input_requests = dataset.sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.custom_output_len, + skip_chat_template=args.custom_skip_chat_template, + ) + + elif args.dataset_name == "sonnet": dataset = SonnetDataset(dataset_path=args.dataset_path) # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": @@ -762,6 +772,10 @@ def main(args: argparse.Namespace): if "temperature" not in sampling_params: sampling_params["temperature"] = 0.0 # Default to greedy decoding. + if args.backend == "llama.cpp": + # Disable prompt caching in llama.cpp backend + sampling_params["cache_prompt"] = False + # Avoid GC processing "static" data - reduce pause times. gc.collect() gc.freeze() @@ -834,6 +848,8 @@ def main(args: argparse.Namespace): ]: if field in result_json: del result_json[field] + if field in benchmark_result: + del benchmark_result[field] # Save to file base_model_id = model_id.split("/")[-1] @@ -846,6 +862,7 @@ def main(args: argparse.Namespace): if args.result_filename: file_name = args.result_filename if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) file_name = os.path.join(args.result_dir, file_name) with open( file_name, mode="a+" if args.append_result else "w", encoding="utf-8" @@ -886,7 +903,7 @@ if __name__ == "__main__": "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -1056,6 +1073,19 @@ if __name__ == "__main__": ) # group for dataset specific arguments + custom_group = parser.add_argument_group("custom dataset options") + custom_group.add_argument( + "--custom-output-len", + type=int, + default=256, + help="Number of output tokens per request, used only for custom dataset.", + ) + custom_group.add_argument( + "--custom-skip-chat-template", + action="store_true", + help="Skip applying chat template to prompt, used only for custom dataset.", + ) + sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group.add_argument( "--sonnet-input-len", diff --git a/benchmarks/kernels/bench_fp8_gemm.py b/benchmarks/kernels/bench_fp8_gemm.py new file mode 100644 index 0000000000000..36d03e40ef9a1 --- /dev/null +++ b/benchmarks/kernels/bench_fp8_gemm.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import copy +import itertools + +import torch +import triton +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=[ + "torch-bf16", + # "fp8-tensor-w-token-a", + "fp8-tensor-w-tensor-a", + "fp8-channel-w-token-a", + # "fp8-channel-w-tensor-a", + # "fp8-tensor-w-token-a-noquant", + "fp8-tensor-w-tensor-a-noquant", + "fp8-channel-w-token-a-noquant", + # "fp8-channel-w-tensor-a-noquant", + ], + line_names=[ + "torch-bf16", + # "fp8-tensor-w-token-a", + "fp8-tensor-w-tensor-a", + "fp8-channel-w-token-a", + # "fp8-channel-w-tensor-a", + # "fp8-tensor-w-token-a-noquant", + "fp8-tensor-w-tensor-a-noquant", + "fp8-channel-w-token-a-noquant", + # "fp8-channel-w-tensor-a-noquant", + ], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs FP8 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + # Create input tensors + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if "torch-bf16" in provider: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + + elif "fp8" in provider: + # Weights are always quantized ahead of time + if "noquant" in provider: + # For no quantization, we just measure the GEMM + if "tensor-w-token-a" in provider: + # Dynamic per-token quant for A, per-tensor quant for B + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) + assert scale_b_fp8.numel() == 1 + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "tensor-w-tensor-a" in provider: + # Static per-tensor quantization with fixed scales + # for both A and B + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + assert scale_b_fp8.numel() == 1 + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-token-a" in provider: + # Static per-channel quantization for weights, per-token + # quant for A + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-tensor-a" in provider: + # Static per-channel quantization for weights, per-tensor + # quant for A + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + else: + # In these cases, we quantize the activations during the GEMM call + if "tensor-w-token-a" in provider: + # Dynamic per-token quant for A, per-tensor quant for B + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) + assert scale_b_fp8.numel() == 1 + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "tensor-w-tensor-a" in provider: + # Static per-tensor quantization with fixed scales + # for both A and B + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + assert scale_b_fp8.numel() == 1 + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-token-a" in provider: + # Static per-channel quantization for weights, per-token + # quant for A + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-tensor-a" in provider: + # Static per-channel quantization for weights, per-tensor + # quant for A + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + b_fp8 = b_fp8.t() + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + # Calculate TFLOP/s, two flops per multiply-add + tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) + return tflops(ms), tflops(max_ms), tflops(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=[*WEIGHT_SHAPES.keys()], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_fp8_res_n{N}_k{K}", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 110d36db157fd..944024ca35725 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -22,7 +22,7 @@ def benchmark_rope_kernels_multi_lora( seed: int, device: str, max_position: int = 8192, - base: int = 10000, + base: float = 10000, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index 89b05d5882a38..afe159ddda6e8 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -48,4 +48,50 @@ WEIGHT_SHAPES = { ([16384, 106496], 1), ([53248, 16384], 0), ], + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], } diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index b3717892db784..e31aa0162628f 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -13,14 +13,34 @@ #include "dispatch_utils.h" #include "quantization/fp8/common.cuh" -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) - #define __HIP__MI300_MI250__ +#if defined(__HIPCC__) && \ + (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__GFX9__ #endif -#if defined(__HIPCC__) && defined(__gfx942__) - #define __HIP__MI300__ +#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__MI3XX__ #endif +#if defined(__gfx950__) + #define LDS_SIZE 160 * 1024 +#else + #define LDS_SIZE 64 * 1024 +#endif + +int get_lds_size() { + static bool is_cached = false; + static int result; + if (is_cached == false) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + size_t substring = device_arch.find("gfx95"); + result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024); + is_cached = true; + } + return result; +} + #if defined(NDEBUG) #undef NDEBUG #include @@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, V0 += (s.x + s.y); \ } -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity template @@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { - #if defined(__HIP__MI300__) + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) constexpr bool use_mfma = (std::is_same_v); #else constexpr bool use_mfma = false; @@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) }; //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU + // Reserving 64/160 KB of LDS to have 1 WG / CU // Goal is to bring the activation matrix A to the LDS // and use it across the lifetime of the work group // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ scalar_t s[1024 * 32]; + __shared__ scalar_t s[max_lds_len]; //---------------------------------------------------- // Fetch the activation matrix to LDS @@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * N, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, 32 * 1024)) break; + if (k_in >= min(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support template __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, @@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support // This version targets cases where A[] marginally exceeds LDS capacity template @@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { - #if defined(__HIP__MI300__) + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) constexpr bool use_mfma = (std::is_same_v); #else constexpr bool use_mfma = false; @@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ scalar_t s[1024 * 32]; + __shared__ scalar_t s[max_lds_len]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * N, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, 32 * 1024)) break; + if (k_in >= min(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Fetch A activation matrix in interleaved fashion from LDS or memory for (int n = 0; n < N; n++) { - if (k_ + K * n < 32 * 1024) + if (k_ + K * n < max_lds_len) bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); else bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); @@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support template __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, @@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support // This version targets big A[] cases, where it is much larger than LDS capacity template @@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { - #if defined(__HIP__MI300__) + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) constexpr bool use_mfma = (std::is_same_v); #else constexpr bool use_mfma = false; @@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) }; //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU + // Reserving 64/160 KB of LDS to have 1 WG / CU // Goal is to bring the activation matrix A to the LDS // and use it across the lifetime of the work group // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ scalar_t s[1024 * 32]; + __shared__ scalar_t s[max_lds_len]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- #define PCML #ifndef PCML - for (uint32_t k = 0; k < min(K * N, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, 32 * 1024)) break; + if (k_in >= min(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #define TUC (THRDS * UNRL * A_CHUNK) uint32_t kBase = 0; // find biggest k size that fits in LDS - uint32_t kFit = (32 * 1024) / N; + uint32_t kFit = (max_lds_len) / N; // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple // of TUC kFit = (kFit % TUC == 0) @@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support template __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, @@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; @@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int max_lds_len = get_lds_size() / 2; #define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ _N) \ { \ dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_ \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ - } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + } else if (K_in * N_in <= max_lds_len * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitK_hf_ \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ @@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, return out_c; } -#if defined(__HIP__MI300__) // TODO: Add NAVI support +#if defined(__HIP__MI3XX__) // TODO: Add NAVI support template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE; using scalar8 = __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; @@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) scalar8 h8; }; - __shared__ fp8_t s[1024 * 64]; + __shared__ fp8_t s[max_lds_len]; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; - k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { *((bigType*)(&s[k])) = *((bigType*)(&A[k])); } __syncthreads(); @@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300__) TODO: Add NAVI support +#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support template __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, @@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300__) TODO: Add NAVI support +#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support -#if defined(__HIP__MI300__) // TODO: Add NAVI support +#if defined(__HIP__MI3XX__) // TODO: Add NAVI support template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) const fp8_t* __restrict__ A, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE; using scalar8 = __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; @@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) scalar8 h8; }; - __shared__ fp8_t s[1024 * 64]; + __shared__ fp8_t s[max_lds_len]; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; - k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { *((bigType*)(&s[k])) = *((bigType*)(&A[k])); } __syncthreads(); @@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; for (int n = 0; n < N; n++) { - if (k_ + K * n < 64 * 1024) + if (k_ + K * n < max_lds_len) bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); else bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); @@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300__) TODO: Add NAVI support +#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support template __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, @@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300__) TODO: Add NAVI support +#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, @@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, dim3 grid(CuCount); const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int max_lds_len = get_lds_size(); #define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ _N) \ { \ dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitKQ_hf_sml_ \ <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ diff --git a/docs/.nav.yml b/docs/.nav.yml index 42aba97753605..a9c594c291777 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -12,6 +12,7 @@ nav: - User Guide: usage/README.md - Developer Guide: contributing/README.md - API Reference: api/README.md + - CLI Reference: cli/README.md - Timeline: - Roadmap: https://roadmap.vllm.ai - Releases: https://github.com/vllm-project/vllm/releases @@ -56,6 +57,8 @@ nav: - Contents: - glob: api/vllm/* preserve_directory_names: true + - CLI Reference: + - Summary: cli/README.md - Community: - community/* - Blog: https://blog.vllm.ai diff --git a/docs/README.md b/docs/README.md index 57b1d03deee28..0c6aff5fa07c3 100644 --- a/docs/README.md +++ b/docs/README.md @@ -12,8 +12,8 @@

Star -Watch -Fork +Watch +Fork

vLLM is a fast and easy-to-use library for LLM inference and serving. diff --git a/docs/cli/README.md b/docs/cli/README.md new file mode 100644 index 0000000000000..5feb316d61a89 --- /dev/null +++ b/docs/cli/README.md @@ -0,0 +1,179 @@ +# vLLM CLI Guide + +The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with: + +``` +vllm --help +``` + +Available Commands: + +``` +vllm {chat,complete,serve,bench,collect-env,run-batch} +``` + +## Table of Contents + +- [serve](#serve) +- [chat](#chat) +- [complete](#complete) +- [bench](#bench) + - [latency](#latency) + - [serve](#serve-1) + - [throughput](#throughput) +- [collect-env](#collect-env) +- [run-batch](#run-batch) +- [More Help](#more-help) + +## serve + +Start the vLLM OpenAI Compatible API server. + +Examples: + +```bash +# Start with a model +vllm serve meta-llama/Llama-2-7b-hf + +# Specify the port +vllm serve meta-llama/Llama-2-7b-hf --port 8100 + +# Check with --help for more options +# To list all groups +vllm serve --help=listgroup + +# To view a argument group +vllm serve --help=ModelConfig + +# To view a single argument +vllm serve --help=max-num-seqs + +# To search by keyword +vllm serve --help=max +``` + +## chat + +Generate chat completions via the running API server. + +Examples: + +```bash +# Directly connect to localhost API without arguments +vllm chat + +# Specify API url +vllm chat --url http://{vllm-serve-host}:{vllm-serve-port}/v1 + +# Quick chat with a single prompt +vllm chat --quick "hi" +``` + +## complete + +Generate text completions based on the given prompt via the running API server. + +Examples: + +```bash +# Directly connect to localhost API without arguments +vllm complete + +# Specify API url +vllm complete --url http://{vllm-serve-host}:{vllm-serve-port}/v1 + +# Quick complete with a single prompt +vllm complete --quick "The future of AI is" +``` + +## bench + +Run benchmark tests for latency online serving throughput and offline inference throughput. + +Available Commands: + +```bash +vllm bench {latency, serve, throughput} +``` + +### latency + +Benchmark the latency of a single batch of requests. + +Example: + +```bash +vllm bench latency \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-len 32 \ + --output-len 1 \ + --enforce-eager \ + --load-format dummy +``` + +### serve + +Benchmark the online serving throughput. + +Example: + +```bash +vllm bench serve \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --host server-host \ + --port server-port \ + --random-input-len 32 \ + --random-output-len 4 \ + --num-prompts 5 +``` + +### throughput + +Benchmark offline inference throughput. + +Example: + +```bash +vllm bench throughput \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-len 32 \ + --output-len 1 \ + --enforce-eager \ + --load-format dummy +``` + +## collect-env + +Start collecting environment information. + +```bash +vllm collect-env +``` + +## run-batch + +Run batch prompts and write results to file. + +Examples: + +```bash +# Running with a local file +vllm run-batch \ + -i offline_inference/openai_batch/openai_example_batch.jsonl \ + -o results.jsonl \ + --model meta-llama/Meta-Llama-3-8B-Instruct + +# Using remote file +vllm run-batch \ + -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \ + -o results.jsonl \ + --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +## More Help + +For detailed options of any subcommand, use: + +```bash +vllm --help +``` diff --git a/docs/contributing/README.md b/docs/contributing/README.md index 2517436afcc11..65ae9cc963676 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -29,20 +29,68 @@ See . Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. Check out the [building from source][build-from-source] documentation for details. -### Building the docs +### Building the docs with MkDocs -Install the dependencies: +#### Introduction to MkDocs + +[MkDocs](https://github.com/mkdocs/mkdocs) is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file. + +#### Install MkDocs and Plugins + +Install MkDocs along with the [plugins](https://github.com/vllm-project/vllm/blob/main/mkdocs.yaml) used in the vLLM documentation, as well as required dependencies: ```bash pip install -r requirements/docs.txt ``` -Start the autoreloading MkDocs server: +!!! note + Ensure that your Python version is compatible with the plugins (e.g., `mkdocs-awesome-nav` requires Python 3.10+) + +#### Verify Installation + +Confirm that MkDocs is correctly installed: + +```bash +mkdocs --version +``` + +Example output: + +```console +mkdocs, version 1.6.1 from /opt/miniconda3/envs/mkdoc/lib/python3.10/site-packages/mkdocs (Python 3.10) +``` + +#### Clone the `vLLM` repository + +```bash +git clone https://github.com/vllm-project/vllm.git +cd vllm +``` + +#### Start the Development Server + +MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the `mkdocs.yml` configuration file, and then start the server by running the `mkdocs serve` command: ```bash mkdocs serve ``` +Example output: + +```console +INFO - Documentation built in 106.83 seconds +INFO - [22:02:02] Watching paths for changes: 'docs', 'mkdocs.yaml' +INFO - [22:02:02] Serving on http://127.0.0.1:8000/ +``` + +#### View in Your Browser + +Open up [http://127.0.0.1:8000/](http://127.0.0.1:8000/) in your browser to see a live preview:. + +#### Learn More + +For additional features and advanced configurations, refer to the official [MkDocs Documentation](https://www.mkdocs.org/). + ## Testing ```bash @@ -60,6 +108,9 @@ pre-commit run mypy-3.9 --hook-stage manual --all-files # Unit tests pytest tests/ + +# Run tests for a single test file with detailed output +pytest -s -v tests/test_logger.py ``` !!! tip diff --git a/docs/design/arch_overview.md b/docs/design/arch_overview.md index 75d3e1b7ccc78..14720a392aafb 100644 --- a/docs/design/arch_overview.md +++ b/docs/design/arch_overview.md @@ -48,8 +48,7 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -More API details can be found in the [Offline Inference] -(#offline-inference-api) section of the API docs. +More API details can be found in the [Offline Inference](#offline-inference-api) section of the API docs. The code for the `LLM` class can be found in . diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 412c42fd580e5..4d58fae20f06c 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -22,13 +22,13 @@ This document describes how vLLM deals with these challenges. [Python multiprocessing methods](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) include: -- `spawn` - spawn a new Python process. This will be the default as of Python - 3.14. In macOS, this is already the default. +- `spawn` - spawn a new Python process. The default on Windows and macOS. -- `fork` - Use `os.fork()` to fork the Python interpreter. This is the default - in Python versions prior to 3.14. +- `fork` - Use `os.fork()` to fork the Python interpreter. The default on + Linux for Python versions prior to 3.14. - `forkserver` - Spawn a server process that will fork a new process on request. + The default on Linux for Python version 3.14 and newer. ### Tradeoffs diff --git a/docs/features/compatibility_matrix.md b/docs/features/compatibility_matrix.md index 77ceea49f1732..5d448eb5c03d8 100644 --- a/docs/features/compatibility_matrix.md +++ b/docs/features/compatibility_matrix.md @@ -10,6 +10,7 @@ The symbols used have the following meanings: - ✅ = Full compatibility - 🟠 = Partial compatibility - ❌ = No compatibility +- ❔ = Unknown or TBD !!! note Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/hardware combination. @@ -36,23 +37,23 @@ th:not(:first-child) { } -| Feature | [CP][chunked-prefill] | [APC][automatic-prefix-caching] | [LoRA][lora-adapter] | prmpt adptr | [SD][spec-decode] | CUDA graph | pooling | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | -|-----------------------------------------------------------|-------------------------|-----------------------------------|------------------------|---------------------------------------------------|---------------------|--------------|-----------------------------------------------|-------------------------------------------------------|--------------------------------------|---------------------------------------------------|-------------------------------------------------------------|--------------------|---------------------------------------------|-----------|---------------| -| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | | -| [APC][automatic-prefix-caching] | ✅ | ✅ | | | | | | | | | | | | | | -| [LoRA][lora-adapter] | ✅ | ✅ | ✅ | | | | | | | | | | | | | -| prmpt adptr | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | | -| [SD][spec-decode] | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | | -| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | -| pooling | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | | -| enc-dec | ❌ | [❌](gh-issue:7366) | ❌ | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | -| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | -| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | -| async output | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | -| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | -| mm | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | -| best-of | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | -| beam-search | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | +| Feature | [CP][chunked-prefill] | [APC][automatic-prefix-caching] | [LoRA][lora-adapter] | prmpt adptr | [SD][spec-decode] | CUDA graph | pooling | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | +|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| +| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | | +| [APC][automatic-prefix-caching] | ✅ | ✅ | | | | | | | | | | | | | | +| [LoRA][lora-adapter] | ✅ | ✅ | ✅ | | | | | | | | | | | | | +| prmpt adptr | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | | +| [SD][spec-decode] | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | | +| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | +| pooling | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | | +| enc-dec | ❌ | [❌](gh-issue:7366) | ❌ | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | +| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | +| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | +| async output | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | +| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | +| mm | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | +| best-of | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | +| beam-search | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | [](){ #feature-x-hardware } @@ -75,3 +76,6 @@ th:not(:first-child) { | multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | | best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | + +!!! note + Please refer to [Feature support through NxD Inference backend][feature-support-through-nxd-inference-backend] for features supported on AWS Neuron hardware diff --git a/docs/features/lora.md b/docs/features/lora.md index 642462f7c4557..04e92dbc45924 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -165,6 +165,7 @@ it will first look in the local directory for a directory `foobar`, and attempt that adapter will then be available for normal use on the server. Alternatively, follow these example steps to implement your own plugin: + 1. Implement the LoRAResolver interface. Example of a simple S3 LoRAResolver implementation: @@ -198,9 +199,9 @@ Alternatively, follow these example steps to implement your own plugin: return lora_request ``` -2. Register LoRAResolver plugin. +2. Register `LoRAResolver` plugin. - ```python + ```python from vllm.lora.resolver import LoRAResolverRegistry s3_resolver = S3LoRAResolver() diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md index 2967bf9c7504a..6a585b1ccb2ca 100644 --- a/docs/features/quantization/supported_hardware.md +++ b/docs/features/quantization/supported_hardware.md @@ -5,13 +5,13 @@ title: Supported Hardware The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Inferentia | Google TPU | +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU | |-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------| | AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ | | GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ | | Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ | | BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | | AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | | bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/getting_started/installation/ai_accelerator/neuron.inc.md b/docs/getting_started/installation/ai_accelerator/neuron.inc.md index f08c78fba6c85..86c12472fb360 100644 --- a/docs/getting_started/installation/ai_accelerator/neuron.inc.md +++ b/docs/getting_started/installation/ai_accelerator/neuron.inc.md @@ -1,8 +1,9 @@ # --8<-- [start:installation] -vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK with continuous batching. -Paged Attention and Chunked Prefill are currently in development and will be available soon. -Data types currently supported in Neuron SDK are FP16 and BF16. +[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) is the software development kit (SDK) used to run deep learning and + generative AI workloads on AWS Inferentia and AWS Trainium powered Amazon EC2 instances and UltraServers (Inf1, Inf2, Trn1, Trn2, + and Trn2 UltraServer). Both Trainium and Inferentia are powered by fully-independent heterogeneous compute-units called NeuronCores. + This tab describes how to set up your environment to run vLLM on Neuron. !!! warning There are no pre-built wheels or images for this device, so you must build vLLM from source. @@ -11,59 +12,31 @@ Data types currently supported in Neuron SDK are FP16 and BF16. # --8<-- [start:requirements] - OS: Linux -- Python: 3.9 -- 3.11 -- Accelerator: NeuronCore_v2 (in trn1/inf2 instances) -- Pytorch 2.0.1/2.1.1 -- AWS Neuron SDK 2.16/2.17 (Verified on python 3.8) +- Python: 3.9 or newer +- Pytorch 2.5/2.6 +- Accelerator: NeuronCore-v2 (in trn1/inf2 chips) or NeuronCore-v3 (in trn2 chips) +- AWS Neuron SDK 2.23 ## Configure a new environment -### Launch Trn1/Inf2 instances +### Launch a Trn1/Trn2/Inf2 instance and verify Neuron dependencies -Here are the steps to launch trn1/inf2 instances, in order to install [PyTorch Neuron ("torch-neuronx") Setup on Ubuntu 22.04 LTS](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/pytorch/neuronx/ubuntu/torch-neuronx-ubuntu22.html). +The easiest way to launch a Trainium or Inferentia instance with pre-installed Neuron dependencies is to follow this +[quick start guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/multiframework/multi-framework-ubuntu22-neuron-dlami.html#setup-ubuntu22-multi-framework-dlami) using the Neuron Deep Learning AMI (Amazon machine image). -- Please follow the instructions at [launch an Amazon EC2 Instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EC2_GetStarted.html#ec2-launch-instance) to launch an instance. When choosing the instance type at the EC2 console, please make sure to select the correct instance type. -- To get more information about instances sizes and pricing see: [Trn1 web page](https://aws.amazon.com/ec2/instance-types/trn1/), [Inf2 web page](https://aws.amazon.com/ec2/instance-types/inf2/) -- Select Ubuntu Server 22.04 TLS AMI -- When launching a Trn1/Inf2, please adjust your primary EBS volume size to a minimum of 512GB. - After launching the instance, follow the instructions in [Connect to your instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstancesLinux.html) to connect to the instance - -### Install drivers and tools - -The installation of drivers and tools wouldn't be necessary, if [Deep Learning AMI Neuron](https://docs.aws.amazon.com/dlami/latest/devguide/appendix-ami-release-notes.html) is installed. In case the drivers and tools are not installed on the operating system, follow the steps below: - +- Once inside your instance, activate the pre-installed virtual environment for inference by running ```console -# Configure Linux for Neuron repository updates -. /etc/os-release -sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <= 0.5.3`. You may see an error `cannot import name 'default_dump_dir...`. To work around this, run a `pip install --upgrade triton==3.0.0` after installing the vLLM wheel. - -Following instructions are applicable to Neuron SDK 2.16 and beyond. - -#### Install transformers-neuronx and its dependencies - -[transformers-neuronx](https://github.com/aws-neuron/transformers-neuronx) will be the backend to support inference on trn1/inf2 instances. -Follow the steps below to install transformer-neuronx package and its dependencies. - -```console -# Install Python venv -sudo apt-get install -y python3.10-venv g++ - -# Create Python venv -python3.10 -m venv aws_neuron_venv_pytorch - -# Activate Python venv -source aws_neuron_venv_pytorch/bin/activate - -# Install Jupyter notebook kernel -pip install ipykernel -python3.10 -m ipykernel install \ - --user \ - --name aws_neuron_venv_pytorch \ - --display-name "Python (torch-neuronx)" -pip install jupyter notebook -pip install environment_kernels - -# Set pip repository pointing to the Neuron repository -python -m pip config set \ - global.extra-index-url \ - https://pip.repos.neuron.amazonaws.com - -# Install wget, awscli -python -m pip install wget -python -m pip install awscli - -# Update Neuron Compiler and Framework -python -m pip install --upgrade neuronx-cc==2.* --pre torch-neuronx==2.1.* torchvision transformers-neuronx -``` - #### Install vLLM from source -Once neuronx-cc and transformers-neuronx packages are installed, we will be able to install vllm as follows: +Install vllm as follows: ```console git clone https://github.com/vllm-project/vllm.git cd vllm pip install -U -r requirements/neuron.txt -VLLM_TARGET_DEVICE="neuron" pip install . +VLLM_TARGET_DEVICE="neuron" pip install -e . ``` -If neuron packages are detected correctly in the installation process, `vllm-0.3.0+neuron212` will be installed. +AWS Neuron maintains a [Github fork of vLLM](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2) at + [https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2), which contains several features in addition to what's + available on vLLM V0. Please utilize the AWS Fork for the following features: + +- Llama-3.2 multi-modal support +- Multi-node distributed inference + +Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html) + for more details and usage examples. + +To install the AWS Neuron fork, run the following: + +```console +git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git +cd upstreaming-to-vllm +pip install -r requirements/neuron.txt +VLLM_TARGET_DEVICE="neuron" pip install -e . +``` + +Note that the AWS Neuron fork is only intended to support Neuron hardware; compatibility with other hardwares is not tested. # --8<-- [end:build-wheel-from-source] # --8<-- [start:set-up-using-docker] @@ -148,5 +98,57 @@ Make sure to use in place of the default Dock # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] -There is no extra information for this device. +[](){ #feature-support-through-nxd-inference-backend } +### Feature support through NxD Inference backend + +The current vLLM and Neuron integration relies on either the `neuronx-distributed-inference` (preferred) or `transformers-neuronx` backend + to perform most of the heavy lifting which includes PyTorch model initialization, compilation, and runtime execution. Therefore, most + [features supported on Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html) are also available via the vLLM integration. + +To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override +as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include +```console +override_neuron_config={ + "enable_bucketing":False, +} +``` +or when launching vLLM from the CLI, pass +```console +--override-neuron-config "{\"enable_bucketing\":false}" +``` + +Alternatively, users can directly call the NxDI library to trace and compile your model, then load the pre-compiled artifacts +(via `NEURON_COMPILED_ARTIFACTS` environment variable) in vLLM to run inference workloads. + +### Known limitations + +- EAGLE speculative decoding: NxD Inference requires the EAGLE draft checkpoint to include the LM head weights from the target model. Refer to this + [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html#eagle-checkpoint-compatibility) + for how to convert pretrained EAGLE model checkpoints to be compatible for NxDI. +- Quantization: the native quantization flow in vLLM is not well supported on NxD Inference. It is recommended to follow this + [Neuron quantization guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/custom-quantization.html) + to quantize and compile your model using NxD Inference, and then load the compiled artifacts into vLLM. +- Multi-LoRA serving: NxD Inference only supports loading of LoRA adapters at server startup. Dynamic loading of LoRA adapters at + runtime is not currently supported. Refer to [multi-lora example](https://github.com/aws-neuron/upstreaming-to-vllm/blob/neuron-2.23-vllm-v0.7.2/examples/offline_inference/neuron_multi_lora.py) +- Multi-modal support: multi-modal support is only available through the AWS Neuron fork. This feature has not been upstreamed + to vLLM main because NxD Inference currently relies on certain adaptations to the core vLLM logic to support this feature. +- Multi-node support: distributed inference across multiple Trainium/Inferentia instances is only supported on the AWS Neuron fork. Refer + to this [multi-node example](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2/examples/neuron/multi_node) + to run. Note that tensor parallelism (distributed inference across NeuronCores) is available in vLLM main. +- Known edge case bug in speculative decoding: An edge case failure may occur in speculative decoding when sequence length approaches + max model length (e.g. when requesting max tokens up to the max model length and ignoring eos). In this scenario, vLLM may attempt + to allocate an additional block to ensure there is enough memory for number of lookahead slots, but since we do not have good support + for paged attention, there isn't another Neuron block for vLLM to allocate. A workaround fix (to terminate 1 iteration early) is + implemented in the AWS Neuron fork but is not upstreamed to vLLM main as it modifies core vLLM logic. + + +### Environment variables +- `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid + compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the + artifacts under `neuron-compiled-artifacts/{unique_hash}/` sub-directory in the model path. If this environment variable is set, + but the directory does not exist, or the contents are invalid, Neuron will also fallback to a new compilation and store the artifacts + under this specified path. +- `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). +- `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). + # --8<-- [end:extra-information] diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7594c6e6fbf1e..b60fefdda2793 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -302,31 +302,31 @@ Specified using `--task generate`. | Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | |---------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------| | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | -| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | ✅︎ | | +| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | | | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | ✅︎ | | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | | `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | -| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | ✅︎ | | -| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | ✅︎ | | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | ✅︎ | | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | ✅︎ | | +| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | +| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | -| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | ✅︎ | | -| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | ✅︎ | ✅︎ | +| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | +| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | | `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | | `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | | `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | -| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | ✅︎ | | +| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | | `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | -| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | ✅︎ | | -| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | ✅︎ | | +| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | +| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | @@ -336,39 +336,39 @@ Specified using `--task generate`. | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | -| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | ✅︎ | | +| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | -| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | ✅︎ | | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | -| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | ✅︎ | | +| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | | -| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | ✅︎ | ✅︎ | -| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | | -| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | ✅︎ | | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | +| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | +| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | +| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | -| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | ✅︎ | | +| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | -| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | ✅︎ | | +| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | | `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | -| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | | +| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | -| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | | -| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | ✅︎ | | -| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | ✅︎ | | +| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | | ✅︎ | +| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | +| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | | `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | | `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | | `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | -| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | ✅︎ | | +| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | !!! note @@ -401,7 +401,7 @@ Specified using `--task embed`. !!! note `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. - You should manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. + You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. !!! note For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded. @@ -512,44 +512,44 @@ Specified using `--task generate`. | Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | |----------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | ✅︎ | ✅︎ | | -| `AyaVisionForConditionalGeneration` | Aya Vision | T + I+ | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | ✅︎ | ✅︎ | | -| `Blip2ForConditionalGeneration` | BLIP-2 | T + IE | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | ✅︎ | ✅︎ | | -| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | ✅︎ | ✅︎ | | -| `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | ✅︎ | ✅︎ | | +| `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | | | ✅︎ | +| `AyaVisionForConditionalGeneration` | Aya Vision | T + I+ | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | +| `Blip2ForConditionalGeneration` | BLIP-2 | T + IE | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | +| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | | ✅︎ | ✅︎ | +| `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. | | | | -| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | ✅︎ | ✅︎ | | +| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | -| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎\* | | -| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | ✅︎ | | -| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | ✅︎ | | | -| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | | -| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | ✅︎ | ✅︎ | | -| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + IE+ | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | ✅︎ | ✅︎ | | -| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | ✅︎ | ✅︎ | | -| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | ✅︎ | ✅︎ | | +| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | +| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ | +| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | | ✅︎ | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | +| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | | ✅︎ | ✅︎ | +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + IE+ | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | ✅︎ | ✅︎ | | +| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | +| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | | | `Mistral3ForConditionalGeneration` | Mistral3 | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | +| `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | ✅︎ | ✅︎ | | -| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | ✅︎ | | | -| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ⚠️ | | -| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | ✅︎ | ✅︎ | | -| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | | -| `PixtralForConditionalGeneration` | Pixtral | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | ✅︎ | ✅︎ | | +| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | +| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | +| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | +| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `PixtralForConditionalGeneration` | Pixtral | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | | `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | ✅︎ | ✅︎ | | +| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎\* | | -| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | ✅︎ | ✅︎ | | -| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | ✅︎ | | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* | +| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | +| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: @@ -647,7 +647,7 @@ The following table lists those that are tested in vLLM. | Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | |-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------| -| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | ✅︎ | | +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | | `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | #### Transcription diff --git a/docs/usage/security.md b/docs/usage/security.md index f1661828d68a4..1209cc8dd4572 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -12,14 +12,14 @@ All communications between nodes in a multi-node vLLM deployment are **insecure The following options control inter-node communications in vLLM: -1. **Environment Variables:** +#### 1. **Environment Variables:** - `VLLM_HOST_IP`: Sets the IP address for vLLM processes to communicate on -2. **KV Cache Transfer Configuration:** +#### 2. **KV Cache Transfer Configuration:** - `--kv-ip`: The IP address for KV cache transfer communications (default: 127.0.0.1) - `--kv-port`: The port for KV cache transfer communications (default: 14579) -3. **Data Parallel Configuration:** +#### 3. **Data Parallel Configuration:** - `data_parallel_master_ip`: IP of the data parallel master (default: 127.0.0.1) - `data_parallel_master_port`: Port of the data parallel master (default: 29500) @@ -39,16 +39,16 @@ Key points from the PyTorch security guide: ### Security Recommendations -1. **Network Isolation:** +#### 1. **Network Isolation:** - Deploy vLLM nodes on a dedicated, isolated network - Use network segmentation to prevent unauthorized access - Implement appropriate firewall rules -2. **Configuration Best Practices:** +#### 2. **Configuration Best Practices:** - Always set `VLLM_HOST_IP` to a specific IP address rather than using defaults - Configure firewalls to only allow necessary ports between nodes -3. **Access Control:** +#### 3. **Access Control:** - Restrict physical and network access to the deployment environment - Implement proper authentication and authorization for management interfaces - Follow the principle of least privilege for all system components diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index bf60d883c410e..15906e1a2768d 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -97,10 +97,14 @@ def main( # with DP, each rank should process different prompts. # usually all the DP ranks process a full dataset, # and each rank processes a different part of the dataset. - promts_per_rank = len(prompts) // dp_size - start = global_dp_rank * promts_per_rank - end = start + promts_per_rank - prompts = prompts[start:end] + floor = len(prompts) // dp_size + remainder = len(prompts) % dp_size + + # Distribute prompts into even groups. + def start(rank): + return rank * floor + min(rank, remainder) + + prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)] if len(prompts) == 0: # if any rank has no prompts to process, # we need to set a placeholder prompt diff --git a/examples/offline_inference/neuron_multimodal.py b/examples/offline_inference/neuron_multimodal.py new file mode 100644 index 0000000000000..a9478650b16f1 --- /dev/null +++ b/examples/offline_inference/neuron_multimodal.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +import requests +import torch +from neuronx_distributed_inference.models.mllama.utils import add_instruct +from PIL import Image + +from vllm import LLM, SamplingParams, TextPrompt + + +def get_image(image_url): + image = Image.open(requests.get(image_url, stream=True).raw) + return image + + +# Model Inputs +PROMPTS = [ + "What is in this image? Tell me a story", + "What is the recipe of mayonnaise in two sentences?", + "Describe this image", + "What is the capital of Italy famous for?", +] +IMAGES = [ + get_image( + "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" + ), + None, + get_image( + "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" + ), + None, +] +SAMPLING_PARAMS = [ + dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16) + for _ in range(len(PROMPTS)) +] + + +def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params): + # Prepare all inputs for mllama generation, including: + # 1. put text prompt into instruct chat template + # 2. compose single text and single image prompt into Vllm's prompt class + # 3. prepare sampling parameters + input_image = single_image + has_image = torch.tensor([1]) + if isinstance(single_image, torch.Tensor) and single_image.numel() == 0: + has_image = torch.tensor([0]) + + instruct_prompt = add_instruct(prompt, has_image) + inputs = TextPrompt(prompt=instruct_prompt) + + if input_image is not None: + inputs["multi_modal_data"] = {"image": input_image} + + sampling_params = SamplingParams(**sampling_params) + return inputs, sampling_params + + +def print_outputs(outputs): + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + assert ( + len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS) + ), f"""Text, image prompts and sampling parameters should have the + same batch size; but got {len(PROMPTS)}, {len(IMAGES)}, + and {len(SAMPLING_PARAMS)}""" + + # Create an LLM. + llm = LLM( + model="meta-llama/Llama-3.2-11B-Vision-Instruct", + max_num_seqs=1, + max_model_len=4096, + block_size=4096, + device="neuron", + tensor_parallel_size=32, + override_neuron_config={ + "sequence_parallel_enabled": False, + "skip_warmup": True, + "save_sharded_checkpoint": True, + "on_device_sampling_config": { + "global_topk": 1, + "dynamic": False, + "deterministic": False, + }, + }, + ) + + batched_inputs = [] + batched_sample_params = [] + for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS): + inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params) + # test batch-size = 1 + outputs = llm.generate(inputs, sampling_params) + print_outputs(outputs) + batched_inputs.append(inputs) + batched_sample_params.append(sampling_params) + + # test batch-size = 4 + outputs = llm.generate(batched_inputs, batched_sample_params) + print_outputs(outputs) diff --git a/pyproject.toml b/pyproject.toml index 5286724b5ca5f..10f5dbeae6851 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ requires = [ "setuptools-scm>=8.0", "torch == 2.7.0", "wheel", - "regex", "jinja2", ] build-backend = "setuptools.build_meta" diff --git a/requirements/common.txt b/requirements/common.txt index 625efc3366f48..de4b3b53166c9 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -14,7 +14,7 @@ protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) -pydantic >= 2.9 +pydantic >= 2.10 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 diff --git a/requirements/test.in b/requirements/test.in index 87af617690388..e906752ff875b 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -51,3 +51,4 @@ numpy runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 +pydantic>=2.10 # 2.9 leads to error on python 3.10 \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index 89d477017342e..60dcaca816a2b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -480,12 +480,13 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.9.2 +pydantic==2.11.5 # via + # -r requirements/test.in # datamodel-code-generator # mistral-common # mteb -pydantic-core==2.23.4 +pydantic-core==2.33.2 # via pydantic pygments==2.18.0 # via rich @@ -784,6 +785,9 @@ typing-extensions==4.12.2 # pydantic-core # torch # typer + # typing-inspection +typing-inspection==0.4.1 + # via pydantic tzdata==2024.2 # via pandas uri-template==1.3.0 diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 3b204a8f99056..edc8b2a456670 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250518 -torchvision==0.22.0.dev20250518 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250529 +torchvision==0.22.0.dev20250529 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/setup.py b/setup.py index b822a4ec36a4f..c190864dda94e 100644 --- a/setup.py +++ b/setup.py @@ -5,12 +5,12 @@ import importlib.util import json import logging import os +import re import subprocess import sys from pathlib import Path from shutil import which -import regex as re import torch from packaging.version import Version, parse from setuptools import Extension, setup diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 86b5e1e0ab7cf..11c8e7a4b9d1c 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) -@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @@ -69,7 +68,6 @@ def test_models( hf_runner, model: str, backend: str, - dtype: str, max_tokens: int, enforce_eager: bool, enable_prompt_embeds: bool, @@ -97,7 +95,7 @@ def test_models( str(i) for i in range(1024)) + " are:" example_prompts = [prompt] - with hf_runner(model, dtype=dtype) as hf_model: + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) if enable_prompt_embeds: with torch.no_grad(): @@ -106,7 +104,6 @@ def test_models( with VllmRunner(model, max_model_len=8192, - dtype=dtype, enforce_eager=enforce_eager, enable_prompt_embeds=enable_prompt_embeds, gpu_memory_utilization=0.7) as vllm_model: diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 143cb49697f5b..5ce520a440257 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -74,11 +74,12 @@ class SillyModel(nn.Module): return x -def test_simple_piecewise_compile(): +def _test_simple_piecewise_compile(*, use_inductor): vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, + use_inductor=use_inductor, splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, cudagraph_capture_sizes=[1, 2], @@ -108,3 +109,11 @@ def test_simple_piecewise_compile(): output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + + +def test_simple_piecewise_compile_inductor(): + _test_simple_piecewise_compile(use_inductor=True) + + +def test_simple_piecewise_compile_no_inductor(): + _test_simple_piecewise_compile(use_inductor=False) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index d4551b1cc3aec..22560befcbd56 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor, @torch.inference_mode def run_model(llama_config, use_compile: bool, + use_inductor: bool, split_attn: bool = False) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, + use_inductor=use_inductor, cudagraph_capture_sizes=[1, 2], ) if split_attn: @@ -304,7 +306,7 @@ def run_model(llama_config, return output.cpu() -def test_toy_llama(): +def _test_toy_llama(*, use_inductor): # compare output with and without piecewise compilation llama_config = LlamaConfig(hidden_size=128, @@ -326,8 +328,14 @@ def test_toy_llama(): num_backend_compilations=0, num_cudagraph_caputured=0, ): - outputs.append(run_model(llama_config, use_compile=False)) - run_model(tractable_config, use_compile=False) + outputs.append( + run_model(llama_config, use_inductor=False, use_compile=False)) + run_model(tractable_config, use_inductor=False, use_compile=False) + + if use_inductor: + kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} + else: + kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -336,9 +344,13 @@ def test_toy_llama(): num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + **kwargs, ): - outputs.append(run_model(llama_config, use_compile=True)) - run_model(tractable_config, use_compile=True) + outputs.append( + run_model(llama_config, + use_inductor=use_inductor, + use_compile=True)) + run_model(tractable_config, use_inductor=use_inductor, use_compile=True) with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -353,13 +365,27 @@ def test_toy_llama(): ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model(llama_config, use_compile=True, split_attn=True)) - run_model(tractable_config, use_compile=True, split_attn=True) + run_model(llama_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True)) + run_model(tractable_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) +def test_toy_llama_inductor(): + _test_toy_llama(use_inductor=True) + + +def test_toy_no_inductor(): + _test_toy_llama(use_inductor=False) + + @torch.inference_mode def benchmark(): from triton.testing import do_bench diff --git a/tests/conftest.py b/tests/conftest.py index 19c2c62471295..6336c6c2ce011 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -311,6 +311,7 @@ class HfRunner: dtype: str = "auto", *, model_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = True, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, @@ -320,10 +321,15 @@ class HfRunner: self.config = AutoConfig.from_pretrained( model_name, - trust_remote_code=True, + trust_remote_code=trust_remote_code, ) self.device = self.get_default_device() - self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype) + self.dtype = torch_dtype = _get_and_verify_dtype( + self.model_name, + self.config, + dtype=dtype, + is_pooling_model=is_sentence_transformer or is_cross_encoder, + ) model_kwargs = model_kwargs if model_kwargs is not None else {} model_kwargs.setdefault("torch_dtype", torch_dtype) @@ -336,7 +342,7 @@ class HfRunner: model_name, device=self.device, model_kwargs=model_kwargs, - trust_remote_code=True, + trust_remote_code=trust_remote_code, ) elif is_cross_encoder: # Lazy init required for AMD CI @@ -346,12 +352,12 @@ class HfRunner: model_name, device=self.device, automodel_args=model_kwargs, - trust_remote_code=True, + trust_remote_code=trust_remote_code, ) else: model = auto_cls.from_pretrained( model_name, - trust_remote_code=True, + trust_remote_code=trust_remote_code, **model_kwargs, ) @@ -372,7 +378,7 @@ class HfRunner: self.tokenizer = AutoTokenizer.from_pretrained( model_name, torch_dtype=torch_dtype, - trust_remote_code=True, + trust_remote_code=trust_remote_code, ) # don't put this import at the top level @@ -381,7 +387,7 @@ class HfRunner: self.processor = AutoProcessor.from_pretrained( model_name, torch_dtype=torch_dtype, - trust_remote_code=True, + trust_remote_code=trust_remote_code, ) if skip_tokenizer_init: self.tokenizer = self.processor.tokenizer diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5346d67b10d16..e6410ab068d23 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -227,6 +227,7 @@ MULTIMODAL_MODELS = { "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(), + "AIDC-AI/Ovis2-1B": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(), diff --git a/tests/entrypoints/llm/test_init.py b/tests/entrypoints/llm/test_init.py deleted file mode 100644 index 925bf56a93402..0000000000000 --- a/tests/entrypoints/llm/test_init.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from vllm import LLM - -from ...utils import error_on_warning - -MODEL_NAME = "facebook/opt-125m" - - -def test_pos_args_deprecated(): - with error_on_warning(DeprecationWarning): - LLM(model=MODEL_NAME, tokenizer=MODEL_NAME) - - with error_on_warning(DeprecationWarning): - LLM(MODEL_NAME, tokenizer=MODEL_NAME) - - with pytest.warns(DeprecationWarning, match="'tokenizer'"): - LLM(MODEL_NAME, MODEL_NAME) - - with pytest.warns(DeprecationWarning, - match="'tokenizer', 'tokenizer_mode'"): - LLM(MODEL_NAME, MODEL_NAME, "auto") diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index 27802945a2164..99639ce51aa74 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -4,6 +4,8 @@ import json import subprocess import tempfile +import pytest + from vllm.entrypoints.openai.protocol import BatchRequestOutput # ruff: noqa: E501 @@ -23,9 +25,13 @@ INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": " {"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}} {"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}""" -INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} +INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} {"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" +INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} +{"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" + def test_empty_file(): with tempfile.NamedTemporaryFile( @@ -105,11 +111,13 @@ def test_embeddings(): BatchRequestOutput.model_validate_json(line) -def test_score(): +@pytest.mark.parametrize("input_batch", + [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) +def test_score(input_batch): with tempfile.NamedTemporaryFile( "w") as input_file, tempfile.NamedTemporaryFile( "r") as output_file: - input_file.write(INPUT_SCORE_BATCH) + input_file.write(input_batch) input_file.flush() proc = subprocess.Popen([ "vllm", diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 9773f3e45b99c..7d823542e3744 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -76,11 +76,11 @@ async def test_tokenize_completions( }) response.raise_for_status() - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192 - } + result = response.json() + assert result["tokens"] == tokens + assert result["count"] == len(tokens) + assert result["max_model_len"] == 8192 + assert result["token_strs"] is None @pytest.mark.asyncio @@ -138,11 +138,11 @@ async def test_tokenize_chat( }) response.raise_for_status() - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192 - } + result = response.json() + assert result["tokens"] == tokens + assert result["count"] == len(tokens) + assert result["max_model_len"] == 8192 + assert result["token_strs"] is None @pytest.mark.asyncio @@ -215,11 +215,46 @@ async def test_tokenize_chat_with_tools( ) response.raise_for_status() - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192, - } + result = response.json() + assert result["tokens"] == tokens + assert result["count"] == len(tokens) + assert result["max_model_len"] == 8192 + assert result["token_strs"] is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name, tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], +) +async def test_tokenize_with_return_token_strs( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, + tokenizer_mode="fast") + + prompt = "This is a token_strs test prompt! vllm1" + response = requests.post( + server.url_for("tokenize"), + json={ + "prompt": prompt, + "model": model_name, + "return_token_strs": True + }, + ) + response.raise_for_status() + + tokens = tokenizer.encode(prompt, add_special_tokens=True) + tokens_str = tokenizer.convert_ids_to_tokens(tokens) + + result = response.json() + assert result["tokens"] == tokens + assert result["count"] == len(tokens) + assert result["max_model_len"] == 8192 + assert result["token_strs"] == tokens_str @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 92ba1376e2002..f5f327ea068c6 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -191,3 +191,27 @@ def test_streaming_tool_call_with_large_steps(): assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL + + +@pytest.mark.parametrize("streaming", [False]) +def test_regex_timeout_handling(streaming: bool): + """test regex timeout is handled gracefully""" + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + + fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 + + # create a mock regex that raises TimeoutError + mock_regex = MagicMock() + mock_regex.match.side_effect = TimeoutError("Regex timeout") + + with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): + content, tool_calls = run_tool_extraction(tool_parser, + fake_problematic_input, + streaming=streaming) + + # should treat as regular text when regex times out + assert content == fake_problematic_input + assert len(tool_calls) == 0 + mock_regex.match.assert_called_once() diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index fbbbc1fb2a596..71f41ea7d93b4 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -159,3 +159,27 @@ def test_streaming_tool_call_with_large_steps(): assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL + + +@pytest.mark.parametrize("streaming", [False]) +def test_regex_timeout_handling(streaming: bool): + """test regex timeout is handled gracefully""" + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + + fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 + + # create a mock regex that raises TimeoutError + mock_regex = MagicMock() + mock_regex.match.side_effect = TimeoutError("Regex timeout") + + with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): + content, tool_calls = run_tool_extraction(tool_parser, + fake_problematic_input, + streaming=streaming) + + # should treat as regular text when regex times out + assert content == fake_problematic_input + assert len(tool_calls) == 0 + mock_regex.match.assert_called_once() diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py new file mode 100644 index 0000000000000..0dd1fdd996948 --- /dev/null +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing +import socket +import threading +import time +from typing import Optional +from unittest.mock import patch + +import pytest + +from vllm.v1.utils import (APIServerProcessManager, + wait_for_completion_or_failure) + +# Global variables to control worker behavior +WORKER_RUNTIME_SECONDS = 0.5 + + +# Mock implementation of run_api_server_worker +def mock_run_api_server_worker(listen_address, sock, args, client_config=None): + """Mock run_api_server_worker that runs for a specific time.""" + print(f"Mock worker started with client_config: {client_config}") + time.sleep(WORKER_RUNTIME_SECONDS) + print("Mock worker completed successfully") + + +@pytest.fixture +def api_server_args(): + """Fixture to provide arguments for APIServerProcessManager.""" + sock = socket.socket() + return { + "target_server_fn": + mock_run_api_server_worker, + "listen_address": + "localhost:8000", + "sock": + sock, + "args": + "test_args", # Simple string to avoid pickling issues + "num_servers": + 3, + "input_addresses": [ + "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", + "tcp://127.0.0.1:5003" + ], + "output_addresses": [ + "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", + "tcp://127.0.0.1:6003" + ], + "stats_update_address": + "tcp://127.0.0.1:7000", + } + + +@pytest.mark.parametrize("with_stats_update", [True, False]) +def test_api_server_process_manager_init(api_server_args, with_stats_update): + """Test initializing the APIServerProcessManager.""" + # Set the worker runtime to ensure tests complete in reasonable time + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.5 + + # Copy the args to avoid mutating the + args = api_server_args.copy() + + if not with_stats_update: + args.pop("stats_update_address") + manager = APIServerProcessManager(**args) + + try: + # Verify the manager was initialized correctly + assert len(manager.processes) == 3 + + # Verify all processes are running + for proc in manager.processes: + assert proc.is_alive() + + print("Waiting for processes to run...") + time.sleep(WORKER_RUNTIME_SECONDS / 2) + + # They should still be alive at this point + for proc in manager.processes: + assert proc.is_alive() + + finally: + # Always clean up the processes + print("Cleaning up processes...") + manager.close() + + # Give processes time to terminate + time.sleep(0.2) + + # Verify all processes were terminated + for proc in manager.processes: + assert not proc.is_alive() + + +@patch("vllm.entrypoints.cli.serve.run_api_server_worker", + mock_run_api_server_worker) +def test_wait_for_completion_or_failure(api_server_args): + """Test that wait_for_completion_or_failure works with failures.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 1.0 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result: dict[str, Optional[Exception]] = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Let all processes run for a short time + time.sleep(0.2) + + # All processes should still be running + assert all(proc.is_alive() for proc in manager.processes) + + # Now simulate a process failure + print("Simulating process failure...") + manager.processes[0].terminate() + + # Wait for the wait_for_completion_or_failure + # to detect and handle the failure + # This should trigger it to terminate all other processes + wait_thread.join(timeout=1.0) + + # The wait thread should have exited + assert not wait_thread.is_alive() + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None + assert "died with exit code" in str(result["exception"]) + + # All processes should now be terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive(), f"Process {i} should not be alive" + + finally: + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_normal_completion(api_server_args): + """Test that wait_for_completion_or_failure works in normal completion.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.1 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Give processes time to terminate + # wait for processes to complete + remaining_processes = manager.processes.copy() + while remaining_processes: + for proc in remaining_processes: + if not proc.is_alive(): + remaining_processes.remove(proc) + time.sleep(0.1) + + # Verify all processes have terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"Process {i} still alive after terminate()" + + # Now call wait_for_completion_or_failure + # since all processes have already + # terminated, it should return immediately + # with no error + wait_for_completion_or_failure(api_server_manager=manager) + + finally: + # Clean up just in case + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_external_process_monitoring(api_server_args): + """Test that wait_for_completion_or_failure handles additional processes.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 100 + + # Create and start the external process + # (simulates local_engine_manager or coordinator) + spawn_context = multiprocessing.get_context("spawn") + external_proc = spawn_context.Process(target=mock_run_api_server_worker, + name="MockExternalProcess") + external_proc.start() + + # Create the class to simulate a coordinator + class MockCoordinator: + + def __init__(self, proc): + self.proc = proc + + def close(self): + if self.proc.is_alive(): + self.proc.terminate() + self.proc.join(timeout=0.5) + + # Create a mock coordinator with the external process + mock_coordinator = MockCoordinator(external_proc) + + # Create the API server manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Verify manager initialization + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result: dict[str, Optional[Exception]] = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager, + coordinator=mock_coordinator) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Terminate the external process to trigger a failure + time.sleep(0.2) + external_proc.terminate() + + # Wait for the thread to detect the failure + wait_thread.join(timeout=1.0) + + # The wait thread should have completed + assert not wait_thread.is_alive( + ), "wait_for_completion_or_failure thread still running" + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None, "No exception was raised" + error_message = str(result["exception"]) + assert "died with exit code" in error_message, \ + f"Unexpected error message: {error_message}" + assert "MockExternalProcess" in error_message, \ + f"Error doesn't mention external process: {error_message}" + + # Verify that all API server processes were terminated as a result + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"API server process {i} was not terminated" + + finally: + # Clean up + manager.close() + mock_coordinator.close() + time.sleep(0.2) diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index f327deb0e549e..8cb56314cf94a 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -70,7 +70,7 @@ def test_rotary_embedding( device: str, use_key: bool, max_position: int = 8192, - base: int = 10000, + base: float = 10000, ) -> None: if rotary_dim is None: rotary_dim = head_size @@ -135,7 +135,7 @@ def test_batched_rotary_embedding( device: str, use_key: bool, max_position: int = 8192, - base: int = 10000, + base: float = 10000, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora( device: str, use_key: bool, max_position: int = 8192, - base: int = 10000, + base: float = 10000, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py index b0d34ddfd4234..922fd66dbef49 100644 --- a/tests/kernels/moe/test_rocm_aiter_topk.py +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -35,6 +35,15 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) +def test_rocm_aiter_grouped_topk_custom_op_registration(): + """Test that the custom op is correctly registered.""" + # Check if the op exists in torch.ops.vllm + assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk') + + # Check if the op is callable + assert callable(torch.ops.vllm.rocm_aiter_grouped_topk) + + def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): """Test that the op can be used with torch.compile.""" # Create test tensors @@ -120,3 +129,87 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): rtol=1e-2, atol=1e-2) assert torch.allclose(topk_ids_original, topk_ids_compiled) + + +def test_rocm_aiter_grouped_topk_torch_compile_compatibility(): + """Test that the op can be used with torch.compile.""" + # Create test tensors + token = 64 + expert = 256 + num_expert_group = 8 + topk = 8 + topk_group = 4 + renormalize = True + scoring_func = "softmax" + scale_factor = 1.0 + + gating_output = torch.randn((token, expert), + dtype=torch.bfloat16, + device="cuda") + + device = gating_output.device + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), + dtype=torch.float32, + device=device) + + # Define a function that uses the op + def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func): + return torch.ops.vllm.rocm_aiter_grouped_topk( + gating_output, topk_weights, topk_ids, num_expert_group, + topk_group, renormalize, scoring_func, scale_factor) + + # Verify the op's fake implementation + torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk, + (gating_output, topk_weights, topk_ids), + kwargs={ + "num_expert_group": num_expert_group, + "topk_group": topk_group, + "need_renorm": renormalize, + "scoring_func": scoring_func, + "routed_scaling_factor": scale_factor + }, + test_utils=("test_faketensor")) + + # Compile the function with appropriate settings + compiled_fn = torch.compile(grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False) + + topk_weights_original = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_original = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + topk_weights_compiled = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_compiled = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) + grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original, + scoring_func) + compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, + scoring_func) + + # Sort the results for comparison since the order might not be deterministic + topk_ids_original, indices_original = torch.sort(topk_ids_original) + topk_weights_original = torch.gather(topk_weights_original, 1, + indices_original) + + topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, + indices_compiled) + + # Verify results match + assert torch.allclose(topk_weights_original, + topk_weights_compiled, + rtol=1e-2, + atol=1e-2) + assert torch.allclose(topk_ids_original, topk_ids_compiled) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 764924f26783d..892309a017e43 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -8,7 +8,7 @@ from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for rocm. -ROCM_FP8_MAX = 224.0 +ROCM_FP8FNUZ_MAX = 224.0 FP8_DTYPE = current_platform.fp8_dtype() @@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor, qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) - qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ + qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else qtype_traits.max - qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \ + qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else qtype_traits.min qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) @@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ -> tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ + fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else fp8_traits.max - fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \ + fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ + and current_platform.is_fp8_fnuz() \ else fp8_traits.min fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index f4837ae952c3f..f45168bc0f1d6 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner, vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.model.llm_engine.model_config.dtype - model_dtype = getattr( - vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", - vllm_dtype) - with set_default_torch_dtype(model_dtype) and hf_runner( + with set_default_torch_dtype(vllm_dtype) and hf_runner( model_info.name, is_sentence_transformer=True, - dtype=model_dtype) as hf_model: + dtype=vllm_dtype) as hf_model: if hf_model_callback is not None: hf_model_callback(hf_model) st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) - print("VLLM:", vllm_dtype, vllm_main_score) - print("SentenceTransformer:", model_dtype, st_main_score) + print("VLLM:", vllm_main_score) + print("SentenceTransformers:", st_main_score) print("Difference:", st_main_score - vllm_main_score) assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 44af3df08a867..57b3cb58d88ba 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -43,6 +43,6 @@ def test_models( # the tolerance value of 1e-2 is selected based on the # half datatype tests in - # tests/models/embedding/language/test_embedding.py + # tests/models/language/pooling/test_embedding.py assert torch.allclose(hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index a44b2154b1376..8f82c8091af37 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -10,29 +10,31 @@ from ...utils import check_embeddings_close @pytest.mark.parametrize( "model", [ - # [Encoder-only] - pytest.param("BAAI/bge-base-en-v1.5", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2"), - pytest.param("intfloat/multilingual-e5-small"), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), + # Be careful of the order of models, decoder-only models should be + # placed before encoder-only models, otherwise `Qwen2.5-0.5B-Instruct` + # case won't pass because gte-Qwen2-1.5B-instruct will cache custom + # model code with bidirectional attention. # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), pytest.param("intfloat/e5-mistral-7b-instruct", marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), + # [Encoder-only] + pytest.param("BAAI/bge-base-en-v1.5", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2"), + pytest.param("intfloat/multilingual-e5-small"), + pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), # [Cross-Encoder] pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) -@pytest.mark.parametrize("dtype", ["half"]) def test_models( hf_runner, vllm_runner, example_prompts, model, - dtype: str, monkeypatch, ) -> None: @@ -44,7 +46,7 @@ def test_models( vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": vllm_extra_kwargs["override_pooler_config"] = \ - PoolerConfig(pooling_type="MEAN") + PoolerConfig(pooling_type="MEAN", normalize=False) # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" @@ -54,13 +56,11 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, dtype=dtype, - is_sentence_transformer=True) as hf_model: + with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model, task="embed", - dtype=dtype, max_model_len=None, **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 18b27a688146d..725e3d168408b 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -45,6 +45,7 @@ MODELS = [ ########### Qwen2ForCausalLM EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", architecture="Qwen2ForCausalLM", + dtype="float32", enable_test=True), ########## ModernBertModel EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4e48bdbd04289..d0b85842a3d8f 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -100,6 +100,7 @@ def run_test( with vllm_runner( model, + dtype="half", max_model_len=448, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 572fa366d3325..d7f950c23d954 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -40,7 +40,7 @@ def _test_processing_correctness( tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, seed=0, - dtype="float16", + dtype="auto", revision=None, hf_overrides=model_info.hf_overrides, ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 18342b671e0d8..fe49d2427c744 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -434,6 +434,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", + trust_remote_code=True, + is_available_online=False, + speculative_model="openbmb/MiniCPM-2B-sft-bf16", + tokenizer="openbmb/MiniCPM-2B-sft-bf16"), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL") diff --git a/tests/models/utils.py b/tests/models/utils.py index ac1fc6c8f0e2e..ffc904bd10f46 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -314,6 +314,7 @@ def check_embeddings_close( dim=0) fail_msg = (f"Test{prompt_idx}:" + f"\nCosine similarity: \t{sim:.4f}" f"\n{name_0}:\t{embeddings_0[:16]!r}" f"\n{name_1}:\t{embeddings_1[:16]!r}") diff --git a/tests/neuron/2_core/test_multi_lora.py b/tests/neuron/2_core/test_multi_lora.py new file mode 100644 index 0000000000000..6fa8f9128def7 --- /dev/null +++ b/tests/neuron/2_core/test_multi_lora.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 + +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + + +def test_llama_single_lora(): + sql_lora_files = snapshot_download( + repo_id="yard1/llama-2-7b-sql-lora-test") + llm = LLM(model="meta-llama/Llama-2-7b-hf", + tensor_parallel_size=2, + max_num_seqs=4, + max_model_len=512, + use_v2_block_manager=True, + override_neuron_config={ + "sequence_parallel_enabled": False, + "skip_warmup": True, + "lora_modules": [{ + "name": "lora_id_1", + "path": sql_lora_files + }] + }, + enable_lora=True, + max_loras=1, + max_lora_rank=256, + device="neuron") + """For multi-lora requests using NxDI as the backend, only the lora_name + needs to be specified. The lora_id and lora_path are supplied at the LLM + class/server initialization, after which the paths are handled by NxDI""" + lora_req_1 = LoRARequest("lora_id_1", 0, " ") + prompts = [ + "The president of the United States is", + "The capital of France is", + ] + outputs = llm.generate(prompts, + SamplingParams(top_k=1), + lora_request=[lora_req_1, lora_req_1]) + + expected_outputs = [ + " the head of state and head of government of the United States. " + "The president direct", + " a city of contrasts. The city is home to the Eiffel Tower" + ] + + for expected_output, output in zip(expected_outputs, outputs): + generated_text = output.outputs[0].text + assert (expected_output == generated_text) + + +def test_llama_multiple_lora(): + sql_lora_files = snapshot_download( + repo_id="yard1/llama-2-7b-sql-lora-test") + llm = LLM(model="meta-llama/Llama-2-7b-hf", + tensor_parallel_size=2, + max_num_seqs=4, + max_model_len=512, + use_v2_block_manager=True, + override_neuron_config={ + "sequence_parallel_enabled": + False, + "skip_warmup": + True, + "lora_modules": [{ + "name": "lora_id_1", + "path": sql_lora_files + }, { + "name": "lora_id_2", + "path": sql_lora_files + }] + }, + enable_lora=True, + max_loras=2, + max_lora_rank=256, + device="neuron") + """For multi-lora requests using NxDI as the backend, only the lora_name + needs to be specified. The lora_id and lora_path are supplied at the LLM + class/server initialization, after which the paths are handled by NxDI""" + lora_req_1 = LoRARequest("lora_id_1", 0, " ") + lora_req_2 = LoRARequest("lora_id_2", 1, " ") + prompts = [ + "The president of the United States is", + "The capital of France is", + ] + outputs = llm.generate(prompts, + SamplingParams(top_k=1), + lora_request=[lora_req_1, lora_req_2]) + + expected_outputs = [ + " the head of state and head of government of the United States. " + "The president direct", + " a city of contrasts. The city is home to the Eiffel Tower" + ] + + for expected_output, output in zip(expected_outputs, outputs): + generated_text = output.outputs[0].text + assert (expected_output == generated_text) diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 355e3adcf5f30..f9688b4b9b272 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -103,7 +103,7 @@ class TestTwoTokenBadWord: add_special_tokens=False)[0] def test_two_token_bad_word(self, vllm_runner): - with vllm_runner(self.MODEL) as llm: + with vllm_runner(self.MODEL, dtype="half") as llm: output_token_ids = self._generate(llm) assert output_token_ids[:2] == [ self.target_token_id1, self.target_token_id2 diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index b6286e1483976..747ec56ad6298 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -4,7 +4,6 @@ import gc import os import pathlib import subprocess -from unittest.mock import MagicMock, patch import pytest import torch @@ -16,7 +15,6 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, TensorSerializer, is_vllm_tensorized, - load_with_tensorizer, open_stream, tensorize_vllm_model) # yapf: enable @@ -61,21 +59,6 @@ def write_keyfile(keyfile_path: str): f.write(encryption_params.key) -@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent') -def test_load_with_tensorizer(mock_agent, tensorizer_config): - mock_linear_method = MagicMock() - mock_agent_instance = mock_agent.return_value - mock_agent_instance.deserialize.return_value = MagicMock() - - result = load_with_tensorizer(tensorizer_config, - quant_method=mock_linear_method) - - mock_agent.assert_called_once_with(tensorizer_config, - quant_method=mock_linear_method) - mock_agent_instance.deserialize.assert_called_once() - assert result == mock_agent_instance.deserialize.return_value - - @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_can_deserialize_s3(vllm_runner): model_ref = "EleutherAI/pythia-1.4b" diff --git a/tests/test_utils.py b/tests/test_utils.py index 0b88d05efeaad..dd8777f068887 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,7 +17,8 @@ from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, - bind_kv_cache, deprecate_kwargs, get_open_port, + bind_kv_cache, common_broadcastable_dtype, + deprecate_kwargs, get_open_port, is_lossless_cast, make_zmq_path, make_zmq_socket, memory_profiling, merge_async_iterators, sha256, split_zmq_path, supports_kw, swap_dict_values) @@ -567,12 +568,65 @@ def test_lru_cache(): assert 6 in cache +# yapf: disable +@pytest.mark.parametrize( + ("src_dtype", "tgt_dtype", "expected_result"), + [ + # Different precision_levels + (torch.bool, torch.int8, True), + (torch.bool, torch.float16, True), + (torch.bool, torch.complex32, True), + (torch.int64, torch.bool, False), + (torch.int64, torch.float16, True), + (torch.int64, torch.complex32, True), + (torch.float64, torch.bool, False), + (torch.float64, torch.int8, False), + (torch.float64, torch.complex32, True), + (torch.complex128, torch.bool, False), + (torch.complex128, torch.int8, False), + (torch.complex128, torch.float16, False), + # precision_level=0 + (torch.bool, torch.bool, True), + # precision_level=1 + (torch.int8, torch.int16, True), + (torch.int16, torch.int8, False), + (torch.uint8, torch.int8, False), + (torch.int8, torch.uint8, False), + # precision_level=2 + (torch.float16, torch.float32, True), + (torch.float32, torch.float16, False), + (torch.bfloat16, torch.float32, True), + (torch.float32, torch.bfloat16, False), + # precision_level=3 + (torch.complex32, torch.complex64, True), + (torch.complex64, torch.complex32, False), + ], +) +# yapf: enable +def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): + assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result + + +# yapf: disable +@pytest.mark.parametrize( + ("dtypes", "expected_result"), + [ + ([torch.bool], torch.bool), + ([torch.bool, torch.int8], torch.int8), + ([torch.bool, torch.int8, torch.float16], torch.float16), + ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501 + ], +) +# yapf: enable +def test_common_broadcastable_dtype(dtypes, expected_result): + assert common_broadcastable_dtype(dtypes) == expected_result + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") def build_ctx(): - return pytest.raises(ModuleNotFoundError, - match="No module named") + return pytest.raises(ModuleNotFoundError, match="No module named") with build_ctx(): int(placeholder) @@ -608,6 +662,7 @@ def test_placeholder_module_error_handling(): _ = placeholder_attr.module +# yapf: disable @pytest.mark.parametrize( "obj,key1,key2", [ @@ -618,6 +673,7 @@ def test_placeholder_module_error_handling(): # Tests for both keys do not exist ({1: "a", 2: "b"}, 3, 4), ]) +# yapf: enable def test_swap_dict_values(obj, key1, key2): original_obj = obj.copy() swap_dict_values(obj, key1, key2) @@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2): assert key1 not in obj -def test_model_specification(parser_with_config, - cli_config_file, +def test_model_specification(parser_with_config, cli_config_file, cli_config_file_with_model): # Test model in CLI takes precedence over config - args = parser_with_config.parse_args([ - 'serve', 'cli-model', '--config', cli_config_file_with_model - ]) + args = parser_with_config.parse_args( + ['serve', 'cli-model', '--config', cli_config_file_with_model]) assert args.model_tag == 'cli-model' assert args.served_model_name == 'mymodel' # Test model from config file works args = parser_with_config.parse_args([ - 'serve', '--config', cli_config_file_with_model, + 'serve', + '--config', + cli_config_file_with_model, ]) assert args.model == 'config-model' assert args.served_model_name == 'mymodel' @@ -654,17 +710,19 @@ def test_model_specification(parser_with_config, # Test using --model option raises error with pytest.raises( - ValueError, - match=( - "With `vllm serve`, you should provide the model as a positional " - "argument or in a config file instead of via the `--model` option." - ), + ValueError, + match= + ("With `vllm serve`, you should provide the model as a positional " + "argument or in a config file instead of via the `--model` option."), ): parser_with_config.parse_args(['serve', '--model', 'my-model']) # Test other config values are preserved args = parser_with_config.parse_args([ - 'serve', 'cli-model', '--config', cli_config_file_with_model, + 'serve', + 'cli-model', + '--config', + cli_config_file_with_model, ]) assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True @@ -673,7 +731,7 @@ def test_model_specification(parser_with_config, @pytest.mark.parametrize("input", [(), ("abc", ), (None, ), - (None, bool, [1, 2, 3])]) + (None, bool, [1, 2, 3])]) @pytest.mark.parametrize("output", [0, 1, 2]) def test_sha256(input: tuple, output: int): hash = sha256(input) @@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int): assert hash != 0 bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big") + assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), + byteorder="big") # hashing again, returns the same value assert hash == sha256(input) @@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int): ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address ("inproc://some_identifier", ("inproc", "some_identifier", "")), - ] -) + ]) def test_split_zmq_path(path, expected): assert split_zmq_path(path) == expected @@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected): "tcp://127.0.0.1", # Missing port "tcp://[::1]", # Missing port for IPv6 "tcp://:5555", # Missing host - ] -) + ]) def test_split_zmq_path_invalid(invalid_path): with pytest.raises(ValueError): split_zmq_path(invalid_path) @@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6(): zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) # Verify that the IPV6 option is set - assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" + assert zsock.getsockopt( + zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" # Clean up zsock.close() diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py index 13fc8bc8fa2ed..19df22f780396 100644 --- a/tests/tpu/test_moe_pallas.py +++ b/tests/tpu/test_moe_pallas.py @@ -26,7 +26,7 @@ TOP_KS = [2, 6] # The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 @pytest.mark.parametrize("m", [8, 16, 64, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) diff --git a/tests/utils.py b/tests/utils.py index bf38d7843853d..d21b18470b1bb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer @@ -99,7 +99,8 @@ class RemoteOpenAIServer: parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) + subparsers = parser.add_subparsers(required=False, dest="subparser") + parser = ServeSubcommand().subparser_init(subparsers) args = parser.parse_args(["--model", model, *vllm_serve_args]) self.host = str(args.host or 'localhost') self.port = int(args.port) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 43a27da2dbe43..d3d62cf09232d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -45,7 +45,6 @@ def make_request(request_id, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3da27786b1f2f..ba3c0b3cf3169 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -38,7 +38,6 @@ def make_request(request_id, sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f40d477a00363..f38454b1b2889 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -138,7 +138,6 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - arrival_time=0, ) requests.append(request) return requests @@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i]) # No draft or accepted tokens counted yet - assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None + assert not engine_core_outputs or ( + engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None) # Schedule the speculated tokens for validation output = scheduler.schedule() @@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): engine_core_outputs = scheduler.update_from_output(output, model_runner_output) - scheduler_stats = engine_core_outputs.scheduler_stats + scheduler_stats = engine_core_outputs[0].scheduler_stats \ + if engine_core_outputs else None if expected[0] == 0: assert scheduler_stats.spec_decoding_stats is None else: @@ -843,7 +844,7 @@ def _step_until_done( # We should be in the decode phase now. assert num_scheduled_tokens == 1 assert len(output.kv_connector_metadata.requests) == 0 - ecos = scheduler.update_from_output(output, model_runner_output) + ecos = scheduler.update_from_output(output, model_runner_output)[0] all_done = True for eco in ecos.outputs: if eco.finish_reason is None: diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index dcf494825b0d4..e78c7480a837a 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): assert len(engine_core.scheduler.running) == 4 # Loop through until they are all done. - while len(engine_core.step().outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass assert len(engine_core.scheduler.waiting) == 0 @@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): req0.request_id = req1.request_id = "test" engine_core.add_request(req0) - while len(engine_core.step().outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass engine_core.add_request(req1) - while len(engine_core.step().outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass assert len(engine_core.scheduler.waiting) == 0 @@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 # Loop through until they are all done. - while len(engine_core.step().outputs) > 0: + while (outs := engine_core.step()[0].get(0)) and outs.outputs: pass assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 @@ -296,7 +296,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): engine_core.add_request(req1) # Schedule Batch 1: (10, req0) - assert engine_core.step_with_batch_queue() is None + assert engine_core.step_with_batch_queue()[0] is None assert engine_core.batch_queue.qsize() == 1 scheduler_output = engine_core.batch_queue.queue[-1][1] assert scheduler_output.num_scheduled_tokens[0] == 10 @@ -305,7 +305,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): req0.request_id].num_computed_tokens == 10 # Schedule Batch 2: (2, req0), (8, req1) - assert engine_core.step_with_batch_queue() is None + assert engine_core.step_with_batch_queue()[0] is None assert engine_core.batch_queue.qsize() == 2 scheduler_output = engine_core.batch_queue.queue[-1][1] assert scheduler_output.num_scheduled_tokens[0] == 2 @@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): assert scheduler_output.num_scheduled_tokens[1] == 4 # Batch queue is full. Finish Batch 2. Get first token of req0. - output = engine_core.step_with_batch_queue() + output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 @@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): assert scheduler_output.num_scheduled_tokens[0] == 1 # Batch queue is full. Finish Batch 3. Get first token of req1. - output = engine_core.step_with_batch_queue() + output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 @@ -358,11 +358,11 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): engine_core.scheduler.requests[1].num_tokens + 1, ] while engine_core.scheduler.get_num_unfinished_requests() == 2: - output = engine_core.step_with_batch_queue() + output = engine_core.step_with_batch_queue()[0] if step % 2 == 0: # Even steps consumes an output. assert output is not None - assert len(output.outputs) == 1 + assert len(output[0].outputs) == 1 if req_id in engine_core.scheduler.requests: assert engine_core.scheduler.requests[ req_id].num_tokens == expected_num_tokens[req_id] diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py new file mode 100644 index 0000000000000..7b4583bc3bf37 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "ibm-research/PowerMoE-3b" + +DP_SIZE = os.getenv("DP_SIZE", "1") + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + "--api-server-count", + "4", + "--data_parallel_size", + DP_SIZE, + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_single_completion(client: openai.AsyncOpenAI, + model_name: str) -> None: + + async def make_request(): + completion = await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=10, + temperature=1.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + # The exact number of tokens can vary slightly with temperature=1.0, + # so we check for a reasonable minimum length. + assert len(choice.text) >= 1 + # Finish reason might not always be 'length' if the model finishes early + # or due to other reasons, especially with high temperature. + # So, we'll accept 'length' or 'stop'. + assert choice.finish_reason in ("length", "stop") + + # Token counts can also vary, so we check they are positive. + assert completion.usage.completion_tokens > 0 + assert completion.usage.prompt_tokens > 0 + assert completion.usage.total_tokens > 0 + return completion + + # Test single request + result = await make_request() + assert result is not None + + await asyncio.sleep(0.5) + + # Send two bursts of requests + num_requests = 100 + tasks = [make_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + await asyncio.sleep(0.5) + + tasks = [make_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str) -> None: + prompt = "What is an LLM?" + + async def make_streaming_request(): + # Perform a non-streaming request to get the expected full output + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + + # Perform the streaming request + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: list[str] = [] + finish_reason_count = 0 + last_chunk = None + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + last_chunk = chunk # Keep track of the last chunk + + # finish reason should only return in the last block for OpenAI API + assert finish_reason_count == 1, ( + "Finish reason should appear exactly once.") + assert last_chunk is not None, ( + "Stream should have yielded at least one chunk.") + assert last_chunk.choices[ + 0].finish_reason == "length", "Finish reason should be 'length'." + # Check that the combined text matches the non-streamed version. + assert "".join( + chunks + ) == single_output, "Streamed output should match non-streamed output." + return True # Indicate success for this request + + # Test single request + result = await make_streaming_request() + assert result is not None + + await asyncio.sleep(0.5) + + # Send two bursts of requests + num_requests = 100 + tasks = [make_streaming_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + + assert len( + results + ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert all(results), "Not all streaming requests completed successfully." + + await asyncio.sleep(0.5) + + tasks = [make_streaming_request() for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + + assert len( + results + ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert all(results), "Not all streaming requests completed successfully." diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 77098140343a0..dc963251c962b 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -43,7 +43,7 @@ def test_basic_lifecycle(): # Ensure the request is finished after 1 tokens. assert request.is_finished() assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED - output = engine_core_outputs.outputs[0] + output = engine_core_outputs[0].outputs[0] assert output.finish_reason == FinishReason.LENGTH assert output.kv_transfer_params is not None @@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle(): scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote]) eco = scheduler.update_from_output(scheduler_output, model_runner_output) - kv_transfer_params = eco.outputs[0].kv_transfer_params + kv_transfer_params = eco[0].outputs[0].kv_transfer_params # Ensure we send all block ids, even if there is a cache hit. assert (len( diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 6fcff0d620452..86eacb693869d 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -61,7 +61,7 @@ def test_basic_lifecycle(): # (1c): update_from_output() engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) - assert len(engine_core_outputs.outputs) == 0 + assert not engine_core_outputs or not engine_core_outputs[0].outputs # STEP (2): # (2a): schedule(): nothing happens! @@ -112,7 +112,7 @@ def test_basic_lifecycle(): model_runner_output) scheduler.schedule() - outputs = engine_core_outputs.outputs + outputs = engine_core_outputs[0].outputs assert len(outputs) == 1 output = outputs[0] assert output.finish_reason == FinishReason.STOP @@ -335,7 +335,7 @@ def test_full_block_prompt(): model_runner_output) scheduler.schedule() - outputs = engine_core_outputs.outputs + outputs = engine_core_outputs[0].outputs assert len(outputs) == 1 output = outputs[0] assert output.finish_reason == FinishReason.STOP diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 53e2d6fda1aea..3c3190b325636 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -153,7 +153,6 @@ def create_request( multi_modal_placeholders=None, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - arrival_time=0, ) req.kv_transfer_params = kv_transfer_params return req diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 319b38b4ca09d..348f12887a446 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -81,7 +81,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), - block_ids=[0], + block_ids=[[0]], # block_ids should be list[list[int]] num_computed_tokens=0, lora_request=None, )) @@ -112,14 +112,35 @@ def _is_req_added(model_runner, req_id: str) -> bool: def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: + """Check if the request state block IDs match the block table. + + This function handles both legacy BlockTable and new MultiGroupBlockTable + structures for backward compatibility. + """ + req_index = model_runner.input_batch.req_id_to_index[req_id] - block_table = model_runner.input_batch.block_table + multi_group_block_table = model_runner.input_batch.block_table req_state = model_runner.requests[req_id] - if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): + + # Access the first block table from MultiGroupBlockTable + # This is safe since we currently only use single KV cache groups + block_table = multi_group_block_table[0] + + # req_state.block_ids is now list[list[int]] for MultiGroupBlockTable + # Extract the first group's block IDs + if isinstance(req_state.block_ids[0], list): + # New format: list[list[int]] - extract first group + req_block_ids = req_state.block_ids[0] + else: + # Legacy format: list[int] - use directly + req_block_ids = req_state.block_ids + + if block_table.num_blocks_per_row[req_index] != len(req_block_ids): return False + num_blocks = block_table.num_blocks_per_row[req_index] - return (block_table.block_table_np[req_index, :num_blocks] == - req_state.block_ids).all() + block_table_values = block_table.block_table_np[req_index, :num_blocks] + return (block_table_values == req_block_ids).all() def test_update_states_new_request(model_runner): diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b8c3d88617d0d..6ba6d1f6f131d 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest +from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) from vllm.sampling_params import SamplingParams @@ -13,27 +16,30 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +BLOCK_SIZE = 16 +NUM_BLOCKS = 10 + def initialize_kv_cache(runner: GPUModelRunner): """ Only perform necessary steps in GPUModelRunner.initialize_kv_cache() """ + attn_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + use_mla=False, + ) + tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( - num_blocks=10, + num_blocks=NUM_BLOCKS, tensors={ - "layer.0": KVCacheTensor(size=1024), + "layer.0": KVCacheTensor(size=tensor_size), }, kv_cache_groups=[ - KVCacheGroupSpec( - layer_names=["layer.0"], - kv_cache_spec=FullAttentionSpec( - block_size=16, - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), - head_size=runner.model_config.get_head_size(), - dtype=runner.kv_cache_dtype, - use_mla=False, - )) + KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) ]) runner.kv_cache_config = kv_cache_config runner.input_batch = InputBatch( @@ -65,7 +71,7 @@ def model_runner(): seed=42, ) cache_config = CacheConfig( - block_size=16, + block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", @@ -77,6 +83,10 @@ def model_runner(): scheduler_config=scheduler_config, parallel_config=parallel_config, ) + num_heads = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + vllm_config.compilation_config.static_forward_context[ + "layer.0"] = Attention(num_heads, head_size, 0.1) device = "cuda" runner = GPUModelRunner(vllm_config, device) @@ -84,6 +94,9 @@ def model_runner(): return runner +model_runner_2 = model_runner + + def _schedule_new_request(*req_ids: str) -> SchedulerOutput: new_reqs = [] num_scheduled_tokens = {} @@ -321,3 +334,53 @@ def test_update_states_request_unscheduled(model_runner): assert _is_req_added(model_runner, req_ids[1]) assert not _is_req_scheduled(model_runner, req_ids[1]) + + +def test_kv_cache_stride_order(monkeypatch, model_runner): + # This test checks if GPUModelRunner initializes correctly when an attention + # backend enforces a non-default KV cache stride order. + n_heads = model_runner.model_config.get_num_kv_heads( + model_runner.parallel_config) + expected_kv_cache_shape = [ + 2, NUM_BLOCKS, BLOCK_SIZE, n_heads, + model_runner.model_config.get_head_size() + ] + # TODO mla test + default_stride = list(range(5)) + # Permutation that gets you back to expected kv shape + rnd_stride = tuple(random.sample(default_stride, len(default_stride))) + + def rnd_stride_order(): + return rnd_stride + + # Patch the attention backend class and re-trigger the KV cache creation. + for attn_backend in model_runner.attn_backends: + monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", + rnd_stride_order) + + model_runner.attn_backends = [] + model_runner.attn_metadata_builders = [] + model_runner.initialize_kv_cache(model_runner.kv_cache_config) + + # Shape is unchanged, but layout may differ + kv_cache_shape = model_runner.kv_caches[0].shape + assert list(kv_cache_shape) == expected_kv_cache_shape + if default_stride == rnd_stride: + assert all(kv.is_contiguous() for kv in model_runner.kv_caches) + else: + assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) + + +def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): + # In this test, model_runner loads model + weights in one go, while + # model_runner_2 loads dummy weights first then load real weights inplace + model_runner.load_model() + original_load_format = model_runner_2.load_config.load_format + model_runner_2.load_config.load_format = "dummy" + model_runner_2.load_model() # Initial model loading with dummy weights + assert str(model_runner.get_model().state_dict()) != str( + model_runner_2.get_model().state_dict()) + model_runner_2.load_config.load_format = original_load_format + model_runner_2.load_model() # Load real weights inplace + assert str(model_runner.get_model().state_dict()) == str( + model_runner_2.get_model().state_dict()) diff --git a/tools/enforce_regex_import.py b/tools/enforce_regex_import.py index b55c4a94eac80..6c201dd2543e9 100644 --- a/tools/enforce_regex_import.py +++ b/tools/enforce_regex_import.py @@ -58,6 +58,9 @@ def main() -> int: if not Path(filepath).exists(): continue + if filepath == "setup.py": + continue + violations = check_file(filepath) if violations: print(f"\n❌ {filepath}:") diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index b048220020f14..c974f2a15a0ef 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -132,8 +132,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): super().__init__(input_builder) - assert self.runner.model_config.max_model_len == 32768,\ - "AITER MLA requires max model len to be set to 32768" assert self.block_size == 1, "AITER MLA requires only block size 1." def prepare(self): diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index abcb68911a8bb..7134472daa605 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): and layer._v_scale and layer._prob_scale and self.kv_cache_dtype == "fp8") full_scales = ( - layer._q_scale, layer._k_scale, layer._v_scale, - layer._prob_scale) if use_fp8_scales else None + layer._q_scale.item(), layer._k_scale.item(), + layer._v_scale.item(), + layer._prob_scale.item()) if use_fp8_scales else None self.triton_attn_func( query, key, diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 785799b6bf684..6ca2a64145bd6 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -264,8 +264,8 @@ def chunked_prefill_paged_decode( # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton if "fp8" in kv_cache_dtype: - assert key_cache.dtype == torch.uint8 - assert value_cache.dtype == torch.uint8 + assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] if kv_cache_dtype in ("fp8", "fp8_e4m3"): target_dtype = current_platform.fp8_dtype() diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 86d256b630bf5..729b61b029063 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -744,8 +744,8 @@ def context_attention_fwd(q, # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) + assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] if kv_cache_dtype in ("fp8", "fp8_e4m3"): target_dtype = current_platform.fp8_dtype() diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 8940d0b662258..62cfb813d5f94 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -1,236 +1,33 @@ +#!/usr/bin/env python # SPDX-License-Identifier: Apache-2.0 """ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm -See https://tridao.me/publications/flash2/flash2.pdf +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team -Credits: -AMD Triton kernels team -OpenAI kernel team - -Currently only the forward kernel is supported, and contains these features: +Features supported: 1) Fwd with causal masking -2) Arbitrary Q and KV sequence lengths -3) Arbitrary head sizes -4) Multi and grouped query attention -5) Variable sequence lengths -6) ALiBi and matrix bias -7) FP8 support +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims """ -from typing import Optional - import torch -from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.platforms.rocm import on_gfx1x from vllm.triton_utils import tl, triton -SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] - -default_eight_bit_dtype_triton = tl.float8e4b8 -default_eight_bit_dtype_torch = current_platform.fp8_dtype() -default_float8_info = torch.finfo(default_eight_bit_dtype_torch) - -FP8_MIN = triton.language.constexpr(default_float8_info.min) - -# According to https://github.com/vllm-project/vllm/blob/main -# /csrc/quantization/utils.cuh#L31, -# need to make the max for the uz datatype be 224.0 for accuracy reasons. -FP8_MAX = triton.language.constexpr( - default_float8_info.max if default_eight_bit_dtype_torch != - torch.float8_e4m3fnuz else 224.0) - - -class MetaData: - cu_seqlens_q = None - cu_seqlens_k = None - max_seqlens_q = 0 - max_seqlens_k = 0 - bias = None - alibi_slopes = None - causal = False - num_contexts = 0 - varlen = False - eight_bit = False - layout = None - return_encoded_softmax = False - eight_bit_dtype_triton = default_eight_bit_dtype_triton - eight_bit_dtype_torch = default_eight_bit_dtype_torch - output_dtype = None - - # Note about layouts: - # - # thd - [num_tokens, num_heads, head_size] - # bshd - [batch_size, seq_len, num_heads, head_size] - # bhsd - [batch_size, num_heads, seq_len, head_size] - # - # This is for each tensor, all tensors must have same layout. - # Q can have num_heads and seq_len differ from from K and V, - # however K and V must agree on this. - # - # Notes about varlen and bias: - # Only one or the other is implemented, meaning can't combine - # both varlen and bias right now. - # - # Note about quantization: - # Only 8-bit quantization supported (for now) and specifically fp8. - # Scales must be tensors. - # o_scale: This is 'output scaling', but comes from parameter called - # 'input_scale', this is applied to the output from the kernel. - # o_scale should be None if none of the other quantization parameters - # are used. - # - # NOTE: Object is in a tentatively good state after initialized, however, - # to verify, call check_args(q,k,v,o) where o is the output tensor. - def __init__( - self, - sm_scale=1.0, - layout=None, # layout can be 'bshd', 'bhsd', or 'thd' - output_dtype=None, - max_seqlens_q=0, - max_seqlens_k=0, - # varlen params - cu_seqlens_q=None, # only 'thd' layout supported for varlen - cu_seqlens_k=None, - # quant params - q_descale=None, - k_descale=None, - v_descale=None, - p_scale=None, - o_scale=None, - # bias params - bias=None, # varlen not implemented for bias - seqlen_q=None, - seqlen_k=None, - # alibi params - alibi_slopes=None, - alibi_batch=None, - alibi_nheads=None, - # causal - causal=None, - ): - self.sm_scale = sm_scale - self.output_dtype = output_dtype - self.max_seqlens_q = max_seqlens_q - self.max_seqlens_k = max_seqlens_k - self.layout = layout - if cu_seqlens_q is not None or cu_seqlens_k is not None: - assert cu_seqlens_q is not None and cu_seqlens_k is not None - assert layout is None or layout not in [ - 'bshd', 'bhsd' - ], "Varlen only implemented for thd layout" - self.set_varlen_params(cu_seqlens_q, cu_seqlens_k) - quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale] - if any(x is not None for x in quant_params): - p_descale = 1.0 / p_scale if p_scale is not None else None - self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale, - p_descale, o_scale) - if bias is not None: - self.need_bias(bias, seqlen_q, seqlen_k) - if alibi_slopes is not None: - self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads) - if causal is not None and causal: - self.need_causal() - - def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): - self.varlen = True - self.layout = 'thd' - self.cu_seqlens_q = cu_seqlens_q - self.cu_seqlens_k = cu_seqlens_k - # Without "varlen", there should still be one sequence. - assert len(cu_seqlens_q) >= 2 - assert len(cu_seqlens_q) == len(cu_seqlens_k) - self.num_contexts = len(cu_seqlens_q) - 1 - for i in range(0, self.num_contexts): - self.max_seqlens_q = max( - cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), - self.max_seqlens_q) - self.max_seqlens_k = max( - cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), - self.max_seqlens_k) - - def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale, - p_descale, o_scale): - self.eight_bit = True - self.q_descale = q_descale - self.k_descale = k_descale - self.v_descale = v_descale - self.p_scale = p_scale - self.p_descale = p_descale - self.o_scale = o_scale - self.use_p_scale = (p_scale is not None) and ( - p_descale is not None) and (v_descale is not None) - self.eight_bit_kv = ((q_descale is None) and (k_descale is not None) - and (v_descale is not None)) - self.eight_bit_dtype_torch = default_eight_bit_dtype_torch - - def need_bias(self, bias, seqlen_q, seqlen_k): - assert bias is not None - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.shape[0] == 1 - assert bias.shape[2:] == (seqlen_q, seqlen_k) - self.bias = bias - - def need_alibi(self, alibi_slopes, batch, nheads): - assert alibi_slopes.is_cuda - assert alibi_slopes.dim() == 2 - assert alibi_slopes.shape[0] == batch - assert alibi_slopes.shape[1] == nheads - self.alibi_slopes = alibi_slopes - - def need_causal(self): - self.causal = True - - def check_args(self, q, k, v, o): - assert q.dim() == k.dim() and q.dim() == v.dim() - - batch, nheads_q, nheads_k, head_size = get_shape_from_layout( - q, k, self) - if self.varlen: - assert q.dim() == 3 - assert self.cu_seqlens_q is not None - assert self.cu_seqlens_k is not None - assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) - # TODO: Remove once bias is supported with varlen - assert self.bias is None - assert not self.return_encoded_softmax - else: - assert q.dim() == 4 - assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 - assert self.cu_seqlens_q is None and self.cu_seqlens_k is None - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - if self.eight_bit: - if self.eight_bit_kv: - assert (v.dtype == k.dtype - and k.dtype == self.eight_bit_dtype_torch) - assert q.dtype != k.dtype - assert (self.v_descale is not None) and (self.k_descale - is not None) - else: - assert (q.dtype == k.dtype and q.dtype == v.dtype - and q.dtype == self.eight_bit_dtype_torch) - assert (self.q_descale - is not None) and (self.k_descale - is not None) and (self.v_descale - is not None) - if self.use_p_scale: - assert (self.p_scale is not None) and (self.p_descale - is not None) - else: - assert (q.dtype == k.dtype) and (q.dtype == v.dtype) - assert head_size <= 256 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - assert self.layout is not None - assert self.layout == 'thd' or not self.varlen +torch_dtype: tl.constexpr = torch.float16 @triton.jit @@ -243,155 +40,103 @@ def max_fn(x, y): return tl.math.max(x, y) -# Convenience function to load with optional boundary checks. -# "First" is the major dim, "second" is the minor dim. @triton.jit -def masked_load(ptrs, offset_first, offset_second, boundary_first, - boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & \ - (offset_second[None, :] < boundary_second) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) else: - tensor = tl.load(ptrs) + tensor = tl.load(block_ptr) return tensor -@triton.jit -def compute_alibi_block(alibi_slope, - seqlen_q, - seqlen_k, - offs_m, - offs_n, - transpose=False): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to - # the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is - # masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that - # spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, - # offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = - # [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q - - offs_n[None, :]) - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block - - -def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, - device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, - device="cuda").unsqueeze(0) # (1, N_CTX_K) - relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - - k_idx) # (N_CTX_Q, N_CTX_K) - return -1 * alibi_slopes.unsqueeze(-1).unsqueeze( - -1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) - - -@triton.jit -def quant_fp8(x, scale): - x *= scale - x = tl.clamp(x, FP8_MIN, FP8_MAX) - return x - - @triton.jit def _attn_fwd_inner( acc, l_i, m_i, q, - k_ptrs, - v_ptrs, - bias_ptrs, - stride_kn, - stride_vk, - stride_bn, + K_block_ptr, + V_block_ptr, start_m, actual_seqlen_k, - actual_seqlen_q, + dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, + encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, - alibi_slope, - q_descale, - k_descale, - v_descale, - p_scale, + bias_ptr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, - SHOULD_PRE_LOAD_V: tl.constexpr, - SHOULD_MASK_STEPS: tl.constexpr, - SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_PADDED_HEAD: tl.constexpr, - IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, - QK_SCALE: tl.constexpr, - IS_EIGHT_BIT_GEMM: tl.constexpr, - USE_P_SCALE: tl.constexpr, - IS_EIGHT_BIT_KV: tl.constexpr, - QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, + USE_FP8: tl.constexpr, + qk_scale, + p_descale, ): - # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k_offs_n = start_n + tl.arange(0, - BLOCK_N) if SHOULD_MASK_STEPS else None - k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) - k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL, - actual_seqlen_k) - if SHOULD_PRE_LOAD_V: - # We can use the same offsets as k, just with dims transposed. - v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, - IS_ACTUAL_BLOCK_DMODEL) + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if SHOULD_MASK_STEPS: # noqa: SIM102 + if MASK_STEPS: # noqa: SIM102 # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not - # is_modulo_mn. last step might get wasted but that is okay. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): boundary_m = tl.full([BLOCK_M], @@ -404,97 +149,112 @@ def _attn_fwd_inner( causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - if IS_EIGHT_BIT_GEMM: - qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * - QK_SCALE) - else: - if IS_EIGHT_BIT_KV: - k = (k * k_descale).to(q.type.element_ty) - qk += (tl.dot(q, k) * QK_SCALE) - - if bias_ptrs is not None: - bias_offs_n = start_n + tl.arange( - 0, BLOCK_N) if SHOULD_MASK_STEPS else None - bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, - actual_seqlen_k) - # While bias is added after multiplying qk with sm_scale, - # our optimization to use 2^x instead of e^x results in an - # additional scale factor of log2(e) which we must also multiply - # the bias with. - qk += (bias * 1.44269504089) - - if alibi_slope is not None: - # Compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, - actual_seqlen_k, - global_m_positions, - global_n_positions) - qk += (alibi_block * 1.44269504089) # scale factor of log2(e) - - # softmax + qk += tl.dot(q, k) + if USE_FP8: + qk *= qk_scale + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) - if SHOULD_RETURN_ENCODED_SOFTMAX: - tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - if not SHOULD_PRE_LOAD_V: - v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, - IS_ACTUAL_BLOCK_DMODEL) + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - if IS_EIGHT_BIT_GEMM: - if USE_P_SCALE: - p = quant_fp8(p, p_scale).to(QUANT_DTYPE) - acc += tl.dot(p, v) - else: - # v is in eight_bit but p is not, we want the gemm in p's type - acc += tl.dot(p, v.to(p.type.element_ty)) - else: - if IS_EIGHT_BIT_KV: - v = (v * v_descale).to(p.type.element_ty) - acc += tl.dot(p.to(v.type.element_ty), v) + if USE_FP8: + p *= p_descale - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if bias_ptrs is not None: - bias_ptrs += BLOCK_N * stride_bn - if SHOULD_RETURN_ENCODED_SOFTMAX: - encoded_sm_ptrs += BLOCK_N + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) return acc, l_i, m_i def get_cdna_autotune_configs(): return [ + triton.Config( + { + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), triton.Config( { 'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=4), + triton.Config( + { + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), triton.Config( { 'BLOCK_M': 128, 'BLOCK_N': 64, - 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, num_warps=4), @@ -503,8 +263,7 @@ def get_cdna_autotune_configs(): 'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': True }, num_stages=1, num_warps=4), @@ -512,26 +271,45 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 128, 'BLOCK_N': 64, - 'waves_per_eu': 1, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'waves_per_eu': 3, + 'PRE_LOAD_V': False }, num_stages=1, num_warps=4), triton.Config( { - 'BLOCK_M': 128, - 'BLOCK_N': 32, - 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'BLOCK_M': 64, + 'BLOCK_N': 64, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4), - ], [ - 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', - 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' - ] + num_warps=8), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # "BLOCK_M": 16, + # "BLOCK_N": 16, + # "waves_per_eu": 1, + # "PRE_LOAD_V": False, + # }, + # num_stages=1, + # num_warps=4, + # ), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] def get_rdna_autotune_configs(): @@ -541,8 +319,7 @@ def get_rdna_autotune_configs(): 'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), @@ -551,8 +328,7 @@ def get_rdna_autotune_configs(): 'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), @@ -561,8 +337,7 @@ def get_rdna_autotune_configs(): 'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), @@ -571,109 +346,57 @@ def get_rdna_autotune_configs(): 'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 + 'PRE_LOAD_V': False }, num_stages=1, num_warps=2), - triton.Config( - { - 'BLOCK_M': 16, - 'BLOCK_N': 16, - 'waves_per_eu': 4, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=2), - triton.Config( - { - 'BLOCK_M': 16, - 'BLOCK_N': 16, - 'waves_per_eu': 2, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config( - { - 'BLOCK_M': 16, - 'BLOCK_N': 16, - 'waves_per_eu': 1, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=2), - ], [ - 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', - 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' - ] - - -def get_general_autotune_configs(): - return [ - triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 128, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=4), - triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=4), - triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 32, - 'SHOULD_PRE_LOAD_V': False, - 'GRID_CU_MULTIP': 2 - }, - num_stages=1, - num_warps=4), - ], [ - 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', - 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' - ] - - -def has_cdna_target(): - ROCM_CDNA_TARGETS = ["gfx942", "gfx90a", "gfx908"] - return triton.runtime.driver.active.get_current_target( - ).arch in ROCM_CDNA_TARGETS - - -def is_rocm_cdna(): - return current_platform.is_rocm() and has_cdna_target() + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] def get_autotune_configs(): - if is_rocm_cdna(): - return get_cdna_autotune_configs() - elif current_platform.is_rocm(): + if on_gfx1x(): return get_rdna_autotune_configs() else: - return get_general_autotune_configs() + return get_cdna_autotune_configs() autotune_configs, autotune_keys = get_autotune_configs() +float8_info = torch.finfo(current_platform.fp8_dtype()) + @triton.autotune( configs=autotune_configs, key=autotune_keys, - use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -681,7 +404,13 @@ def attn_fwd( K, V, bias, - SM_SCALE: tl.constexpr, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, L, Out, stride_qz: tl.int64, @@ -704,70 +433,44 @@ def attn_fwd( stride_bh: tl.int64, stride_bm: tl.int64, stride_bn: tl.int64, - stride_az: tl.int64, - stride_ah: tl.int64, - q_descale_ptr, - k_descale_ptr, - p_scale_ptr, - p_descale_ptr, - o_descale_ptr, - v_descale_ptr, - q_descale_has_singleton: tl.constexpr, - k_descale_has_singleton: tl.constexpr, - p_descale_has_singleton: tl.constexpr, - v_descale_has_singleton: tl.constexpr, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, - NUM_CU: tl.constexpr, - GRID_CU_MULTIP: tl.constexpr, - B: tl.constexpr, philox_offset_base, encoded_softmax, - alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, - IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, + USE_FP8_OUT: tl.constexpr, BLOCK_N: tl.constexpr, - SHOULD_PRE_LOAD_V: tl.constexpr, - USE_BIAS: tl.constexpr, - SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_ALIBI: tl.constexpr, - IS_EIGHT_BIT: tl.constexpr, - USE_P_SCALE: tl.constexpr, - IS_EIGHT_BIT_KV: tl.constexpr, - QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): - - if o_descale_ptr is not None: - o_descale = tl.load(o_descale_ptr) - - start_m: tl.int64 = tl.program_id(0) - off_h_q: tl.int64 = tl.program_id(1) - off_z: tl.int64 = tl.program_id(2) - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64) - offs_n = tl.arange(0, BLOCK_N).to(tl.int64) - offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64) - - # as we can't have return statements inside while loop in Triton - continue_condition = True - + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be - # too small for all start_m so for those we return early. + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: - continue_condition = False - # return + return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start @@ -777,598 +480,499 @@ def attn_fwd( seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K - if continue_condition: - # Now we compute whether we need to exit early due to causal - # masking. This is because for seqlen_q > seqlen_k, M rows of the - # attn scores are completely masked, resulting in 0s written to the - # output, and inf written to LSE. We don't need to do any GEMMs in - # this case. This block of code determines what N is, and if this - # WG is operating on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if (IS_CAUSAL): - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which - # means the causal mask boundary is bottom right aligned, and - # ends at either the top edge (seqlen_q < seqlen_k) or left - # edge. This captures the decrease in n_blocks if we have a - # rectangular attn matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all - # n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this - # WG is part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + - cu_seqlens_q_start * stride_om) - o_ptrs = (o_offset + offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to( - [BLOCK_M, BLOCK_DMODEL]) - # We still need to write 0s to the result - tl.store(o_ptrs, acc, mask=o_ptrs_mask) - # The tensor allocated for L is based on MAX_SEQLENS_Q as - # that is statically known. - l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + - off_h_q * MAX_SEQLENS_Q + offs_m) - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this from qk which makes it -inf, such that - # exp(qk - inf) = 0 for these masked blocks. - l_value = tl.full([BLOCK_M], - value=float("inf"), - dtype=tl.float32) - l_ptrs_mask = offs_m < MAX_SEQLENS_Q - tl.store(l_ptrs, l_value, mask=l_ptrs_mask) - # TODO: Should dropout and return encoded softmax be - # handled here too? - continue_condition = False - # return + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return - if continue_condition: - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL - != BLOCK_DMODEL) + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q - # Compute pointers for all the tensors used in this kernel. - q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - q_ptrs = (q_offset + offs_m[:, None] * stride_qm + - offs_d[None, :] * stride_qk) - k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - k_ptrs = (k_offset + offs_d[:, None] * stride_kk + - offs_n[None, :] * stride_kn) - v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - v_ptrs = (v_offset + offs_n[:, None] * stride_vk + - offs_d[None, :] * stride_vn) - # Compute pointers for all scale tensors used in this kernel. + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & ( - not IS_EIGHT_BIT_KV) - if IS_EIGHT_BIT: - if k_descale_has_singleton: - k_descale_ptrs = k_descale_ptr - else: - k_descale_ptrs = k_descale_ptr + off_h_k + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + if not USE_FP8: + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + acc_scale = 1.0 + else: + qk_scale *= q_scale * k_scale + acc_scale = p_scale * v_scale - if v_descale_has_singleton: - v_descale_ptrs = v_descale_ptr - else: - v_descale_ptrs = v_descale_ptr + off_h_k + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N - if not IS_EIGHT_BIT_KV: - if q_descale_has_singleton: - q_descale_ptrs = q_descale_ptr - else: - q_descale_ptrs = q_descale_ptr + off_h_q - if USE_P_SCALE: - if p_descale_has_singleton: - p_scale_ptrs = p_scale_ptr - p_descale_ptrs = p_descale_ptr - else: - p_scale_ptrs = p_scale_ptr + off_h_q - p_descale_ptrs = p_descale_ptr + off_h_q + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + # epilogue - if USE_BIAS: - bias_offset = off_h_q * stride_bh - bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm + - offs_n[None, :] * stride_bn) - else: - bias_ptrs = None + if USE_FP8: + acc *= acc_scale + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if USE_FP8_OUT: + acc *= o_descale + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - if USE_ALIBI: - a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(alibi_slopes + a_offset) - else: - alibi_slope = None - - batch_philox_offset = 0 - # We can ask to return the dropout mask without doing any - # dropout. In this case, we return an invalid pointer so - # indicate the mask is not valid. - if SHOULD_RETURN_ENCODED_SOFTMAX: - encoded_sm_base = (encoded_softmax + - off_h_q * seqlen_q * seqlen_k) - encoded_sm_ptrs = (encoded_sm_base + - offs_m[:, None] * seqlen_k + - offs_n[None, :]) - else: - encoded_sm_ptrs = None - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do - # not have native e^x support in HW. - QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q_ptrs_mask = offs_m[:, None] < seqlen_q - if USE_PADDED_HEAD: - q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] - < IS_ACTUAL_BLOCK_DMODEL) - q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - - if IS_EIGHT_BIT: - k_descale = tl.load(k_descale_ptrs) - v_descale = tl.load(v_descale_ptrs) - q_descale = None if IS_EIGHT_BIT_KV else tl.load( - q_descale_ptrs) - if USE_P_SCALE: - p_scale = tl.load(p_scale_ptrs) - p_descale = tl.load(p_descale_ptrs) - else: - p_scale = None - p_descale = None - else: - q_descale = None - k_descale = None - v_descale = None - p_scale = None - p_descale = None - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked - # blocks. Additionally there might be one more due to - # dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an - # additional block. In this case we might exceed n_blocks so - # pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false - # regardless of its actual value because there is no masking. - # Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - bias_ptrs, - stride_kn, - stride_vk, - stride_bn, - start_m, - seqlen_k, - seqlen_q, - philox_seed, - batch_philox_offset, - encoded_sm_ptrs, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - alibi_slope, - q_descale, - k_descale, - v_descale, - p_scale, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, SHOULD_MASK_STEPS, ... - SHOULD_PRE_LOAD_V, - False, - SHOULD_RETURN_ENCODED_SOFTMAX, - USE_PADDED_HEAD, - IS_ACTUAL_BLOCK_DMODEL, - QK_SCALE, - IS_EIGHT_BIT_GEMM, - USE_P_SCALE, - IS_EIGHT_BIT_KV, - QUANT_DTYPE) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vk - if USE_BIAS: - bias_ptrs += n_full_blocks * BLOCK_N * stride_bn - if SHOULD_RETURN_ENCODED_SOFTMAX: - encoded_sm_ptrs += n_full_blocks * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - bias_ptrs, - stride_kn, - stride_vk, - stride_bn, - start_m, - seqlen_k, - seqlen_q, - philox_seed, - batch_philox_offset, - encoded_sm_ptrs, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - alibi_slope, - q_descale, - k_descale, - v_descale, - p_scale, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, SHOULD_MASK_STEPS, ... - SHOULD_PRE_LOAD_V, - True, - SHOULD_RETURN_ENCODED_SOFTMAX, - USE_PADDED_HEAD, - IS_ACTUAL_BLOCK_DMODEL, - QK_SCALE, - IS_EIGHT_BIT_GEMM, - USE_P_SCALE, - IS_EIGHT_BIT_KV, - QUANT_DTYPE) - - if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: - if USE_P_SCALE: - acc *= p_descale - acc *= v_descale - - # epilogue - # This helps the compiler do Newton Raphson on l_i vs on acc - # which is much larger. - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - - # If seqlen_q > seqlen_k but the delta is not a multiple of - # BLOCK_M, then we have one block with a row of all NaNs which - # come from computing softmax over a row of all - # -infs (-inf - inf = NaN). We check for that here and store 0s - # where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102 - if o_descale_ptr is not None: - acc = quant_fp8(acc, o_descale) - - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if (causal_start_idx > start_m_idx - and causal_start_idx < end_m_idx): - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = tl.zeros((1, ), tl.float32) - acc = tl.where(out_ptrs_mask, acc, - z.to(acc.type.element_ty)) - # write back LSE - l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + - off_h_q * MAX_SEQLENS_Q + offs_m) - # If seqlen_q not multiple of BLOCK_M, we need to mask out the - # last few rows. This is only true for the last M block. - # For others, overflow_size will be -ve - overflow_size = end_m_idx - seqlen_q - if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), - BLOCK_M - overflow_size, - dtype=tl.int32) - l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary - tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - else: - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + - cu_seqlens_q_start * stride_om) - o_ptrs = (o_offset + offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) - if overflow_size > 0: - o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) - if USE_PADDED_HEAD: - o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] - < IS_ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) -def get_shape_from_layout(q, k, metadata): - assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." - - if metadata.layout == 'thd': - nheads_q, nheads_k = q.shape[1], k.shape[1] - head_size = q.shape[-1] - batch = metadata.num_contexts - elif metadata.layout == 'bhsd': - batch, nheads_q, _, head_size = q.shape - nheads_k = k.shape[1] - elif metadata.layout == 'bshd': - batch, _, nheads_q, head_size = q.shape - nheads_k = k.shape[2] - return batch, nheads_q, nheads_k, head_size - - -def get_strides_from_layout(q, k, v, o, metadata): - assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." - - STRIDE_PERMUTATIONS = { - 'thd': (None, 1, 0, 2), - 'bhsd': (0, 1, 2, 3), - 'bshd': (0, 2, 1, 3), - } - - perm = STRIDE_PERMUTATIONS[metadata.layout] - stride = lambda x, p: (0 if p is None else x.stride(p)) - strides = lambda x: (stride(x, p) for p in perm) - - return tuple(strides(x) for x in [q, k, v, o]) +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, o, metadata: MetaData): - # NOTE: a large bias tensor leads to overflow during pointer arithmetic - if (metadata.bias is not None): - assert (metadata.bias.numel() < 2**31) + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + fp8_scales=None, + fp8_out_scale=None, + ): + if fp8_scales is not None: + use_fp8 = True + (q_scale, k_scale, v_scale, p_scale) = fp8_scales + float8 = current_platform.fp8_dtype() + + def check_and_convert(t, scale): + if t.dtype != float8: + descale = 1.0 / scale + ts = (t * descale).clamp(min=float8_info.min, + max=float8_info.max) + return ts.to(float8) + else: + return t + + q = check_and_convert(q, q_scale) + k = check_and_convert(k, k_scale) + v = check_and_convert(v, v_scale) + else: + use_fp8 = False + q_scale = k_scale = v_scale = p_scale = 1.0 if o is None: - if metadata.eight_bit: - o = torch.empty_like( - q, - dtype=metadata.output_dtype if metadata.output_dtype - is not None else metadata.eight_bit_dtype_torch) - else: - o = torch.empty_like(q, dtype=q.dtype) + o = torch.empty_like(q, dtype=v.dtype) - metadata.check_args(q, k, v, o) - - batch, nheads_q, nheads_k, head_size = get_shape_from_layout( - q, k, metadata) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout( - q, k, v, o, metadata) + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - # Smallest head_dim supported is 16. If smaller, the tile in the - # kernel is padded - there is no padding in memory for any dims. - padded_d_model = max(padded_d_model, 16) - - # encoded_softmax is used to validate dropout behavior vs the - # PyTorch SDPA math backend reference. We zero this out to give a - # consistent starting point and then populate it with the output of - # softmax with the sign bit set according to the dropout mask. - # The resulting return allows this mask to be fed into the reference - # implementation for testing only. This return holds no useful output - # aside from debugging. - if metadata.return_encoded_softmax: - encoded_softmax = torch.zeros( - (q.shape[0], q.shape[1], q.shape[2], k.shape[2]), - device=q.device, - dtype=torch.float32) + unpadded_head_dims = {32, 64, 128, 256} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None else: - encoded_softmax = None + padded_d_model = head_size - M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), - device=q.device, - dtype=torch.float32) + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 - if metadata.bias is not None: - bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), - metadata.bias.stride(2), metadata.bias.stride(3)) + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) else: bias_strides = (0, 0, 0, 0) - if metadata.alibi_slopes is not None: - alibi_strides = (metadata.alibi_slopes.stride(0), - metadata.alibi_slopes.stride(1)) - else: - alibi_strides = (0, 0) + p_descale = 1.0 / p_scale + o_descale = 1.0 / fp8_out_scale.item( + ) if fp8_out_scale is not None else 1.0 - if metadata.eight_bit: - q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( - metadata.q_descale, metadata.k_descale, metadata.p_scale, - metadata.p_descale, metadata.v_descale, metadata.o_scale) - o_descale = 1.0 / o_scale if o_scale is not None else None - else: - q_descale = k_descale = p_scale = None - p_descale = v_descale = o_descale = None - - # number of compute units available - NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count - - grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[ - 'BLOCK_M']), nheads_q, batch) + arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q + arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k attn_fwd[grid]( q, k, v, - metadata.bias, - metadata.sm_scale, - M, + bias, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, + None, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, - *alibi_strides, - q_descale, - k_descale, - p_scale, - p_descale, - o_descale, - v_descale, - q_descale.numel() == 1 if q_descale is not None else False, - k_descale.numel() == 1 if k_descale is not None else False, - p_descale.numel() == 1 if p_descale is not None else False, - v_descale.numel() == 1 if v_descale is not None else False, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, - IS_ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=metadata.max_seqlens_q, - MAX_SEQLENS_K=metadata.max_seqlens_k, - IS_CAUSAL=metadata.causal, - VARLEN=metadata.varlen, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, BLOCK_DMODEL=padded_d_model, - USE_BIAS=metadata.bias is not None, - USE_ALIBI=metadata.alibi_slopes is not None, - SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, - IS_EIGHT_BIT=metadata.eight_bit, - USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale, - IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv, - NUM_CU=NUM_CU, - B=batch, - QUANT_DTYPE=metadata.eight_bit_dtype_triton) + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + USE_FP8=use_fp8, + USE_FP8_OUT=fp8_out_scale is not None, + ) ctx.grid = grid - ctx.sm_scale = metadata.sm_scale + ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = head_size - ctx.causal = metadata.causal - ctx.alibi_slopes = metadata.alibi_slopes + ctx.causal = causal + ctx.dropout_p = 0.0 ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = metadata.return_encoded_softmax + ctx.return_encoded_softmax = False return o, encoded_softmax -triton_attention_rocm = _attention.apply - - -def scale_fp8(t, scale=None): - t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]), - scale) - return t_scaled.reshape(t.shape), scale_out - - -def maybe_quantize_fp8(t, scale): - eight_bit_dtype = current_platform.fp8_dtype() - if t.dtype != eight_bit_dtype: - t, _ = scale_fp8(t, scale) - return t - - -def check_and_maybe_quantize_qkv(q, k, v, fp8_scales): - (q_scale, k_scale, v_scale, p_scale) = fp8_scales - - q = maybe_quantize_fp8(q, q_scale) - k = maybe_quantize_fp8(k, k_scale) - v = maybe_quantize_fp8(v, v_scale) - - return q, k, v - - -# query - [num_tokens, num_heads, head_size] -# key - [num_tokens, num_kv_heads, head_size] -# value - [num_tokens, num_kv_heads, head_size -# output - [num_tokens, num_heads, head_size] -def triton_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlens_q: int, - max_seqlens_k: int, - causal: bool = False, - sm_scale: float = 1.0, - bias: Optional[torch.Tensor] = None, - fp8_scales: Optional[tuple[float, ...]] = None, - input_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - if fp8_scales is not None: - q_descale, k_descale, v_descale, p_scale = fp8_scales - else: - q_descale = k_descale = v_descale = p_scale = None - - attn_metadata = MetaData(sm_scale=sm_scale, - max_seqlens_q=max_seqlens_q, - max_seqlens_k=max_seqlens_k, - causal=causal, - bias=bias, - output_dtype=q.dtype, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - p_scale=p_scale, - o_scale=input_scale) - - if fp8_scales is not None: - q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales) - - return triton_attention_rocm(q, k, v, o, attn_metadata) +triton_attention = _attention.apply diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 712e83528f122..35cc303f60eeb 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -9,9 +9,6 @@ generation. Supported dataset types include: - BurstGPT - HuggingFace - VisionArena - -TODO: Implement CustomDataset to parse a JSON file and convert its contents into -SampleRequest instances, similar to the approach used in ShareGPT. """ import base64 import io @@ -26,6 +23,7 @@ from io import BytesIO from typing import Any, Callable, Optional, Union import numpy as np +import pandas as pd from PIL import Image from transformers import PreTrainedTokenizerBase @@ -443,6 +441,99 @@ class ShareGPTDataset(BenchmarkDataset): return samples +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset.") + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 040815e879f0c..858a0c6a00e4b 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -1110,6 +1110,8 @@ def main(args: argparse.Namespace): ]: if field in result_json: del result_json[field] + if field in benchmark_result: + del benchmark_result[field] # Save to file base_model_id = model_id.split("/")[-1] @@ -1120,6 +1122,7 @@ def main(args: argparse.Namespace): if args.result_filename: file_name = args.result_filename if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) file_name = os.path.join(args.result_dir, file_name) with open(file_name, mode="a+" if args.append_result else "w", diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0358c9d0d1b5c..b724479a95dee 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -16,7 +16,7 @@ import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import resolve_obj_by_qualname +from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname from .compiler_interface import (CompilerInterface, EagerAdaptor, InductorAdaptor, InductorStandaloneAdaptor) @@ -29,7 +29,8 @@ logger = init_logger(__name__) def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: if compilation_config.use_inductor: - if envs.VLLM_TEST_STANDALONE_COMPILE: + if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( + "2.8.0"): logger.info("Using InductorStandaloneAdaptor") return InductorStandaloneAdaptor() else: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 21af5eb76ee8a..9293610cc2469 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -12,6 +12,7 @@ import torch._inductor.compile_fx import torch.fx as fx import vllm.envs as envs +from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig from vllm.utils import is_torch_equal_or_newer @@ -154,7 +155,7 @@ class InductorStandaloneAdaptor(CompilerInterface): This is not on by default yet, but we plan to turn it on by default for PyTorch 2.8. - Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off. + Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off. """ name = "inductor_standalone" @@ -175,6 +176,7 @@ class InductorStandaloneAdaptor(CompilerInterface): runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) @@ -262,6 +264,7 @@ class InductorAdaptor(CompilerInterface): runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 from torch._inductor.compile_fx import compile_fx current_config = {} if compiler_config is not None: @@ -412,8 +415,14 @@ class InductorAdaptor(CompilerInterface): # compilation cache. So turn off the checks if we disable the # compilation cache. if not envs.VLLM_DISABLE_COMPILE_CACHE: - assert hash_str is not None, ( - "failed to get the hash of the compiled graph") + if hash_str is None: + raise RuntimeError( + "vLLM failed to compile the model. The most " + "likely reason for this is that a previous compilation " + "failed, leading to a corrupted compilation artifact. " + "We recommend trying to " + "remove ~/.cache/vllm/torch_compile_cache and try again " + "to see the real issue. ") assert file_path is not None, ( "failed to get the file path of the compiled graph") return compiled_graph, (hash_str, file_path) @@ -528,6 +537,7 @@ class EagerAdaptor(CompilerInterface): runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_eager_compiles += 1 # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. return graph, None diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 5be452593c620..2200671b8848b 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -15,6 +15,10 @@ class CompilationCounter: num_piecewise_capturable_graphs_seen: int = 0 num_backend_compilations: int = 0 num_cudagraph_caputured: int = 0 + # InductorAdapter.compile calls + num_inductor_compiles: int = 0 + # EagerAdapter.compile calls + num_eager_compiles: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self) diff --git a/vllm/config.py b/vllm/config.py index fe2ad70f5aac8..d0891d670b76d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -24,6 +24,7 @@ import torch from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator, model_validator) from pydantic.dataclasses import dataclass +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated, runtime_checkable @@ -42,15 +43,16 @@ from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - try_get_generation_config, uses_mrope) + try_get_generation_config, try_get_safetensors_metadata, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, - LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, get_open_port, is_torch_equal_or_newer, - random_uuid, resolve_obj_by_qualname) + LayerBlockType, common_broadcastable_dtype, + cuda_device_count_stateless, get_cpu_memory, + get_open_port, is_torch_equal_or_newer, random_uuid, + resolve_obj_by_qualname) if TYPE_CHECKING: from _typeshed import DataclassInstance @@ -304,7 +306,7 @@ class ModelConfig: - 25.6k -> 25,600""" spec_target_max_model_len: Optional[int] = None """Specify the maximum length for spec decoding draft models.""" - quantization: Optional[QuantizationMethods] = None + quantization: SkipValidation[Optional[QuantizationMethods]] = None """Method used to quantize the weights. If `None`, we first check the `quantization_config` attribute in the model config file. If that is `None`, we assume the model weights are not quantized and use `dtype` to @@ -380,7 +382,7 @@ class ModelConfig: """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that can not be gathered from the vllm - arguments. e.g. `{"cast_logits_dtype": "bloat16"}`.""" + arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`.""" pooler_config: Optional["PoolerConfig"] = field(init=False) """Pooler config which controls the behaviour of output pooling in pooling models.""" @@ -540,7 +542,24 @@ class ModelConfig: self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision) - self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) + + supported_tasks, task = self._resolve_task(self.task) + self.supported_tasks = supported_tasks + self.task = task + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" + + self.pooler_config = self._init_pooler_config() + + self.dtype = _get_and_verify_dtype( + self.model, + self.hf_config, + self.dtype, + is_pooling_model=self.runner_type == "pooling", + revision=self.revision, + ) # Workaround for Gemma 2 which uses interleaved sliding window # attention, but it's not specified in its config. TODO: remove this @@ -597,16 +616,6 @@ class ModelConfig: raise ValueError( "`override_neuron_config` is only supported on Neuron.") - supported_tasks, task = self._resolve_task(self.task) - self.supported_tasks = supported_tasks - self.task = task - if self.task in ("draft", "generate"): - self.truncation_side = "left" - else: - self.truncation_side = "right" - - self.pooler_config = self._init_pooler_config() - self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -692,7 +701,6 @@ class ModelConfig: self.model, self.revision) def _init_pooler_config(self) -> Optional["PoolerConfig"]: - if self.runner_type == "pooling": if isinstance(self.override_pooler_config, dict): self.override_pooler_config = PoolerConfig( @@ -1360,6 +1368,16 @@ class ModelConfig: @property def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" + """ + For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to + True to enable cross-attention + Neuron needs all multimodal data to be in the decoder and does not + need to explicitly enable cross-attention + """ + if (current_platform.is_neuron() + and self.hf_config.model_type == "mllama"): + return False + return is_encoder_decoder(self.hf_config) @property @@ -1772,6 +1790,10 @@ class ParallelConfig: rank: int = 0 """Global rank in distributed setup.""" + enable_multimodal_encoder_data_parallel: bool = False + """ Use data parallelism instead of tensor parallelism for vision encoder. + Only support LLama4 for now""" + @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world @@ -2231,7 +2253,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] class DeviceConfig: """Configuration for the device to use for vLLM execution.""" - device: Union[Device, torch.device] = "auto" + device: SkipValidation[Union[Device, torch.device]] = "auto" """Device type for vLLM execution. This parameter is deprecated and will be removed in a future release. @@ -2573,7 +2595,8 @@ class SpeculativeConfig: else: eagle_config = EAGLEConfig( self.draft_model_config.hf_config, - method=self.method) + method=self.method, + model_type="eagle") self.draft_model_config.hf_config = eagle_config if (self.num_speculative_tokens is not None @@ -3064,13 +3087,37 @@ _STR_DTYPE_TO_TORCH_DTYPE = { "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] # +# model_type -> reason +_FLOAT16_NOT_SUPPORTED_MODELS = { + "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", + "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", +} -def _get_and_verify_dtype( +def _is_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 + return False + + return True + + +def _check_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] + raise ValueError(f"The model type {model_type!r} " + f"does not support float16. Reason: {reason}") + + return True + + +def _find_dtype( + model_id: str, config: PretrainedConfig, - dtype: Union[str, torch.dtype], -) -> torch.dtype: + *, + revision: Optional[str], +): # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) @@ -3082,75 +3129,111 @@ def _get_and_verify_dtype( if config_dtype is None and hasattr(config, "vision_config"): config_dtype = getattr(config.vision_config, "torch_dtype", None) + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + if config_dtype is None: config_dtype = torch.float32 + return config_dtype + + +def _resolve_auto_dtype( + model_type: str, + config_dtype: torch.dtype, + *, + is_pooling_model: bool, +): + from vllm.platforms import current_platform + + supported_dtypes = [ + dtype for dtype in current_platform.supported_dtypes + if _is_valid_dtype(model_type, dtype) + ] + + if is_pooling_model and torch.float16 in supported_dtypes: + preferred_dtype = torch.float16 + else: + preferred_dtype = supported_dtypes[0] + + # Downcast for float32 models + if config_dtype == torch.float32: + config_dtype = preferred_dtype + + if config_dtype in supported_dtypes: + return config_dtype + + # Ensure device compatibility + device_name = current_platform.get_device_name() + device_capability = current_platform.get_device_capability() + + if device_capability is None: + device_str = f"{device_name!r}" + else: + version_str = device_capability.as_version_str() + device_str = f"{device_name!r} (with compute capability {version_str})" + + logger.warning( + "Your device %s doesn't support %s. " + "Falling back to %s for compatibility.", + device_str, + config_dtype, + preferred_dtype, + ) + + return preferred_dtype + + +def _get_and_verify_dtype( + model_id: str, + config: PretrainedConfig, + dtype: Union[str, torch.dtype], + *, + is_pooling_model: bool, + revision: Optional[str] = None, +) -> torch.dtype: + config_dtype = _find_dtype(model_id, config, revision=revision) + model_type = config.model_type + if isinstance(dtype, str): dtype = dtype.lower() if dtype == "auto": # Set default dtype from model config - if config_dtype == torch.float32: - # Following common practice, we use float16 for float32 models - torch_dtype = torch.float16 - else: - torch_dtype = config_dtype - - if config.model_type == "plamo2": - logger.warning( - "For PLaMo2, we cast models to bfloat16 instead of using " - "float16 by default. This is because float16 does not work." - ) - torch_dtype = torch.bfloat16 - - # Deal with torch dtype fallback for device compatibility. - from vllm.platforms import current_platform - if torch_dtype not in current_platform.supported_dtypes: - device_name = current_platform.get_device_name() - - if ((capability := current_platform.get_device_capability()) - is None): - compute_str = "" - else: - version_str = capability.as_version_str() - compute_str = f" (with compute capability {version_str})" - fallback_dtype = current_platform.supported_dtypes[0] - logger.warning( - "Your %s device%s doesn't support %s. " \ - "Falling back to %s for compatibility.", - device_name, compute_str, torch_dtype, fallback_dtype - ) - torch_dtype = fallback_dtype - - if current_platform.is_hpu() and torch_dtype == torch.float16: - logger.warning( - "For HPU, we cast models to bfloat16 instead of " - "using float16 by default. Please specify `dtype` if you " - "want to use float16.") - torch_dtype = torch.bfloat16 - elif dtype == "float16" and config.model_type == "plamo2": - logger.warning( - "For PLaMo2, using float16 is unstable and might cause " - "unexpected behavior. Please use bfloat16 or float32 instead.") - torch_dtype = torch.float16 + torch_dtype = _resolve_auto_dtype( + model_type, + config_dtype, + is_pooling_model=is_pooling_model, + ) else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: - raise ValueError(f"Unknown dtype: {dtype}") + raise ValueError(f"Unknown dtype: {dtype!r}") torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] elif isinstance(dtype, torch.dtype): torch_dtype = dtype else: raise ValueError(f"Unknown dtype: {dtype}") - # Verify the dtype. + _check_valid_dtype(model_type, torch_dtype) + if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) - pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) - pass else: # Casting between float16 and bfloat16 is allowed with a warning. logger.warning("Casting %s to %s.", config_dtype, torch_dtype) @@ -4315,15 +4398,10 @@ class VllmConfig: self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ not self.model_config.enforce_eager: - # NOTE(woosuk): Currently, we use inductor because the piecewise - # CUDA graphs do not work properly with the custom CUDA kernels. - # FIXME(woosuk): Disable inductor to reduce the compilation time - # and avoid any potential issues with the inductor. # FIXME(rob): Add function to set all of these. if not self.compilation_config.custom_ops: self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True - self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_noop = False diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 06b3983ed68bd..dce0b545c188e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -70,7 +70,8 @@ class KVConnectorFactory: connector_module = importlib.import_module(connector_module_path) connector_cls = getattr(connector_module, connector_name) assert issubclass(connector_cls, KVConnectorBase_V1) - logger.info("Creating v1 connector with name: %s", connector_name) + logger.info("Creating v1 connector with name: %s and engine_id: %s", + connector_name, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. # Scheduler connector: # - Co-locate with scheduler process diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6303d77ad3055..4d228dbc9d492 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -172,6 +172,11 @@ class NixlConnectorScheduler: self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id + self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv. @@ -310,8 +315,8 @@ class NixlConnectorScheduler: do_remote_decode=False, remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, - remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, - remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, ) @@ -330,9 +335,18 @@ class NixlConnectorWorker: # Map of engine_id -> agent_name. self._remote_agents: dict[str, str] = {} + # NIXL handshake port. + # NOTE(rob): Within a DP group, each DP rank gets its own + # base port (which is sent in the KVTransferParams). + # Each TP rank listens/queries on the base_port + tp_rank. + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + # Metadata. self.engine_id = engine_id - self.rank = get_tensor_model_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() @@ -382,15 +396,11 @@ class NixlConnectorWorker: @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, rank: int): + ready_event: threading.Event, base_port: int, + tp_rank: int): """Background thread for getting new NIXL handshakes.""" # NOTE(rob): this is a simple implementation. We will move - # to a better approach like an ETCD server in the future. - - # NOTE(rob): to support heterogeneous TP, we will have to - # move this into the scheduler rather than worker, since - # each rank needs the metadata of all other ranks (whereas - # in this setup, each rank only gets one other rank's meta. + # to a better approach via HTTP endpoint soon. encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) @@ -400,11 +410,7 @@ class NixlConnectorWorker: # Listen for new requests for metadata. host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - # NOTE(rob): we need each rank to have a unique port. This - # hack to keeps us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank - path = make_zmq_path("tcp", host, port) + path = make_zmq_path("tcp", host, base_port + tp_rank) logger.debug("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() @@ -419,10 +425,10 @@ class NixlConnectorWorker: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() - # NOTE(rob): we need each rank to have a unique port. This is - # a hack to keep us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - path = make_zmq_path("tcp", host, port + self.rank) + # NOTE(rob): we need each tp_rank to have a unique port. + # This is a hack to keep us moving. We will switch when + # we switch to HTTP-based NIXL metadata exchange. + path = make_zmq_path("tcp", host, port + self.tp_rank) logger.debug("Querying metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: # Send query for the request. @@ -486,7 +492,8 @@ class NixlConnectorWorker: for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len - caches_data.append((base_addr, region_len, self.rank, "")) + caches_data.append( + (base_addr, region_len, cache.device.index, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) @@ -529,7 +536,7 @@ class NixlConnectorWorker: ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.rank), + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), daemon=True, name="nixl_handshake_listener") self._nixl_handshake_listener_t.start() @@ -553,9 +560,9 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len # (addr, len, device id) blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) + (base_addr + block_offset, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and tp_rank %s", + len(blocks_data), self.engine_id, self.tp_rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") @@ -570,9 +577,9 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len # (addr, len, device id) blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, self.rank) + (base_addr + block_offset, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for dst engine %s and tp_rank %s", + len(blocks_data), engine_id, self.tp_rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") @@ -597,14 +604,14 @@ class NixlConnectorWorker: if len(done_sending) > 0 or len(done_recving) > 0: logger.debug( "Rank %s, get_finished: %s requests done sending " - "and %s requests done recving", self.rank, len(done_sending), - len(done_recving)) + "and %s requests done recving", self.tp_rank, + len(done_sending), len(done_recving)) if self.world_size == 1: return done_sending, done_recving # Rank 0: get finished from all other ranks. - if self.rank == 0: + if self.tp_rank == 0: for req_id in done_sending: self._done_sending_count[req_id] += 1 for req_id in done_recving: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b674d05a7771b..6e48c02da6692 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -41,8 +41,8 @@ from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, - supports_custom_op) +from vllm.utils import (direct_register_custom_op, get_distributed_init_method, + resolve_obj_by_qualname, supports_custom_op) @dataclass @@ -929,7 +929,7 @@ def init_distributed_environment( world_size = parallel_config.world_size_across_dp ip = parallel_config.data_parallel_master_ip port = parallel_config.get_next_dp_init_port() - distributed_init_method = f"tcp://{ip}:{port}" # noqa + distributed_init_method = get_distributed_init_method(ip, port) logger.info( "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", world_size, rank, distributed_init_method) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2a1a342110ba5..299c8347f458a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -224,7 +224,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: elif contains_type(type_hints, int): kwargs[name]["type"] = int # Special case for large integers - if name in {"max_model_len"}: + if name in {"max_model_len", "max_num_batched_tokens"}: kwargs[name]["type"] = human_readable_int elif contains_type(type_hints, float): kwargs[name]["type"] = float @@ -423,6 +423,9 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location + enable_multimodal_encoder_data_parallel: bool = \ + ParallelConfig.enable_multimodal_encoder_data_parallel + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -637,6 +640,9 @@ class EngineArgs: **parallel_kwargs["worker_cls"]) parallel_group.add_argument("--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]) + parallel_group.add_argument( + "--enable-multimodal-encoder-data-parallel", + **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -1078,6 +1084,8 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, + enable_multimodal_encoder_data_parallel=self. + enable_multimodal_encoder_data_parallel, ) speculative_config = self.create_speculative_config( @@ -1380,7 +1388,8 @@ class EngineArgs: if (self.pipeline_parallel_size > 1 and self.distributed_executor_backend - not in ("ray", "mp", "external_launcher")): + not in (ParallelConfig.distributed_executor_backend, "ray", + "mp", "external_launcher")): name = "Pipeline Parallelism without Ray distributed executor " \ "or multiprocessing executor or external launcher" _raise_or_fallback(feature_name=name, recommend_to_remove=False) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 957fec290bf26..e65c97073218b 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,24 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import os import signal +import sys import uvloop +import zmq import vllm.envs as envs from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, + setup_server) from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG, show_filtered_argument_or_group_from_help) +from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus +from vllm.v1.utils import (APIServerProcessManager, CoreEngine, + EngineZmqAddresses, get_engine_client_zmq_addr, + wait_for_completion_or_failure, + wait_for_engine_startup) logger = init_logger(__name__) @@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand): if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - if args.headless: + if args.headless or args.api_server_count < 1: run_headless(args) + elif args.api_server_count > 1: + run_multi_api_server(args) else: + # Single API server (this process). uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: @@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand): type=int, default=0, help='Starting data parallel rank for secondary nodes.') + serve_parser.add_argument('--api-server-count', + '-asc', + type=int, + default=1, + help='How many API server processes to run.') serve_parser.add_argument( "--config", type=str, @@ -91,23 +110,26 @@ def cmd_init() -> list[CLISubcommand]: def run_headless(args: argparse.Namespace): + if args.api_server_count > 1: + raise ValueError("api_server_count can't be set in headless mode") + # Create the EngineConfig. engine_args = AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) if not envs.VLLM_USE_V1: - raise RuntimeError("Headless mode is only supported for V1") + raise ValueError("Headless mode is only supported for V1") parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too - input_address = get_tcp_uri(host, port) + handshake_address = get_tcp_uri(host, port) if local_engine_count <= 0: - raise RuntimeError("data_parallel_size_local must be > 0 in " - "headless mode") + raise ValueError("data_parallel_size_local must be > 0 in " + "headless mode") # Catch SIGTERM and SIGINT to allow graceful shutdown. def signal_handler(signum, frame): @@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace): logger.info( "Launching %d data parallel engine(s) in headless mode, " - "with head node address %s.", local_engine_count, input_address) + "with head node address %s.", local_engine_count, handshake_address) # Create the engines. engine_manager = CoreEngineProcManager( @@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace): local_start_index=0, vllm_config=vllm_config, on_head_node=False, - input_address=input_address, + handshake_address=handshake_address, executor_class=Executor.get_class(vllm_config), log_stats=not engine_args.disable_log_stats, ) @@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace): finally: logger.info("Shutting down.") engine_manager.close() + + +def run_multi_api_server(args: argparse.Namespace): + + assert not args.headless + num_api_servers = args.api_server_count + assert num_api_servers > 0 + + if num_api_servers > 1: + setup_multiprocess_prometheus() + + listen_address, sock = setup_server(args) + + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + model_config = vllm_config.model_config + + if num_api_servers > 1: + if not envs.VLLM_USE_V1: + raise ValueError("api_server_count > 1 is only supported for V1") + + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " + "with api_server_count > 1") + + if model_config.is_multimodal_model and not ( + model_config.disable_mm_preprocessor_cache): + logger.warning( + "Multi-model preprocessor cache will be disabled for" + " api_server_count > 1") + model_config.disable_mm_preprocessor_cache = True + + parallel_config = vllm_config.parallel_config + + assert parallel_config.data_parallel_rank == 0 + + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + local_only = local_engine_count == dp_size + + # Set up input and output addresses. + input_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + output_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + + addresses = EngineZmqAddresses( + inputs=input_addresses, + outputs=output_addresses, + ) + + # Set up coordinator for dp > 1. + coordinator = None + stats_update_address = None + if dp_size > 1: + coordinator = DPCoordinator(parallel_config) + addresses.coordinator_input, addresses.coordinator_output = ( + coordinator.get_engine_socket_addresses()) + stats_update_address = coordinator.get_stats_publish_address() + logger.info("Started DP Coordinator process (PID: %d)", + coordinator.proc.pid) + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) + + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + # Start local engines. + if not local_engine_count: + local_engine_manager = None + else: + local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=0, + local_start_index=0) + + # Start API servers using the manager + api_server_manager = APIServerProcessManager( + target_server_fn=run_api_server_worker_proc, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) + + # Wait for engine handshakes to complete. + core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] + wait_for_engine_startup( + handshake_socket, + addresses, + core_engines, + parallel_config, + vllm_config.cache_config, + local_engine_manager, + coordinator.proc if coordinator else None, + ) + + # Wait for API servers + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + local_engine_manager=local_engine_manager, + coordinator=coordinator) + + +def run_api_server_worker_proc(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + """Entrypoint for individual API server worker processes.""" + + # Add process-specific prefix to stdout and stderr. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + uvloop.run( + run_server_worker(listen_address, sock, args, client_config, + **uvicorn_kwargs)) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f8eeae61fc913..e05189ef49611 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -45,8 +45,7 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, - is_list_of) +from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of if TYPE_CHECKING: from vllm.v1.metrics.reader import Metric @@ -143,12 +142,6 @@ class LLM: DEPRECATE_LEGACY: ClassVar[bool] = True """A flag to toggle whether to deprecate the legacy generate/encode API.""" - DEPRECATE_INIT_POSARGS: ClassVar[bool] = True - """ - A flag to toggle whether to deprecate positional arguments in - [LLM.__init__][]. - """ - @classmethod @contextmanager def deprecate_legacy_api(cls): @@ -158,16 +151,11 @@ class LLM: cls.DEPRECATE_LEGACY = False - @deprecate_args( - start_index=2, # Ignore self and model - is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS, - additional_message=( - "All positional arguments other than `model` will be " - "replaced with keyword arguments in an upcoming version."), - ) def __init__( self, model: str, + *, + task: TaskOption = "auto", tokenizer: Optional[str] = None, tokenizer_mode: TokenizerMode = "auto", skip_tokenizer_init: bool = False, @@ -189,8 +177,6 @@ class LLM: hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, - # After positional args are removed, move this right below `model` - task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any]]] = None, **kwargs, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b991cb3a444bc..5a4295ff716db 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,6 +5,7 @@ import atexit import gc import importlib import inspect +import json import multiprocessing import os import signal @@ -16,8 +17,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from json import JSONDecodeError -from typing import Annotated, Optional +from typing import Annotated, Any, Optional import prometheus_client import regex as re @@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import State from starlette.routing import Mount @@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, is_valid_ipv6_address, set_ulimit) +from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -142,14 +145,17 @@ async def lifespan(app: FastAPI): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[EngineClient]: + args: Namespace, + client_config: Optional[dict[str, Any]] = None, +) -> AsyncIterator[EngineClient]: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing) as engine: + engine_args, args.disable_frontend_multiprocessing, + client_config) as engine: yield engine @@ -157,6 +163,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, + client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM async_llm: Optional[AsyncLLM] = None + client_index = client_config.pop( + "client_index") if client_config else 0 try: async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, disable_log_requests=engine_args.disable_log_requests, - disable_log_stats=engine_args.disable_log_stats) + disable_log_stats=engine_args.disable_log_stats, + client_addresses=client_config, + client_index=client_index) # Don't keep the dummy data in memory await async_llm.reset_mm_cache() @@ -318,22 +329,9 @@ class PrometheusResponse(Response): def mount_metrics(app: FastAPI): - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app, - multiprocess) - from prometheus_fastapi_instrumentator import Instrumentator + """Mount prometheus metrics to a FastAPI app.""" - registry = REGISTRY - - prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) - if prometheus_multiproc_dir_path is not None: - logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) + registry = get_prometheus_registry() # `response_class=PrometheusResponse` is needed to return an HTTP response # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" @@ -932,7 +930,7 @@ async def invocations(raw_request: Request): """ try: body = await raw_request.json() - except JSONDecodeError as e: + except json.JSONDecodeError as e: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}") from e @@ -1005,6 +1003,18 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: return Response(status_code=200, content=response) +def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: + if not log_config_file: + return None + try: + with open(log_config_file) as f: + return json.load(f) + except Exception as e: + logger.warning("Failed to load log config from file %s: error %s", + log_config_file, e) + return None + + def build_app(args: Namespace) -> FastAPI: if args.disable_fastapi_docs: app = FastAPI(openapi_url=None, @@ -1256,16 +1266,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: return sock -async def run_server(args, **uvicorn_kwargs) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - log_non_default_args(args) - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - +def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: + and args.tool_call_parser not in valid_tool_parses: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valid_tool_parses)} }})") @@ -1276,6 +1280,19 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"invalid reasoning parser: {args.reasoning_parser} " f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + +def setup_server(args): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 @@ -1292,22 +1309,46 @@ async def run_server(args, **uvicorn_kwargs) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as engine_client: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + """Run a single-worker API server.""" + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +async def run_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + """Run a single API server worker.""" + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + # Load logging config for uvicorn if specified + log_config = load_log_config(args.log_config_file) + if log_config is not None: + uvicorn_kwargs['log_config'] = log_config + + async with build_async_engine_client(args, client_config) as engine_client: app = build_app(args) vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) - def _listen_addr(a: str) -> str: - if is_valid_ipv6_address(a): - return '[' + a + ']' - return a or "0.0.0.0" - - is_ssl = args.ssl_keyfile and args.ssl_certfile - logger.info("Starting vLLM API server on http%s://%s:%d", - "s" if is_ssl else "", _listen_addr(sock_addr[0]), - sock_addr[1]) - + logger.info("Starting vLLM API server %d on %s", server_index, + listen_address) shutdown_task = await serve_http( app, sock=sock, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index d01af5e422666..f196ff6ed3021 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,6 +11,7 @@ import ssl from collections.abc import Sequence from typing import Optional, Union, get_args +import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) @@ -243,6 +244,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: " into OpenAI API format, the name register in this plugin can be used " "in ``--tool-call-parser``.") + parser.add_argument( + "--log-config-file", + type=str, + default=envs.VLLM_LOGGING_CONFIG_PATH, + help="Path to logging config JSON file for both vllm and uvicorn", + ) + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a7f85e9eef394..e72c23993ac8c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1481,6 +1481,10 @@ class TranscriptionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, + ScoreRequest, RerankRequest] + + class BatchRequestInput(OpenAIBaseModel): """ The per-line object of the batch input file. @@ -1501,21 +1505,22 @@ class BatchRequestInput(OpenAIBaseModel): url: str # The parameters of the request. - body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] + body: BatchRequestInputBody @field_validator('body', mode='plain') @classmethod def check_type_for_url(cls, value: Any, info: ValidationInfo): # Use url to disambiguate models - url = info.data['url'] + url: str = info.data["url"] if url == "/v1/chat/completions": return ChatCompletionRequest.model_validate(value) if url == "/v1/embeddings": return TypeAdapter(EmbeddingRequest).validate_python(value) - if url == "/v1/score": + if url.endswith("/score"): return ScoreRequest.model_validate(value) - return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest, - ScoreRequest]).validate_python(value) + if url.endswith("/rerank"): + return RerankRequest.model_validate(value) + return TypeAdapter(BatchRequestInputBody).validate_python(value) class BatchResponseData(OpenAIBaseModel): @@ -1527,7 +1532,7 @@ class BatchResponseData(OpenAIBaseModel): # The body of the response. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, - ScoreResponse]] = None + ScoreResponse, RerankResponse]] = None class BatchRequestOutput(OpenAIBaseModel): @@ -1558,6 +1563,11 @@ class TokenizeCompletionRequest(OpenAIBaseModel): "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt."), ) + return_token_strs: Optional[bool] = Field( + default=False, + description=("If true, also return the token strings " + "corresponding to the token ids."), + ) class TokenizeChatRequest(OpenAIBaseModel): @@ -1571,6 +1581,11 @@ class TokenizeChatRequest(OpenAIBaseModel): "This is a parameter used by chat template in tokenizer config of the " "model."), ) + return_token_strs: Optional[bool] = Field( + default=False, + description=("If true, also return the token strings " + "corresponding to the token ids."), + ) continue_final_message: bool = Field( default=False, description= @@ -1628,6 +1643,7 @@ class TokenizeResponse(OpenAIBaseModel): count: int max_model_len: int tokens: list[int] + token_strs: Optional[list[str]] = None class DetokenizeRequest(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index f38465b22bcca..ac250b3cb4fbf 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -21,7 +21,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchResponseData, ChatCompletionResponse, EmbeddingResponse, ErrorResponse, - ScoreResponse) + RerankResponse, ScoreResponse) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -274,8 +274,11 @@ async def run_request(serving_engine_func: Callable, tracker: BatchProgressTracker) -> BatchRequestOutput: response = await serving_engine_func(request.body) - if isinstance(response, - (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)): + if isinstance( + response, + (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, + RerankResponse), + ): batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, @@ -397,7 +400,7 @@ async def main(args): response_futures.append( run_request(embed_handler_fn, request, tracker)) tracker.submitted() - elif request.url == "/v1/score": + elif request.url.endswith("/score"): score_handler_fn = openai_serving_scores.create_score if \ openai_serving_scores is not None else None if score_handler_fn is None: @@ -411,13 +414,29 @@ async def main(args): response_futures.append( run_request(score_handler_fn, request, tracker)) tracker.submitted() + elif request.url.endswith("/rerank"): + rerank_handler_fn = openai_serving_scores.do_rerank if \ + openai_serving_scores is not None else None + if rerank_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Rerank API", + )) + continue + + response_futures.append( + run_request(rerank_handler_fn, request, tracker)) + tracker.submitted() else: response_futures.append( make_async_error_request_output( request, - error_msg= - "Only /v1/chat/completions, /v1/embeddings, and /v1/score " - "are supported in the batch endpoint.", + error_msg=f"URL {request.url} was used. " + "Supported endpoints: /v1/chat/completions, /v1/embeddings," + " /score, /rerank ." + "See vllm/entrypoints/openai/api_server.py for supported " + "score/rerank versions.", )) with tracker.pbar(): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6a0e3b14d07bb..ea8e187dc6b7f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -988,7 +988,8 @@ class OpenAIServingChat(OpenAIServing): tool_calls=[ tool_call_class(function=FunctionCall( name=tool_call.name, - arguments=json.dumps(tool_call.parameters))) + arguments=json.dumps(tool_call.parameters, + ensure_ascii=False))) for tool_call in tool_calls ]) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 5ef1a486d86c8..0d739bbf9bf22 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -110,7 +110,12 @@ class OpenAIServingTokenization(OpenAIServing): dict) and "prompt_token_ids" in engine_prompt: input_ids.extend(engine_prompt["prompt_token_ids"]) + token_strs = None + if request.return_token_strs: + token_strs = tokenizer.convert_ids_to_tokens(input_ids) + return TokenizeResponse(tokens=input_ids, + token_strs=token_strs, count=len(input_ids), max_model_len=self.max_model_len) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 13565d0ef8dd7..9fc5b562e7d5c 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -278,7 +278,9 @@ class OpenAIServingTranscription(OpenAIServing): result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None try: - # TODO(rob): subtract len of tokenized prompt. + # Unlike most decoder-only models, whisper generation length is not + # constrained by the size of the input audio, which is mapped to a + # fixed-size log-mel-spectogram. default_max_tokens = self.model_config.max_model_len sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 858c8db99fd29..323fb144181ea 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -7,6 +7,7 @@ from typing import Any, Union import regex as re from transformers import PreTrainedTokenizerBase +import vllm.envs as envs from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -64,7 +65,19 @@ class Llama4PythonicToolParser(ToolParser): if model_output.startswith("<|python_start|>"): model_output = model_output[len("<|python_start|>"):] model_output = model_output.replace("<|python_end|>", "") - if not (self.TOOL_CALL_REGEX.match(model_output)): + + is_tool_call_pattern = False + try: + is_tool_call_pattern = self.TOOL_CALL_REGEX.match( + model_output, + timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + except TimeoutError: + logger.warning( + "Regex timeout occurred when matching tool call pattern.") + logger.debug("Regex timeout occurred when matching user input: %s", + model_output) + + if not is_tool_call_pattern: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index b403a146716d5..00690ad79a7ac 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -68,8 +68,8 @@ class Phi4MiniJsonToolParser(ToolParser): len(function_call_arr)) except json.JSONDecodeError as e: logger.error( - "Failed to parse function calls from model output: %s. " - "Error: %s", model_output, str(e)) + "Failed to parse function calls from model output. " + "Error: %s", str(e)) tool_calls: list[ToolCall] = [ ToolCall( diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 548ff39d1ca4f..bc5d15dcb82f4 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -8,6 +8,7 @@ from typing import Any, Union import regex as re from transformers import PreTrainedTokenizerBase +import vllm.envs as envs from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -61,8 +62,18 @@ class PythonicToolParser(ToolParser): """ Extract the tool calls from a complete model response. """ + is_tool_call_pattern = False + try: + is_tool_call_pattern = self.TOOL_CALL_REGEX.match( + model_output, + timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + except TimeoutError: + logger.warning( + "Regex timeout occurred when matching tool call pattern.") + logger.debug("Regex timeout occurred when matching user input: %s", + model_output) - if not (self.TOOL_CALL_REGEX.match(model_output)): + if not is_tool_call_pattern: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) diff --git a/vllm/envs.py b/vllm/envs.py index b007bf8c59b72..44baf5a189b43 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -118,6 +119,7 @@ if TYPE_CHECKING: VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 + VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 def get_default_cache_root(): @@ -142,10 +144,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. - + Returns: The port number as an integer if VLLM_PORT is set, None otherwise. - + Raises: ValueError: If VLLM_PORT is a URI, suggest k8s service discovery issue. """ @@ -158,17 +160,13 @@ def get_vllm_port() -> Optional[int]: return int(port) except ValueError as err: from urllib.parse import urlparse - try: - parsed = urlparse(port) - if parsed.scheme: - raise ValueError( - f"VLLM_PORT '{port}' appears to be a URI. " - "This may be caused by a Kubernetes service discovery issue" - "check the warning in: https://docs.vllm.ai/en/stable/usage/env_vars.html" - ) - except Exception: - pass - + parsed = urlparse(port) + if parsed.scheme: + raise ValueError( + f"VLLM_PORT '{port}' appears to be a URI. " + "This may be caused by a Kubernetes service discovery issue," + "check the warning in: https://docs.vllm.ai/en/stable/serving/env_vars.html" + ) from None raise ValueError( f"VLLM_PORT '{port}' must be a valid integer") from err @@ -290,6 +288,13 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Use separate prefill and decode kernels for V1 attention instead of + # the unified triton kernel. + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": + lambda: + (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in + ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": @@ -300,9 +305,11 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), - # Internal flag to enable/disable Inductor standalone compile - "VLLM_TEST_STANDALONE_COMPILE": - lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0", + # Feature flag to enable/disable Inductor standalone compile. + # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is + # enabled by default. + "VLLM_USE_STANDALONE_COMPILE": + lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1", # local rank of the process in the distributed setting, used to determine # the GPU device id @@ -323,8 +330,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to log responses from API Server for debugging "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": - lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"). - lower() == "true", + lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" + ).lower() == "true", # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": @@ -822,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # This is used to prevent the kernel from running out of memory. "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + + # Regex timeout for use by the vLLM tool parsing plugins. + "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": + lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")), } # --8<-- [end:env-vars-definition] @@ -884,7 +895,7 @@ def compute_hash() -> str: "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", "VLLM_DP_SIZE", - "VLLM_TEST_STANDALONE_COMPILE", + "VLLM_USE_STANDALONE_COMPILE", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 592ca650a5546..f192be1c40d54 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -47,8 +47,12 @@ class DPMetadata: return num_tokens_tensor @staticmethod - def make(parallel_config: ParallelConfig, attn_metadata: Any, - num_tokens: int) -> "DPMetadata": + def make( + parallel_config: ParallelConfig, + attn_metadata: Any, + num_tokens: int, + num_tokens_across_dp: Optional[torch.Tensor] = None + ) -> "DPMetadata": assert parallel_config.data_parallel_size > 1 dp_size = parallel_config.data_parallel_size @@ -62,10 +66,15 @@ class DPMetadata: # for v1 attention backends or no attn_metadata batchsize = num_tokens - num_tokens_tensor = DPMetadata.num_tokens_across_dp( - batchsize, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_tensor) - cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) + # If num_tokens_across_dp is None, it will be computed by all_reduce + # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize + assert (num_tokens_across_dp is None + or num_tokens_across_dp[dp_rank] == batchsize) + if num_tokens_across_dp is None: + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + batchsize, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) + cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) @@ -101,7 +110,8 @@ def get_forward_context() -> ForwardContext: def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0): + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -111,9 +121,11 @@ def set_forward_context(attn_metadata: Any, if need_to_track_batchsize: forward_start_time = time.perf_counter() dp_metadata: Optional[DPMetadata] = None - if vllm_config.parallel_config.data_parallel_size > 1: + if vllm_config.parallel_config.data_parallel_size > 1 and ( + attn_metadata is not None or num_tokens is not None): dp_metadata = DPMetadata.make(vllm_config.parallel_config, - attn_metadata, num_tokens) + attn_metadata, num_tokens or 0, + num_tokens_across_dp) global _forward_context prev_context = _forward_context diff --git a/vllm/lora/models.py b/vllm/lora/models.py index d3b1374a9dd29..dfdc908d7e05b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import copy import math import os from collections.abc import Sequence @@ -34,6 +33,7 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.model_executor.utils import get_packed_modules_mapping from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -364,8 +364,8 @@ class LoRAModelManager(AdapterModelManager): # We need to replace rotary emb layer to do batch computation # for long lora. self.supported_lora_modules.append("rotary_emb") - self.packed_modules_mapping = copy.deepcopy( - self.model.packed_modules_mapping) + + self.packed_modules_mapping = get_packed_modules_mapping(self.model) # Used to indicate whether the model is a multimodal model self.supports_mm: bool = ( supports_multimodal(self.model) diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index af79f98415cbc..ab65faceb2c10 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -36,10 +36,13 @@ def bgmv_expand(inputs: torch.Tensor, if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1 + # LoRA adapter and model may add different amounts of padding to output + common_len = min(outputs.shape[1], output_tensor.shape[1]) + if add_inputs: - output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + output_tensor[:, :common_len] += outputs[:limit, :common_len] else: - output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + output_tensor[:, :common_len] = outputs[:limit, :common_len] def sgmv_shrink( diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index afc8a8dc3b260..f1ae030975074 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): self.add_adapter(lora) def add_adapter(self, lora_request: LoRARequest) -> bool: + # Note that this method is not thread-safe. It may be invoked multiple + # times for the same adapter when using multiple API servers. + # This is ok because it's currently only called from + # the single-threaded core engine loop. + if lora_request.lora_int_id not in self.list_adapters(): # Load the new adapter first to ensure it is actually valid, before # evicting any existing adapters. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ae72826ee9765..9242e4ae57468 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -48,7 +48,7 @@ else: FusedMoEPrepareAndFinalize = None # type: ignore if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_biased_group_topk as grouped_topk) + rocm_aiter_grouped_topk as grouped_topk) else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index babeb97308a9f..9d8bd62c6969a 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +import torch_xla.experimental.custom_kernel # noqa: F401 def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: @@ -66,15 +67,10 @@ def fused_moe( token_indices = token_indices[topk_argsort_indices] group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) - # NOTE(woosuk): The GMM Pallas kernel requires a different weight layout - # from HF Transformers. - w1 = w1.transpose(1, 2) - w2 = w2.transpose(1, 2) - x = hidden_states[token_indices] - x = torch.ops.xla.gmm(x, w1, group_sizes) + x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True) x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] - x = torch.ops.xla.gmm(x, w2, group_sizes) + x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) x = x * topk_weights.unsqueeze(dim=-1) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 10b61fcda1767..824062491f0ed 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -140,6 +140,36 @@ def rocm_aiter_biased_grouped_topk_fake( pass +def rocm_aiter_grouped_topk_impl( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + + from aiter import grouped_topk + + grouped_topk(gating_output, topk_weights, topk_ids, num_expert_group, + topk_group, need_renorm, scoring_func, routed_scaling_factor) + + +def rocm_aiter_grouped_topk_fake( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + pass + + def rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -218,36 +248,54 @@ if current_platform.is_rocm(): dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_grouped_topk", + op_func=rocm_aiter_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=rocm_aiter_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) -def rocm_aiter_biased_group_topk( + +def rocm_aiter_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, - scoring_func: str = "sigmoid", + scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: - assert scoring_func == "sigmoid", ( - "rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func.") - assert e_score_correction_bias is not None, ( - "'e_score_correction_bias' must not be None.") token = hidden_states.shape[0] device = hidden_states.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) - torch.ops.vllm.rocm_aiter_biased_grouped_topk( - gating_output, - e_score_correction_bias, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - renormalize, - ) + + if e_score_correction_bias is not None: + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + e_score_correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + ) + else: + assert (scoring_func == "softmax" or scoring_func == "sigmoid") + torch.ops.vllm.rocm_aiter_grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scoring_func, + ) + return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 6abbc90819a82..d2c42191bb3ff 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -6,10 +6,9 @@ from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from transformers import PretrainedConfig from typing_extensions import assert_never -from vllm.config import PoolerConfig +from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput @@ -283,30 +282,37 @@ class Pooler(nn.Module): ) -class CrossEncodingPooler(nn.Module): - """A layer that pools specific information from hidden states. +class ClassifierPooler(nn.Module): + """A pooling layer for classification tasks. This layer does the following: - 1. Extracts specific tokens or aggregates data based on pooling method. - 2. Normalizes output if specified. - 3. Returns structured results as `PoolerOutput`. - - Attributes: - pooling_type: The type of pooling to use. - normalize: Whether to normalize the pooled data. + 1. Applies a classification layer to the hidden states. + 2. Optionally applies a pooler layer. + 3. Applies an activation function to the output. In the case of + classification models it is either sigmoid or softmax. In the + case of scoring models, the same behavior is configuration + dependent, as in the sentence-transformers library. """ def __init__( self, - config: PretrainedConfig, + config: ModelConfig, classifier: nn.Module, pooler: Optional[nn.Module] = None, ): super().__init__() self.classifier = classifier self.pooler = pooler - self.default_activation_function = \ - get_cross_encoder_activation_function(config) + + if config.task == "score": + self.default_activation_function = \ + get_cross_encoder_activation_function(config.hf_config) + elif config.task == "classify": + self.default_activation_function = nn.Sigmoid() \ + if config.hf_config.num_labels == 1 else nn.Softmax() + else: + raise NotImplementedError(f"task={config.task!r} is not supported" + " with the classification pooler") def forward( self, diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index 2d9f5e52bd65a..eb8ffa37882cb 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -116,8 +116,9 @@ class AutoRoundConfig(QuantizationConfig): quantized = True if self.block_name_to_quantize: - quantized = any(name in layer_name - for name in self.block_name_to_quantize) + quantized = any( + layer_name.startswith(name) + for name in self.block_name_to_quantize) elif isinstance(layer, ParallelLMHead): quantized = False diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 4660c28c8de4a..87afdb623d912 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -101,7 +101,13 @@ class AWQLinearMethod(LinearMethodBase): output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - if input_size_per_partition % self.quant_config.group_size != 0: + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + if input_size_per_partition % group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " @@ -127,9 +133,11 @@ class AWQLinearMethod(LinearMethodBase): packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader) + num_groups = input_size_per_partition // group_size + qzeros = PackedvLLMParameter( data=torch.empty( - input_size_per_partition // self.quant_config.group_size, + num_groups, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), @@ -140,7 +148,7 @@ class AWQLinearMethod(LinearMethodBase): weight_loader=weight_loader) scales = GroupQuantScaleParameter(data=torch.empty( - input_size_per_partition // self.quant_config.group_size, + num_groups, output_size_per_partition, dtype=params_dtype, ), diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1c5680f952ab5..2abe16a08a265 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -585,9 +585,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 - assert torch.allclose( - layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), ( - "w1_weight_scale_2 must match w3_weight_scale_2") + if not torch.allclose(layer.w13_weight_scale_2[:, 0], + layer.w13_weight_scale_2[:, 1]): + logger.warning_once( + "w1_weight_scale_2 must match w3_weight_scale_2. " + "Accuracy may be affected.") w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 15177af58ae6e..13dcdc00a2156 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -22,7 +22,12 @@ def is_fp4_marlin_supported(): def fp4_marlin_process_scales(marlin_scales): - assert (marlin_scales >= 0).all() + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") # convert to half first, we would convert to fp8 later marlin_scales = marlin_scales.to(torch.half) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4b041cff2eccb..eed8998fe3da5 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -155,8 +155,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: - from vllm.platforms.rocm import on_mi250_mi300 - if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300( + from vllm.platforms.rocm import on_mi3xx + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count()) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 70463ecd90ae7..afc0597197962 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -96,7 +96,7 @@ class RotaryEmbedding(CustomOp): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, dtype: torch.dtype, ) -> None: @@ -113,7 +113,7 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we @@ -404,7 +404,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, scaling_factors: Union[list[float], float], dtype: torch.dtype, @@ -464,7 +464,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, @@ -474,7 +474,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding): super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + def _compute_inv_freq(self, base: float) -> torch.Tensor: base = self.base * (self.scaling_factor if self.mixed_b is None else 1) inv_freq = super()._compute_inv_freq(base) @@ -501,7 +501,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, @@ -582,7 +582,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, @@ -644,7 +644,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): rotary_dim: int, max_position_embeddings: int, original_max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, dtype: torch.dtype, short_factor: list[float], @@ -769,7 +769,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, @@ -877,7 +877,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, dtype: torch.dtype, scaling_factor: float, @@ -892,7 +892,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + def _compute_inv_freq(self, base: float) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) low_freq_wavelen = self.orig_max_position / self.low_freq_factor high_freq_wavelen = self.orig_max_position / self.high_freq_factor @@ -923,14 +923,14 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, dtype: torch.dtype, ): super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + def _compute_inv_freq(self, base: float) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) inv_freqs = inv_freqs[:(self.rotary_dim // 2)] return inv_freqs @@ -989,7 +989,7 @@ class MRotaryEmbedding(RotaryEmbedding): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[list[int]] = None, @@ -1529,7 +1529,7 @@ class DualChunkRotaryEmbedding(CustomOp): head_size: int, rotary_dim: int, max_position_embeddings: int, - base: int, + base: float, is_neox_style: bool, dtype: torch.dtype, chunk_size: int, @@ -1558,7 +1558,7 @@ class DualChunkRotaryEmbedding(CustomOp): q_inter_cache, persistent=False) - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. # However, we use `torch.arange(..., dtype=torch.float)` instead to @@ -1705,7 +1705,7 @@ def get_rope( head_size: int, rotary_dim: int, max_position: int, - base: int, + base: float, is_neox_style: bool = True, rope_scaling: Optional[dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 18783d0d77856..001e6aaf0cc7f 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -70,9 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, def rocm_unquantized_gemm(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - from vllm.platforms.rocm import on_mi250_mi300 + from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] - use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \ + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 010dd515784af..d619d9f25e087 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +import torch import torch.nn as nn from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) class BaseModelLoader(ABC): @@ -18,7 +21,22 @@ class BaseModelLoader(ABC): raise NotImplementedError @abstractmethod - def load_model(self, *, vllm_config: VllmConfig, + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + """Load weights into a model. This standalone API allows + inplace weights loading for an already-initialized model""" + raise NotImplementedError + + def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module: """Load a model with the given configurations.""" - raise NotImplementedError + device_config = vllm_config.device_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config, + model_config=model_config) + # Quantization does not happen in `load_weights` but after it + self.load_weights(model, model_config) + process_weights_after_loading(model, model_config, target_device) + return model.eval() diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 0d83c8d534199..3df835a938968 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # ruff: noqa: SIM117 -import copy import fnmatch import glob import itertools @@ -15,7 +14,7 @@ from huggingface_hub import HfApi from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import LoadConfig, ModelConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) # yapf: enable @@ -29,14 +28,14 @@ from vllm.model_executor.layers.linear import (LinearBase, RowParallelLinear) from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import (ParamMapping, - initialize_model, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models import is_pooling_model -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import (get_packed_modules_mapping, + set_weight_attrs) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -408,8 +407,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ), "vllm currently does not support BNB quantization for" f" {type(model).__name__}" - def _load_weights(self, model_config: ModelConfig, - model: nn.Module) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: if not hasattr(model, "load_weights"): raise AttributeError( "The required method 'load_weights' is not defined in class" @@ -420,8 +418,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet. No 'packed_modules_mapping' found.") self.is_pool_model=is_pooling_model(model) - self.modules_mapping = ParamMapping( - copy.deepcopy(model.packed_modules_mapping)) + + self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) # For some models like Molmo, we need to use hf_to_vllm_mapper # to ensure correct loading of weights. @@ -568,15 +566,3 @@ class BitsAndBytesModelLoader(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: - device_config = vllm_config.device_config - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - - model = initialize_model(vllm_config=vllm_config) - - self._load_weights(model_config, model) - - return model.eval() diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 29a6e0af4bc67..6946627a54d24 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -12,11 +12,9 @@ from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm import envs -from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig +from vllm.config import LoadConfig, LoadFormat, ModelConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, @@ -264,32 +262,20 @@ class DefaultModelLoader(BaseModelLoader): fall_back_to_pt=True, allow_patterns_overrides=None) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: - device_config = vllm_config.device_config - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model(vllm_config=vllm_config, - model_config=model_config) - - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model)) - self.counter_after_loading_weights = time.perf_counter() - logger.info( - "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights) - # We only enable strict check for non-quantized models - # that have loaded weights tracking currently. - if model_config.quantization is None and loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") - - process_weights_after_loading(model, model_config, target_device) - - return model.eval() + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model)) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 0e2f0be1ec26c..64fa2be76d08b 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -1,11 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import torch import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import LoadConfig, ModelConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( initialize_dummy_weights) @@ -22,16 +19,8 @@ class DummyModelLoader(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: - device_config = vllm_config.device_config - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model(vllm_config=vllm_config) - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - - process_weights_after_loading(model, model_config, target_device) - return model.eval() + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 806004bf9604f..1eac504227e25 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -92,6 +92,13 @@ class GGUFModelLoader(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map)) + def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config @@ -106,8 +113,7 @@ class GGUFModelLoader(BaseModelLoader): with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config) - model.load_weights( - self._get_weights_iterator(local_model_path, gguf_weights_map)) + self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) return model diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 557feea46a907..72ad4da296ac6 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -17,6 +17,8 @@ from neuronx_distributed_inference.models.config import ( FusedSpecNeuronConfig, OnDeviceSamplingConfig) from neuronx_distributed_inference.models.mllama.utils import ( create_vision_mask) +from neuronx_distributed_inference.modules.lora_serving import ( + LoraServingConfig) from neuronx_distributed_inference.utils.hf_adapter import ( load_pretrained_config) from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig @@ -80,25 +82,26 @@ class NeuronCausalLM(nn.Module): # Lazy initialized self.model: nn.Module - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, - ) -> torch.Tensor: + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + prev_hidden: Optional[torch.Tensor] = None, + adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor: # sort block ids sequentially for perf/neuron support reasons sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) input_ids = torch.index_select(input_ids, 0, sorted_indices) positions = torch.index_select(positions, 0, sorted_indices) sampling_params = torch.index_select(sampling_params, 0, sorted_indices) - output = self.model(input_ids, attention_mask=None, position_ids=positions, seq_ids=sorted_input_block_ids, - sampling_params=sampling_params) + sampling_params=sampling_params, + prev_hidden=prev_hidden, + adapter_ids=adapter_ids) # on-device sampling if self.config.neuron_config.on_device_sampling_config: output = output.hidden_states @@ -201,6 +204,11 @@ class NeuronMllamaForCausalLM(nn.Module): config: PretrainedConfig, on_device_sampling_disabled: bool = False) -> None: super().__init__() + # has_image is the only multimodal input that is used in + # token-generation + # This is a cache (on CPU) that saves has_image data per sequence id + # The number of entries in this cache is <= Batch-Size + self.has_image_cache: dict[int, torch.Tensor] = {} self.config = config self.logits_processor = LogitsProcessor( config.get_text_config().vocab_size, logits_as_input=True) @@ -212,11 +220,57 @@ class NeuronMllamaForCausalLM(nn.Module): # Lazy initialized self.model: nn.Module + self.is_reorder_needed: bool = True + + def read_from_has_image_cache(self, seq_ids: torch.Tensor): + has_image_list = [] + for index in range(len(seq_ids)): + seq_id = seq_ids[index].item() + if seq_id in self.has_image_cache: + has_image_list.append(self.has_image_cache[seq_id]) + else: + has_image_list.append(torch.tensor([0])) + return torch.tensor(has_image_list) + + def write_to_has_image_cache(self, seq_ids: torch.Tensor, + has_image: torch.Tensor): + for index in range(len(seq_ids)): + seq_id = seq_ids[index].item() + if index < len(has_image): + self.has_image_cache[seq_id] = has_image[index] + else: + self.has_image_cache[seq_id] = torch.zeros(1) def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, seq_ids: torch.Tensor, pixel_values: torch.Tensor, aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, has_image: torch.Tensor, sampling_params) -> torch.Tensor: + + # We update the has_image cache during prefill + # and read the has_image cache during decode + if input_ids.shape[-1] > 1: # prefill + self.write_to_has_image_cache(seq_ids, has_image) + else: + has_image = self.read_from_has_image_cache(seq_ids) + bs = input_ids.shape[0] + num_chunks = torch.zeros((bs, 1)) + aspect_ratios = torch.zeros((bs, 1, 2)) + + input_block_ids = seq_ids + origin_input_block_ids = seq_ids + if self.is_reorder_needed: + # sort block ids sequentially for perf/neuron support reasons + input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + pixel_values = torch.index_select(pixel_values, 0, sorted_indices) + aspect_ratios = torch.index_select(aspect_ratios, 0, + sorted_indices) + num_chunks = torch.index_select(num_chunks, 0, sorted_indices) + has_image = torch.index_select(has_image, 0, sorted_indices) + self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) output = self.model( input_ids.to(torch.int32), @@ -232,8 +286,14 @@ class NeuronMllamaForCausalLM(nn.Module): has_image=has_image.to(torch.int32), ) if self.config.neuron_config.on_device_sampling_config: - return output.hidden_states - return output.logits[:, -1, :] + output = output.hidden_states + else: + output = output.logits[:, -1, :] + + if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1: + restored_indices = torch.argsort(sorted_indices) + output = torch.index_select(output, 0, restored_indices) + return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -296,7 +356,7 @@ class NeuronMllamaForCausalLM(nn.Module): self.model = neuronx_model_cls(compiled_model_path) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.vision_token_id = tokenizer( - "<|image|>", add_special_tokens=False).input_ids + "<|image|>", add_special_tokens=False).input_ids[0] self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError): @@ -323,7 +383,7 @@ class NeuronMllamaForCausalLM(nn.Module): # Read "<|image|>" token_id from the tokenizer self.vision_token_id = tokenizer("<|image|>", - add_special_tokens=False).input_ids + add_special_tokens=False).input_ids[0] logger.info("\nLoading model from compiled checkpoint...") self.model.load(compiled_model_path) @@ -522,7 +582,8 @@ def _get_model_architecture(config: PretrainedConfig) -> str: def _get_default_neuron_config(model_config: ModelConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig): + scheduler_config: SchedulerConfig, + lora_serving_config: LoraServingConfig): """Generate a neuron config based on vllm config args.""" on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, deterministic=False) @@ -541,7 +602,7 @@ def _get_default_neuron_config(model_config: ModelConfig, padding_side="right", on_device_sampling_config=on_device_sampling_config, sequence_parallel_enabled=True, - ) + lora_serving_config=lora_serving_config) return neuron_config @@ -581,7 +642,8 @@ def _get_neuron_config_after_override(default_neuron_config, def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + lora_serving_config: LoraServingConfig) -> nn.Module: """Initializes a neuron-optimized model for inference.""" model_arch = _get_model_architecture(model_config.hf_config) if model_arch == "MllamaForConditionalGeneration": @@ -589,7 +651,7 @@ def get_neuron_model(model_config: ModelConfig, else: model = NeuronCausalLM(model_config.hf_config) default_neuron_config_args = _get_default_neuron_config( - model_config, parallel_config, scheduler_config) + model_config, parallel_config, scheduler_config, lora_serving_config) neuron_config = _get_neuron_config_after_override( default_neuron_config_args, model_config.override_neuron_config) diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 9f1022c259251..a39e26c6da50d 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -9,10 +9,8 @@ import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import LoadConfig, ModelConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, runai_safetensors_weights_iterator) @@ -100,21 +98,11 @@ class RunaiModelStreamerLoader(BaseModelLoader): """Download model if necessary""" self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: - """Perform streaming of the model to destination""" - device_config = vllm_config.device_config - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model(vllm_config=vllm_config) - - model_weights = model_config.model - if hasattr(model_config, "model_weights"): - model_weights = model_config.model_weights - model.load_weights( - self._get_weights_iterator(model_weights, - model_config.revision)) - - process_weights_after_loading(model, model_config, target_device) - return model.eval() + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + """Load weights into a model.""" + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + model.load_weights( + self._get_weights_iterator(model_weights, model_config.revision)) diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 78bca89f0015e..b5a5031bb6f91 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -9,11 +9,9 @@ from typing import Any, Optional import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf, runai_safetensors_weights_iterator) from vllm.transformers_utils.s3_utils import glob as s3_glob @@ -100,11 +98,8 @@ class ShardedStateLoader(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: - device_config = vllm_config.device_config - target_device = torch.device(device_config.device) - + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: from vllm.distributed import get_tensor_model_parallel_rank model_weights = model_config.model @@ -112,53 +107,47 @@ class ShardedStateLoader(BaseModelLoader): model_weights = model_config.model_weights local_model_path = model_weights - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model(vllm_config=vllm_config) - process_weights_after_loading(model, model_config, - target_device) - rank = get_tensor_model_parallel_rank() - pattern = os.path.join( - local_model_path, - self.pattern.format(rank=rank, part="*"), - ) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) - filepaths = [] - if is_s3(local_model_path): - file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" - filepaths = s3_glob(path=local_model_path, - allow_pattern=[file_pattern]) - else: - filepaths = glob.glob(pattern) - if not filepaths: - # TODO: support un-sharded checkpoints too - raise ValueError( - f"Could not find checkpoint files '{pattern}', only " - f"pre-sharded checkpoints are currently supported!") - state_dict = self._filter_subtensors(model.state_dict()) - for key, tensor in self.iterate_over_files(filepaths): - # If loading with LoRA enabled, additional padding may - # be added to certain parameters. We only load into a - # narrowed view of the parameter data. - param_data = state_dict[key].data - param_shape = state_dict[key].shape - for dim, size in enumerate(tensor.shape): - if size < param_shape[dim]: - param_data = param_data.narrow(dim, 0, size) - if tensor.shape != param_shape: - logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", - tensor.shape, - key, - param_shape, - ) - param_data.copy_(tensor) - state_dict.pop(key) - if state_dict: - raise ValueError( - f"Missing keys {tuple(state_dict)} in loaded state!") - return model.eval() + filepaths = [] + if is_s3(local_model_path): + file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" + filepaths = s3_glob(path=local_model_path, + allow_pattern=[file_pattern]) + else: + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for key, tensor in self.iterate_over_files(filepaths): + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") def iterate_over_files( self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 4c4502284a6af..90c0bdf08ef88 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -21,7 +21,8 @@ from torch.utils._python_dispatch import TorchDispatchMode from transformers import PretrainedConfig import vllm.envs as envs -from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config +from vllm.config import (ModelConfig, ParallelConfig, VllmConfig, + set_current_vllm_config) from vllm.engine.arg_utils import EngineArgs from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -208,12 +209,6 @@ class TensorizerConfig: **tensorizer_args.stream_params) -def load_with_tensorizer(tensorizer_config: TensorizerConfig, - **extra_kwargs) -> nn.Module: - tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) - return tensorizer.deserialize() - - @dataclass class TensorizerArgs: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, @@ -366,100 +361,72 @@ class TensorizerArgs: return tensorizer_args -class TensorizerAgent: - """ - A class for performing tensorizer deserializations specifically for - vLLM models using plaid_mode. Uses TensorizerArgs to configure the - behavior of the TensorDeserializer when loading tensors from a serialized - model. For deserializations of HuggingFace models, TensorDeserializer is - instead used as an iterator directly in the func hf_model_weights_iterator - in vllm/model_executor/model_loader/weight_utils.py - """ +def _check_tensors_on_meta_device(model: nn.Module) -> None: + for tensor in model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") - def __init__(self, tensorizer_config: TensorizerConfig, vllm_config): - self.tensorizer_config = tensorizer_config - self.tensorizer_args = ( - self.tensorizer_config._construct_tensorizer_args()) - self.vllm_config = vllm_config - self.model = self._init_model() - def _init_model(self): - assert self.tensorizer_config.hf_config is not None - model_args = self.tensorizer_config.hf_config - model_args.torch_dtype = self.tensorizer_config.dtype - assert self.tensorizer_config.model_class is not None - # TODO: Do we need to consider old-style model class? - with meta_tensor_mode(), set_current_vllm_config(self.vllm_config, - check_compile=True): - return self.tensorizer_config.model_class( - vllm_config=self.vllm_config) +def _resize_lora_embeddings(model: nn.Module): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in model.modules(): + if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0] + < child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight - def _resize_lora_embeddings(self): - """Modify LoRA embedding layers to use bigger tensors - to allow for adapter added tokens.""" - for child in self.model.modules(): - if (isinstance(child, VocabParallelEmbedding) - and child.weight.shape[0] - < child.num_embeddings_per_partition): - new_weight = torch.empty(child.num_embeddings_per_partition, - child.embedding_dim, - dtype=child.weight.dtype, - device=child.weight.device) - new_weight[:child.weight.shape[0]].copy_(child.weight.data) - new_weight[child.weight.shape[0]:].fill_(0) - child.weight.data = new_weight - def _check_tensors_on_meta_device(self): - for tensor in self.model.state_dict().values(): - if tensor.device.type == 'meta': - raise ValueError( - "The serialized model contains tensors on the meta device," - " indicating that some tensors were not loaded properly." - " Please check that the parameters of the model being" - " specified match that of the serialized model, such as" - " its quantization.") +def init_tensorizer_model(tensorizer_config: TensorizerConfig, + vllm_config: VllmConfig) -> nn.Module: + assert tensorizer_config.hf_config is not None + model_args = tensorizer_config.hf_config + model_args.torch_dtype = tensorizer_config.dtype + assert tensorizer_config.model_class is not None + # TODO: Do we need to consider old-style model class? + with meta_tensor_mode(), set_current_vllm_config(vllm_config, + check_compile=True): + return tensorizer_config.model_class(vllm_config=vllm_config) - def deserialize(self): - """ - Deserialize the model using the TensorDeserializer. This method is - specifically for vLLM models using tensorizer's plaid_mode. - The deserializer makes use of tensorizer_args.stream_params - to configure the behavior of the stream when loading tensors from a - serialized model. The deserializer_params are used to configure the - behavior of the TensorDeserializer when loading tensors themselves. - Documentation on these params can be found in TensorizerArgs - - Returns: - nn.Module: The deserialized model. - """ - before_mem = get_mem_usage() - start = time.perf_counter() - with _read_stream( - self.tensorizer_config.tensorizer_uri, - **self.tensorizer_args.stream_params - ) as stream, TensorDeserializer( +def deserialize_tensorizer_model(model: nn.Module, + tensorizer_config: TensorizerConfig) -> None: + tensorizer_args = tensorizer_config._construct_tensorizer_args() + before_mem = get_mem_usage() + start = time.perf_counter() + with _read_stream( + tensorizer_config.tensorizer_uri, + **tensorizer_args.stream_params) as stream, TensorDeserializer( stream, - dtype=self.tensorizer_config.dtype, + dtype=tensorizer_config.dtype, device=f'cuda:{torch.cuda.current_device()}', - **self.tensorizer_args.deserializer_params) as deserializer: - deserializer.load_into_module(self.model) - end = time.perf_counter() + **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.perf_counter() - total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) - duration = end - start - per_second = convert_bytes(deserializer.total_tensor_bytes / duration) - after_mem = get_mem_usage() - deserializer.close() - logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, - end - start, per_second) - logger.info("Memory usage before: %s", before_mem) - logger.info("Memory usage after: %s", after_mem) + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, + end - start, per_second) + logger.info("Memory usage before: %s", before_mem) + logger.info("Memory usage after: %s", after_mem) - self._check_tensors_on_meta_device() - self._resize_lora_embeddings() - del self.model.vllm_tensorized_marker - return self.model.eval() + _check_tensors_on_meta_device(model) + _resize_lora_embeddings(model) + del model.vllm_tensorized_marker def tensorizer_weights_iterator( diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 2afe2b59e2f9a..1923e040af381 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -11,8 +11,8 @@ from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, - serialize_vllm_model, tensorizer_weights_iterator) + TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model, + is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (get_model_architecture, initialize_model, set_default_torch_dtype) @@ -61,38 +61,34 @@ class TensorizerLoader(BaseModelLoader): model.load_weights(self._get_weights_iterator()) return model.eval() - def _load_model_serialized( - self, - vllm_config: VllmConfig, - ) -> nn.Module: - """Load a serialized model with tensorizer. - - Expects a vLLM-tensorized model. See the - examples/others/tensorize_vllm_model.py example script - for serializing vLLM models.""" - - device_config = vllm_config.device_config - model_config = vllm_config.model_config - - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model_class = get_model_architecture(model_config)[0] - - tensorizer_config = copy.copy(self.tensorizer_config) - tensorizer_config.model_class = model_class - tensorizer_config.hf_config = model_config.hf_config - tensorizer_config.dtype = model_config.dtype - - model = load_with_tensorizer(tensorizer_config, - vllm_config=vllm_config) - return model.eval() - def download_model(self, model_config: ModelConfig) -> None: self.tensorizer_config.verify_with_model_config(model_config) with self.tensorizer_config.open_stream(): pass + def _patch_tensorizer_config( + self, model_config: ModelConfig) -> TensorizerConfig: + model_class = get_model_architecture(model_config)[0] + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + return tensorizer_config + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + """Load serialized model weights with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/others/tensorize_vllm_model.py example script + for serializing vLLM models.""" + if is_vllm_tensorized(self.tensorizer_config): + tensorizer_config = self._patch_tensorizer_config(model_config) + deserialize_tensorizer_model(model, tensorizer_config) + else: + model.load_weights(self._get_weights_iterator()) + def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module: parallel_config = vllm_config.parallel_config @@ -106,7 +102,11 @@ class TensorizerLoader(BaseModelLoader): get_tensor_model_parallel_rank()) if is_vllm_tensorized(self.tensorizer_config): - return self._load_model_serialized(vllm_config=vllm_config) + tensorizer_config = self._patch_tensorizer_config(model_config) + model = init_tensorizer_model(tensorizer_config=tensorizer_config, + vllm_config=vllm_config) + self.load_weights(model, model_config) + return model return self._load_model_serialized_cpu(vllm_config=vllm_config) @staticmethod diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f61956f4e8e01..7a9a68be8805e 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -696,7 +696,7 @@ def initialize_dummy_weights( # Note: We avoid using torch.rank_like as it doesn't currently # support the generator argument. param.copy_((high - low) * - torch.rand(*param.shape, + torch.rand(param.shape, generator=generator, dtype=param.dtype, layout=param.layout, diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index aefd6c9737552..2e2a18abd03dd 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -2,16 +2,23 @@ # A modified implementation of the AIMv2 Transformer # inserted here also the image tokenizer used by Ovis2 +from collections.abc import Iterable from typing import Optional import torch import torch.nn as nn -from torch.nn import functional as F +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.utils import divide +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.ovis import AIMv2Config @@ -24,29 +31,27 @@ class AIMv2SwiGLUFFN(nn.Module): in_features = config.hidden_size bias = config.use_bias - # TODO(Isotr0py): investigate if we can add TP to visual tokenizer - self.fc1 = ReplicatedLinear(in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = ReplicatedLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc2") - self.fc3 = ReplicatedLinear(in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc3") + self.fc13 = MergedColumnParallelLinear( + in_features, + [hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc13", + ) + self.fc2 = RowParallelLinear( + input_size=hidden_features, + output_size=in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: - x_parallel, _ = self.fc1(x) - gate, _ = self.fc3(x) - x_parallel = F.silu(x_parallel) * gate - out, _ = self.fc2(x_parallel) - return out + x, _ = self.fc13(x) + x = self.act_fn(x) + x, _ = self.fc2(x) + return x class AIMv2PatchEmbed(nn.Module): @@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module): def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str): super().__init__() - dim = config.hidden_size - - # TODO(Isotr0py): investigate if we can add TP to visual tokenizer + self.config = config + self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads - self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias) - # self.qkv = QKVParallelLinear( - # hidden_size=dim, - # head_size=dim // config.num_attention_heads, - # total_num_heads=config.num_attention_heads, - # bias=config.qkv_bias, - # quant_config=quant_config, - # prefix=f"{prefix}.qkv") - self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias) - # self.proj = RowParallelLinear(input_size=dim, - # output_size=dim, - # bias = config.use_bias, - # quant_config=quant_config, - # prefix=f"{prefix}.proj") + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 - def forward( # todo might implement multiple attn implementations - self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - B, N, C = x.shape + self.qkv = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + self.proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def forward(self, x: torch.Tensor) -> torch.Tensor: qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) - qkv = qkv.reshape(B, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - - q, k, v = qkv.unbind(0) - - x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) - x = x.transpose(1, 2).contiguous().reshape(B, N, C) + x = self.attn(q, k, v) x, _ = self.proj(x) return x @@ -141,37 +152,40 @@ class AIMv2Block(nn.Module): prefix=f"{prefix}.mlp") self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward(self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = x + self.attn(self.norm_1.forward_native(x), mask) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm_1.forward_native(x)) x = x + self.mlp(self.norm_2.forward_native(x)) return x class AIMv2Transformer(nn.Module): - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, + config: AIMv2Config, + quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ): super().__init__() self.blocks = nn.ModuleList([ AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") for i in range(config.num_hidden_layers) ]) - self.post_trunk_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + if require_post_norm: + self.post_trunk_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.post_trunk_norm = None - def forward( - self, - tokens: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def forward(self, tokens: torch.Tensor) -> torch.Tensor: # they take the -1 as the ref embeddings, like a clip skip for block in self.blocks: - tokens = block(tokens, mask) - # NO NORM IN THE OG IMPLEMENTATION - # tokens = self.post_trunk_norm(tokens) + tokens = block(tokens) + if self.post_trunk_norm is not None: + tokens = self.post_trunk_norm(tokens) return tokens @@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module): def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, prefix: str = ""): super().__init__() self.preprocessor = AIMv2ViTPreprocessor(config) self.trunk = AIMv2Transformer(config, quant_config=quant_config, + require_post_norm=require_post_norm, prefix=f"{prefix}.trunk") - def forward( - self, - pixel_values: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: x = self.preprocessor(pixel_values) - x = self.trunk(x, mask) + x = self.trunk(x) return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".fc13", ".fc1", 0), + (".fc13", ".fc3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # post_layernorm is optional in SiglipVisionModel + if (name.startswith("trunk.post_trunk_norm") + and self.trunk.post_trunk_norm is None): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 0c6593bbe3a10..0b1d0f1034083 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, +from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -470,8 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, embedding_class=BertEmbedding, add_pooling_layer=True) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = CrossEncodingPooler(config, self.classifier, - self.bert.pooler) + self._pooler = ClassifierPooler(vllm_config.model_config, + self.classifier, self.bert.pooler) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index e8f3ae2156e02..9fd528fd79779 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -106,7 +106,6 @@ class CLIPAttention(nn.Module): f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, @@ -129,10 +128,6 @@ class CLIPAttention(nn.Module): self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index d9d9002bd5baa..538e9de4f78fc 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -415,6 +415,10 @@ class InternVisionEncoder(nn.Module): class InternVisionModel(nn.Module): + packed_modules_mapping = { + "qkv": ["qkv"], + } + def __init__( self, config: PretrainedConfig, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 4612fc4387412..c37d3afb4e440 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode @@ -36,7 +37,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -1014,7 +1016,8 @@ class InternVLMultiModalProcessor( InternVLMultiModalProcessor, info=InternVLProcessingInfo, dummy_inputs=InternVLDummyInputsBuilder) -class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -1403,3 +1406,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="mlp1", + tower_model="vision_model") diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index f211bfe54a7d7..1e40017fc792a 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -215,6 +215,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) if self.draft_id_to_target_id is None: + assert logits.shape[1] == self.config.vocab_size, \ + "Expected logits to have shape " \ + f"(*, {self.config.vocab_size}), but got {logits.shape}" return logits base = torch.arange(self.config.draft_vocab_size, device=logits.device) @@ -234,24 +237,22 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): return self.model.fc(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader( - self, - skip_prefixes=None, - ) - model_weights = {} + includes_draft_id_mapping = False for name, loaded_weight in weights: if "t2d" in name: continue if "d2t" in name: name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True elif "lm_head" not in name: name = "model." + name model_weights[name] = loaded_weight - loaded_weights = loader.load_weights(model_weights.items()) - - if 'd2t' not in loaded_weights: - self.draft_id_to_target_id = None - - return loaded_weights + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=["draft_id_to_target_id"] \ + if not includes_draft_id_mapping else None, + ) + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 858a1633befa0..65c6467bcf5fb 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -32,7 +32,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -167,6 +167,27 @@ class Mamba2Model(nn.Module): return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsV0Only): @@ -282,21 +303,5 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 0397b552ce9f9..f471a86ffba34 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -242,6 +242,7 @@ class MiniCPMAttention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) + self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -444,6 +445,7 @@ class MiniCPMModel(nn.Module): for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -567,7 +569,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + inputs_embeds) / self.scale_width return hidden_states def compute_logits( @@ -575,7 +577,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - hidden_states = hidden_states / self.scale_width logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py new file mode 100644 index 0000000000000..039c3d22d1604 --- /dev/null +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only EagleMiniCPM model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .minicpm import MiniCPMAttention as EagleMiniCPMAttention +from .minicpm import MiniCPMMLP as EagleMiniCPMMLP +from .minicpm import MiniCPMMoE as EagleMiniCPMMoE +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, maybe_prefix) + + +class EagleMiniCPMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + self.hidden_size = config.hidden_size + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.prefix = prefix + self._init_attn_block() + self._init_ffn_block() + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = EagleMiniCPMAttention( + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=self.config.num_key_value_heads, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", + ) + + def _init_ffn_block(self): + self.post_attention_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.num_experts = getattr(self.config, "num_experts", 0) + if self.num_experts == 0: + self.mlp = EagleMiniCPMMLP( + hidden_size=self.hidden_size, + intermediate_size=self.config.intermediate_size, + hidden_act=self.config.hidden_act, + hidden_act_param=getattr(self.config, "hidden_act_param", 0.), + quant_config=self.quant_config, + ) + else: + self.mlp = EagleMiniCPMMoE( + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_tok, + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.mup_denominator)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.mup_denominator)) + + return hidden_states, None + + +@support_torch_compile +class EagleMiniCPMModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer: int = 0): + super().__init__() + + config = vllm_config.speculative_config.draft_model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + self.input_norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.num_experts = getattr(self.config, "num_experts", 0) + self._init_layers(prefix, config, cache_config, quant_config, + start_layer) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) + + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + start_layer: int, + ): + self.eagle_layers = nn.ModuleList([ + EagleMiniCPMDecoderLayer( + config, + cache_config, + quant_config, + f"{prefix}.eagle_layers.{i + start_layer}", + ) for i in range(self.config.num_hidden_layers) + ]) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + return embedding * self.config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> Union[torch.Tensor, IntermediateTensors]: + input_embeds = self.get_input_embeddings(input_ids) + input_embeds = self.input_norm1(input_embeds) + hidden_states = self.input_norm2(hidden_states) + + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for layer in self.eagle_layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.speculative_config.draft_model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.prefix = prefix + self.vllm_config = vllm_config + self.config = config + self.lora_config = lora_config + self.cache_config = cache_config + self.quant_config = quant_config + + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + start_layer=target_layer_num) + + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _init_model(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer: int = 0): + return EagleMiniCPMModel(vllm_config=vllm_config, + prefix=prefix, + start_layer=start_layer) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states, hidden_states2 = self.model(input_ids, positions, + hidden_states) + hidden_states = hidden_states / self.scale_width + hidden_states2 = hidden_states2 / self.scale_width + return hidden_states, hidden_states2 + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 36bab9ee13b17..ac0fe7b10c836 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -141,7 +141,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp): head_size: int, rotary_dim: int, max_position: int, - base: int, + base: float, is_neox_style: bool, cache_dtype: torch.dtype, ) -> None: @@ -155,10 +155,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp): cache = self._compute_cos_sin_cache().to(cache_dtype) self.register_buffer("cos_sin_cache", cache, persistent=False) - def _compute_inv_freq( - self, - base: Union[int, float], - ) -> torch.Tensor: + def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" inv_freq = 1.0 / (base**(torch.arange( 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 8c98492c0bedd..58549b10e9666 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -34,6 +34,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -49,6 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -84,23 +86,29 @@ class Llama4ImagePatchInputs(TypedDict): class Llama4VisionMLP(nn.Module): - def __init__(self, - input_size: int, - intermediate_size: int, - output_size: int, - bias: bool, - output_activation: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + input_size: int, + intermediate_size: int, + output_size: int, + bias: bool, + output_activation: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() - self.fc1 = ColumnParallelLinear( + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1( input_size=input_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.fc2 = RowParallelLinear( + cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear + self.fc2 = cls_fc2( input_size=intermediate_size, output_size=output_size, bias=bias, @@ -155,10 +163,12 @@ def pixel_shuffle(input_tensor, shuffle_ratio): int(channels / shuffle_ratio)) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() - reshaped_tensor = reshaped_tensor.view(batch_size, - int(height * shuffle_ratio), - int(width * shuffle_ratio), - int(channels / (shuffle_ratio**2))) + reshaped_tensor = reshaped_tensor.view( + batch_size, + int(height * shuffle_ratio), + int(width * shuffle_ratio), + int(channels / (shuffle_ratio**2)), + ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() output_tensor = reshaped_tensor.view(batch_size, -1, @@ -173,6 +183,7 @@ class Llama4VisionPixelShuffleMLP(nn.Module): config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.pixel_shuffle_ratio = config.pixel_shuffle_ratio @@ -186,7 +197,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module): bias=config.multi_modal_projector_bias, output_activation=True, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: encoded_patches = pixel_shuffle(encoded_patches, @@ -201,10 +214,12 @@ class Llama4VisionAttention(nn.Module): config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.config = config - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -217,22 +232,39 @@ class Llama4VisionAttention(nn.Module): self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, self.scaling) - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=True, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) + + if use_data_parallel: + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + self.q_size + 2 * self.kv_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = ReplicatedLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=True, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( head_size=self.head_dim, @@ -275,22 +307,29 @@ class Llama4VisionEncoderLayer(nn.Module): config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.intermediate_size = config.intermediate_size - self.self_attn = Llama4VisionAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") - self.mlp = Llama4VisionMLP(input_size=config.hidden_size, - intermediate_size=config.intermediate_size, - output_size=config.hidden_size, - bias=True, - output_activation=False, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.self_attn = Llama4VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel, + ) + self.mlp = Llama4VisionMLP( + input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + output_activation=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) self.input_layernorm = nn.LayerNorm(config.hidden_size) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) @@ -322,6 +361,7 @@ class Llama4VisionEncoder(nn.Module): config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.config = config @@ -330,6 +370,7 @@ class Llama4VisionEncoder(nn.Module): config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, ) for layer_idx in range(config.num_hidden_layers) ]) @@ -357,23 +398,33 @@ class Llama4VisionEncoder(nn.Module): class Llama4UnfoldConvolution(nn.Module): - def __init__(self, - config: Llama4VisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() kernel_size = config.patch_size if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) - self.linear = ColumnParallelLinear(config.num_channels * - kernel_size[0] * kernel_size[1], - config.hidden_size, - bias=False, - quant_config=quant_config, - gather_output=True, - prefix=f"{prefix}.linear") + params = { + "input_size": + config.num_channels * kernel_size[0] * kernel_size[1], + "output_size": config.hidden_size, + "bias": False, + "quant_config": quant_config, + "prefix": f"{prefix}.linear", + } + if use_data_parallel: + cls = ReplicatedLinear + else: + cls = ColumnParallelLinear + params["gather_output"] = True + self.linear = cls(**params) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) @@ -389,6 +440,7 @@ class Llama4VisionModel(nn.Module): config: Llama4VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.config = config @@ -403,7 +455,9 @@ class Llama4VisionModel(nn.Module): self.patch_embedding = Llama4UnfoldConvolution( config, quant_config=quant_config, - prefix=f"{prefix}.patch_embedding") + prefix=f"{prefix}.patch_embedding", + use_data_parallel=use_data_parallel, + ) self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) @@ -415,11 +469,18 @@ class Llama4VisionModel(nn.Module): self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5) # encoders - self.model = Llama4VisionEncoder(config, - quant_config=quant_config, - prefix=f"{prefix}.model") + self.model = Llama4VisionEncoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.model", + use_data_parallel=use_data_parallel, + ) self.vision_adapter = Llama4VisionPixelShuffleMLP( - config, quant_config, prefix=f"{prefix}.vision_adapter") + config, + quant_config, + prefix=f"{prefix}.vision_adapter", + use_data_parallel=use_data_parallel, + ) def forward( self, @@ -528,8 +589,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] vision_config = self.info.get_hf_config().vision_config if processed_outputs.get("pixel_values") is not None: - assert "images" in mm_data, \ - "images expected to be in mm_data when pixel_values is present" + assert ( + "images" in mm_data + ), "images expected to be in mm_data when pixel_values is present" images = mm_data["images"] parsed_images = (self._get_data_parser().parse_mm_data({ @@ -546,8 +608,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] get_best_fit( (image.size[1], image.size[0]), torch.tensor(possible_resolutions), - resize_to_max_canvas=image_processor.resize_to_max_canvas) - for image in parsed_images + resize_to_max_canvas=image_processor.resize_to_max_canvas, + ) for image in parsed_images ] # TODO tile height/width do not necessarily need to match aspect_ratios = [(image_size[0] // tile_size, @@ -659,13 +721,17 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = (vllm_config.parallel_config. + enable_multimodal_encoder_data_parallel) self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.vision_model = Llama4VisionModel(config.vision_config, - None, - prefix=maybe_prefix( - prefix, "vision_model")) + self.vision_model = Llama4VisionModel( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel, + ) self.multi_modal_projector = Llama4MultiModalProjector( self.config, None, @@ -709,7 +775,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() - vision_embeddings_flat = self.vision_model(flat_data) + # shard image input + if self.use_data_parallel: + vision_embeddings_flat = run_dp_sharded_vision_model( + flat_data, self.vision_model) + else: + vision_embeddings_flat = self.vision_model(flat_data) + vision_embeddings_flat = self.multi_modal_projector( vision_embeddings_flat) @@ -796,6 +868,30 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return get_prefix_weights(), get_other_weights() + def _consolidate_qkv_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + qkv_idx_mappings = { + ".self_attn.q_proj": 0, + ".self_attn.k_proj": 1, + ".self_attn.v_proj": 2, + } + qkv_weights = {} + for name, loaded_weight in weights: + for weight_name, idx in qkv_idx_mappings.items(): + if weight_name not in name: + continue + new_name = name.replace(weight_name, ".self_attn.qkv_proj") + if new_name not in qkv_weights: + qkv_weights[new_name] = [None] * 3 + qkv_weights[new_name][idx] = loaded_weight + break + else: + yield name, loaded_weight + for key, weight in qkv_weights.items(): + qkv_weight = torch.cat(weight, dim=0) + yield key, qkv_weight + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -818,9 +914,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, assert loaded_language_model_params is not None updated_params.update(loaded_language_model_params) + if self.use_data_parallel: + other_weights = self._consolidate_qkv_weights(other_weights) + for name, loaded_weight in other_weights: for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in name or self.use_data_parallel: continue name = name.replace(weight_name, param_name) param = params_dict[name] diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 86552aa05bf95..18eab6051736f 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -12,7 +12,7 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -278,8 +278,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = CrossEncodingPooler(config, self.classifier, - ModernBertPooler(config)) + self._pooler = ClassifierPooler(vllm_config.model_config, + self.classifier, + ModernBertPooler(config)) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index e03705d48f3e8..232a63c506890 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -30,6 +30,9 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.models.aimv2 import AIMv2Model from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, @@ -48,7 +51,7 @@ from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig, OvisConfig) from vllm.transformers_utils.processors.ovis import OvisProcessor -from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. @@ -106,12 +109,14 @@ class VisualTokenizer(torch.nn.Module): config: BaseVisualTokenizerConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ): + ) -> nn.Module: model_type = config.backbone_config.model_type if model_type == "aimv2": + # No post rms_norm in Ovis2's AIMv2 ViT. return AIMv2Model( config=config.backbone_config, quant_config=quant_config, + require_post_norm=False, prefix=prefix, ) elif model_type == "siglip_vision_model": @@ -124,14 +129,14 @@ class VisualTokenizer(torch.nn.Module): f"Unsupported visual tokenizer model_type: {model_type}") @property - def dtype(self): + def dtype(self) -> torch.dtype: return next(self.head.parameters()).dtype @property - def device(self): + def device(self) -> torch.device: return next(self.head.parameters()).device - def tokenize(self, logits): + def tokenize(self, logits: torch.Tensor) -> torch.Tensor: if self.config.tokenize_function == 'softmax': tokens = softmax(logits, dim=-1) elif self.config.tokenize_function == 'gumbel_argmax': @@ -144,7 +149,7 @@ class VisualTokenizer(torch.nn.Module): f'or st_argmax, but got {self.config.tokenize_function}') return tokens - def encode(self, pixel_values): + def encode(self, pixel_values: torch.Tensor) -> torch.Tensor: features = self.backbone(pixel_values) if self.config.drop_cls_token: features = features[:, 1:, :] @@ -395,7 +400,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): @MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, info=OvisProcessingInfo, dummy_inputs=OvisDummyInputsBuilder) -class Ovis(nn.Module, SupportsMultiModal): +class Ovis(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -410,7 +415,7 @@ class Ovis(nn.Module, SupportsMultiModal): self.visual_tokenizer = VisualTokenizer( config=config.visual_tokenizer_config, - quant_config=quant_config, + quant_config=self._maybe_ignore_quant_config(quant_config), prefix=f"{prefix}.visual_tokenizer", ) @@ -421,9 +426,16 @@ class Ovis(nn.Module, SupportsMultiModal): text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] - # TODO(Isotr0py): PP support - # self.make_empty_intermediate_tensors = ( - # self.language_model.make_empty_intermediate_tensors) + self.make_empty_intermediate_tensors = ( + self.get_language_model().make_empty_intermediate_tensors) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + # See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[OvisImagePatchInputs]: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 0d0d98c59dbc7..a664864ff898f 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -34,32 +34,27 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - extract_layer_index, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -logger = init_logger(__name__) - class Qwen2MLP(nn.Module): @@ -499,69 +494,3 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - # TODO: Replace this model class with as_embedding_model( - # Qwen2ForCausalLM) after changing the default pooling method - if pooler_config.pooling_type is None: - logger.warning( - "This embedding model will default to last-token pooling in " - "an upcoming version. To avoid breaking changes, you should " - "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" - " explicitly.") - - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.MEAN, - normalize=True, - softmax=False) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: - return self.model(input_ids, positions, intermediate_tensors) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 68dd07820189e..e3fa9f67ca078 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -821,17 +821,6 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0ff0836b08975..873baa56faf37 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1069,17 +1069,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] dummy_inputs=Qwen2VLDummyInputsBuilder) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 97ea12de65373..8efd4825beea9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -142,7 +142,7 @@ _EMBEDDING_MODELS = { "ModernBertModel": ("modernbert", "ModernBertModel"), "NomicBertModel": ("bert_with_rope", "NomicBertModel"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), - "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), + "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), @@ -223,6 +223,7 @@ _SPECULATIVE_DECODING_MODELS = { "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 9a4d0ab2dd4d7..76008b72941da 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,7 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -186,7 +186,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, embedding_class=RobertaEmbedding, add_pooling_layer=False) self.classifier = RobertaClassificationHead(config) - self._pooler = CrossEncodingPooler(config, self.classifier) + + self._pooler = ClassifierPooler(vllm_config.model_config, + self.classifier) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): bert_weights, task_weights = roberta_task_weights_filter(weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 3b5334afa7af8..4803da2956ef1 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -130,11 +130,10 @@ class SiglipVisionEmbeddings(nn.Module): embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding( + embeddings += self.interpolate_pos_encoding( embeddings, height, width) else: - embeddings = embeddings + self.position_embedding( - self.position_ids) + embeddings += self.position_embedding(self.position_ids) return embeddings @@ -271,12 +270,12 @@ class SiglipEncoderLayer(nn.Module): hidden_states = self.layer_norm1(hidden_states) hidden_states, _ = self.self_attn(hidden_states=hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual return hidden_states, None @@ -354,7 +353,8 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): residual = hidden_state hidden_state = self.layernorm(hidden_state) - hidden_state = residual + self.mlp(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state += residual return hidden_state[:, 0] diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index f9d89e64bd9db..1b120c3545a56 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Utils for model executor.""" +import copy from typing import Any, Optional import torch @@ -51,3 +52,23 @@ def _make_synced_weight_loader(original_weight_loader): torch._sync(param) return _synced_weight_loader + + +def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: + parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {})) + + # don't infer mapping if the model has defined it explicitly. + if parent_map: + return parent_map + + # We only check main components instead of whole model submodules + for child in model.children(): + child_map = getattr(child, "packed_modules_mapping", {}) + if any((k in parent_map and parent_map[k] != v) + for k, v in child_map.items()): + raise ValueError( + f"Can't update {type(model).__name__}'s packed_modules_mapping " + f"safely because of conflicts from {type(child).__name__}.") + else: + parent_map.update(child_map) + return parent_map \ No newline at end of file diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 9ddba67bff702..1d838f66f1dec 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -12,6 +12,9 @@ from PIL import Image import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from .audio import AudioMediaIO from .base import MediaIO @@ -390,3 +393,35 @@ def group_mm_inputs_by_modality( return [ list(group) for _, group in groupby(mm_inputs, key=modality_group_func) ] + + +def run_dp_sharded_vision_model(image_input: torch.Tensor, + vision_model: torch.nn.Module) -> torch.Tensor: + """Run a vision model with data parallelism (DP) sharding. The function + will shard the input image tensor on the first dimension and run the vision + model + + Args: + image_input (torch.Tensor): Image input tensor. + vision_model (torch.nn.Module): Vision model. + + Returns: + torch.Tensor: Output image embeddings + """ + + num_chunks = image_input.shape[0] + mp_world_size = get_tensor_model_parallel_world_size() + num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size + num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks + pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) + image_input_padded = torch.nn.functional.pad(image_input, pad) + rank = get_tensor_model_parallel_rank() + image_input_per_rank = image_input_padded[rank * + num_chunks_per_rank:(rank + 1) * + num_chunks_per_rank, ...] + + vision_embeddings = vision_model(image_input_per_rank) + vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, + dim=0) + vision_embeddings = vision_embeddings[:num_chunks, ...] + return vision_embeddings diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index c79c603c02ebc..eaffaac78cce9 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -28,7 +28,7 @@ class CpuPlatform(Platform): dispatch_key: str = "CPU" @property - def supported_dtypes(self) -> list: + def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: return [torch.bfloat16, torch.float32] elif sys.platform.startswith( diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 474c70d04140b..56f204e71da17 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -49,9 +49,6 @@ class NeuronPlatform(Platform): if parallel_config.world_size > 1: parallel_config.distributed_executor_backend = "uni" - assert (vllm_config.lora_config - is None), "LoRA is not supported for Neuron backend." - if vllm_config.cache_config and vllm_config.model_config: # neuron needs block_size = max_model_len vllm_config.cache_config.block_size = \ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d544b4ab4b020..ef1c632a53989 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -99,9 +99,21 @@ def with_amdsmi_context(fn): @cache -def on_mi250_mi300() -> bool: +def on_gfx1x() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName - return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) + return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) + + +@cache +def on_mi3xx() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"]) + + +@cache +def on_gfx9() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @cache diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 69e7207cc3500..8774f95a2f60b 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,12 +4,12 @@ import enum import json import os import time -from functools import cache +from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub -from huggingface_hub import hf_hub_download +from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, @@ -93,10 +93,15 @@ class ConfigFormat(str, enum.Enum): MISTRAL = "mistral" -def with_retry(func: Callable[[], Any], - log_msg: str, - max_retries: int = 2, - retry_delay: int = 2): +_R = TypeVar("_R") + + +def with_retry( + func: Callable[[], _R], + log_msg: str, + max_retries: int = 2, + retry_delay: int = 2, +) -> _R: for attempt in range(max_retries): try: return func() @@ -109,6 +114,8 @@ def with_retry(func: Callable[[], Any], time.sleep(retry_delay) retry_delay *= 2 + raise AssertionError("Should not be reached") + # @cache doesn't cache exceptions @cache @@ -823,13 +830,39 @@ def try_get_generation_config( def get_cross_encoder_activation_function(config: PretrainedConfig): - if (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): + function_name: Optional[str] = None + if hasattr(config, "sentence_transformers") and "activation_fn" in \ + config.sentence_transformers: + function_name = config.sentence_transformers["activation_fn"] + + elif (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): function_name = config.sbert_ce_default_activation_function + + if function_name is not None: assert function_name.startswith("torch.nn.modules."), \ "Loading of activation functions is restricted to " \ "torch.nn.modules for security reasons" return resolve_obj_by_qualname(function_name)() else: return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() + + +def try_get_safetensors_metadata( + model: str, + *, + revision: Optional[str] = None, +): + get_safetensors_metadata_partial = partial( + get_safetensors_metadata, + model, + revision=revision, + token=os.getenv('HF_TOKEN', None), + ) + + try: + return with_retry(get_safetensors_metadata_partial, + "Error retrieving safetensors") + except Exception: + return None diff --git a/vllm/utils.py b/vllm/utils.py index c1213d463c212..c879b38d065aa 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -37,8 +37,8 @@ from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, _ArgumentGroup) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, - Iterable, Iterator, KeysView, Mapping) +from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, + Hashable, Iterable, Iterator, KeysView, Mapping) from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps @@ -107,7 +107,7 @@ STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( "currently not supported for encoder/decoder " "models.") -STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently " +STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is not currently " "supported with encoder/decoder " "models.") @@ -979,6 +979,53 @@ def get_dtype_size(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() +# bool = 0, int = 1, float = 2, complex = 3 +def _get_precision_level(dtype: torch.dtype) -> int: + # NOTE: Complex dtypes return `is_floating_point=False` + return ((dtype != torch.bool) + dtype.is_floating_point + + dtype.is_complex * 2) + + +def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): + """ + Test whether it is lossless to cast a tensor from + `src_dtype` to `tgt_dtype`. + """ + if src_dtype == tgt_dtype: + return True + + src_level = _get_precision_level(src_dtype) + tgt_level = _get_precision_level(tgt_dtype) + + if src_level < tgt_level: + return True + if src_level > tgt_level: + return False + + # Compare integral types + if not src_dtype.is_floating_point and not src_dtype.is_complex: + src_info = torch.iinfo(src_dtype) + tgt_info = torch.iinfo(tgt_dtype) + return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + + # Compare floating-point types + src_info = torch.finfo(src_dtype) + tgt_info = torch.finfo(tgt_dtype) + return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution) + + +def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): + """ + Get the common `dtype` where all of the other `dtypes` can be + cast to it without losing any information. + """ + return max( + dtypes, + key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), + ) + + # `collections` helpers def is_list_of( value: object, @@ -2420,6 +2467,7 @@ def make_zmq_socket( socket_type: Any, bind: Optional[bool] = None, identity: Optional[bytes] = None, + linger: Optional[int] = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2439,7 +2487,7 @@ def make_zmq_socket( buf_size = -1 # Use system default buffer size if bind is None: - bind = socket_type != zmq.PUSH + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): socket.setsockopt(zmq.RCVHWM, 0) @@ -2452,6 +2500,9 @@ def make_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + # Determine if the path is a TCP socket with an IPv6 address. # Enable IPv6 on the zmq socket if so. scheme, host, _ = split_zmq_path(path) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 31980e94a0376..d1e823bbe3965 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -66,9 +66,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): super().__init__(runner, kv_cache_spec, block_table) - max_model_len = self.runner.model_config.max_model_len - assert max_model_len == 32768,\ - "AITER MLA requires max_model_len=32768" assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4000f93984d39..a97bb85004f6f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.chunked_prefill_paged_decode import ( @@ -126,6 +127,8 @@ class TritonAttentionImpl(AttentionImpl): "TritonAttentionImpl") self.fp8_dtype = current_platform.fp8_dtype() + self.force_prefill_decode_attn = \ + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION def forward( self, @@ -166,9 +169,9 @@ class TritonAttentionImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_queries_per_kv = query.shape[1] // key.shape[1] - use_prefill_decode_attn = (num_queries_per_kv & - (num_queries_per_kv - 1)) != 0 - + num_q_is_pow2 = (num_queries_per_kv & (num_queries_per_kv - 1)) == 0 + use_prefill_decode_attn = (self.force_prefill_decode_attn + or not num_q_is_pow2) num_actual_tokens = attn_metadata.num_actual_tokens if use_prefill_decode_attn: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 403b5401be75a..a41fe48818702 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -544,16 +544,17 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, available_memory) estimated_msg = "" if estimated_max_len > 0: - estimated_msg = " Based on the available memory," - f" the estimated maximum model length is {estimated_max_len}." + estimated_msg = ( + "Based on the available memory, " + f"the estimated maximum model length is {estimated_max_len}.") raise ValueError( f"To serve at least one request with the models's max seq len " f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory/GiB_bytes:.2f} GiB)." + f"memory ({available_memory/GiB_bytes:.2f} GiB). " f"{estimated_msg} " - f" Try increasing `gpu_memory_utilization` or decreasing " + f"Try increasing `gpu_memory_utilization` or decreasing " f"`max_model_len` when initializing the engine.") diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c17f80b6ae78a..055ce446051ef 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -45,7 +45,7 @@ class SchedulerInterface(ABC): self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", - ) -> "EngineCoreOutputs": + ) -> dict[int, "EngineCoreOutputs"]: """Update the scheduler state based on the model runner output. This method is called after the model runner has processed the scheduled @@ -55,7 +55,8 @@ class SchedulerInterface(ABC): for each request. Returns: - A EngineCoreOutputs object containing the outputs for each request. + A dict of client index to EngineCoreOutputs object containing the + outputs for each request originating from that client. """ raise NotImplementedError @@ -126,6 +127,11 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + raise NotImplementedError + @abstractmethod def make_stats(self) -> Optional["SchedulerStats"]: """Make a SchedulerStats object for logging. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4c6b3eea0cb75..ce16a1ed5a096 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface): # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.include_finished_set = include_finished_set + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface): self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, - ) -> EngineCoreOutputs: + ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs @@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface): num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] - outputs: list[EngineCoreOutput] = [] + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below @@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface): if new_token_ids or kv_transfer_params: # Add EngineCoreOutput for this Request. - outputs.append( + outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, @@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface): self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running - engine_core_outputs = EngineCoreOutputs( - outputs=outputs, - scheduler_stats=self.make_stats(spec_decoding_stats), - ) - if self.include_finished_set: - #TODO currently sending duplicates here, improve this - engine_core_outputs.finished_requests = ( - scheduler_output.finished_req_ids | self.finished_req_ids) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids is not None: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) return engine_core_outputs + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request @@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface): delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) - self._cached_reqs_data.pop(request.request_id, None) - self.finished_req_ids.add(request.request_id) + request_id = request.request_id + self._cached_reqs_data.pop(request_id, None) + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) if not delay_free_blocks: self._free_blocks(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 41db99beaad5e..0c9f61a764279 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -44,10 +44,6 @@ class EngineCoreRequest( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, - # but this object is currently not playing well with msgspec - # due to circular imports and typing we have in data.py - request_id: str prompt_token_ids: list[int] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] @@ -59,6 +55,10 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] + # Index of the client, used to ensure outputs are sent back to the same + # client for this request when scaling out the front-end. + client_index: int = 0 + # Used in DP case to indicate which wave of requests this is expected to # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 74c2251c75214..86781e7528fa3 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, setup_default_loggers) +from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -54,6 +55,8 @@ class AsyncLLM(EngineClient): log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> None: """ Create an AsyncLLM. @@ -124,6 +127,8 @@ class AsyncLLM(EngineClient): vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, + client_addresses=client_addresses, + client_index=client_index, ) if self.stat_loggers: for stat_logger in self.stat_loggers[0]: @@ -145,6 +150,8 @@ class AsyncLLM(EngineClient): stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( @@ -162,6 +169,8 @@ class AsyncLLM(EngineClient): log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, + client_addresses=client_addresses, + client_index=client_index, ) @classmethod @@ -195,6 +204,8 @@ class AsyncLLM(EngineClient): def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" + shutdown_prometheus() + if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() @@ -398,7 +409,6 @@ class AsyncLLM(EngineClient): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. if stat_loggers: - assert outputs.scheduler_stats is not None AsyncLLM._record_stats( stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, @@ -422,7 +432,7 @@ class AsyncLLM(EngineClient): @staticmethod def _record_stats( stat_loggers: list[StatLoggerBase], - scheduler_stats: SchedulerStats, + scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], ): """static so that it can be used from the output_handler task diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py new file mode 100644 index 0000000000000..b84d4b144b5f2 --- /dev/null +++ b/vllm/v1/engine/coordinator.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +import multiprocessing +import time +import weakref +from typing import Optional + +import msgspec.msgpack +import zmq + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket +from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType +from vllm.v1.serial_utils import MsgpackDecoder +from vllm.v1.utils import get_engine_client_zmq_addr, shutdown + +logger = init_logger(__name__) + + +class DPCoordinator: + """Coordinator process used for data-parallel deployments (DP>1). + + Intermediates between multiple DP engine rank processes and one or more + front-end API server processes. + + * Collects stats from each DP engine (currently just waiting and running + queue lengths), and publishes these to all front-ends for use in + load-balancing decisions. + + * Keeps track of the current DP "request wave" number and running state + of the engines. This is received from the DP rank 0 engine and published + to the front-end processes along with the current load stats. + + The engines alternate between a global running/paused state. The global + "request wave" number is a count of the number of times that the workers + collectively move from a running state to a paused state. This transition + is synchronized via the all-reduce operation performed in the + DPEngineCoreProc._has_global_unfinished_reqs method. + + * Broadcasts the START_DP_WAVE message to engines to move them from paused + to running state when one engine receives a new request. This can happen + in two cases: + 1) A front-end sending a new request while the engines are paused will + concurrently notify the coordinator. + 2) An engine receiving a request for a stale request wave while in paused + state will notify the coordinator. + + Engines will move into running state when receiving a new request or + START_DP_WAVE message. + """ + + def __init__(self, parallel_config: ParallelConfig): + + # Assume coordinator is colocated with front-end procs. + front_publish_address = get_open_zmq_ipc_path() + + dp_size = parallel_config.data_parallel_size + assert dp_size > 1, "Coordinator only used for data parallel" + + local_only = dp_size == parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + back_publish_address = get_engine_client_zmq_addr(local_only, host) + back_output_address = get_engine_client_zmq_addr(local_only, host) + + context = get_mp_context() + self.proc: multiprocessing.Process = context.Process( + target=CoordinatorProc.run_coordinator, + name="VLLM_DP_Coordinator", + kwargs={ + "engine_count": parallel_config.data_parallel_size, + "front_publish_address": front_publish_address, + "back_output_address": back_output_address, + "back_publish_address": back_publish_address, + }, + daemon=True) + self.proc.start() + + self.stats_publish_address = front_publish_address + self.coord_in_address = back_publish_address + self.coord_out_address = back_output_address + self._finalizer = weakref.finalize(self, shutdown, [self.proc]) + + def get_stats_publish_address(self) -> str: + return self.stats_publish_address + + def get_engine_socket_addresses(self) -> tuple[str, str]: + """Returns tuple of ZMQ input address, output address.""" + return self.coord_in_address, self.coord_out_address + + def close(self): + self._finalizer() + + +class EngineState: + + def __init__(self): + self.request_counts = [0, 0] # [waiting, running] + + +class CoordinatorProc: + + def __init__(self, engine_count: int): + + self.ctx = zmq.Context() + + self.engines = [EngineState() for _ in range(engine_count)] + + self.current_wave = 0 + self.engines_running = False + self.stats_changed = False + + @staticmethod + def run_coordinator( + engine_count: int, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): + coordinator = CoordinatorProc(engine_count=engine_count) + try: + coordinator.process_input_socket( + front_publish_address, + back_output_address, + back_publish_address, + ) + except KeyboardInterrupt: + logger.info("DP Coordinator process exiting") + + def process_input_socket(self, front_publish_address: str, + back_output_address: str, + back_publish_address: str): + + decoder = MsgpackDecoder(EngineCoreOutputs) + + with make_zmq_socket( + path=front_publish_address, # IPC + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_front, make_zmq_socket( + path=back_output_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.PULL, + bind=True, + ) as output_back, make_zmq_socket( + path=back_publish_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_back: + + poller = zmq.Poller() + poller.register(publish_front, zmq.POLLIN) + poller.register(output_back, zmq.POLLIN) + last_publish_time = 0 + while True: + elapsed = int(time.time() * 1000) - last_publish_time + # Send at 100 ms interval if the stats have changed, + # or otherwise every 3 seconds. + wait_for = 100 if self.stats_changed else 3000 + events = poller.poll(timeout=max(0, wait_for - elapsed)) + if not events: + # Poller timeout - publish current stats to front-ends. + engine_req_counts_list = self._get_engine_counts() + to_publish = (engine_req_counts_list, self.current_wave, + self.engines_running) + publish_front.send(msgspec.msgpack.encode(to_publish)) + last_publish_time = int(time.time() * 1000) + self.stats_changed = False + continue + + events = dict(events) + + if publish_front in events: + buffer = publish_front.recv() + if buffer == b'\x01': + # Ignore subscription messages. + continue + + # We received a message on the front-end XPUB socket, + # from an API server sending a new request while the + # engines are paused, so that we can wake the other + # engines. + engine_to_exclude, wave = msgspec.msgpack.decode(buffer) + if wave < self.current_wave: + # If the wave number is stale, ensure the message is + # handled by all the engines. + engine_to_exclude = None + if not self.engines_running: + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, self.current_wave, + engine_to_exclude) + + if output_back in events: + # We received a message from one of the engines. + + buffer = output_back.recv() + outputs: EngineCoreOutputs = decoder.decode(buffer) + + assert not outputs.outputs + assert outputs.utility_output is None + + eng_index = outputs.engine_index + if outputs.scheduler_stats: + # 1. Updated request load stats - update our local + # state with these. + stats = self.engines[eng_index].request_counts + stats[0] = outputs.scheduler_stats.num_waiting_reqs + stats[1] = outputs.scheduler_stats.num_running_reqs + self.stats_changed = True + + if (wave := outputs.wave_complete) is not None: + # 2. Notification from rank 0 engine that we've + # moved into the global paused state + # (engines_running==False) + if self.current_wave <= wave: + logger.debug("Moving DP wave from %d to %d.", + self.current_wave, wave) + self.current_wave = wave + 1 + self.engines_running = False + self.stats_changed = True + elif (wave := outputs.start_wave) is not None and ( + wave > self.current_wave or + (wave == self.current_wave + and not self.engines_running)): + # 3. The engine received request for a non-current wave + # so we must ensure that other engines progress to the + # next wave (race condition handling). + logger.debug( + "Starting wave %d after notification of " + "stale wave request from engine.", wave) + self.current_wave = wave + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, wave, eng_index) + + @staticmethod + def _send_start_wave(socket: zmq.Socket, wave: int, + exclude_engine_index: Optional[int]): + """Broadcast the START_DP_WAVE message to all the engines. + It includes the current wave number and index of engine which + has already received a request with this wave number and so doesn't + require additional notification. + """ + wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) + socket.send_multipart( + (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + + def _get_engine_counts(self) -> list[list[int]]: + """Return list of [waiting, running] count lists for each engine.""" + return [e.request_counts for e in self.engines] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 740ba60fe231b..a02abb62b1f36 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,6 +7,7 @@ import threading import time from collections import deque from concurrent.futures import Future +from contextlib import ExitStack from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union @@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx +from vllm.utils import make_zmq_socket, resolve_obj_by_qualname from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -123,7 +126,6 @@ class EngineCore: logger.info("Batch queue is enabled with size %d", self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) - self.vllm_config = vllm_config def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: @@ -212,24 +214,27 @@ class EngineCore: # Re-raise exception raise err - def step(self) -> EngineCoreOutputs: - """Schedule, execute, and make output.""" + def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: + """Schedule, execute, and make output. + + Returns tuple of outputs and a flag indicating whether the model + was executed. + """ # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): - return EngineCoreOutputs( - outputs=[], - scheduler_stats=self.scheduler.make_stats(), - ) + return {}, False scheduler_output = self.scheduler.schedule() model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output) # type: ignore - return engine_core_outputs + return (engine_core_outputs, + scheduler_output.total_num_scheduled_tokens > 0) - def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: + def step_with_batch_queue( + self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -271,10 +276,10 @@ class EngineCore: # Blocking until the first result is available. model_output = future.result() self.batch_queue.task_done() - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) + engine_core_outputs = (self.scheduler.update_from_output( + scheduler_output, model_output)) - return engine_core_outputs + return engine_core_outputs, scheduled_batch def shutdown(self): self.structured_output_manager.clear_backend() @@ -357,7 +362,7 @@ class EngineCoreProc(EngineCore): self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -370,65 +375,70 @@ class EngineCoreProc(EngineCore): # Create input socket. input_ctx = zmq.Context() identity = engine_index.to_bytes(length=2, byteorder="little") - input_socket = make_zmq_socket(input_ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False) - try: + with make_zmq_socket(input_ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False) as handshake_socket: + # Register engine with front-end. - output_address = self.startup_handshake( - input_socket, on_head_node, vllm_config.parallel_config) + addresses = self.startup_handshake(handshake_socket, on_head_node, + vllm_config.parallel_config) + self.client_count = len(addresses.outputs) # Update config which may have changed from the handshake. vllm_config.__post_init__() # Set up data parallel environment. + self.has_coordinator = addresses.coordinator_output is not None self._init_data_parallel(vllm_config) # Initialize engine core and model. super().__init__(vllm_config, executor_class, log_stats, executor_fail_callback) + self.engine_index = engine_index self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) self.engines_running = False + self.last_counts = (0, 0) # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "READY", "local": on_head_node, "num_gpu_blocks": num_gpu_blocks, })) - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() - threading.Thread(target=self.process_input_socket, - args=(input_socket, ), - daemon=True).start() - input_socket = None - self.output_thread = threading.Thread( - target=self.process_output_socket, - args=(output_address, engine_index), - daemon=True) - self.output_thread.start() - finally: - if input_socket is not None: - input_socket.close(linger=0) + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + threading.Thread(target=self.process_input_sockets, + args=(addresses.inputs, addresses.coordinator_input, + identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(addresses.outputs, addresses.coordinator_output, + engine_index), + daemon=True) + self.output_thread.start() @staticmethod - def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, - parallel_config: ParallelConfig) -> str: + def startup_handshake( + handshake_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> EngineZmqAddresses: # Send registration message. - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "HELLO", "local": on_head_node, @@ -436,22 +446,20 @@ class EngineCoreProc(EngineCore): # Receive initialization message. logger.info("Waiting for init message from front-end.") - if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): raise RuntimeError("Did not receive response from front-end " f"process within {HANDSHAKE_TIMEOUT_MINS} " f"minutes") - init_bytes = input_socket.recv() - init_message = msgspec.msgpack.decode(init_bytes) + init_bytes = handshake_socket.recv() + init_message: EngineHandshakeMetadata = msgspec.msgpack.decode( + init_bytes, type=EngineHandshakeMetadata) logger.debug("Received init message: %s", init_message) - output_socket_address = init_message["output_socket_address"] - #TBD(nick) maybe replace IP with configured head node address - - received_parallel_config = init_message["parallel_config"] + received_parallel_config = init_message.parallel_config for key, value in received_parallel_config.items(): setattr(parallel_config, key, value) - return output_socket_address + return init_message.addresses @staticmethod def run_engine_core(*args, @@ -523,7 +531,7 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not (self.scheduler.has_requests()): + while not self.engines_running and not self.scheduler.has_requests(): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -538,14 +546,16 @@ class EngineCoreProc(EngineCore): req = self.input_queue.get_nowait() self._handle_client_request(*req) - def _process_engine_step(self): + def _process_engine_step(self) -> bool: """Called only when there are unfinished local requests.""" # Step the engine core. - outputs = self.step_fn() + outputs, model_executed = self.step_fn() # Put EngineCoreOutputs into the output queue. - if outputs is not None: - self.output_queue.put_nowait(outputs) + for output in (outputs.items() if outputs else ()): + self.output_queue.put_nowait(output) + + return model_executed def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -556,7 +566,7 @@ class EngineCoreProc(EngineCore): elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: - call_id, method_name, args = request + client_idx, call_id, method_name, args = request output = UtilityOutput(call_id) try: method = getattr(self, method_name) @@ -567,7 +577,7 @@ class EngineCoreProc(EngineCore): output.failure_message = (f"Call to {method_name} method" f" failed: {str(e)}") self.output_queue.put_nowait( - EngineCoreOutputs(utility_output=output)) + (client_idx, EngineCoreOutputs(utility_output=output))) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: @@ -600,27 +610,68 @@ class EngineCoreProc(EngineCore): logger.fatal("vLLM shutdown signal from EngineCore failed " "to send. Please report this issue.") - def process_input_socket(self, input_socket: zmq.Socket): + def process_input_sockets(self, input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - while True: - # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + with ExitStack() as stack, zmq.Context() as ctx: + input_sockets = [ + stack.enter_context( + make_zmq_socket(ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False)) + for input_address in input_addresses + ] + if coord_input_address is None: + coord_socket = None + else: + coord_socket = stack.enter_context( + make_zmq_socket(ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False)) + # Send subscription message to coordinator. + coord_socket.send(b'\x01') - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) + # Register sockets with poller. + poller = zmq.Poller() + for input_socket in input_sockets: + # Send initial message to each input socket - this is required + # before the front-end ROUTER socket can send input messages + # back to us. + input_socket.send(b'') + poller.register(input_socket, zmq.POLLIN) + if coord_socket is not None: + poller.register(coord_socket, zmq.POLLIN) - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + while True: + for input_socket, _ in poller.poll(): + # (RequestType, RequestData) + type_frame, *data_frames = input_socket.recv_multipart( + copy=False) + request_type = EngineCoreRequestType( + bytes(type_frame.buffer)) - def process_output_socket(self, output_path: str, engine_index: int): + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frames) + + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) + + def process_output_sockets(self, output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -634,30 +685,49 @@ class EngineCoreProc(EngineCore): # We must set linger to ensure the ENGINE_CORE_DEAD # message is sent prior to closing the socket. - with zmq_socket_ctx(output_path, zmq.constants.PUSH, - linger=4000) as socket: + with ExitStack() as stack, zmq.Context() as ctx: + sockets = [ + stack.enter_context( + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + for output_path in output_paths + ] + coord_socket = stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, + linger=4000)) if coord_output_path is not None else None + max_reuse_bufs = len(sockets) + 1 + while True: - outputs = self.output_queue.get() - if outputs == EngineCoreProc.ENGINE_CORE_DEAD: - socket.send(outputs, copy=False) + output = self.output_queue.get() + if output == EngineCoreProc.ENGINE_CORE_DEAD: + for socket in sockets: + socket.send(output) break - assert not isinstance(outputs, bytes) + assert not isinstance(output, bytes) + client_index, outputs = output outputs.engine_index = engine_index + if client_index == -1: + # Don't reuse buffer for coordinator message + # which will be very small. + assert coord_socket is not None + coord_socket.send_multipart(encoder.encode(outputs)) + continue + # Reclaim buffers that zmq is finished with. while pending and pending[-1][0].done: reuse_buffers.append(pending.pop()[2]) buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = socket.send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart(buffers, + copy=False, + track=True) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) - elif len(reuse_buffers) < 2: - # Keep at most 2 buffers to reuse. + elif len(reuse_buffers) < max_reuse_bufs: + # Limit the number of buffers to reuse. reuse_buffers.append(buffer) @@ -669,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc): self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -684,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc): # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. self.counter = 0 + self.current_wave = 0 # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, on_head_node, input_address, + super().__init__(vllm_config, on_head_node, handshake_address, executor_class, log_stats, dp_rank) def _init_data_parallel(self, vllm_config: VllmConfig): @@ -700,6 +771,15 @@ class DPEngineCoreProc(EngineCoreProc): assert dp_size > 1 assert 0 <= local_dp_rank <= dp_rank < dp_size + if vllm_config.kv_transfer_config is not None: + # modify the engine_id and append the local_dp_rank to it to ensure + # that the kv_transfer_config is unique for each DP rank. + vllm_config.kv_transfer_config.engine_id = ( + f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}" + ) + logger.debug("Setting kv_transfer_config.engine_id to %s", + vllm_config.kv_transfer_config.engine_id) + from vllm.platforms import current_platform device_control_env_var = current_platform.device_control_env_var world_size = vllm_config.parallel_config.world_size @@ -710,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc): self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() - self.current_wave = 0 def shutdown(self): super().shutdown() @@ -718,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc): stateless_destroy_torch_distributed_process_group(dp_group) def add_request(self, request: EngineCoreRequest): - if request.current_wave != self.current_wave: + if self.has_coordinator and request.current_wave != self.current_wave: if request.current_wave > self.current_wave: self.current_wave = request.current_wave elif not self.engines_running: # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - EngineCoreOutputs(start_wave=self.current_wave)) + (-1, EngineCoreOutputs(start_wave=self.current_wave))) super().add_request(request) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: - new_wave: int = request - if new_wave >= self.current_wave: + new_wave, exclude_eng_index = request + if exclude_eng_index != self.engine_index and ( + new_wave >= self.current_wave): self.current_wave = new_wave if not self.engines_running: logger.debug("EngineCore starting idle loop for wave %d.", @@ -742,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc): else: super()._handle_client_request(request_type, request) + def _maybe_publish_request_counts(self): + if not self.has_coordinator: + return + + # Publish our request counts (if they've changed). + counts = self.scheduler.get_request_counts() + if counts != self.last_counts: + self.last_counts = counts + stats = SchedulerStats(*counts) + self.output_queue.put_nowait( + (-1, EngineCoreOutputs(scheduler_stats=stats))) + def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -750,30 +842,18 @@ class DPEngineCoreProc(EngineCoreProc): # 1) Poll the input queue until there is work to do. self._process_input_queue() + # 2) Step the engine core. + executed = self._process_engine_step() + self._maybe_publish_request_counts() + local_unfinished_reqs = self.scheduler.has_unfinished_requests() - - if local_unfinished_reqs: - # 2) Step the engine core. - self._process_engine_step() - - # Check if we have now finished all requests. - local_unfinished_reqs = ( - self.scheduler.has_unfinished_requests()) - else: - if self.scheduler.has_finished_requests(): - # There are no unfinished requests, but there are some - # finished requests remaining to be removed from the - # batch state. This engine step won't perform a forward - # pass but will flush the finished requests to ensure - # up-to-date state is returned in the engine outputs. - self._process_engine_step() - - if not self.engines_running: + if not executed: + if not local_unfinished_reqs and not self.engines_running: # All engines are idle. continue - # There must be unfinished requests in DP peers, run a - # dummy forward pass. + # We are in a running state and so must execute a dummy pass + # if the model didn't execute any ready requests. self.execute_dummy_batch() # 3) All-reduce operation to determine global unfinished reqs. @@ -786,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc): logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) self.output_queue.put_nowait( - EngineCoreOutputs(wave_complete=self.current_wave)) + (-1, + EngineCoreOutputs(wave_complete=self.current_wave))) self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 0d52bc9a68148..232d6742b7718 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,6 +2,7 @@ import asyncio import contextlib import queue +import sys import uuid import weakref from abc import ABC, abstractmethod @@ -9,26 +10,28 @@ from collections import deque from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass -from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union -import msgspec +import msgspec.msgpack import zmq import zmq.asyncio -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_port, get_open_zmq_inproc_path, - get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) +from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, + zmq_socket_ctx) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import CoreEngineProcManager +from vllm.v1.utils import (CoreEngine, CoreEngineProcManager, + EngineZmqAddresses, get_engine_client_zmq_addr, + wait_for_engine_startup) logger = init_logger(__name__) @@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]] _R = TypeVar('_R') # Return type for collective_rpc -STARTUP_POLL_PERIOD_MS = 10000 - class EngineCoreClient(ABC): """ @@ -206,7 +207,8 @@ class InprocClient(EngineCoreClient): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: - return self.engine_core.step() + outputs, _ = self.engine_core.step() + return outputs.get(0) or EngineCoreOutputs() def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) @@ -265,24 +267,6 @@ class InprocClient(EngineCoreClient): return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngineState(Enum): - NEW = auto() - CONNECTED = auto() - READY = auto() - - -class CoreEngine: - """One per data parallel rank.""" - - def __init__(self, index: int = 0, local: bool = True): - self.local = local - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") - - self.state = CoreEngineState.NEW - self.num_reqs_in_flight = 0 - - @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding @@ -290,9 +274,12 @@ class BackgroundResources: ctx: Union[zmq.Context] local_engine_manager: Optional[CoreEngineProcManager] = None + coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + first_req_send_socket: Optional[zmq.asyncio.Socket] = None output_queue_task: Optional[asyncio.Task] = None + stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None # Set if any of the engines are dead. Here so that the output @@ -305,16 +292,21 @@ class BackgroundResources: self.engine_dead = True if self.local_engine_manager is not None: self.local_engine_manager.close() + if self.coordinator is not None: + self.coordinator.close() if self.output_queue_task is not None: self.output_queue_task.cancel() + if self.stats_update_task is not None: + self.stats_update_task.cancel() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. - if self.output_socket is not None: - self.output_socket.close(linger=0) - if self.input_socket is not None: - self.input_socket.close(linger=0) + for socket in (self.output_socket, self.input_socket, + self.first_req_send_socket): + if socket is not None: + socket.close(linger=0) + if self.shutdown_path is not None: # We must ensure that the sync output socket is # closed cleanly in its own thread. @@ -349,6 +341,7 @@ class MPClient(EngineCoreClient): vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, ): self.vllm_config = vllm_config # Serialization setup. @@ -368,8 +361,9 @@ class MPClient(EngineCoreClient): try: parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local - start_index = parallel_config.data_parallel_rank local_start_index = parallel_config.data_parallel_rank_local + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see @@ -377,46 +371,55 @@ class MPClient(EngineCoreClient): spmd_mode = local_start_index is not None if spmd_mode: assert local_engine_count == 1 - self.core_engines = [ - CoreEngine(index=local_start_index, local=True) - ] + self.core_engines = [CoreEngine(index=dp_rank, local=True)] else: - assert start_index == 0 + assert dp_rank == 0 local_start_index = 0 self.core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(parallel_config.data_parallel_size) + for i in range(dp_size) ] - input_address, output_address = self._get_zmq_addresses( - parallel_config, spmd_mode) + local_only = spmd_mode or local_engine_count == dp_size + + self.stats_update_address: Optional[str] = None + if client_addresses is not None: + input_address = client_addresses["input_address"] + output_address = client_addresses["output_address"] + self.stats_update_address = client_addresses.get( + "stats_update_address") + else: + host = parallel_config.data_parallel_master_ip + input_address = get_engine_client_zmq_addr(local_only, host) + output_address = get_engine_client_zmq_addr(local_only, host) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True) - self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.constants.PULL) - # Start local engines. - if local_engine_count: - # In server mode, start_index and local_start_index will - # both be 0. - self.resources.local_engine_manager = CoreEngineProcManager( - EngineCoreProc.run_engine_core, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=log_stats, - input_address=input_address, - on_head_node=True, - local_engine_count=local_engine_count, - start_index=start_index, - local_start_index=local_start_index) + self.ctx, output_address, zmq.PULL) + + if client_addresses is None: + self._init_engines_direct(vllm_config, local_only, + local_start_index, input_address, + output_address, executor_class, + log_stats) + coordinator = self.resources.coordinator + if coordinator: + self.stats_update_address = ( + coordinator.get_stats_publish_address()) + + # Wait for ready messages from each engine on the input socket. + identities = set(e.identity for e in self.core_engines) + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError("Timed out waiting for engines to send" + "initial message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + identities.remove(identity) self.core_engine = self.core_engines[0] - - # Wait for engine core process(es) to start. - self._wait_for_engine_startup(output_address, parallel_config) - self.utility_results: dict[int, AnyFuture] = {} # Request objects which may contain pytorch-allocated tensors @@ -429,116 +432,67 @@ class MPClient(EngineCoreClient): if not success: self._finalizer() - @staticmethod - def _get_zmq_addresses(parallel_config: ParallelConfig, - spmd_mode: bool) -> tuple[str, str]: - """Returns (input_address, output_address).""" - dp_size = parallel_config.data_parallel_size + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + host = parallel_config.data_parallel_master_ip - if local_engine_count == dp_size or spmd_mode: - input_address = get_open_zmq_ipc_path() - output_address = get_open_zmq_ipc_path() - else: - host = parallel_config.data_parallel_master_ip - input_port = parallel_config.data_parallel_rpc_port - output_port = get_open_port() - input_address = get_tcp_uri(host, input_port) - output_address = get_tcp_uri(host, output_port) + if len(self.core_engines) > 1: + self.resources.coordinator = DPCoordinator(parallel_config) - return input_address, output_address + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) - def _wait_for_engine_startup(self, output_address: str, - parallel_config: ParallelConfig): - # Get a sync handle to the socket which can be sync or async. - sync_input_socket = zmq.Socket.shadow(self.input_socket) + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: - # Wait for engine core process(es) to send ready messages. - local_count = parallel_config.data_parallel_size_local - remote_count = len(self.core_engines) - local_count - # [local, remote] counts - conn_pending, start_pending = [local_count, remote_count], [0, 0] + # Start local engines. + if local_engine_count: + # In server mode, start_index and local_start_index will + # both be 0. + self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) - poller = zmq.Poller() - poller.register(sync_input_socket, zmq.POLLIN) - proc_manager = self.resources.local_engine_manager - if proc_manager is not None: - for sentinel in proc_manager.sentinels(): - poller.register(sentinel, zmq.POLLIN) - while any(conn_pending) or any(start_pending): - events = poller.poll(STARTUP_POLL_PERIOD_MS) - if not events: - if any(conn_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) - if any(start_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) - continue - if len(events) > 1 or events[0][0] != sync_input_socket: - # One of the local core processes exited. - finished = proc_manager.finished_procs( - ) if proc_manager else {} - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") + # Wait for engine core process(es) to start. + self._wait_for_engine_startup(handshake_socket, input_address, + output_address) - # Receive HELLO and READY messages from the input socket. - eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() - eng_index = int.from_bytes(eng_identity, byteorder="little") - engine = next( - (e for e in self.core_engines if e.identity == eng_identity), - None) - if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") - msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] - if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") + def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, + input_address: str, output_address: str): + addresses = EngineZmqAddresses( + inputs=[input_address], + outputs=[output_address], + ) - if status == "HELLO" and engine.state == CoreEngineState.NEW: + coordinator = self.resources.coordinator + if coordinator is not None: + addresses.coordinator_input, addresses.coordinator_output = ( + coordinator.get_engine_socket_addresses()) - # Send init message with DP config info. - init_message = self.encoder.encode({ - "output_socket_address": output_address, - "parallel_config": { - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": - parallel_config.data_parallel_size, - }, - }) - sync_input_socket.send_multipart((eng_identity, *init_message), - copy=False) - conn_pending[0 if local else 1] -= 1 - start_pending[0 if local else 1] += 1 - engine.state = CoreEngineState.CONNECTED - elif status == "READY" and (engine.state - == CoreEngineState.CONNECTED): - # Setup KV cache config with initialization state from - # engine core process. Sum values from all engines in DP case. - cache_config = self.vllm_config.cache_config - num_gpu_blocks = cache_config.num_gpu_blocks or 0 - num_gpu_blocks += msg['num_gpu_blocks'] - cache_config.num_gpu_blocks = num_gpu_blocks - - start_pending[0 if local else 1] -= 1 - engine.state = CoreEngineState.READY - else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") - - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + wait_for_engine_startup( + handshake_socket, + addresses, + self.core_engines, + self.vllm_config.parallel_config, + self.vllm_config.cache_config, + self.resources.local_engine_manager, + coordinator.proc if coordinator else None, + ) def shutdown(self): # Terminate background resources. @@ -604,8 +558,8 @@ class SyncMPClient(MPClient): try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() - poller.register(shutdown_socket) - poller.register(out_socket) + poller.register(shutdown_socket, zmq.POLLIN) + poller.register(out_socket, zmq.POLLIN) while True: socks = poller.poll() if not socks: @@ -667,7 +621,7 @@ class SyncMPClient(MPClient): future: Future[Any] = Future() self.utility_results[call_id] = future self._send_input(EngineCoreRequestType.UTILITY, - (call_id, method, args)) + (0, call_id, method, args)) return future.result() @@ -729,15 +683,21 @@ class SyncMPClient(MPClient): class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): super().__init__( asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, + client_addresses=client_addresses, ) + self.client_index = client_index self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: @@ -853,12 +813,13 @@ class AsyncMPClient(MPClient): future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (call_id, method, args))) + (self.client_index, call_id, method, args))) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: + request.client_index = self.client_index await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() @@ -920,17 +881,120 @@ class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): self.current_wave = 0 self.engines_running = False + # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} - super().__init__(vllm_config, executor_class, log_stats) + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) assert len(self.core_engines) > 1 + # List of [waiting, running] pair per engine. + self.lb_engines: list[list[int]] = [] + + self.first_req_sock_addr = get_open_zmq_inproc_path() + self.first_req_send_socket = self.resources.first_req_send_socket = ( + make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=True)) + try: + # If we are running in an asyncio event loop, start the stats task. + # Otherwise, it will be started lazily. + asyncio.get_running_loop() + self._ensure_stats_update_task() + except RuntimeError: + pass + + def _ensure_stats_update_task(self): + resources = self.resources + if resources.stats_update_task is not None: + return + + assert self.stats_update_address is not None + + async def run_engine_stats_update_task(): + with make_zmq_socket(self.ctx, self.stats_update_address, + zmq.XSUB) as socket, make_zmq_socket( + self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=False) as first_req_rcv_socket: + # Send subscription message. + await socket.send(b'\x01') + + poller = zmq.asyncio.Poller() + poller.register(socket, zmq.POLLIN) + poller.register(first_req_rcv_socket, zmq.POLLIN) + + while True: + events = await poller.poll() + if not self.engines_running and len(events) == 2 or ( + events[0][0] == first_req_rcv_socket): + # Send a message to notify the coordinator that + # we're sending a request while the engines are + # paused, so that it can wake the others up + # (to run dummy EP loop). + self.engines_running = True + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + target_eng_index = int.from_bytes(buf, "little") + msg = msgspec.msgpack.encode( + (target_eng_index, self.current_wave)) + await socket.send(msg) + + buf = None + while True: + # Drain all stats events (we only care about latest). + future: asyncio.Future[bytes] = socket.recv( + flags=zmq.NOBLOCK) + if isinstance(future.exception(), zmq.Again): + break + buf = future.result() + if buf is None: + continue + + # Update local load-balancing state. + counts, wave, running = msgspec.msgpack.decode(buf) + self.current_wave = wave + self.engines_running = running + self.lb_engines = counts + + resources.stats_update_task = asyncio.create_task( + run_engine_stats_update_task()) + + def get_core_engine_for_request(self) -> CoreEngine: + if not self.lb_engines: + return self.core_engines[0] + # TODO use P2C alg for larger DP sizes + num_engines = len(self.lb_engines) + min_counts = [sys.maxsize, sys.maxsize] + eng_index = 0 + for i in range(num_engines): + # Start from client_index to help with balancing when engines + # are empty. + idx = (self.client_index + i) % num_engines + counts = self.lb_engines[idx] + if counts < min_counts: + min_counts = counts + eng_index = idx + # Adjust local counts for better balancing between stats updates + # from the coordinator (which happen every 100ms). + if min_counts[0]: + min_counts[0] += 1 + else: + min_counts[1] += 1 + return self.core_engines[eng_index] + async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ @@ -939,62 +1003,30 @@ class DPAsyncMPClient(AsyncMPClient): ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: + self._ensure_stats_update_task() + request.current_wave = self.current_wave + request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine - chosen_engine.num_reqs_in_flight += 1 to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: - # Send request to chosen engine and dp start loop - # control message to all other engines. - self.engines_running = True - to_await = asyncio.gather( - to_await, # type: ignore[assignment] - *self._start_wave_coros(exclude_index=chosen_engine.index)) + # Notify coordinator that we're sending a request + await self.first_req_send_socket.send(chosen_engine.identity) await to_await self._ensure_output_queue_task() - def get_core_engine_for_request(self) -> CoreEngine: - return min(self.core_engines, key=lambda e: e.num_reqs_in_flight) - @staticmethod async def process_engine_outputs(self: "DPAsyncMPClient", outputs: EngineCoreOutputs): - if self.reqs_in_flight: - for req_id in outputs.finished_requests or (): - if engine := self.reqs_in_flight.pop(req_id, None): - engine.num_reqs_in_flight -= 1 - - if outputs.wave_complete is not None: - # Current wave is complete, move to next wave number - # and mark engines as paused. - if self.current_wave <= outputs.wave_complete: - self.current_wave = outputs.wave_complete + 1 - self.engines_running = False - - elif outputs.start_wave is not None and ( - outputs.start_wave > self.current_wave or - (outputs.start_wave == self.current_wave - and not self.engines_running)): - # Engine received request for a non-current wave so we must ensure - # that other engines progress to the next wave. - self.current_wave = outputs.start_wave - self.engines_running = True - await asyncio.gather(*self._start_wave_coros( - exclude_index=outputs.engine_index)) - - def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: - logger.debug("Sending start DP wave %d.", self.current_wave) - return [ - self._send_input(EngineCoreRequestType.START_DP_WAVE, - self.current_wave, engine) - for engine in self.core_engines if engine.index != exclude_index - ] + if outputs.finished_requests and self.reqs_in_flight: + for req_id in outputs.finished_requests: + self.reqs_in_flight.pop(req_id, None) async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3dc2f77444f63..665e5873d5891 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason +from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5.0 - StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] @@ -35,7 +34,7 @@ class StatLoggerBase(ABC): ... @abstractmethod - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): ... @@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase): # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log Stats to standard output.""" if iteration_stats: self._track_iteration_stats(iteration_stats) - self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats is not None: + self.prefix_caching_metrics.observe( + scheduler_stats.prefix_cache_stats) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_logging.observe( + scheduler_stats.spec_decoding_stats) - self.last_scheduler_stats = scheduler_stats + self.last_scheduler_stats = scheduler_stats def log(self): now = time.monotonic() @@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): - logger.info( - "vllm cache_config_info with initialization " \ - "after num_gpu_blocks is: %d", - self.vllm_config.cache_config.num_gpu_blocks) + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "Engine %03d: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks) class PrometheusStatLogger(StatLoggerBase): @@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase): _spec_decoding_cls = SpecDecodingProm def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - self._unregister_vllm_metrics() + + unregister_vllm_metrics() self.vllm_config = vllm_config self.engine_index = engine_index # Use this flag to hide metrics that were deprecated in @@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) # @@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.counter_gpu_prefix_cache_queries = self._counter_cls( @@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. + # See: https://github.com/vllm-project/vllm/pull/18053 self.histogram_iteration_tokens = \ self._histogram_cls( name="vllm:iteration_tokens_total", @@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase): # # LoRA metrics # + + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase): self._gauge_cls( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", + multiprocess_mode="sum", labelnames=[ self.labelname_max_lora, self.labelname_waiting_lora_adapters, self.labelname_running_lora_adapters, - ]) + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + metrics_info = config_obj.metrics_info() metrics_info["engine"] = self.engine_index @@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase): info_gauge = self._gauge_cls( name=name, documentation=documentation, - labelnames=metrics_info.keys()).labels(**metrics_info) + multiprocess_mode="mostrecent", + labelnames=metrics_info.keys(), + ).labels(**metrics_info) info_gauge.set(1) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log to prometheus.""" - self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) - self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + if scheduler_stats is not None: + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) - self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) - self.counter_gpu_prefix_cache_queries.inc( - scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits.inc( - scheduler_stats.prefix_cache_stats.hits) + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return @@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_lora_info.labels(**lora_info_labels)\ .set_to_current_time() - @staticmethod - def _unregister_vllm_metrics(): - # Unregister any existing vLLM collectors (for CI/CD - for collector in list(prometheus_client.REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - prometheus_client.REGISTRY.unregister(collector) - def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py new file mode 100644 index 0000000000000..a364b286d21b9 --- /dev/null +++ b/vllm/v1/metrics/prometheus.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from typing import Optional + +from prometheus_client import REGISTRY, CollectorRegistry, multiprocess + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Global temporary directory for prometheus multiprocessing +_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None + + +def setup_multiprocess_prometheus(): + """Set up prometheus multiprocessing directory if not already configured. + + """ + global _prometheus_multiproc_dir + + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + _prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name + logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", + _prometheus_multiproc_dir.name) + else: + logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + +def get_prometheus_registry(): + """Get the appropriate prometheus registry based on multiprocessing + configuration. + + Returns: + Registry: A prometheus registry + """ + if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None: + logger.debug("Using multiprocess registry for prometheus metrics") + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + return registry + + return REGISTRY + + +def unregister_vllm_metrics(): + """Unregister any existing vLLM collectors from the prometheus registry. + + This is useful for testing and CI/CD where metrics may be registered + multiple times across test runs. + + Also, in case of multiprocess, we need to unregister the metrics from the + global registry. + """ + registry = REGISTRY + # Unregister any existing vLLM collectors + for collector in list(registry._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + registry.unregister(collector) + + +def shutdown_prometheus(): + """Shutdown prometheus metrics.""" + + path = _prometheus_multiproc_dir + if path is None: + return + try: + pid = os.getpid() + multiprocess.mark_process_dead(pid, path) + logger.debug("Marked Prometheus metrics for process %d as dead", pid) + except Exception as e: + logger.error("Error during metrics cleanup: %s", str(e)) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b4c84507532a1..42c75ef964016 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -26,12 +26,13 @@ class Request: multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], - arrival_time: float, + client_index: int = 0, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, ) -> None: self.request_id = request_id + self.client_index = client_index self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id @@ -90,13 +91,13 @@ class Request: return cls( request_id=request.request_id, + client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, - arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index fbd38fc472031..78f37c1e8b218 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -158,8 +158,8 @@ class MsgpackEncoder: self, obj: torch.Tensor ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None - # view the tensor as a 1D array of bytes - arr = obj.flatten().view(torch.uint8).numpy() + # view the tensor as a contiguous 1D array of bytes + arr = obj.flatten().contiguous().view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index c701ab1d35a58..07b422814e13a 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -149,31 +149,37 @@ class StructuredOutputManager: # NOTE: This outer loop can likely be parallelized to improve # performance of bitmask generation for large batches. for req_id, _ in ordered_seq: - request = requests[req_id].structured_output_request - if TYPE_CHECKING: - assert request is not None - assert request.grammar is not None + request = requests[req_id] + structured_output_request = request.structured_output_request - apply_bitmask = ( - request.reasoning_ended if self.reasoner is not None else True - ) # noqa: E501 + if TYPE_CHECKING: + assert structured_output_request is not None + assert structured_output_request.grammar is not None + apply_bitmask: bool = True + if self.reasoner is not None: + if structured_output_request.reasoning_ended is None: + structured_output_request.reasoning_ended = \ + self.reasoner.is_reasoning_end(request.prompt_token_ids) + apply_bitmask = structured_output_request.reasoning_ended state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] for i, token in enumerate(req_tokens): - if apply_bitmask and not request.grammar.is_terminated(): - request.grammar.fill_bitmask(bitmask_tensor, - cumulative_index) + if apply_bitmask and not \ + structured_output_request.grammar.is_terminated(): + structured_output_request.grammar.fill_bitmask( + bitmask_tensor, cumulative_index) if token is not None: # In order to generate the correct bitmask for each # position in the speculative sequence, we advance # the FSM state for each speculative token and rollback # to restore the previous state when we are finished. - assert request.grammar.accept_tokens(req_id, [token]) + assert structured_output_request.grammar.accept_tokens( + req_id, [token]) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: - request.grammar.rollback(state_advancements) + structured_output_request.grammar.rollback(state_advancements) if cumulative_index < bitmask_tensor.shape[0]: bitmask_tensor = bitmask_tensor[:cumulative_index] diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index c16320b9e74c6..9a7e30d41aaa8 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -20,7 +20,7 @@ class StructuredOutputRequest: sampling_params: SamplingParams _grammar: Optional[Union[Future[StructuredOutputGrammar], StructuredOutputGrammar]] = None - reasoning_ended: bool = False + reasoning_ended: Optional[bool] = None def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0758747a83cc6..a26794561a526 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,31 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 -import os +import argparse +import multiprocessing import time import weakref from collections import defaultdict from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto from multiprocessing import Process, connection -from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, - overload) +from multiprocessing.process import BaseProcess +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) +import msgspec import torch +import zmq -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import get_mp_context, kill_process_tree +from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, + get_tcp_uri, kill_process_tree) from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: from vllm.attention.layer import Attention + from vllm.v1.engine.coordinator import DPCoordinator logger = init_logger(__name__) T = TypeVar("T") +STARTUP_POLL_PERIOD_MS = 10000 + class ConstantList(Generic[T], Sequence): @@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence): return f"ConstantList({self._x})" +def get_engine_client_zmq_addr(local_only: bool, + host: str, + port: int = 0) -> str: + return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( + host, port or get_open_port())) + + +class APIServerProcessManager: + """Manages a group of API server processes. + + Handles creation, monitoring, and termination of API server worker + processes. Also monitors extra processes to check if they are healthy. + """ + + def __init__( + self, + target_server_fn: Callable, + listen_address: str, + sock: Any, + args: argparse.Namespace, + num_servers: int, + input_addresses: list[str], + output_addresses: list[str], + stats_update_address: Optional[str] = None, + ): + """Initialize and start API server worker processes. + + Args: + target_server_fn: Function to call for each API server process + listen_address: Address to listen for client connections + sock: Socket for client connections + args: Command line arguments + num_servers: Number of API server processes to start + input_addresses: Input addresses for each API server + output_addresses: Output addresses for each API server + stats_update_address: Optional stats update address + """ + self.listen_address = listen_address + self.sock = sock + self.args = args + + # Start API servers + spawn_context = multiprocessing.get_context("spawn") + self.processes: list[BaseProcess] = [] + + for i, in_addr, out_addr in zip(range(num_servers), input_addresses, + output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + proc = spawn_context.Process(target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + self.processes.append(proc) + proc.start() + + logger.info("Started %d API server processes", len(self.processes)) + + # Shutdown only the API server processes on garbage collection + # The extra processes are managed by their owners + self._finalizer = weakref.finalize(self, shutdown, self.processes) + + def close(self) -> None: + self._finalizer() + + class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown @@ -109,7 +191,7 @@ class CoreEngineProcManager: local_start_index: int, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -117,12 +199,12 @@ class CoreEngineProcManager: common_kwargs = { "vllm_config": vllm_config, "on_head_node": on_head_node, - "input_address": input_address, + "handshake_address": handshake_address, "executor_class": executor_class, "log_stats": log_stats, } - self.processes: list[Process] = [] + self.processes: list[BaseProcess] = [] for index in range(local_engine_count): local_index = local_start_index + index global_index = start_index + index @@ -135,8 +217,7 @@ class CoreEngineProcManager: "local_dp_rank": local_index, })) - self._finalizer = weakref.finalize(self, shutdown, self.processes, - input_address) + self._finalizer = weakref.finalize(self, shutdown, self.processes) try: for proc in self.processes: proc.start() @@ -164,9 +245,199 @@ class CoreEngineProcManager: } +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0, local: bool = True): + self.local = local + self.index = index + self.identity = index.to_bytes(2, "little") + + self.state = CoreEngineState.NEW + + +@dataclass +class EngineZmqAddresses: + # ZMQ input socket addresses for each front-end client (requests) + inputs: list[str] + # ZMQ output socket addresses for each front-end client (responses) + outputs: list[str] + # ZMQ input socket address of DP coordinator if applicable + coordinator_input: Optional[str] = None + # ZMQ output socket address of DP coordinator if applicable + coordinator_output: Optional[str] = None + + +@dataclass +class EngineHandshakeMetadata: + """Metadata sent to each engine process during startup handshake, + including addresses of the front-end ZMQ queues that they should + connect to. + """ + addresses: EngineZmqAddresses + parallel_config: dict[str, Union[int, str]] + + +def wait_for_engine_startup( + handshake_socket: zmq.Socket, + addresses: EngineZmqAddresses, + core_engines: list[CoreEngine], + parallel_config: ParallelConfig, + cache_config: CacheConfig, + proc_manager: Optional[CoreEngineProcManager], + coord_process: Optional[Process], +): + + # Wait for engine core process(es) to send ready messages. + local_count = parallel_config.data_parallel_size_local + remote_count = len(core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() + poller.register(handshake_socket, zmq.POLLIN) + + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + if coord_process is not None: + poller.register(coord_process.sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): + events = poller.poll(STARTUP_POLL_PERIOD_MS) + if not events: + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) + continue + if len(events) > 1 or events[0][0] != handshake_socket: + # One of the local core processes exited. + finished = proc_manager.finished_procs() if proc_manager else {} + if coord_process is not None and coord_process.exitcode is not None: + finished[coord_process.name] = coord_process.exitcode + raise RuntimeError("Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, "little") + engine = next((e for e in core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = msgspec.msgpack.encode( + EngineHandshakeMetadata( + addresses=addresses, + parallel_config={ + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + })) + handshake_socket.send_multipart((eng_identity, init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg["num_gpu_blocks"] + cache_config.num_gpu_blocks = num_gpu_blocks + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) + + +def wait_for_completion_or_failure( + api_server_manager: APIServerProcessManager, + local_engine_manager: Optional[CoreEngineProcManager] = None, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all processes to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + """ + + try: + logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, BaseProcess] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc + + if local_engine_manager: + for proc in local_engine_manager.processes: + sentinel_to_proc[proc.sentinel] = proc + + # Check if any process terminates + while sentinel_to_proc: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if local_engine_manager: + local_engine_manager.close() + + # Note(rob): shutdown function cannot be a bound method, -# else the gc cannot collect the objedecoupct. -def shutdown(procs: list[Process], input_address: str): +# else the gc cannot collect the object. +def shutdown(procs: list[BaseProcess]): # Shutdown the process. for proc in procs: if proc.is_alive(): @@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str): if proc.is_alive() and (pid := proc.pid) is not None: kill_process_tree(pid) - # Remove zmq ipc socket files. - if input_address.startswith("ipc://"): - socket_file = input_address[len("ipc://"):] - if os and os.path.exists(socket_file): - os.remove(socket_file) - def bind_kv_cache( kv_caches: dict[str, torch.Tensor], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d1195bcfb27b9..9f7c474c71cbc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,7 +28,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.model_loader import TensorizerLoader, get_model +from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -500,6 +500,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -525,17 +545,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - # Get batched arange. - # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # Equivalent to but faster than: - # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) - # Step 1. [2, 5, 3] -> [2, 7, 10] - cu_num_tokens = np.cumsum(num_scheduled_tokens) - # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] - cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, - num_scheduled_tokens) - # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] @@ -841,32 +854,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Compute the logits indices. # [4, 1, 3, 1, 2] num_sampled_tokens = num_draft_tokens + 1 - # Step 1. [4, 5, 8, 9, 11] - cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) - total_num_sampled_tokens = cu_num_sampled_tokens[-1] - # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, - num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets - # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + + # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] + # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( + num_sampled_tokens, cumsum_dtype=np.int32) + # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 # Compute the draft logits indices. - # [3, 3, 5, 5, 6] - cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) - total_num_draft_tokens = cu_num_draft_tokens[-1] - # [0, 0, 0, 3, 3, 5] - cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, - num_draft_tokens) - # [0, 1, 2, 0, 1, 0] - arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets + # cu_num_draft_tokens: [3, 3, 5, 5, 6] + # arange: [0, 1, 2, 0, 1, 0] + cu_num_draft_tokens, arange = self._get_cumsum_and_arange( + num_draft_tokens, cumsum_dtype=np.int32) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) @@ -1105,17 +1111,30 @@ class GPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) - def get_dp_padding(self, num_tokens: int): + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank - if dp_size == 1: + + # For DP: Don't pad when setting enforce_eager. + # This lets us set enforce_eager on the prefiller in a P/D setup and + # still use CUDA graphs (enabled by this padding) on the decoder. + # + # TODO(tms) : There are many cases where padding is enabled for + # prefills, causing unnecessary and excessive padding of activations. + + if dp_size == 1 or self.vllm_config.model_config.enforce_eager: # Early exit. - return 0 + return 0, None num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens, dp_size, dp_rank) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - return max_tokens_across_dp_cpu - num_tokens + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding @torch.inference_mode() def execute_model( @@ -1155,7 +1174,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_input_tokens = num_scheduled_tokens # Padding for DP - num_input_tokens += self.get_dp_padding(num_input_tokens) + num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1202,7 +1222,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -1543,7 +1564,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 time_before_load = time.perf_counter() - self.model = get_model(vllm_config=self.vllm_config) + model_loader = get_model_loader(self.load_config) + if not hasattr(self, "model"): + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.model_config) + else: + logger.info( + "Model was already initialized. Loading weights inplace..." + ) + model_loader.load_weights(self.model, + model_config=self.model_config) if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, @@ -1675,7 +1707,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) -> torch.Tensor: # Padding for DP - num_tokens += self.get_dp_padding(num_tokens) + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1741,9 +1774,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_tokens): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): outputs = model( input_ids=input_ids, positions=positions, @@ -2033,9 +2068,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) + try: + kv_cache_stride_order = self.attn_backends[ + i].get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len( + kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple( + range(len(kv_cache_shape))) + # The allocation respects the backend-defined stride order + # to ensure the semantic remains consistent for each + # backend. We first obtain the generic kv cache shape and + # then permute it according to the stride order which could + # result in a non-contiguous tensor. + kv_cache_shape = tuple(kv_cache_shape[i] + for i in kv_cache_stride_order) + # Maintain original KV shape view. + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + kv_caches[layer_name] = torch.zeros( + kv_cache_shape, dtype=dtype, + device=self.device).permute(*inv_order) else: # TODO: add new branches when introducing more types of # KV cache specs. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 669908cb577bf..5de92351e24ba 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -21,7 +21,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA -from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader import get_model_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, PlaceholderRange) @@ -171,15 +171,25 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.encoder_cache_size = encoder_cache_size # Lazy initialization - # self.model: nn.Module # Set after load_model + self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} - # self.input_batch: InputBatch # Persistent batch. # Request states. self.requests: dict[str, CachedRequestState] = {} + # Initialize input batch early to avoid AttributeError in _update_states + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_size=self.block_size, + ) + # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. # Sometimes the numpy op is faster so we create both. @@ -409,7 +419,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def get_model(self) -> nn.Module: - assert self.model is not None return self.model def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: @@ -926,7 +935,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): - model = get_model(vllm_config=self.vllm_config) + # model = get_model(vllm_config=self.vllm_config) + model_loader = get_model_loader(self.load_config) + if not hasattr(self, "model"): + logger.info("Loading model from scratch...") + model = model_loader.load_model(vllm_config=self.vllm_config, + model_config=self.model_config) + else: + logger.info( + "Model was already initialized. Loading weights inplace..." + ) + model_loader.load_weights(self.model, + model_config=self.model_config) if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, @@ -937,7 +957,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): # loading. xm.mark_step() xm.wait_device_ops() - self.model = model + if not hasattr(self, "model"): + self.model = model self.sampler = TPUSampler() @torch.no_grad() @@ -1286,16 +1307,19 @@ class TPUModelRunner(LoRAModelRunnerMixin): "Hybrid models with more than one KV cache type are not " "supported yet.") - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec. - block_size, - ) + if kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size != self.block_size: + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec. + block_size, + ) + # Verify dtype compatibility between block_table_cpu and input_batch assert self.block_table_cpu.dtype == self.input_batch.block_table[ 0].get_cpu_tensor().dtype diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 968596471a26e..3aff3e01aef16 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -2,13 +2,15 @@ import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union import torch from torch import nn from vllm.config import DeviceConfig, VllmConfig from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model @@ -36,6 +38,7 @@ class ModelInputForNeuron(ModelRunnerInputBase): input_block_ids: Optional[torch.Tensor] = None sampling_metadata: SamplingMetadata = None multi_modal_kwargs: BatchedTensorInputs = None + adapter_ids: Optional[str] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -80,6 +83,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): "The model will run without sliding window.") self.device_config = (self.device_config if self.device_config is not None else DeviceConfig()) + self.lora_config = vllm_config.lora_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -165,6 +169,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): mm_kwargs = seq_group_metadata.multi_modal_data if mm_kwargs: + mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs) multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) @@ -270,6 +275,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): sampling_params.top_p = top_p sampling_params.temperature = temperature + # we need multi_modal_data for later tokens as well + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + for seq_group_metadata in seq_group_metadata_list: + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + multi_modal_kwargs_list.append(mm_data) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -378,6 +391,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, + adapter_ids=model_input.adapter_ids, **MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs or {}, dtype=self.model_config.dtype, @@ -416,3 +430,32 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() + + def process_multi_modal_data_neuron(self, mm_data): + # this is a no-op for NeuronModelRunner + return mm_data + + def remove_all_loras(self): + raise NotImplementedError( + "LoRAs are not supported for Transformers NeuronX framework") + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + raise NotImplementedError( + "LoRAs are not supported for Transformers NeuronX framework") + + def add_lora(self, lora_request: LoRARequest): + raise NotImplementedError( + "LoRAs are not supported for Transformers NeuronX framework") + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError( + "LoRAs are not supported for Transformers NeuronX framework") + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError( + "LoRAs are not supported for Transformers NeuronX framework") + + def list_loras(self) -> Set[int]: + raise NotImplementedError( + "LoRAs are not supported for Transformers NeuronX framework") diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index aa8e39613eec8..64daee31bbdf5 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """A Neuron worker class.""" import os -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch.distributed @@ -9,19 +9,19 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.platforms.neuron import NeuronFramework from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoRANotSupportedWorkerBase, WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) -class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class NeuronWorker(LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ @@ -38,6 +38,7 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker + self.lora_config = vllm_config.lora_config if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -59,6 +60,9 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): "[transformers-neuronx, neuronx-distributed-inference]") def get_tnx_model_runner(self, vllm_config): + assert (self.lora_config + is None), ("LoRA is not supported for TransformersNeuronX " + "framework.") from vllm.worker.multi_step_neuron_model_runner import ( MultiStepNeuronModelRunner) if self.speculative_config is not None: @@ -72,6 +76,8 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): from vllm.worker.neuronx_distributed_model_runner import ( NeuronxDistributedModelRunner) if self.speculative_config is not None: + assert (self.lora_config + is None), "LoRA is not supported for Speculative Decoding" return MultiStepNeuronxDistributedModelRunner( vllm_config=vllm_config) else: @@ -156,3 +162,31 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): 1, 1, ) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if current_platform.use_transformers_neuronx(): + raise NotImplementedError( + f"{type(self)} does not support LoRA with Neuron Framework " + f"Transformers NeuronX") + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if current_platform.use_transformers_neuronx(): + raise NotImplementedError( + f"{type(self)} does not support LoRA with Neuron Framework " + f"Transformers NeuronX") + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if current_platform.use_transformers_neuronx(): + raise NotImplementedError( + f"{type(self)} does not support LoRA with Neuron Framework " + f"Transformers NeuronX") + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + if current_platform.use_transformers_neuronx(): + raise NotImplementedError( + f"{type(self)} does not support LoRA with Neuron Framework " + f"Transformers NeuronX") + return self.model_runner.list_loras() diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py index 4e784e5e0302d..9cd4f88d32f06 100644 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -1,17 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import List, Optional, Set import torch +from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import ( + get_all_supported_aspect_ratios) from neuronx_distributed_inference.modules.generation.sampling import ( prepare_sampling_params) +from neuronx_distributed_inference.modules.lora_serving import ( + LoraCheckpoint, LoraServingConfig) from vllm.config import VllmConfig +from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuronx_distributed import ( _get_model_architecture, get_neuron_model) -from vllm.sequence import IntermediateTensors +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.worker.neuron_model_runner import (ModelInputForNeuron, NeuronModelRunner) @@ -25,11 +34,44 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): vllm_config: VllmConfig, ): super().__init__(vllm_config) + self.lora_checkpoint = None + self.model = None + self.lora_serving_config = None + + @staticmethod + def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]): + if not lora_modules: + return None + return {_.get("name"): _.get("path") for _ in lora_modules} + + def _get_nxdi_lora_config(self): + override_neuron_config = self.model_config.override_neuron_config + lora_modules = override_neuron_config.pop("lora_modules", None) + target_modules = override_neuron_config.pop("target_modules", None) + lora_ckpt_paths = self._get_lora_paths_strings(lora_modules) + if self.lora_config.max_loras < len(lora_ckpt_paths): + raise ValueError( + "Number of LoRAs (%s) exceeds maximum " + "allowed (%s)", len(lora_ckpt_paths), + self.lora_config.max_loras) + + return LoraServingConfig( + max_loras=self.lora_config.max_loras, + max_lora_rank=self.lora_config.max_lora_rank, + target_modules=target_modules, + lora_ckpt_paths=lora_ckpt_paths, + ) def load_model(self) -> None: - self.model = get_neuron_model(self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + # Update LoRA config + if self.lora_config is not None: + self.lora_serving_config = self._get_nxdi_lora_config() + self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config) + self.model = get_neuron_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + lora_serving_config=self.lora_serving_config) def get_nxd_sampling_params(self, sampling_metadata): if self.model.config.neuron_config.on_device_sampling_config: @@ -81,42 +123,28 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): sampling_params = self.get_nxd_sampling_params( model_input.sampling_metadata) - if model_input.multi_modal_kwargs.get('image') is not None: - pixel_values = [] - aspect_ratios = [] - num_chunks = [] - has_image = [] - for multi_modal_input in model_input.multi_modal_kwargs.get( - 'image'): - image_tensors = self.get_multi_modal_data_neuron( - multi_modal_input.squeeze(0)) - pixel_values.append(image_tensors[0]) - aspect_ratios.append(image_tensors[1]) - num_chunks.append(image_tensors[2]) - has_image.append(image_tensors[3]) - - pixel_values = torch.cat(pixel_values, dim=0) - aspect_ratios = torch.cat(aspect_ratios, dim=0) - num_chunks = torch.cat(num_chunks, dim=0) - has_image = torch.cat(has_image, dim=0) - + if model_input.multi_modal_kwargs.get('pixel_values') is not None: hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, seq_ids=model_input.input_block_ids, - pixel_values=pixel_values, - aspect_ratios=aspect_ratios, + pixel_values=model_input.multi_modal_kwargs.get( + 'pixel_values'), + aspect_ratios=model_input.multi_modal_kwargs.get( + 'aspect_ratios'), sampling_params=sampling_params, - num_chunks=num_chunks, - has_image=has_image, + num_chunks=model_input.multi_modal_kwargs.get('num_chunks'), + has_image=model_input.multi_modal_kwargs.get( + 'has_image').squeeze(1), ) else: - empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], + bs = model_input.input_tokens.shape[0] if (model_input.input_tokens + is not None) else 1 + empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560], dtype=torch.bfloat16) - empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) - num_chunks = torch.tensor([[1] - ]) # dummy num_chunks, will not be used - has_image = torch.tensor([0]) + empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64) + num_chunks = torch.zeros((bs, 1), dtype=torch.int32) + has_image = torch.zeros([bs], dtype=torch.int32) hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -134,3 +162,132 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): ) return [output] + + def process_multi_modal_data_neuron(self, mm_data): + # Neuron uses aspect_ratios instead of aspect_ratio_ids + all_supported_aspect_ratios = get_all_supported_aspect_ratios( + self.model.config.vision_config.max_num_tiles) + aspect_ratio_ids = mm_data.get("aspect_ratio_ids") + mm_data["aspect_ratios"] = torch.tensor( + all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0) + + # Neuron's num_chunks is HF's num_tiles + mm_data["num_chunks"] = mm_data.get("num_tiles") + + # Input has an image if it has pixel_values + bs = mm_data["num_chunks"].shape[0] + pixel_values = mm_data.get("pixel_values") + if pixel_values is not None and not torch.all(pixel_values == 0): + mm_data["has_image"] = torch.ones(bs) + + else: + mm_data["has_image"] = torch.zeros(bs) + return mm_data + + def _get_lora_adapter_ids(self, seq_group_metadata_list): + # set LoRA adapter IDs for multi-lora serving + batch_size = len(seq_group_metadata_list) + if self.lora_checkpoint is not None: + # "0" indicates NxDI to use the base model for inference + adapter_ids = ["0"] * batch_size + for idx, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.lora_request is not None: + adapter_ids[ + idx] = seq_group_metadata.lora_request.lora_name + + # convert adapter_ids from strings to integers + adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices( + adapter_ids, batch_size) + else: + adapter_ids = torch.zeros((batch_size), dtype=torch.int32) + + return adapter_ids + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForNeuron: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs + ) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + input_block_ids) = self._prepare_decode(seq_group_metadata_list) + seq_lens = None + + if not self._on_device_sampling_disabled: + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + top_k, top_p, temperature = ( + self._convert_to_neuron_sampling_params(sampling_params)) + sampling_params.top_k = top_k + sampling_params.top_p = top_p + sampling_params.temperature = temperature + + # we need multi_modal_data for later tokens as well + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + for seq_group_metadata in seq_group_metadata_list: + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + multi_modal_kwargs_list.append(mm_data) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + + lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since neuron worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + self.pin_memory, + generators=self.get_generators(finished_requests_ids)) + + return ModelInputForNeuron(input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata, + multi_modal_kwargs=multi_modal_kwargs, + adapter_ids=lora_adapter_ids) + + def remove_all_loras(self): + raise NotImplementedError( + "Managing LoRAs is only supported through the " + "lora_modules parameter in override_neuron_config") + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + raise NotImplementedError( + "Managing LoRAs is only supported through the " + "lora_modules parameter in override_neuron_config") + + def add_lora(self, lora_request: LoRARequest): + logger.warning( + "Adding LoRAs is only supported through the " + "lora_modules parameter in override_neuron_config. If you supplied " + "the parameter, you can ignore this warning. Ignoring" + "lora request: ", lora_request) + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError( + "Managing LoRAs is only supported through the " + "lora_modules parameter in override_neuron_config") + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError( + "Managing LoRAs is only supported through the " + "lora_modules parameter in override_neuron_config") + + def list_loras(self) -> Set[int]: + raise NotImplementedError( + "Managing LoRAs is only supported through the " + "lora_modules parameter in override_neuron_config")