From ff47aab05640394a513a5b2ac772a115ddc2e05a Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 12 Mar 2025 18:41:13 +0800 Subject: [PATCH 1/9] [CPU] Upgrade CPU backend to torch-2.6 (#13381) Signed-off-by: jiang1.li Co-authored-by: Isotr0py <2037008807@qq.com> --- .buildkite/run-cpu-test.sh | 8 +++++--- Dockerfile.cpu | 2 +- cmake/cpu_extension.cmake | 2 +- requirements/cpu.txt | 2 +- tests/lora/test_qwen2vl.py | 2 +- vllm/attention/ops/ipex_attn.py | 2 +- vllm/executor/multiproc_worker_utils.py | 9 +++++---- vllm/model_executor/layers/fused_moe/layer.py | 6 +++++- vllm/platforms/cpu.py | 3 +++ 9 files changed, 23 insertions(+), 13 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f6dad818ddc05..e45e184852f29 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -19,13 +19,14 @@ remove_docker_container # Run the image, setting --shm-size=4g for tensor parallel. docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER" + --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER" docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 + --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 function cpu_tests() { set -e export NUMA_NODE=$2 + export BUILDKITE_BUILD_NUMBER=$3 # offline inference docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c " @@ -36,6 +37,7 @@ function cpu_tests() { docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e pip install -r vllm/requirements/test.txt + pip install -r vllm/requirements/cpu.txt pytest -v -s tests/models/decoder_only/language -m cpu_model pytest -v -s tests/models/embedding/language -m cpu_model pytest -v -s tests/models/encoder_decoder/language -m cpu_model @@ -85,4 +87,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE $BUILDKITE_BUILD_NUMBER" diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 08a4e188f4c14..a10090529d8a9 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li RUN echo 'ulimit -c 0' >> ~/.bashrc -RUN pip install intel_extension_for_pytorch==2.5.0 +RUN pip install intel_extension_for_pytorch==2.6.0 WORKDIR /workspace diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index ca2ffb1bc3c8c..345b75d622331 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -149,7 +149,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.6 + GIT_TAG v3.7.1 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) diff --git a/requirements/cpu.txt b/requirements/cpu.txt index ba059d3ff72ee..b4e6abb6e3d66 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -2,7 +2,7 @@ -r common.txt # Dependencies for CPUs -torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" and platform_machine != "s390x" +torch==2.6.0+cpu; platform_machine == "x86_64" torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" torch==2.7.0.dev20250304; platform_machine == "s390x" diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 90735d55be712..7bd3e3d0fe27f 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -12,7 +12,7 @@ from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=not current_platform.is_cpu()) def v1(run_with_both_engines_lora): # Simple autouse wrapper to run both engines for each test # This can be promoted up to conftest.py to run for every diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 598ceea130d97..6d96f58320c84 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -17,7 +17,7 @@ class _PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [32, 64, 80, 96, 112, 128, 256] + return [32, 64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 68a83bb610a49..74237f9eb451b 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -254,10 +254,11 @@ def _run_worker_process( # online (in situ) tuning is enabled. # Offline tuning API (record_untuned_is_enabled()) only # available in PyTorch 2.6 or later. - import torch.cuda.tunable as tunable - if (tunable.is_enabled() and tunable.tuning_is_enabled() - and not tunable.record_untuned_is_enabled()): - tunable.write_file() + if torch.cuda.is_available(): + import torch.cuda.tunable as tunable + if (tunable.is_enabled() and tunable.tuning_is_enabled() + and not tunable.record_untuned_is_enabled()): + tunable.write_file() logger.info("Worker exiting") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 51c4df9d4a5e2..2c5fa509c595d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -193,10 +193,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", **kwargs, ): - assert custom_routing_function is None assert activation == "silu", f"{activation} is not supported." return layer.ipex_fusion( x, @@ -206,6 +207,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize, topk_group, num_expert_group, + custom_routing_function, + scoring_func, + e_score_correction_bias, ) def forward_tpu( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index ab8982a3a6e1c..140335dfb64a6 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -121,6 +121,9 @@ class CpuPlatform(Platform): # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + # MLA attention is not supported + os.environ["VLLM_MLA_DISABLE"] = "1" + # Intel OpenMP setting ld_prealod_str = os.getenv("LD_PRELOAD", "") if "libiomp5.so" in ld_prealod_str: From 45f3f3f59e4898a11baf9bcb8d6ec5db2581af34 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 12 Mar 2025 05:00:28 -0700 Subject: [PATCH 2/9] [ROCm][Bugfix] Ensure that the moe_wna16_gemm kernel is not built on ROCm platforms. (#14629) Signed-off-by: Sage Moore --- CMakeLists.txt | 2 +- csrc/moe/moe_ops.h | 3 ++- csrc/moe/torch_bindings.cpp | 2 +- vllm/_custom_ops.py | 4 ++++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e028bf5951a3e..ea6d52379499e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -559,7 +559,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/moe_wna16.cu" "csrc/moe/topk_softmax_kernels.cu") set_gencode_flags_for_srcs( @@ -574,6 +573,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${VLLM_MOE_WNA16_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") + list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 371edb6495b13..0bae119a7c460 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -18,7 +18,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); - +#ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, @@ -28,3 +28,4 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); +#endif \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d2c03c4d4bef0..957ac765290c6 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -31,6 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); +#ifndef USE_ROCM m.def( "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "Tensor b_scales, Tensor? b_qzeros, " @@ -41,7 +42,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); -#ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 64175cc4e13c6..d68c097fbe84d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1146,6 +1146,10 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, num_tokens_post_pad: torch.Tensor, top_k: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, bit: int) -> torch.Tensor: + if not current_platform.is_cuda(): + raise NotImplementedError( + "The optimized moe_wna16_gemm kernel is only " + "available on CUDA platforms") torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, b_qzeros, topk_weights, sorted_token_ids, experts_ids, num_tokens_post_pad, top_k, From c0c25e25fa93ee7c3f279abbba5597c0fafa74ee Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 08:36:33 -0700 Subject: [PATCH 3/9] [Model] Add support for Gemma 3 (#14660) Signed-off-by: Woosuk Kwon Signed-off-by: Roger Wang Signed-off-by: DarkLight1337 Co-authored-by: Roger Wang Co-authored-by: DarkLight1337 --- docs/source/models/supported_models.md | 41 +- examples/offline_inference/vision_language.py | 20 +- .../vision_language_multi_image.py | 37 ++ .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 4 + vllm/config.py | 15 +- vllm/entrypoints/chat_utils.py | 2 + vllm/model_executor/models/gemma3.py | 533 ++++++++++++++++++ vllm/model_executor/models/gemma3_mm.py | 425 ++++++++++++++ vllm/model_executor/models/registry.py | 2 + 10 files changed, 1071 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/models/gemma3.py create mode 100644 vllm/model_executor/models/gemma3_mm.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index e46934b9caebe..98e7572981dee 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -263,10 +263,15 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ - * `Gemma2ForCausalLM` - * Gemma2 + * Gemma 2 * `google/gemma-2-9b`, `google/gemma-2-27b`, etc. * ✅︎ * ✅︎ +- * `Gemma3ForCausalLM` + * Gemma 3 + * `google/gemma-3-1b-it`, etc. + * ✅︎ + * ✅︎ - * `GlmForCausalLM` * GLM-4 * `THUDM/glm-4-9b-chat-hf`, etc. @@ -504,7 +509,7 @@ you should explicitly specify the task type to ensure that the model is used in * * - * `Gemma2Model` - * Gemma2-based + * Gemma 2-based * `BAAI/bge-multilingual-gemma2`, etc. * * ✅︎ @@ -752,6 +757,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Gemma3ForConditionalGeneration` + * Gemma 3 + * T + I+ + * `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. + * ✅︎ + * ✅︎ + * ✅︎\* - * `GLM4VForCausalLM`^ * GLM-4V * T + I @@ -937,6 +949,31 @@ For more details, please see: To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`. ::: +:::{note} +To use Gemma3 series models, you have to install Hugging Face Transformers library from source via +`pip install git+https://github.com/huggingface/transformers`. +The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357). + +Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. +However, there are differences in how they handle text + image inputs: + +V0 correctly implements the model's attention pattern: +- Uses bidirectional attention between the image tokens corresponding to the same image +- Uses causal attention for other tokens +- Implemented via (naive) PyTorch SDPA with masking tensors +- Note: May use significant memory for long prompts with image + +V1 currently uses a simplified attention pattern: +- Uses causal attention for all tokens, including image tokens +- Generates reasonable outputs but does not match the original model's attention for text + image inputs +- Will be updated in the future to support the correct behavior + +This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. + +Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views. +Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions. +::: + ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 716c31b96ed1c..39acab4765a30 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -118,6 +118,23 @@ def run_fuyu(questions: list[str], modality: str): return llm, prompts, stop_token_ids +# Gemma 3 +def run_gemma3(questions: list[str], modality: str): + assert modality == "image" + model_name = "google/gemma-3-4b-it" + + llm = LLM(model=model_name, + max_model_len=2048, + max_num_seqs=2, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + + prompts = [("user\n" + f"{question}\n" + "model\n") for question in questions] + stop_token_ids = None + return llm, prompts, stop_token_ids + + # GLM-4v def run_glm4v(questions: list[str], modality: str): assert modality == "image" @@ -405,7 +422,7 @@ def run_mllama(questions: list[str], modality: str): "type": "image" }, { "type": "text", - "text": f"{question}" + "text": question }] }] for question in questions] prompts = tokenizer.apply_chat_template(messages, @@ -664,6 +681,7 @@ model_example_map = { "deepseek_vl_v2": run_deepseek_vl2, "florence2": run_florence2, "fuyu": run_fuyu, + "gemma3": run_gemma3, "glm4v": run_glm4v, "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 6fdd4383c1a1e..4963e6a8c4e72 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -80,6 +80,42 @@ def load_deepseek_vl2(question: str, image_urls: list[str]): ) +def load_gemma3(question, image_urls: list[str]) -> ModelRequestData: + model_name = "google/gemma-3-4b-it" + + llm = LLM(model=model_name, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "h2oai/h2ovl-mississippi-800m" @@ -496,6 +532,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData: model_example_map = { "aria": load_aria, "deepseek_vl_v2": load_deepseek_vl2, + "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index e64b703cc5201..467114eedb01c 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -162,6 +162,7 @@ def _test_processing_correctness( "deepseek-ai/deepseek-vl2-tiny", "microsoft/Florence-2-base", "adept/fuyu-8b", + "google/gemma-3-4b-it", "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index a7a88d1990479..eadbd7e6f4927 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -124,6 +124,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), + "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", + min_transformers_version="4.50"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), @@ -241,6 +243,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it", + min_transformers_version="4.50"), "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index a0f30d0e7b776..2ee45f1837c4e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -350,10 +350,11 @@ class ModelConfig: if self.enforce_eager is None: self.enforce_eager = False + interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] sliding_window = getattr(self.hf_text_config, "sliding_window", None) has_interleaved_attention = (sliding_window is not None) and ( isinstance(sliding_window, list) or - (self.hf_text_config.model_type in ["gemma2", "cohere2"])) + (self.hf_text_config.model_type in interleaved_attn_models)) if (not self.disable_sliding_window and has_interleaved_attention): if (backend := @@ -2501,11 +2502,11 @@ def _get_and_verify_dtype( dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: - if config.model_type == "gemma2": + if config.model_type in ("gemma2", "gemma3", "gemma3_text"): logger.info( - "For Gemma 2, we downcast float32 to bfloat16 instead " - "of float16 by default. Please specify `dtype` if you " - "want to use float16.") + "For Gemma 2 and 3, we downcast float32 to bfloat16 " + "instead of float16 by default. Please specify `dtype` " + "if you want to use float16.") torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 @@ -2637,7 +2638,9 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None: + # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE + # scaling, so we skip applying the scaling factor again. + if rope_scaling is not None and "gemma3" not in hf_config.model_type: # No need to consider "type" key because of patch_rope_scaling when # loading HF config rope_type = rope_scaling["rope_type"] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b51ade17def6b..61f21482f7072 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -433,6 +433,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): return "" if model_type == "aria": return "<|fim_prefix|><|img|><|fim_suffix|>" + if model_type == "gemma3": + return "" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py new file mode 100644 index 0000000000000..f1ecf7fa821d9 --- /dev/null +++ b/vllm/model_executor/models/gemma3.py @@ -0,0 +1,533 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Gemma3TextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3Attention(nn.Module): + + def __init__(self, + config: Gemma3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # TODO(woosuk): Add reference to the original HF implementation. + layer_idx = extract_layer_index(prefix) + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + # Initialize the rotary embedding. + if self.is_sliding: + # Local attention. Override the values in config.json. + self.rope_theta = config.rope_local_base_freq + self.rope_scaling = {"rope_type": "default"} + self.sliding_window = config.interleaved_sliding_window + else: + # Global attention. Use the values in config.json. + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.sliding_window = None + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + + # Initialize the attention. + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + if not kwargs.get("has_images", False): + # Fast path for text-only inputs. The performance for the text-only + # inputs are not affected by the naive attention below. + output, _ = self.o_proj(attn_output) + return output + + # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens + # that correspond to the same image while using causal attention + # otherwise. Current attention backends cannot handle this pattern, so + # we temporarily use a naive attention implementation with mask tensors. + + # We intentionally keep the attention backend as-is and only override + # `attn_output` with the naive implementation's output. This minimizes + # changes to existing model runners and attention backends. The call to + # `self.attn(q, k, v)` is only used to populate the KV cache - its + # output is discarded and overwritten below. While this duplicates + # computation, it maintains compatibility. + # TODO(woosuk): Optimize by implementing custom attention kernels. + attn_output = self.naive_attn_with_masks(q, + k, + v, + out=attn_output, + **kwargs) + output, _ = self.o_proj(attn_output) + return output + + def naive_attn_with_masks( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # NOTE(woosuk): As described in the comment above, this code is not + # meant to be performant. It is only meant to be correct. + q = q.view(-1, self.num_heads, self.head_dim) + # Expand the key and value to handle GQA. + num_queries_per_kv = self.num_heads // self.num_kv_heads + k = k.view(-1, self.num_kv_heads, self.head_dim) + k = k.repeat_interleave(num_queries_per_kv, dim=-2) + v = v.view(-1, self.num_kv_heads, self.head_dim) + v = v.repeat_interleave(num_queries_per_kv, dim=-2) + + if self.is_sliding: + attn_masks = kwargs["local_attn_masks"] + else: + attn_masks = kwargs["global_attn_masks"] + + seq_lens = kwargs["seq_lens"] + start_idx = 0 + for seq_len, attn_mask in zip(seq_lens, attn_masks): + end_idx = start_idx + seq_len + query = q[start_idx:end_idx].unsqueeze(0) + key = k[start_idx:end_idx].unsqueeze(0) + value = v[start_idx:end_idx].unsqueeze(0) + + # Transpose. + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask, + self.scaling, + ) + output = output.transpose(1, 2).flatten(-2, -1) + out[start_idx:end_idx] = output + start_idx = end_idx + return out + + +class Gemma3DecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma3TextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=None, + prefix=f"{prefix}.self_attn", + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Gemma3Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # NOTE(woosuk): Only apply the normalizer to the output of + # vocab embedding. Don't apply it to the vision embedding. + return self.embed_tokens(input_ids) * self.normalizer + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + **kwargs, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) + return loaded_params + + +class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py new file mode 100644 index 0000000000000..121aee51786b8 --- /dev/null +++ b/vllm/model_executor/models/gemma3_mm.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set, + Tuple, TypedDict, Union) + +import torch +from torch import nn +from transformers import BatchFeature, Gemma3Config, ProcessorMixin + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +class Gemma3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +Gemma3ImageInputs = Gemma3ImagePixelInputs + + +class Gemma3ProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + hf_config = self.ctx.get_hf_config() + return {"image": hf_config.mm_tokens_per_image} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin], + ) -> int: + hf_config = self.ctx.get_hf_config() + return hf_config.mm_tokens_per_image + + def get_image_size_with_most_features(self) -> ImageSize: + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=8000, width=50) + + +class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + tokenizer = self.info.get_tokenizer() + boi_token = tokenizer.boi_token + + num_images = mm_counts.get("image", 0) + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + return ProcessorInputs( + prompt_text=" ".join([boi_token] * num_images), + mm_data=mm_data, + ) + + +class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # TODO(woosuk): Support pan-and-scan. + img_kwargs = mm_kwargs.get("images_kwargs", {}) + img_kwargs["do_pan_and_scan"] = False + mm_kwargs["images_kwargs"] = img_kwargs + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + hf_config = self.info.get_hf_config() + + boi_token = tokenizer.boi_token + image_token = tokenizer.image_token + mm_tokens_per_image = hf_config.mm_tokens_per_image + image_tokens_expanded = "".join([image_token] * mm_tokens_per_image) + + def get_replacement_gemma3(item_idx: int): + return PromptUpdateDetails( + full=hf_processor.full_image_sequence, + features=image_tokens_expanded, + ) + + return [ + PromptReplacement( + modality="image", + target=boi_token, + replacement=get_replacement_gemma3, + ) + ] + + +class Gemma3MultiModalProjector(nn.Module): + + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, + config.text_config.hidden_size)) + + self.mm_soft_emb_norm = GemmaRMSNorm( + config.vision_config.hidden_size, + eps=config.vision_config.layer_norm_eps) + + self.patches_per_image = int(config.vision_config.image_size // + config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, + stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, + self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder) +class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.sliding_window = config.text_config.interleaved_sliding_window + + self.vision_tower = SiglipVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_tower")) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma3ForCausalLM"], + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + if d.shape != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_dims}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Gemma3ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma3 does not support image_embeds." + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + return Gemma3ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) + return image_features + + def _process_image_input( + self, + image_input: Gemma3ImageInputs, + ) -> torch.Tensor: + assert self.vision_tower is not None + pixel_values = image_input["data"] + vision_outputs = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + return self.multi_modal_projector(vision_outputs) + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + if multimodal_embeddings is None: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + else: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + if vision_embeddings is not None: + kwargs = self.prepare_attn_masks( + input_ids, + positions, + mask_dtype=vision_embeddings.dtype, + **kwargs) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs) + + return hidden_states + + def prepare_attn_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + **kwargs, + ): + kwargs["has_images"] = True + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. + start_idices = (positions == 0).cpu().nonzero() + num_seqs = len(start_idices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_idices[i].item() + if i < num_seqs - 1: + end_idx = start_idices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + kwargs["seq_lens"] = seq_lens + + global_attn_masks = [] + local_attn_masks = [] + start_idx = 0 + for seq_len in seq_lens: + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + start_idx = end_idx + # Create a global causal mask. + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0. + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Consider the bidirectional attention between image tokens. + img_mask = torch.zeros_like(global_attn_mask) + img_pos = (input_token_ids == self.config.image_token_index) + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) + local_attn_masks.append(local_attn_mask) + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + return kwargs + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 74160e2d9ee40..5dd3aa2973cd9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = { "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), @@ -161,6 +162,7 @@ _MULTIMODAL_MODELS = { "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), From 4a754fcf15993c53309148ababa230df870aa47b Mon Sep 17 00:00:00 2001 From: ameyanjarlekar <40833548+ameyanjarlekar@users.noreply.github.com> Date: Wed, 12 Mar 2025 08:50:49 -0700 Subject: [PATCH 4/9] [Bugfix] Missing thumbnail from NVLM-D processor (#14633) Signed-off-by: ameyanjarlekar --- vllm/model_executor/models/nvlm_d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 1e1760491a974..0f5cbf082d9d4 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -45,7 +45,7 @@ class NVLMProcessor(BaseInternVLProcessor): raise NotImplementedError("Embedding inputs are not supported") tile_pos_identifiers = [f"" for i in range(1, num_patches)] - if self.use_thumbnail and num_patches != 1: + if self.use_thumbnail: tile_pos_identifiers += [""] context_size = feature_size // num_patches From d9f83d62068baa8ac4923cdb45f7945c7d6651a3 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 12 Mar 2025 08:51:20 -0700 Subject: [PATCH 5/9] [ROCm] Enable chunked prefill/paged attention in MLA on ROCm (#14316) Signed-off-by: Sage Moore --- vllm/attention/backends/mla/common.py | 18 ++---------------- vllm/config.py | 4 ++-- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index e912b1e9757a5..fc5f3420e394d 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1327,21 +1327,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): [0, q.shape[-1] - v.shape[-1]], value=0) - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: - attn_output, attn_softmax_lse = self.triton_fa_func( - q, - k, - v_padded, - None, - prefill_metadata.query_start_loc, - prefill_metadata.context_chunk_cu_seq_lens[i], - prefill_metadata.max_query_len, - prefill_metadata.context_chunk_max_seq_lens[i], - False, # causal - self.scale, - None, # attn_mask is None unless applying ALiBi mask - ) - elif is_vllm_fa: + if is_vllm_fa: attn_output, attn_softmax_lse = self.flash_attn_varlen_func( q=q, k=k, @@ -1416,7 +1402,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: output = self.triton_fa_func( q, k, diff --git a/vllm/config.py b/vllm/config.py index 2ee45f1837c4e..b61d1a22c8a08 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3450,9 +3450,9 @@ class VllmConfig: self.compilation_config.level = CompilationLevel.NO_COMPILATION if self.model_config and self.model_config.use_mla and \ - not current_platform.is_cuda(): + not (current_platform.is_cuda() or current_platform.is_rocm()): logger.info( - "MLA is enabled on a non-cuda platform; forcing chunked " + "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.chunked_prefill_enabled = False From 916836bbfb7ef52077e78602439a944298dbf886 Mon Sep 17 00:00:00 2001 From: TJian Date: Thu, 13 Mar 2025 00:31:19 +0800 Subject: [PATCH 6/9] [FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash Attention to enable embedding models. (#14664) Signed-off-by: tjtanaa --- CMakeLists.txt | 4 + csrc/moe/torch_bindings.cpp | 1 + .../embedding/language/test_cls_models.py | 17 ++- .../embedding/language/test_embedding.py | 13 +- .../models/embedding/language/test_gritlm.py | 4 +- .../vision_language/test_llava_next.py | 17 +++ vllm/attention/backends/rocm_flash_attn.py | 112 +++++++++++------- 7 files changed, 118 insertions(+), 50 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ea6d52379499e..5baa39b6f9e59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -561,6 +561,10 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") +if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") +endif() + set_gencode_flags_for_srcs( SRCS "${VLLM_MOE_EXT_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 957ac765290c6..718418e6cd497 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -52,6 +52,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); // conditionally compiled so impl registration is in source file + #endif } diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py index b0420ff5cc78c..c6155da50b585 100644 --- a/tests/models/embedding/language/test_cls_models.py +++ b/tests/models/embedding/language/test_cls_models.py @@ -7,6 +7,8 @@ import pytest import torch from transformers import AutoModelForSequenceClassification +from vllm.platforms import current_platform + @pytest.mark.parametrize( "model", @@ -15,14 +17,21 @@ from transformers import AutoModelForSequenceClassification marks=[pytest.mark.core_model, pytest.mark.cpu_model]), ], ) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", + ["half"] if current_platform.is_rocm() else ["float"]) def test_classification_models( hf_runner, vllm_runner, example_prompts, model: str, dtype: str, + monkeypatch, ) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) @@ -43,4 +52,8 @@ def test_classification_models( hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, 1e-3) + # the tolerance value of 1e-2 is selected based on the + # half datatype tests in + # tests/models/embedding/language/test_embedding.py + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index a8ac70d58e6ea..6c28ee91a50ad 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -6,6 +6,7 @@ Run `pytest tests/models/embedding/language/test_embedding.py`. import pytest from vllm.config import PoolerConfig +from vllm.platforms import current_platform from ..utils import check_embeddings_close @@ -18,15 +19,15 @@ from ..utils import check_embeddings_close marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), + pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), pytest.param("intfloat/e5-mistral-7b-instruct", marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), - pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), - # [Encoder-decoder] + # [Cross-Encoder] pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) @@ -37,11 +38,19 @@ def test_models( example_prompts, model, dtype: str, + monkeypatch, ) -> None: + + if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": vllm_extra_kwargs["override_pooler_config"] = \ PoolerConfig(pooling_type="MEAN") + if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": vllm_extra_kwargs["hf_overrides"] = {"is_causal": True} diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 470dc04107764..cae3e1a5c6244 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -15,8 +15,8 @@ import vllm.config from ....utils import RemoteOpenAIServer # GritLM embedding implementation is only supported by XFormers backend. -pytest.mark.skipif(not importlib.util.find_spec("xformers"), - reason="GritLM requires XFormers") +pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"), + reason="GritLM requires XFormers") MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index 4c2fbd526ed1e..8b9a856d005e4 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -4,10 +4,27 @@ import pytest import torch.nn.functional as F from transformers import AutoModelForVision2Seq +from vllm.platforms import current_platform + from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from ....utils import large_gpu_test from ..utils import check_embeddings_close +# Llava Next embedding implementation is only supported by CUDA. +# If run on ROCm, hf_model.model.resize_token_embeddings will +# cause the following error: +# RuntimeError: Calling torch.linalg.cholesky on a CUDA tensor +# requires compiling PyTorch with MAGMA. Please use PyTorch +# built with MAGMA support. +# If run on CPU, hf_model.model.resize_token_embeddings will +# cause the following error: +# RuntimeError: Calling torch.linalg.cholesky on a CPU tensor +# requires compiling PyTorch with LAPACK. Please use PyTorch +# built with LAPACK support. +pytestmark = pytest.mark.skipif( + not current_platform.is_cuda(), + reason="Llava Next model uses op that is only supported in CUDA") + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 HF_TEXT_PROMPTS = [ diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02a2a48fe8593..c47202099ac60 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer ROCm GPUs.""" +import itertools from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -342,28 +343,27 @@ def _get_seq_len_block_table_args( Decoder attn -> select entirely decoder self-attention-related fields Encoder/decoder cross-attn -> select encoder sequence lengths Encoder attn -> select encoder sequence lengths fields + Encoder-only attn -> select prefill sequence lengths with + bidirectional attention Arguments: * attn_metadata: Attention metadata structure associated with attention op * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention + encoder/decoder cross-attention, encoder-only Returns: * Appropriate sequence-lengths tensors for query and key * Appropriate max sequence-length scalar + * Causal masking flag ''' - partial_prefix_sum = 0 if attn_type == AttentionType.ENCODER: assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.encoder_seq_lens_tensor is not None query_seq_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), device=attn_metadata.encoder_seq_lens_tensor.device, dtype=attn_metadata.encoder_seq_lens_tensor.dtype) causal_mask = False @@ -372,16 +372,29 @@ def _get_seq_len_block_table_args( return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, query_seq_start_loc, attn_metadata.max_encoder_seq_len, attn_metadata.encoder_seq_lens, causal_mask) + + elif attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, we use the prefill sequence lengths + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + # Encoder-only models typically use bidirectional attention + causal_mask = False + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens_tensor is not None query_seq_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.seq_lens)), device=attn_metadata.seq_lens_tensor.device, dtype=attn_metadata.seq_lens_tensor.dtype) max_seq_len = attn_metadata.max_prefill_seq_len @@ -393,21 +406,14 @@ def _get_seq_len_block_table_args( assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens_tensor is not None query_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.seq_lens)), device=attn_metadata.encoder_seq_lens_tensor.device, dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - partial_prefix_sum = 0 assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.seq_lens_tensor is not None key_seq_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), device=attn_metadata.seq_lens_tensor.device, dtype=attn_metadata.seq_lens_tensor.dtype) causal_mask = False @@ -584,6 +590,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): will match encoder sequence lengths, pass encoder sequence attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ max_encoder_seq_len) + * ENCODER_ONLY: bidirectional attention with no KV caching; + use prefill sequence attributes Args: query: shape = [num_tokens, num_heads * head_size] @@ -608,7 +616,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: assert value is None - if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: + # Only update KV cache for decoder self-attention + # and encoder-decoder cross-attention + if self.attn_type not in [ + AttentionType.ENCODER, AttentionType.ENCODER_ONLY + ] and kv_cache.numel() > 0: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -632,6 +644,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): if self.attn_type != AttentionType.ENCODER: num_prefill_tokens = attn_metadata.num_prefill_tokens + elif self.attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, all tokens are processed in one go + num_prefill_tokens = query.shape[0] else: assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens @@ -642,8 +657,13 @@ class ROCmFlashAttentionImpl(AttentionImpl): # QKV for prefill. query = query[:num_prefill_tokens] + # For encoder-only and encoder models, + # we process all tokens at once + # For decoder and encoder-decoder, + # we may need to limit key/value to prefill tokens if key is not None and value is not None \ - and self.attn_type != AttentionType.ENCODER_DECODER: + and self.attn_type not in [AttentionType.ENCODER_DECODER, + AttentionType.ENCODER_ONLY]: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -678,7 +698,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.alibi_slopes, query.dtype, seq_lens, - make_attn_mask=False) # type: ignore + make_attn_mask=causal_mask) # type: ignore out, _ = self.attn_func( query, key, @@ -703,7 +723,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.alibi_slopes, query.dtype, attn_metadata.seq_lens, - make_attn_mask=True) # type: ignore + make_attn_mask=causal_mask) # type: ignore query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) @@ -729,7 +749,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): max_seqlen_q=prefill_meta.max_prefill_seq_len, max_seqlen_k=key_max_seq_len, softmax_scale=self.scale, - causal=True, + causal=causal_mask, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, @@ -742,25 +762,29 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: output = out else: - # prefix-enabled attention - output[:num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) - - if decode_meta := attn_metadata.decode_metadata: + # prefix-enabled attention - + # not applicable for encoder-only models + if self.attn_type != AttentionType.ENCODER_ONLY: + output[: + num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) + # Skip decode phase for encoder-only models + if (decode_meta := attn_metadata.decode_metadata) and ( + self.attn_type != AttentionType.ENCODER_ONLY): # Decoding run. # Whether to use rocm custom paged attention or not num_seqs, num_heads, head_size = decode_query.shape @@ -885,4 +909,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) \ No newline at end of file + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) From f5d3acd47466f094beb36f7a5d05520466713f93 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 12 Mar 2025 13:29:48 -0400 Subject: [PATCH 7/9] [BugFix][V1] Fix parallel sampling finishing/aborts (#14512) Signed-off-by: Nick Hill --- tests/v1/engine/test_async_llm.py | 56 ++++++++++++++-- .../v1/entrypoints/openai/test_completion.py | 21 ++++-- vllm/outputs.py | 64 ++++++------------- vllm/v1/engine/async_llm.py | 3 +- vllm/v1/engine/llm_engine.py | 2 +- vllm/v1/engine/output_processor.py | 49 ++++++++------ vllm/v1/engine/parallel_sampling.py | 55 +++++++--------- 7 files changed, 137 insertions(+), 113 deletions(-) diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 0de0026eb2842..5b9725d59ddc5 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -46,6 +46,7 @@ async def generate(engine: AsyncLLM, prompt: PromptType, output_kind: RequestOutputKind, max_tokens: int, + n: int = 1, prompt_logprobs: Optional[int] = None) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) @@ -54,13 +55,15 @@ async def generate(engine: AsyncLLM, sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True, output_kind=output_kind, - temperature=0, + temperature=0.5, + seed=33, + n=n, prompt_logprobs=prompt_logprobs) async for out in engine.generate(request_id=request_id, prompt=prompt, sampling_params=sampling_params): - num_tokens = len(out.outputs[0].token_ids) + num_tokens = sum(len(output.token_ids) for output in out.outputs) if output_kind == RequestOutputKind.DELTA: count += num_tokens else: @@ -136,17 +139,22 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind, NUM_REQUESTS = 100 NUM_EXPECTED_TOKENS = 100 + NUM_EXPECTED_TOKENS_LONG = 50000 REQUEST_IDS_TO_ABORT = range(1, 100, 10) + PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15) request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] # Create concurrent requests. tasks: list[asyncio.Task] = [] - for request_id in request_ids: + for idx, request_id in enumerate(request_ids): + max_tokens = NUM_EXPECTED_TOKENS_LONG if ( + idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 tasks.append( asyncio.create_task( generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS))) + max_tokens, n))) # API server cancels requests when they disconnect. for idx in REQUEST_IDS_TO_ABORT: @@ -162,10 +170,13 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind, else: # Otherwise, make sure the request was not impacted. num_generated_tokens, request_id = await task - assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + expected_tokens = NUM_EXPECTED_TOKENS * n + assert num_generated_tokens == expected_tokens, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") + f"expected {expected_tokens}") + # Make sure all aborted requests were really aborted. assert not engine.output_processor.has_unfinished_requests() # Confirm we can do another generation. @@ -176,3 +187,36 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind, num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() + + +@pytest.mark.parametrize("n", [1, 3]) +@pytest.mark.parametrize("engine_args_and_prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), + (VISION_ENGINE_ARGS, VISION_PROMPT)]) +@pytest.mark.asyncio +async def test_finished_flag(monkeypatch, n: int, + engine_args_and_prompt: tuple[AsyncEngineArgs, + PromptType]): + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + engine_args, prompt = engine_args_and_prompt + + engine = AsyncLLM.from_engine_args(engine_args) + after.callback(engine.shutdown) + + sampling_params = SamplingParams(max_tokens=100, + output_kind=RequestOutputKind.DELTA, + temperature=1.0, + seed=33, + n=n) + outputs = [ + out + async for out in engine.generate(request_id="request-33", + prompt=prompt, + sampling_params=sampling_params) + ] + + # Assert only the last output has the finished flag set + assert all(not out.finished for out in outputs[:-1]) + assert outputs[-1].finished diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 171c84176eae7..57ca99e1f68c6 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -263,15 +263,16 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, prompt = "What is an LLM?" n = 3 - max_tokens = 5 + max_tokens = 50 # we want some to finish earlier than others # High temperature to maximize chance of unique completions. completion = await client.completions.create(model=model_name, prompt=prompt, max_tokens=max_tokens, n=n, - temperature=0.95, + temperature=1.0, stream=False, + logprobs=0, seed=42) # Assert `n` completions @@ -279,6 +280,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, assert num_completions == n, ( f"Num completions {num_completions} but expected {n}.") completion_repeats: dict[str, int] = {} + output_token_lengths = set() for idx, choice in enumerate(completion.choices): # Assert correct completion index & some finish reason. assert choice.index == idx, ( @@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, "None finish_reason is invalid.") text = choice.text completion_repeats[text] = completion_repeats.get(text, 0) + 1 + output_token_lengths.add(len(choice.logprobs.tokens)) + # Assert subrequests finished at different times + assert len(output_token_lengths) > 1 # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: @@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): prompt = "What is an LLM?" n = 3 - max_tokens = 5 + max_tokens = 50 # we want some to finish earlier than others stream = await client.completions.create(model=model_name, prompt=prompt, max_tokens=max_tokens, n=n, - temperature=0.95, + temperature=1.0, stream=True, seed=42) - chunks: list[list[str]] = [[] for i in range(n)] + chunks: list[list[str]] = [[] for _ in range(n)] finish_reason_count = 0 async for chunk in stream: index = chunk.choices[0].index @@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): assert finish_reason_count == n, ( f"Expected {n} completions with valid indices and finish_reason.") completion_repeats: dict[str, int] = {} + chunk_lengths = set() for chunk in chunks: chunk_len = len(chunk) # Assert correct number of completion tokens - assert chunk_len == max_tokens, ( + chunk_lengths.add(chunk_len) + assert chunk_len <= max_tokens, ( f"max_tokens={max_tokens} but chunk len is {chunk_len}.") text = "".join(chunk) completion_repeats[text] = completion_repeats.get(text, 0) + 1 print(text) + # Assert subrequests finished at different times + assert len(chunk_lengths) > 1 # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: diff --git a/vllm/outputs.py b/vllm/outputs.py index 8c355c89e3e9b..7a20c340edcf7 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -134,57 +134,29 @@ class RequestOutput: self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - @classmethod - def new( - cls, - request_id: str, - prompt: Optional[str], - prompt_token_ids: Optional[list[int]], - text: str, - token_ids: list[int], - logprobs: Optional[SampleLogprobs], - prompt_logprobs: Optional[PromptLogprobs], - cumulative_logprob: Optional[float], - finished: bool = False, - ) -> "RequestOutput": - """Initialize a new RequestOutput object.""" - - # TODO: Support `n` > 1. - completion_output = CompletionOutput( - index=0, - text=text, - token_ids=token_ids, - cumulative_logprob=cumulative_logprob, - logprobs=logprobs) - - return RequestOutput( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=prompt_logprobs, - outputs=[completion_output], - finished=finished, - ) - def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" - self.prompt = next_output.prompt - self.prompt_token_ids = next_output.prompt_token_ids - self.prompt_logprobs = next_output.prompt_logprobs self.finished |= next_output.finished - #TODO assuming n == 1 for now - completion = self.outputs[0] - next_completion = next_output.outputs[0] - completion.text += next_completion.text - if not isinstance(completion.token_ids, MutableSequence): - completion.token_ids = list(completion.token_ids) - completion.token_ids.extend(next_completion.token_ids) - if next_completion.logprobs: - assert completion.logprobs is not None - completion.logprobs.extend(next_completion.logprobs) - completion.cumulative_logprob = next_completion.cumulative_logprob + for next_completion in next_output.outputs: + for completion in self.outputs: + if completion.index == next_completion.index: + # Merge outputs with same index + completion.text += next_completion.text + if not isinstance(completion.token_ids, MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend(next_completion.logprobs) + completion.cumulative_logprob = ( + next_completion.cumulative_logprob) + completion.finish_reason = next_completion.finish_reason + completion.stop_reason = next_completion.stop_reason + break + else: + self.outputs.append(next_completion) @classmethod def from_seq_group( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3dc513a728339..05633352be6c0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -298,9 +298,8 @@ class AsyncLLM(EngineClient): async def abort(self, request_id: str) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = [request_id] + request_ids = self.output_processor.abort_requests((request_id, )) await self.engine_core.abort_requests_async(request_ids) - self.output_processor.abort_requests(request_ids) if self.log_requests: logger.info("Aborted request %s.", request_id) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 213faaa451605..d56aee1accc2d 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -137,8 +137,8 @@ class LLMEngine: def abort_request(self, request_ids: list[str]) -> None: """Remove request_ids from EngineCore and Detokenizer.""" + request_ids = self.output_processor.abort_requests(request_ids) self.engine_core.abort_requests(request_ids) - self.output_processor.abort_requests(request_ids) def add_request( self, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index aea526188a8f5..83180b66bea0d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +from collections.abc import Iterable from dataclasses import dataclass from typing import Optional, Union @@ -102,8 +103,7 @@ class RequestState: ) -> Optional[RequestOutput]: finished = finish_reason is not None - output_kind = self.output_kind - final_only = output_kind == RequestOutputKind.FINAL_ONLY + final_only = self.output_kind == RequestOutputKind.FINAL_ONLY # In follow up, we will switch to invariant where EngineCore # does not stream partial prefills. @@ -111,24 +111,24 @@ class RequestState: # Only the final output is required in FINAL_ONLY mode. return None - def new_request_output(request_id: str) -> RequestOutput: - return self._new_request_output(request_id, finished) - completion_output = self._new_completion_output( new_token_ids, finish_reason, stop_reason) - if self.parent_req is not None: - return self.parent_req.make_request_output(final_only, - completion_output, - new_request_output) + request_id = self.request_id + if self.parent_req is None: + outputs = [completion_output] + else: + request_id, outputs, finished = self.parent_req.get_outputs( + request_id, completion_output) + if not outputs: + return None - request_output = new_request_output(self.request_id) - request_output.outputs.append(completion_output) - return request_output + return self._new_request_output(request_id, outputs, finished) def _new_request_output( self, request_id: str, + outputs: list[CompletionOutput], finished: bool, ) -> RequestOutput: @@ -143,7 +143,7 @@ class RequestState: prompt=self.prompt, prompt_token_ids=self.prompt_token_ids, prompt_logprobs=prompt_logprobs, - outputs=[], + outputs=outputs, finished=finished, ) @@ -188,6 +188,7 @@ class OutputProcessor: self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} + self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() def get_num_unfinished_requests(self): @@ -198,14 +199,20 @@ class OutputProcessor: def abort_requests( self, - request_ids: list[str], - ) -> None: + request_ids: Iterable[str], + ) -> list[str]: + request_ids_to_abort = [] for request_id in request_ids: req_state = self.request_states.pop(request_id, None) if req_state is not None: self.lora_states.abort_request(req_state) - if req_state.parent_req is not None: - req_state.parent_req.finish_child_request(request_id) + request_ids_to_abort.append(request_id) + else: + parent = self.parent_requests.pop(request_id, None) + if parent and parent.child_requests: + self.abort_requests(parent.child_requests) + request_ids_to_abort.extend(parent.child_requests) + return request_ids_to_abort def add_request( self, @@ -227,6 +234,8 @@ class OutputProcessor: log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) + if parent_req: + self.parent_requests[parent_req.request_id] = parent_req def process_outputs( self, @@ -314,12 +323,14 @@ class OutputProcessor: # Free completed requests. if finish_reason is not None: self.request_states.pop(req_id) + # Remove parent request if applicable. + parent_req = req_state.parent_req + if parent_req and not parent_req.child_requests: + self.parent_requests.pop(parent_req.request_id, None) if not engine_core_output.finished: # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) - if req_state.parent_req is not None: - req_state.parent_req.finish_child_request(req_id) # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 4e2c78173b513..0eeca657406e5 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import Callable, Optional, Union +from typing import Optional, Union -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import CompletionOutput from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.v1.metrics.stats import IterationStats @@ -23,7 +23,7 @@ class ParentRequest: child_requests: set[str] # To aggregate child completions when not streaming - output_aggregator: Optional[RequestOutput] + output_aggregator: list[CompletionOutput] # To find the max number of generated tokens across all children max_num_generation_tokens: int @@ -37,7 +37,9 @@ class ParentRequest: self.sampling_params = sampling_params self.child_requests = set() - self.output_aggregator = None + self.output_aggregator = [None] * sampling_params.n if ( + sampling_params.output_kind + == RequestOutputKind.FINAL_ONLY) else [] self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @@ -93,43 +95,30 @@ class ParentRequest: """ child_req_id = f"{index}_{self.request_id}" self.child_requests.add(child_req_id) - return (child_req_id, self._get_child_sampling_params(index)) - - def finish_child_request(self, req_id: str): - self.child_requests.remove(req_id) + return child_req_id, self._get_child_sampling_params(index) @property def n(self) -> int: return self.sampling_params.n - def make_request_output( + def get_outputs( self, - final_only: bool, + child_request_id: str, completion_output: CompletionOutput, - new_request_output: Callable[[str], RequestOutput], - ) -> Optional[RequestOutput]: - # Use an existing RequestOutput if we're aggregating - request_output = self.output_aggregator + ) -> tuple[str, list[CompletionOutput], bool]: + if completion_output.finished(): + self.child_requests.remove(child_request_id) - # Make new RequestOutput otherwise - if request_output is None: - request_output = new_request_output(self.request_id) + if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY: + # If streaming, just return the current output. + outputs = [completion_output] + else: + # If not streaming, aggregate the n final outputs. + self.output_aggregator[completion_output.index] = completion_output + outputs = [] if self.child_requests else self.output_aggregator - # Add a new completion - request_output.outputs.append(completion_output) - - # If not streaming, aggregate until all child requests complete - if final_only and len(request_output.outputs) != self.n: - self.output_aggregator = request_output - return None - - # We're done aggregating - self.output_aggregator = None - - # Parent completion output list must be sorted by index - request_output.outputs = sorted(request_output.outputs, - key=lambda x: x.index) - return request_output + finished = not self.child_requests + return self.request_id, outputs, finished def observe_num_generation_tokens(self, num_generation_tokens: int): self.max_num_generation_tokens = max(num_generation_tokens, From 53be4a863486d02bd96a59c674bbec23eec508f6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 11:21:19 -0700 Subject: [PATCH 8/9] [V1] Allow sliding window + prefix caching (#13069) Signed-off-by: Woosuk Kwon --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index b61d1a22c8a08..aa8b16920a97f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1146,7 +1146,7 @@ class CacheConfig: if not self.enable_prefix_caching: return - if self.sliding_window is not None: + if self.sliding_window is not None and not envs.VLLM_USE_V1: raise NotImplementedError( "Prefix caching is not supported with sliding window. " "Run with --disable-sliding-window to use prefix caching.") From ce20124671cf4580627089e02f391cc95747939f Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Wed, 12 Mar 2025 15:35:18 -0700 Subject: [PATCH 9/9] [release] Add force remove for TPU logs (#14697) --- .buildkite/release-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 9354bdd8a4464..096a1c870c6ba 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -57,8 +57,8 @@ steps: agents: queue: tpu_queue_postmerge commands: - - "rm /var/log/syslog" - - "rm /var/log/kern.log" + - "rm -f /var/log/syslog" + - "rm -f /var/log/kern.log" - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f Dockerfile.tpu ." - "docker push vllm/vllm-tpu:nightly" - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"