mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 01:37:04 +08:00
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/dbo-full-cudagraphs
This commit is contained in:
commit
e283eff060
@ -82,7 +82,7 @@ steps:
|
|||||||
- bash standalone_tests/python_only_compile.sh
|
- bash standalone_tests/python_only_compile.sh
|
||||||
|
|
||||||
- label: Basic Correctness Test # 30min
|
- label: Basic Correctness Test # 30min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
fast_check: true
|
fast_check: true
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -99,7 +99,7 @@ steps:
|
|||||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||||
|
|
||||||
- label: Chunked Prefill Test
|
- label: Chunked Prefill Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/basic_correctness/test_chunked_prefill
|
- tests/basic_correctness/test_chunked_prefill
|
||||||
@ -108,7 +108,7 @@ steps:
|
|||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
|
||||||
- label: Core Test # 10min
|
- label: Core Test # 10min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
fast_check: true
|
fast_check: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/core
|
- vllm/core
|
||||||
@ -209,7 +209,7 @@ steps:
|
|||||||
- pytest -v -s distributed/test_eplb_execute.py
|
- pytest -v -s distributed/test_eplb_execute.py
|
||||||
|
|
||||||
- label: Metrics, Tracing Test # 10min
|
- label: Metrics, Tracing Test # 10min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -228,7 +228,7 @@ steps:
|
|||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
|
|
||||||
- label: Regression Test # 5min
|
- label: Regression Test # 5min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/test_regression
|
- tests/test_regression
|
||||||
@ -280,7 +280,7 @@ steps:
|
|||||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||||
|
|
||||||
- label: Examples Test # 25min
|
- label: Examples Test # 25min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/entrypoints
|
- vllm/entrypoints
|
||||||
@ -305,7 +305,7 @@ steps:
|
|||||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||||
|
|
||||||
- label: Prefix Caching Test # 9min
|
- label: Prefix Caching Test # 9min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/prefix_caching
|
- tests/prefix_caching
|
||||||
@ -314,7 +314,7 @@ steps:
|
|||||||
|
|
||||||
|
|
||||||
- label: Platform Tests (CUDA)
|
- label: Platform Tests (CUDA)
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/cuda
|
- tests/cuda
|
||||||
@ -353,9 +353,10 @@ steps:
|
|||||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
- pytest -v -s compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s compile/test_async_tp.py
|
- pytest -v -s compile/test_async_tp.py
|
||||||
|
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -368,7 +369,7 @@ steps:
|
|||||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Test # 18min
|
- label: PyTorch Fullgraph Test # 18min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -377,7 +378,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_full_graph.py
|
- pytest -v -s compile/test_full_graph.py
|
||||||
|
|
||||||
- label: Kernels Core Operation Test
|
- label: Kernels Core Operation Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- tests/kernels/core
|
- tests/kernels/core
|
||||||
@ -416,7 +417,7 @@ steps:
|
|||||||
parallelism: 2
|
parallelism: 2
|
||||||
|
|
||||||
- label: Kernels Mamba Test
|
- label: Kernels Mamba Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/mamba/
|
- csrc/mamba/
|
||||||
- tests/kernels/mamba
|
- tests/kernels/mamba
|
||||||
@ -424,7 +425,7 @@ steps:
|
|||||||
- pytest -v -s kernels/mamba
|
- pytest -v -s kernels/mamba
|
||||||
|
|
||||||
- label: Tensorizer Test # 11min
|
- label: Tensorizer Test # 11min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
soft_fail: true
|
soft_fail: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/model_loader
|
- vllm/model_executor/model_loader
|
||||||
@ -437,7 +438,7 @@ steps:
|
|||||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||||
|
|
||||||
- label: Model Executor Test
|
- label: Model Executor Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor
|
- vllm/model_executor
|
||||||
- tests/model_executor
|
- tests/model_executor
|
||||||
@ -447,7 +448,7 @@ steps:
|
|||||||
- pytest -v -s model_executor
|
- pytest -v -s model_executor
|
||||||
|
|
||||||
- label: Benchmarks # 9min
|
- label: Benchmarks # 9min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
working_dir: "/vllm-workspace/.buildkite"
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- benchmarks/
|
- benchmarks/
|
||||||
@ -455,7 +456,7 @@ steps:
|
|||||||
- bash scripts/run-benchmarks.sh
|
- bash scripts/run-benchmarks.sh
|
||||||
|
|
||||||
- label: Benchmarks CLI Test # 10min
|
- label: Benchmarks CLI Test # 10min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/benchmarks/
|
- tests/benchmarks/
|
||||||
@ -494,7 +495,7 @@ steps:
|
|||||||
- pytest -s entrypoints/openai/correctness/
|
- pytest -s entrypoints/openai/correctness/
|
||||||
|
|
||||||
- label: Encoder Decoder tests # 5min
|
- label: Encoder Decoder tests # 5min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/encoder_decoder
|
- tests/encoder_decoder
|
||||||
@ -502,7 +503,7 @@ steps:
|
|||||||
- pytest -v -s encoder_decoder
|
- pytest -v -s encoder_decoder
|
||||||
|
|
||||||
- label: OpenAI-Compatible Tool Use # 20 min
|
- label: OpenAI-Compatible Tool Use # 20 min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
fast_check: false
|
fast_check: false
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -623,7 +624,7 @@ steps:
|
|||||||
|
|
||||||
# This test is used only in PR development phase to test individual models and should never run on main
|
# This test is used only in PR development phase to test individual models and should never run on main
|
||||||
- label: Custom Models Test
|
- label: Custom Models Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
optional: true
|
optional: true
|
||||||
commands:
|
commands:
|
||||||
- echo 'Testing custom models...'
|
- echo 'Testing custom models...'
|
||||||
@ -658,7 +659,7 @@ steps:
|
|||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
|
|
||||||
- label: Distributed Comm Ops Test # 7min
|
- label: Distributed Comm Ops Test # 7min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -755,7 +756,7 @@ steps:
|
|||||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||||
|
|
||||||
- label: Multi-step Tests (4 GPUs) # 36min
|
- label: Multi-step Tests (4 GPUs) # 36min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -776,7 +777,7 @@ steps:
|
|||||||
- pytest -v -s multi_step/test_correctness_llm.py
|
- pytest -v -s multi_step/test_correctness_llm.py
|
||||||
|
|
||||||
- label: Pipeline Parallelism Test # 45min
|
- label: Pipeline Parallelism Test # 45min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -790,7 +791,7 @@ steps:
|
|||||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
|
|
||||||
- label: LoRA TP Test (Distributed)
|
- label: LoRA TP Test (Distributed)
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental]
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
|
|||||||
@ -370,6 +370,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Install vllm wheel first, so that torch etc will be installed.
|
# Install vllm wheel first, so that torch etc will be installed.
|
||||||
|
# !bang
|
||||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||||
--mount=type=cache,target=/root/.cache/uv \
|
--mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system dist/*.whl --verbose \
|
uv pip install --system dist/*.whl --verbose \
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
ARG NIGHTLY_DATE="20250724"
|
ARG NIGHTLY_DATE="20250730"
|
||||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE"
|
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE"
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
FROM $BASE_IMAGE
|
||||||
|
|||||||
@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat
|
|||||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
|
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
|
||||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||||
|
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
|
||||||
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
|
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
|
||||||
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
|
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,38 @@ except ImportError:
|
|||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
QUESTION = "What is the content of each image?"
|
||||||
|
IMAGE_URLS = [
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_mm_prompts(num_prompts):
|
||||||
|
prompts = []
|
||||||
|
for url in IMAGE_URLS:
|
||||||
|
prompts.append(
|
||||||
|
[
|
||||||
|
{"type": "image_url", "image_url": {"url": url}},
|
||||||
|
{"type": "text", "text": QUESTION},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if num_prompts > len(IMAGE_URLS):
|
||||||
|
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
|
||||||
|
|
||||||
|
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = FlexibleArgumentParser()
|
parser = FlexibleArgumentParser()
|
||||||
add_dataset_parser(parser)
|
add_dataset_parser(parser)
|
||||||
@ -35,6 +67,7 @@ def parse_args():
|
|||||||
parser.add_argument("--output-len", type=int, default=256)
|
parser.add_argument("--output-len", type=int, default=256)
|
||||||
parser.add_argument("--model-dir", type=str, default=None)
|
parser.add_argument("--model-dir", type=str, default=None)
|
||||||
parser.add_argument("--eagle-dir", type=str, default=None)
|
parser.add_argument("--eagle-dir", type=str, default=None)
|
||||||
|
parser.add_argument("--custom-mm-prompts", action="store_true")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -44,14 +77,26 @@ def main():
|
|||||||
|
|
||||||
model_dir = args.model_dir
|
model_dir = args.model_dir
|
||||||
if args.model_dir is None:
|
if args.model_dir is None:
|
||||||
|
if args.custom_mm_prompts:
|
||||||
|
raise ValueError(
|
||||||
|
"custom_mm_prompts requires mm based models"
|
||||||
|
"default llama3.1-8b-instruct is not mm based"
|
||||||
|
"please specify model_dir to give a mm based model"
|
||||||
|
)
|
||||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
args.custom_skip_chat_template = True
|
||||||
|
|
||||||
prompts = get_samples(args, tokenizer)
|
if not args.custom_mm_prompts:
|
||||||
# add_special_tokens is False to avoid adding bos twice when using chat templates
|
prompts = get_samples(args, tokenizer)
|
||||||
prompt_ids = [
|
# add_special_tokens is False to avoid adding bos twice
|
||||||
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
|
# when using chat templates
|
||||||
]
|
prompt_ids = [
|
||||||
|
tokenizer.encode(prompt.prompt, add_special_tokens=False)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
prompts = get_custom_mm_prompts(args.num_prompts)
|
||||||
|
|
||||||
if args.method == "eagle" or args.method == "eagle3":
|
if args.method == "eagle" or args.method == "eagle3":
|
||||||
eagle_dir = args.eagle_dir
|
eagle_dir = args.eagle_dir
|
||||||
@ -85,10 +130,17 @@ def main():
|
|||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
max_model_len=16384,
|
max_model_len=16384,
|
||||||
|
limit_mm_per_prompt={"image": 5},
|
||||||
|
disable_chunked_mm_input=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||||
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
|
if not args.custom_mm_prompts:
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompt_token_ids=prompt_ids, sampling_params=sampling_params
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||||
|
|
||||||
# print the generated text
|
# print the generated text
|
||||||
if args.print_output:
|
if args.print_output:
|
||||||
|
|||||||
@ -22,9 +22,7 @@ aiohttp==3.10.11
|
|||||||
aiohttp-cors==0.8.1
|
aiohttp-cors==0.8.1
|
||||||
# via ray
|
# via ray
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
# via
|
# via aiohttp
|
||||||
# aiohttp
|
|
||||||
# ray
|
|
||||||
albucore==0.0.16
|
albucore==0.0.16
|
||||||
# via terratorch
|
# via terratorch
|
||||||
albumentations==1.4.6
|
albumentations==1.4.6
|
||||||
@ -139,7 +137,7 @@ contourpy==1.3.0
|
|||||||
# via matplotlib
|
# via matplotlib
|
||||||
cramjam==2.9.0
|
cramjam==2.9.0
|
||||||
# via fastparquet
|
# via fastparquet
|
||||||
cupy-cuda12x==13.3.0
|
cupy-cuda12x==13.5.1
|
||||||
# via ray
|
# via ray
|
||||||
cycler==0.12.1
|
cycler==0.12.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
@ -226,7 +224,6 @@ frozenlist==1.5.0
|
|||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# aiosignal
|
# aiosignal
|
||||||
# ray
|
|
||||||
fsspec==2024.9.0
|
fsspec==2024.9.0
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
@ -603,10 +600,18 @@ opencv-python-headless==4.11.0.86
|
|||||||
opentelemetry-api==1.35.0
|
opentelemetry-api==1.35.0
|
||||||
# via
|
# via
|
||||||
# mlflow-skinny
|
# mlflow-skinny
|
||||||
|
# opentelemetry-exporter-prometheus
|
||||||
# opentelemetry-sdk
|
# opentelemetry-sdk
|
||||||
# opentelemetry-semantic-conventions
|
# opentelemetry-semantic-conventions
|
||||||
|
opentelemetry-exporter-prometheus==0.56b0
|
||||||
|
# via ray
|
||||||
|
opentelemetry-proto==1.36.0
|
||||||
|
# via ray
|
||||||
opentelemetry-sdk==1.35.0
|
opentelemetry-sdk==1.35.0
|
||||||
# via mlflow-skinny
|
# via
|
||||||
|
# mlflow-skinny
|
||||||
|
# opentelemetry-exporter-prometheus
|
||||||
|
# ray
|
||||||
opentelemetry-semantic-conventions==0.56b0
|
opentelemetry-semantic-conventions==0.56b0
|
||||||
# via opentelemetry-sdk
|
# via opentelemetry-sdk
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
@ -697,7 +702,9 @@ pqdm==0.2.0
|
|||||||
pretrainedmodels==0.7.4
|
pretrainedmodels==0.7.4
|
||||||
# via segmentation-models-pytorch
|
# via segmentation-models-pytorch
|
||||||
prometheus-client==0.22.0
|
prometheus-client==0.22.0
|
||||||
# via ray
|
# via
|
||||||
|
# opentelemetry-exporter-prometheus
|
||||||
|
# ray
|
||||||
propcache==0.2.0
|
propcache==0.2.0
|
||||||
# via yarl
|
# via yarl
|
||||||
proto-plus==1.26.1
|
proto-plus==1.26.1
|
||||||
@ -707,6 +714,7 @@ protobuf==5.28.3
|
|||||||
# google-api-core
|
# google-api-core
|
||||||
# googleapis-common-protos
|
# googleapis-common-protos
|
||||||
# mlflow-skinny
|
# mlflow-skinny
|
||||||
|
# opentelemetry-proto
|
||||||
# proto-plus
|
# proto-plus
|
||||||
# ray
|
# ray
|
||||||
# tensorboardx
|
# tensorboardx
|
||||||
@ -854,7 +862,7 @@ rasterio==1.4.3
|
|||||||
# rioxarray
|
# rioxarray
|
||||||
# terratorch
|
# terratorch
|
||||||
# torchgeo
|
# torchgeo
|
||||||
ray==2.43.0
|
ray==2.48.0
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
redis==5.2.0
|
redis==5.2.0
|
||||||
# via tensorizer
|
# via tensorizer
|
||||||
|
|||||||
@ -19,8 +19,8 @@ nixl==0.3.0
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch==2.9.0.dev20250724
|
torch==2.9.0.dev20250730
|
||||||
torchvision==0.24.0.dev20250724
|
torchvision==0.24.0.dev20250730
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
|
||||||
|
|
||||||
|
|||||||
203
setup.py
203
setup.py
@ -282,10 +282,69 @@ class cmake_build_ext(build_ext):
|
|||||||
self.copy_file(file, dst_file)
|
self.copy_file(file, dst_file)
|
||||||
|
|
||||||
|
|
||||||
class repackage_wheel(build_ext):
|
class precompiled_wheel_utils:
|
||||||
"""Extracts libraries and other files from an existing wheel."""
|
"""Extracts libraries and other files from an existing wheel."""
|
||||||
|
|
||||||
def get_base_commit_in_main_branch(self) -> str:
|
@staticmethod
|
||||||
|
def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict:
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
temp_dir = None
|
||||||
|
try:
|
||||||
|
if not os.path.isfile(wheel_url_or_path):
|
||||||
|
wheel_filename = wheel_url_or_path.split("/")[-1]
|
||||||
|
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
||||||
|
wheel_path = os.path.join(temp_dir, wheel_filename)
|
||||||
|
print(f"Downloading wheel from {wheel_url_or_path} "
|
||||||
|
f"to {wheel_path}")
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
urlretrieve(wheel_url_or_path, filename=wheel_path)
|
||||||
|
else:
|
||||||
|
wheel_path = wheel_url_or_path
|
||||||
|
print(f"Using existing wheel at {wheel_path}")
|
||||||
|
|
||||||
|
package_data_patch = {}
|
||||||
|
|
||||||
|
with zipfile.ZipFile(wheel_path) as wheel:
|
||||||
|
files_to_copy = [
|
||||||
|
"vllm/_C.abi3.so",
|
||||||
|
"vllm/_moe_C.abi3.so",
|
||||||
|
"vllm/_flashmla_C.abi3.so",
|
||||||
|
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||||
|
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||||
|
"vllm/cumem_allocator.abi3.so",
|
||||||
|
]
|
||||||
|
|
||||||
|
compiled_regex = re.compile(
|
||||||
|
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
||||||
|
file_members = list(
|
||||||
|
filter(lambda x: x.filename in files_to_copy,
|
||||||
|
wheel.filelist))
|
||||||
|
file_members += list(
|
||||||
|
filter(lambda x: compiled_regex.match(x.filename),
|
||||||
|
wheel.filelist))
|
||||||
|
|
||||||
|
for file in file_members:
|
||||||
|
print(f"[extract] {file.filename}")
|
||||||
|
target_path = os.path.join(".", file.filename)
|
||||||
|
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
||||||
|
with wheel.open(file.filename) as src, open(
|
||||||
|
target_path, "wb") as dst:
|
||||||
|
shutil.copyfileobj(src, dst)
|
||||||
|
|
||||||
|
pkg = os.path.dirname(file.filename).replace("/", ".")
|
||||||
|
package_data_patch.setdefault(pkg, []).append(
|
||||||
|
os.path.basename(file.filename))
|
||||||
|
|
||||||
|
return package_data_patch
|
||||||
|
finally:
|
||||||
|
if temp_dir is not None:
|
||||||
|
print(f"Removing temporary directory {temp_dir}")
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_commit_in_main_branch() -> str:
|
||||||
# Force to use the nightly wheel. This is mainly used for CI testing.
|
# Force to use the nightly wheel. This is mainly used for CI testing.
|
||||||
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
|
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
|
||||||
return "nightly"
|
return "nightly"
|
||||||
@ -334,115 +393,6 @@ class repackage_wheel(build_ext):
|
|||||||
"wheel may not be compatible with your dev branch: %s", err)
|
"wheel may not be compatible with your dev branch: %s", err)
|
||||||
return "nightly"
|
return "nightly"
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
assert _is_cuda(
|
|
||||||
), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
|
||||||
|
|
||||||
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
|
|
||||||
if wheel_location is None:
|
|
||||||
base_commit = self.get_base_commit_in_main_branch()
|
|
||||||
wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
|
||||||
# Fallback to nightly wheel if latest commit wheel is unavailable,
|
|
||||||
# in this rare case, the nightly release CI hasn't finished on main.
|
|
||||||
if not is_url_available(wheel_location):
|
|
||||||
wheel_location = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
|
||||||
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
if os.path.isfile(wheel_location):
|
|
||||||
wheel_path = wheel_location
|
|
||||||
print(f"Using existing wheel={wheel_path}")
|
|
||||||
else:
|
|
||||||
# Download the wheel from a given URL, assume
|
|
||||||
# the filename is the last part of the URL
|
|
||||||
wheel_filename = wheel_location.split("/")[-1]
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
# create a temporary directory to store the wheel
|
|
||||||
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
|
|
||||||
wheel_path = os.path.join(temp_dir, wheel_filename)
|
|
||||||
print(f"Downloading wheel from {wheel_location} to {wheel_path}")
|
|
||||||
from urllib.request import urlretrieve
|
|
||||||
try:
|
|
||||||
urlretrieve(wheel_location, filename=wheel_path)
|
|
||||||
except Exception as e:
|
|
||||||
from setuptools.errors import SetupError
|
|
||||||
raise SetupError(
|
|
||||||
f"Failed to get vLLM wheel from {wheel_location}") from e
|
|
||||||
|
|
||||||
# Set the dist_dir for Docker build context
|
|
||||||
dist_dir = ("/workspace/dist"
|
|
||||||
if envs.VLLM_DOCKER_BUILD_CONTEXT else "dist")
|
|
||||||
os.makedirs(dist_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Extract only necessary compiled .so files from precompiled wheel
|
|
||||||
with zipfile.ZipFile(wheel_path) as wheel:
|
|
||||||
# Get version from METADATA (optional, mostly useful for logging)
|
|
||||||
metadata_file = next((n for n in wheel.namelist()
|
|
||||||
if n.endswith(".dist-info/METADATA")), None)
|
|
||||||
if not metadata_file:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not find METADATA in precompiled wheel.")
|
|
||||||
metadata = wheel.read(metadata_file).decode()
|
|
||||||
version_line = next((line for line in metadata.splitlines()
|
|
||||||
if line.startswith("Version: ")), None)
|
|
||||||
if not version_line:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not determine version from METADATA.")
|
|
||||||
version = version_line.split(": ")[1].strip()
|
|
||||||
|
|
||||||
print(f"Extracting precompiled kernels from vLLM wheel version: "
|
|
||||||
f"{version}")
|
|
||||||
|
|
||||||
# List of compiled shared objects to extract
|
|
||||||
files_to_copy = [
|
|
||||||
"vllm/_C.abi3.so",
|
|
||||||
"vllm/_moe_C.abi3.so",
|
|
||||||
"vllm/_flashmla_C.abi3.so",
|
|
||||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
|
||||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
|
||||||
"vllm/cumem_allocator.abi3.so",
|
|
||||||
]
|
|
||||||
|
|
||||||
file_members = list(
|
|
||||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist))
|
|
||||||
compiled_regex = re.compile(
|
|
||||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
|
||||||
file_members += list(
|
|
||||||
filter(lambda x: compiled_regex.match(x.filename),
|
|
||||||
wheel.filelist))
|
|
||||||
|
|
||||||
for file in file_members:
|
|
||||||
print(f"Extracting and including {file.filename} "
|
|
||||||
"from existing wheel")
|
|
||||||
package_name = os.path.dirname(file.filename).replace("/", ".")
|
|
||||||
file_name = os.path.basename(file.filename)
|
|
||||||
|
|
||||||
if package_name not in package_data:
|
|
||||||
package_data[package_name] = []
|
|
||||||
|
|
||||||
output_base = (dist_dir
|
|
||||||
if envs.VLLM_DOCKER_BUILD_CONTEXT else ".")
|
|
||||||
target_path = os.path.join(output_base, file.filename)
|
|
||||||
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
|
||||||
with wheel.open(file.filename) as src, open(target_path,
|
|
||||||
"wb") as dst:
|
|
||||||
shutil.copyfileobj(src, dst)
|
|
||||||
|
|
||||||
package_data[package_name].append(file_name)
|
|
||||||
|
|
||||||
# Copy wheel into dist dir for Docker to consume (e.g., via --mount)
|
|
||||||
if envs.VLLM_DOCKER_BUILD_CONTEXT:
|
|
||||||
arch_tag = "cp38-abi3-manylinux1_x86_64"
|
|
||||||
corrected_wheel_name = f"vllm-{version}-{arch_tag}.whl"
|
|
||||||
final_wheel_path = os.path.join(dist_dir, corrected_wheel_name)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"Docker build context detected, copying precompiled wheel to "
|
|
||||||
f"{final_wheel_path}")
|
|
||||||
shutil.copy2(wheel_path, final_wheel_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _no_device() -> bool:
|
def _no_device() -> bool:
|
||||||
return VLLM_TARGET_DEVICE == "empty"
|
return VLLM_TARGET_DEVICE == "empty"
|
||||||
@ -676,16 +626,37 @@ package_data = {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# If using precompiled, extract and patch package_data (in advance of setup)
|
||||||
|
if envs.VLLM_USE_PRECOMPILED:
|
||||||
|
assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
|
||||||
|
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
|
||||||
|
if wheel_location is not None:
|
||||||
|
wheel_url = wheel_location
|
||||||
|
else:
|
||||||
|
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||||
|
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||||
|
from urllib.request import urlopen
|
||||||
|
try:
|
||||||
|
with urlopen(wheel_url) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[warn] Falling back to nightly wheel: {e}")
|
||||||
|
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||||
|
|
||||||
|
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
|
||||||
|
wheel_url)
|
||||||
|
for pkg, files in patch.items():
|
||||||
|
package_data.setdefault(pkg, []).extend(files)
|
||||||
|
|
||||||
if _no_device():
|
if _no_device():
|
||||||
ext_modules = []
|
ext_modules = []
|
||||||
|
|
||||||
if not ext_modules:
|
if not ext_modules or envs.VLLM_USE_PRECOMPILED:
|
||||||
|
# Disable build_ext when using precompiled wheel
|
||||||
cmdclass = {}
|
cmdclass = {}
|
||||||
else:
|
else:
|
||||||
cmdclass = {
|
cmdclass = {"build_ext": cmake_build_ext}
|
||||||
"build_ext":
|
|
||||||
repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
|
|
||||||
}
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
# static metadata should rather go in pyproject.toml
|
# static metadata should rather go in pyproject.toml
|
||||||
|
|||||||
@ -7,22 +7,26 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||||
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||||
ModelConfig, PassConfig, VllmConfig)
|
ModelConfig, PassConfig, VllmConfig)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
GroupShape, QuantFP8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import has_module_attribute, multi_gpu_test
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
|||||||
|
|
||||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16, eps=1e-6):
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
|||||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.eps = eps
|
||||||
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
|
self.quant_fp8 = QuantFP8(static=True,
|
||||||
|
group_shape=GroupShape.PER_TENSOR)
|
||||||
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
self.output = torch.empty((token_num, hidden_size),
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual):
|
||||||
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
torch.ops._C.static_scaled_fp8_quant(self.output,
|
||||||
|
norm_output.contiguous(),
|
||||||
|
self.scale)
|
||||||
|
return self.output, residual_output
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.all_reduce.default,
|
||||||
|
torch.ops._C.static_scaled_fp8_quant.default
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.eps = eps
|
||||||
|
self.norm = RMSNorm(hidden_size, eps)
|
||||||
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
|
self.output = torch.empty((token_num, hidden_size),
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
round_up = lambda x, y: (x + y - 1) // y * y
|
||||||
|
rounded_m = round_up(token_num, 128)
|
||||||
|
scale_n = hidden_size // 16
|
||||||
|
rounded_n = round_up(scale_n, 4)
|
||||||
|
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual):
|
||||||
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||||
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||||
|
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
|
||||||
|
self.output_scale, self.scale)
|
||||||
|
return self.output, residual_output, self.output_scale
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.all_reduce.default,
|
||||||
|
torch.ops._C.scaled_fp4_quant.default
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("test_model", [
|
||||||
"test_model",
|
TestAllReduceRMSNormModel,
|
||||||
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
|
TestAllReduceFusedAddRMSNormModel,
|
||||||
|
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||||
|
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [8])
|
@pytest.mark.parametrize("seq_len", [8])
|
||||||
@pytest.mark.parametrize("hidden_size", [4096])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA")
|
||||||
@pytest.mark.skipif(not find_spec("flashinfer"),
|
@pytest.mark.skipif(
|
||||||
reason="flashinfer is not installed")
|
not find_spec("flashinfer")
|
||||||
@pytest.mark.skipif(not current_platform.is_device_capability(100),
|
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||||
reason="Only test on SM100")
|
reason="flashinfer is not found or flashinfer "
|
||||||
|
"is not compiled with trtllm_allreduce_fusion")
|
||||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
||||||
batch_size: int, seq_len: int,
|
batch_size: int, seq_len: int,
|
||||||
hidden_size: int, dtype: torch.dtype):
|
hidden_size: int, dtype: torch.dtype):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
|
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||||
|
and not current_platform.has_device_capability(100)):
|
||||||
|
pytest.skip("Skip as nvfp4 is only supported on "
|
||||||
|
"devices with compute capability 10.0 (Blackwell)")
|
||||||
|
|
||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
torch.multiprocessing.spawn(fn,
|
torch.multiprocessing.spawn(fn,
|
||||||
@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
custom_ops=["+rms_norm"],
|
custom_ops=["+rms_norm", "+quant_fp8"]))
|
||||||
compile_sizes=[2, 4, 8]))
|
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_fi_allreduce_fusion=True)
|
enable_fi_allreduce_fusion=True, enable_noop=True)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
seed=42)
|
seed=42)
|
||||||
|
|
||||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||||
backend = TestBackend(all_reduce_fusion_pass)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size)
|
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
token_num = batch_size * seq_len
|
||||||
requires_grad=False)
|
model = test_model_cls(hidden_size, token_num)
|
||||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
|
||||||
requires_grad=False)
|
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||||
|
residual = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||||
|
|
||||||
compiled_model = torch.compile(model, backend=backend)
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
compiled_model(hidden_states, residual)
|
compiled_model(hidden_states, residual)
|
||||||
|
|||||||
@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||||
|
"Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
|
||||||
|
trust_remote_code=True,
|
||||||
|
is_available_online=False),
|
||||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
|
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||||
@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
||||||
|
"Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
|
||||||
|
trust_remote_code=True,
|
||||||
|
is_available_online=False),
|
||||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
|
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -974,3 +975,14 @@ def get_client_text_logprob_generations(
|
|||||||
return [(text_generations, text,
|
return [(text_generations, text,
|
||||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||||
for completion in completions for x in completion.choices]
|
for completion in completions for x in completion.choices]
|
||||||
|
|
||||||
|
|
||||||
|
def has_module_attribute(module_name, attribute_name):
|
||||||
|
"""
|
||||||
|
Helper function to check if a module has a specific attribute.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return hasattr(module, attribute_name)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|||||||
196
tests/v1/attention/test_chunked_local_attention.py
Normal file
196
tests/v1/attention/test_chunked_local_attention.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
make_local_attention_virtual_batches)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalAttentionTestData:
|
||||||
|
# Input parameters
|
||||||
|
batch_spec: BatchSpec
|
||||||
|
attn_chunk_size: int
|
||||||
|
block_size: int
|
||||||
|
# Expected return values
|
||||||
|
expected_q_seqlens: list[int]
|
||||||
|
expected_k_seqlens: list[int]
|
||||||
|
expected_local_block_table: list[list[int]]
|
||||||
|
|
||||||
|
|
||||||
|
test_data_list = [
|
||||||
|
# Same as example in docstring of make_local_attention_virtual_batches
|
||||||
|
# except block table has 9 columns instead of 10
|
||||||
|
LocalAttentionTestData(
|
||||||
|
batch_spec=BatchSpec(
|
||||||
|
query_lens=[4, 10, 5],
|
||||||
|
seq_lens=[6, 17, 9],
|
||||||
|
),
|
||||||
|
attn_chunk_size=4,
|
||||||
|
block_size=2,
|
||||||
|
expected_q_seqlens=[2, 2, 1, 4, 4, 1, 4, 1],
|
||||||
|
expected_k_seqlens=[4, 2, 4, 4, 4, 1, 4, 1],
|
||||||
|
# 2 pages per local branch
|
||||||
|
# (chunk size 4 // block size 2)
|
||||||
|
expected_local_block_table=[
|
||||||
|
[0, 1], # local-batch 0, (batch 0, starting from k[0])
|
||||||
|
[2, 3], # local-batch 1, (batch 0, starting from k[4])
|
||||||
|
[11, 12], # local-batch 2, (batch 1, starting from k[4])
|
||||||
|
[13, 14], # local-batch 3, (batch 1, starting from k[8])
|
||||||
|
[15, 16], # local-batch 4, (batch 1, starting from k[12])
|
||||||
|
[17, 17], # local-batch 5, (batch 1, starting from k[16])
|
||||||
|
[20, 21], # local-batch 6, (batch 2, starting from k[4])
|
||||||
|
[22, 23], # local-batch 7, (batch 2, starting from k[8])
|
||||||
|
]),
|
||||||
|
# Case where block indices are not clipped to block table ncols-1
|
||||||
|
# because tokens_in_last_block == attn_chunk_size
|
||||||
|
LocalAttentionTestData(batch_spec=BatchSpec(
|
||||||
|
query_lens=[8],
|
||||||
|
seq_lens=[12],
|
||||||
|
),
|
||||||
|
attn_chunk_size=4,
|
||||||
|
block_size=2,
|
||||||
|
expected_q_seqlens=[4, 4],
|
||||||
|
expected_k_seqlens=[4, 4],
|
||||||
|
expected_local_block_table=[
|
||||||
|
[2, 3],
|
||||||
|
[4, 5],
|
||||||
|
]),
|
||||||
|
# Case where all kv_seq positions are involved in attn
|
||||||
|
LocalAttentionTestData(
|
||||||
|
batch_spec=BatchSpec(
|
||||||
|
query_lens=[7],
|
||||||
|
# 10 - 7 = 3 previously computed tokens
|
||||||
|
seq_lens=[10],
|
||||||
|
),
|
||||||
|
attn_chunk_size=4,
|
||||||
|
block_size=2,
|
||||||
|
expected_q_seqlens=[1, 4, 2],
|
||||||
|
expected_k_seqlens=[4, 4, 2],
|
||||||
|
expected_local_block_table=[
|
||||||
|
[0, 1],
|
||||||
|
[2, 3],
|
||||||
|
[4, 4],
|
||||||
|
]),
|
||||||
|
# Case where attn_chunk_size > kv_seq_len
|
||||||
|
# so no extra mini virtual batches are created
|
||||||
|
LocalAttentionTestData(
|
||||||
|
batch_spec=BatchSpec(
|
||||||
|
query_lens=[4],
|
||||||
|
seq_lens=[6],
|
||||||
|
),
|
||||||
|
# Larger than kv_seq_len
|
||||||
|
attn_chunk_size=10,
|
||||||
|
block_size=2,
|
||||||
|
# No change to q_seqlens and k_seqlens
|
||||||
|
expected_q_seqlens=[4],
|
||||||
|
expected_k_seqlens=[6],
|
||||||
|
# In this case, we only need a block-table like:
|
||||||
|
# block_table = [ [0, 1, 2] ] # 1 batch, 3 pages
|
||||||
|
# But we need to pad it to 5 pages per local batch
|
||||||
|
# because currently the pages_per_local_batch
|
||||||
|
# is calculated as (attn_chunk_size // block_size)
|
||||||
|
expected_local_block_table=[
|
||||||
|
[0, 1, 2, 2, 2],
|
||||||
|
]),
|
||||||
|
# Block size equal to chunk size
|
||||||
|
# Expect single page per batch in local batch table
|
||||||
|
LocalAttentionTestData(
|
||||||
|
batch_spec=BatchSpec(
|
||||||
|
query_lens=[6, 6],
|
||||||
|
seq_lens=[8, 8],
|
||||||
|
),
|
||||||
|
attn_chunk_size=4,
|
||||||
|
block_size=4,
|
||||||
|
expected_q_seqlens=[2, 4, 2, 4],
|
||||||
|
expected_k_seqlens=[4, 4, 4, 4],
|
||||||
|
# Initial block table = [
|
||||||
|
# [0, 1], < batch 0
|
||||||
|
# [2, 3], < batch 1
|
||||||
|
# ]
|
||||||
|
expected_local_block_table=[
|
||||||
|
[0], # local-batch 0, (batch 0, starting from k[0])
|
||||||
|
[1], # local-batch 1, (batch 0, starting from k[4])
|
||||||
|
[2], # local-batch 1, (batch 0, starting from k[0])
|
||||||
|
[3], # local-batch 1, (batch 0, starting from k[4])
|
||||||
|
]),
|
||||||
|
# Case where query falls in the second attention chunk
|
||||||
|
# k_toks > 0 1 2 3 4
|
||||||
|
# q_toks v _____________
|
||||||
|
# 0 | 1
|
||||||
|
# 1 | 1 1
|
||||||
|
# 2 | 1 1 1
|
||||||
|
# 3 | 1 1 1 1
|
||||||
|
# 4 | 1
|
||||||
|
# where tokens 0,1,2,3 have been pre-computed
|
||||||
|
LocalAttentionTestData(batch_spec=BatchSpec(
|
||||||
|
query_lens=[1],
|
||||||
|
seq_lens=[5],
|
||||||
|
),
|
||||||
|
attn_chunk_size=4,
|
||||||
|
block_size=2,
|
||||||
|
expected_q_seqlens=[1],
|
||||||
|
expected_k_seqlens=[1],
|
||||||
|
expected_local_block_table=[
|
||||||
|
[2, 2],
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_data", test_data_list)
|
||||||
|
def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
batch_spec = test_data.batch_spec
|
||||||
|
attn_chunk_size = test_data.attn_chunk_size
|
||||||
|
block_size = test_data.block_size
|
||||||
|
expected_q_seqlens = test_data.expected_q_seqlens
|
||||||
|
expected_k_seqlens = test_data.expected_k_seqlens
|
||||||
|
expected_local_block_table = test_data.expected_local_block_table
|
||||||
|
|
||||||
|
# Create common attention metadata
|
||||||
|
common_attn_metadata = create_common_attn_metadata(
|
||||||
|
batch_spec,
|
||||||
|
block_size,
|
||||||
|
device,
|
||||||
|
# Use torch.arange instead of torch.randint so we can assert on
|
||||||
|
# block table tensor values. The block table will have shape
|
||||||
|
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
|
||||||
|
# aranged from 0 to cdiv(max_seq_len, block_size)-1
|
||||||
|
arange_block_indices=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = make_local_attention_virtual_batches(attn_chunk_size,
|
||||||
|
common_attn_metadata,
|
||||||
|
block_size)
|
||||||
|
|
||||||
|
# Convert to numpy for easier comparison
|
||||||
|
actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy())
|
||||||
|
actual_k_seqlens = result.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
|
# Check that all query lengths are less than or equal to attn_chunk_size
|
||||||
|
assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens)
|
||||||
|
# Check that all key lengths are less than or equal to attn_chunk_size
|
||||||
|
assert all(k_len <= attn_chunk_size for k_len in actual_k_seqlens)
|
||||||
|
# Check that the total number of query tokens is preserved
|
||||||
|
assert sum(actual_q_seqlens) == sum(batch_spec.query_lens)
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens)
|
||||||
|
np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens)
|
||||||
|
|
||||||
|
expected_block_table_tensor =\
|
||||||
|
torch.tensor(expected_local_block_table,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
print(f"Expected block table:\n{expected_block_table_tensor}")
|
||||||
|
print(f"Actual block table:\n{result.block_table_tensor}")
|
||||||
|
|
||||||
|
torch.testing.assert_close(result.block_table_tensor,
|
||||||
|
expected_block_table_tensor)
|
||||||
@ -40,7 +40,8 @@ def create_common_attn_metadata(
|
|||||||
batch_spec: BatchSpec,
|
batch_spec: BatchSpec,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
max_block_idx: int = 1000) -> CommonAttentionMetadata:
|
max_block_idx: int = 1000,
|
||||||
|
arange_block_indices: bool = False) -> CommonAttentionMetadata:
|
||||||
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
|
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
|
||||||
# Create query start locations
|
# Create query start locations
|
||||||
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
|
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
|
||||||
@ -65,19 +66,28 @@ def create_common_attn_metadata(
|
|||||||
]
|
]
|
||||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||||
|
|
||||||
# Create block table (random for testing)
|
# Create block table and slot mapping
|
||||||
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||||
block_table_tensor = torch.randint(0,
|
if arange_block_indices:
|
||||||
max_block_idx,
|
num_blocks = batch_spec.batch_size * max_blocks
|
||||||
(batch_spec.batch_size, max_blocks),
|
block_table_tensor = torch.arange(num_blocks,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device).view(
|
||||||
|
batch_spec.batch_size,
|
||||||
# Create slot mapping
|
max_blocks)
|
||||||
slot_mapping = torch.randint(0,
|
slot_mapping = torch.arange(num_tokens,
|
||||||
max_block_idx, (num_tokens, ),
|
dtype=torch.int64,
|
||||||
dtype=torch.int64,
|
device=device).view(num_tokens)
|
||||||
device=device)
|
else:
|
||||||
|
block_table_tensor = torch.randint(0,
|
||||||
|
max_block_idx,
|
||||||
|
(batch_spec.batch_size, max_blocks),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
slot_mapping = torch.randint(0,
|
||||||
|
max_block_idx, (num_tokens, ),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
|
||||||
# Calculate max query length
|
# Calculate max query length
|
||||||
max_query_len = max(batch_spec.query_lens)
|
max_query_len = max(batch_spec.query_lens)
|
||||||
|
|||||||
@ -3,29 +3,34 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||||
|
from vllm.assets.image import VLM_IMAGES_DIR
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def get_test_prompts(mm_enabled: bool):
|
||||||
def test_prompts():
|
|
||||||
prompt_types = ["repeat", "sentence"]
|
prompt_types = ["repeat", "sentence"]
|
||||||
|
if mm_enabled:
|
||||||
|
prompt_types.append("mm")
|
||||||
num_prompts = 100
|
num_prompts = 100
|
||||||
prompts = []
|
prompts = []
|
||||||
|
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||||
|
print(f"Prompt types: {random_prompt_type_choices}")
|
||||||
|
|
||||||
# Generate a mixed batch of prompts, some of which can be easily
|
# Generate a mixed batch of prompts, some of which can be easily
|
||||||
# predicted by n-gram matching and some which likely cannot.
|
# predicted by n-gram matching and some which likely cannot.
|
||||||
for kind in random_prompt_type_choices:
|
for kind in random_prompt_type_choices:
|
||||||
word_choices = ["test", "temp", "hello", "where"]
|
word_choices = ["test", "temp", "hello", "where"]
|
||||||
word = random.choice(word_choices)
|
word = random.choice(word_choices)
|
||||||
|
prompt: Union[str, list[dict[str, Any]]] = ""
|
||||||
if kind == "repeat":
|
if kind == "repeat":
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
please repeat the word '{word}' 10 times.
|
please repeat the word '{word}' 10 times.
|
||||||
@ -38,6 +43,21 @@ def test_prompts():
|
|||||||
uses the word {word} at least once.
|
uses the word {word} at least once.
|
||||||
give no other output than that simple sentence without quotes.
|
give no other output than that simple sentence without quotes.
|
||||||
"""
|
"""
|
||||||
|
elif kind == "mm":
|
||||||
|
placeholders = [{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url":
|
||||||
|
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
prompt = [
|
||||||
|
*placeholders,
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "The meaning of the image is"
|
||||||
|
},
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown prompt type: {kind}")
|
raise ValueError(f"Unknown prompt type: {kind}")
|
||||||
prompts.append([{"role": "user", "content": prompt}])
|
prompts.append([{"role": "user", "content": prompt}])
|
||||||
@ -57,7 +77,6 @@ def model_name():
|
|||||||
|
|
||||||
def test_ngram_correctness(
|
def test_ngram_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
@ -67,6 +86,7 @@ def test_ngram_correctness(
|
|||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
test_prompts = get_test_prompts(mm_enabled=False)
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
@ -103,23 +123,32 @@ def test_ngram_correctness(
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_setup", [
|
@pytest.mark.parametrize(
|
||||||
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
["model_setup", "mm_enabled"], [
|
||||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
|
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
|
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
pytest.param(
|
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
pytest.param(
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
],
|
False,
|
||||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
|
pytest.param(
|
||||||
|
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
|
True,
|
||||||
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
|
],
|
||||||
|
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_setup: tuple[str, str, str, int],
|
model_setup: tuple[str, str, str, int],
|
||||||
|
mm_enabled: bool,
|
||||||
):
|
):
|
||||||
|
# Generate test prompts inside the function instead of using fixture
|
||||||
|
test_prompts = get_test_prompts(mm_enabled)
|
||||||
'''
|
'''
|
||||||
Compare the outputs of a original LLM and a speculative LLM
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
should be the same when using eagle speculative decoding.
|
should be the same when using eagle speculative decoding.
|
||||||
|
|||||||
@ -37,6 +37,8 @@ logger = init_logger(__name__)
|
|||||||
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
|
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
|
||||||
RMS_OP = torch.ops._C.rms_norm.default
|
RMS_OP = torch.ops._C.rms_norm.default
|
||||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||||
|
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
|
||||||
|
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||||
|
|
||||||
|
|
||||||
class BasePattern:
|
class BasePattern:
|
||||||
@ -394,7 +396,7 @@ if flashinfer_comm is not None:
|
|||||||
# Max size of the input tensor per world size
|
# Max size of the input tensor per world size
|
||||||
# to use flashinfer fused allreduce
|
# to use flashinfer fused allreduce
|
||||||
_FI_MAX_SIZES = {
|
_FI_MAX_SIZES = {
|
||||||
2: MiB, # 1MB
|
2: 64 * MiB, # 64MB
|
||||||
4: MiB, # 1MB
|
4: MiB, # 1MB
|
||||||
6: MiB // 2, # 512KB
|
6: MiB // 2, # 512KB
|
||||||
8: MiB // 2, # 512KB
|
8: MiB // 2, # 512KB
|
||||||
@ -414,9 +416,13 @@ if flashinfer_comm is not None:
|
|||||||
trigger_completion_at_end: bool,
|
trigger_completion_at_end: bool,
|
||||||
fp32_acc: bool,
|
fp32_acc: bool,
|
||||||
max_token_num: int,
|
max_token_num: int,
|
||||||
|
pattern_code: int,
|
||||||
|
fuse_rms_quant: bool,
|
||||||
norm_out: Optional[torch.Tensor] = None,
|
norm_out: Optional[torch.Tensor] = None,
|
||||||
|
quant_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_factor: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
num_tokens, hidden_size = allreduce_in.shape
|
num_tokens, hidden_size = allreduce_in.shape
|
||||||
element_size = allreduce_in.element_size()
|
element_size = allreduce_in.element_size()
|
||||||
current_tensor_size = num_tokens * hidden_size * element_size
|
current_tensor_size = num_tokens * hidden_size * element_size
|
||||||
@ -425,7 +431,6 @@ if flashinfer_comm is not None:
|
|||||||
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
|
||||||
max_fusion_size,
|
max_fusion_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_flashinfer:
|
if use_flashinfer:
|
||||||
assert (_FI_WORKSPACE_TENSOR is not None
|
assert (_FI_WORKSPACE_TENSOR is not None
|
||||||
), "Flashinfer must be enabled when using flashinfer"
|
), "Flashinfer must be enabled when using flashinfer"
|
||||||
@ -455,37 +460,65 @@ if flashinfer_comm is not None:
|
|||||||
use_oneshot=True,
|
use_oneshot=True,
|
||||||
trigger_completion_at_end=trigger_completion_at_end,
|
trigger_completion_at_end=trigger_completion_at_end,
|
||||||
fp32_acc=fp32_acc,
|
fp32_acc=fp32_acc,
|
||||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
pattern_code=pattern_code,
|
||||||
kARResidualRMSNorm,
|
|
||||||
allreduce_out=None,
|
allreduce_out=None,
|
||||||
quant_out=None,
|
quant_out=quant_out,
|
||||||
scale_out=None,
|
scale_out=scale_out,
|
||||||
layout_code=None,
|
# in vllm we only support swizzled layout
|
||||||
scale_factor=None,
|
layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
|
||||||
|
scale_factor=scale_factor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
|
||||||
if norm_out is None:
|
if (scale_factor is not None and scale_out is None
|
||||||
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
|
and fuse_rms_quant):
|
||||||
rms_gamma, rms_eps)
|
# Do fused rms norm static fp8 quant fused op
|
||||||
|
if norm_out is None:
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||||
|
quant_out, allreduce_out, residual, rms_gamma,
|
||||||
|
scale_factor, rms_eps)
|
||||||
|
else:
|
||||||
|
torch.ops._C.rms_norm_static_fp8_quant(
|
||||||
|
quant_out, allreduce_out, rms_gamma, scale_factor,
|
||||||
|
rms_eps)
|
||||||
else:
|
else:
|
||||||
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
|
if norm_out is None:
|
||||||
rms_eps)
|
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
|
||||||
allreduce_in.copy_(allreduce_out)
|
rms_gamma, rms_eps)
|
||||||
|
norm_out = allreduce_out
|
||||||
|
else:
|
||||||
|
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
|
||||||
|
rms_eps)
|
||||||
|
if scale_factor is not None:
|
||||||
|
if scale_out is not None:
|
||||||
|
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
|
||||||
|
scale_out, scale_factor)
|
||||||
|
else:
|
||||||
|
torch.ops._C.static_scaled_fp8_quant(
|
||||||
|
quant_out, norm_out, scale_factor)
|
||||||
|
if scale_factor is None or norm_out is not None:
|
||||||
|
# we need to return allreduce outpput
|
||||||
|
# in cases of non quant fused AR + RMS norm
|
||||||
|
# and fused AR + RMS norm + quant without fused add
|
||||||
|
allreduce_in.copy_(allreduce_out)
|
||||||
|
|
||||||
def call_trtllm_fused_allreduce_norm_fake(
|
def call_trtllm_fused_allreduce_norm_fake(
|
||||||
allreduce_in: torch.Tensor,
|
allreduce_in: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
rms_gamma: torch.Tensor,
|
rms_gamma: torch.Tensor,
|
||||||
rms_eps: float,
|
rms_eps: float,
|
||||||
world_rank: int,
|
world_rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
launch_with_pdl: bool,
|
launch_with_pdl: bool,
|
||||||
trigger_completion_at_end: bool,
|
trigger_completion_at_end: bool,
|
||||||
fp32_acc: bool,
|
fp32_acc: bool,
|
||||||
max_token_num: int,
|
max_token_num: int,
|
||||||
norm_out: Optional[torch.Tensor] = None,
|
pattern_code: int,
|
||||||
) -> None:
|
fuse_rms_quant: bool,
|
||||||
|
norm_out: Optional[torch.Tensor] = None,
|
||||||
|
quant_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_out: Optional[torch.Tensor] = None,
|
||||||
|
scale_factor: Optional[torch.Tensor] = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
@ -495,6 +528,8 @@ if flashinfer_comm is not None:
|
|||||||
"allreduce_in",
|
"allreduce_in",
|
||||||
"residual",
|
"residual",
|
||||||
"norm_out",
|
"norm_out",
|
||||||
|
"quant_out",
|
||||||
|
"scale_out",
|
||||||
],
|
],
|
||||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams:
|
|||||||
world_size: int,
|
world_size: int,
|
||||||
use_fp32_lamport: bool = False,
|
use_fp32_lamport: bool = False,
|
||||||
max_token_num: int = 1024,
|
max_token_num: int = 1024,
|
||||||
|
fuse_rms_quant: bool = False,
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams:
|
|||||||
self.fp32_acc = True
|
self.fp32_acc = True
|
||||||
self.use_oneshot = False
|
self.use_oneshot = False
|
||||||
self.max_token_num = max_token_num
|
self.max_token_num = max_token_num
|
||||||
|
self.fuse_rms_quant = fuse_rms_quant
|
||||||
|
|
||||||
def get_trtllm_fused_allreduce_kwargs(self):
|
def get_trtllm_fused_allreduce_kwargs(self):
|
||||||
return {
|
return {
|
||||||
@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams:
|
|||||||
"trigger_completion_at_end": self.trigger_completion_at_end,
|
"trigger_completion_at_end": self.trigger_completion_at_end,
|
||||||
"fp32_acc": self.fp32_acc,
|
"fp32_acc": self.fp32_acc,
|
||||||
"max_token_num": self.max_token_num,
|
"max_token_num": self.max_token_num,
|
||||||
|
"fuse_rms_quant": self.fuse_rms_quant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AllReduceRMSNORMPattern(BasePattern):
|
class AllReduceRMSNormPattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (without residual)
|
||||||
|
with fused flashinfer implementation.
|
||||||
|
Applies to allreduce + rmsnorm before attn in the first Transformer block.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern):
|
|||||||
|
|
||||||
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
|
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
all_reduce_output = tensor_model_parallel_all_reduce(input)
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
rms = auto_functionalized(
|
rms = auto_functionalized(
|
||||||
RMS_OP,
|
RMS_OP,
|
||||||
result=rms_result,
|
result=rms_result,
|
||||||
input=all_reduce_output,
|
input=allreduce_output,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
epsilon=self.epsilon,
|
epsilon=self.epsilon,
|
||||||
)
|
)
|
||||||
return rms[1], all_reduce_output
|
# rms_result, allreduce_output
|
||||||
|
return rms[1], allreduce_output
|
||||||
|
|
||||||
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
|
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
residual = torch.zeros_like(input)
|
residual = torch.zeros_like(input)
|
||||||
allreduce = auto_functionalized(
|
allreduce = auto_functionalized(
|
||||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
allreduce_in=input,
|
allreduce_in=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
norm_out=rms_result,
|
norm_out=rms_result,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
rms_gamma=weight,
|
rms_gamma=weight,
|
||||||
rms_eps=self.epsilon,
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNorm,
|
||||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
)
|
)
|
||||||
|
# rms_result, allreduce_in
|
||||||
return allreduce[3], allreduce[1]
|
return allreduce[3], allreduce[1]
|
||||||
|
|
||||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern):
|
|||||||
|
|
||||||
|
|
||||||
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (with residual)
|
||||||
|
with fused flashinfer implementation.
|
||||||
|
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
|||||||
|
|
||||||
def pattern(residual: torch.Tensor, input: torch.Tensor,
|
def pattern(residual: torch.Tensor, input: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
all_reduce_output = tensor_model_parallel_all_reduce(input)
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
rms = auto_functionalized(
|
rms = auto_functionalized(
|
||||||
RMS_ADD_OP,
|
RMS_ADD_OP,
|
||||||
input=all_reduce_output,
|
input=allreduce_output,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
epsilon=self.epsilon,
|
epsilon=self.epsilon,
|
||||||
)
|
)
|
||||||
|
# input, residual
|
||||||
return rms[1], rms[2]
|
return rms[1], rms[2]
|
||||||
|
|
||||||
def replacement(residual: torch.Tensor, input: torch.Tensor,
|
def replacement(residual: torch.Tensor, input: torch.Tensor,
|
||||||
weight: torch.Tensor):
|
weight: torch.Tensor):
|
||||||
allreduce = auto_functionalized(
|
allreduce = auto_functionalized(
|
||||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
allreduce_in=input,
|
allreduce_in=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
|
norm_out=None,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
rms_gamma=weight,
|
rms_gamma=weight,
|
||||||
rms_eps=self.epsilon,
|
rms_eps=self.epsilon,
|
||||||
norm_out=None,
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNorm,
|
||||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
)
|
)
|
||||||
|
# allreduce_in, residual
|
||||||
return allreduce[1], allreduce[2]
|
return allreduce[1], allreduce[2]
|
||||||
|
|
||||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
pm.fwd_only, pm_pass)
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (without residual)
|
||||||
|
+ static fp8 quant with fused flashinfer implementation.
|
||||||
|
Applies to allreduce + rmsnorm + quant before attn
|
||||||
|
in the first Transformer block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
self.quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.zeros([1, 8, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
rmsnorm_result = torch.empty([1, 8, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
quant_result = torch.empty([1, 8, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.quant_dtype)
|
||||||
|
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||||
|
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||||
|
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
rmsnorm_result: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||||
|
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
|
||||||
|
result=rmsnorm_result,
|
||||||
|
input=all_reduce,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
|
||||||
|
quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP,
|
||||||
|
result=quant_result,
|
||||||
|
input=rmsnorm_out_tuple[1],
|
||||||
|
scale=scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output
|
||||||
|
return quant_out_tuple[1], all_reduce
|
||||||
|
|
||||||
|
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
residual = torch.zeros_like(input)
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=result_rms,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output
|
||||||
|
return allreduce[4], allreduce[1]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (with residual)
|
||||||
|
+ static fp8 quant with fused flashinfer implementation.
|
||||||
|
Applies to o_proj + rmsnorm after attn + quant and
|
||||||
|
mlp + rmsnorm + quant before attn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
self.quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
residual = torch.empty([4, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
quant_result = torch.empty([4, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.quant_dtype)
|
||||||
|
scale = torch.empty([1, 1],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
return [
|
||||||
|
quant_result,
|
||||||
|
residual,
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
scale,
|
||||||
|
]
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
quant_result: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
|
|
||||||
|
fused_add_rmsnorm_out_tuple = \
|
||||||
|
auto_functionalized(
|
||||||
|
RMS_ADD_OP,
|
||||||
|
input=allreduce_output,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
quant_out_tuple = auto_functionalized(
|
||||||
|
STATIC_FP8_QUANT_OP,
|
||||||
|
result=quant_result,
|
||||||
|
input=fused_add_rmsnorm_out_tuple[1],
|
||||||
|
scale=scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output
|
||||||
|
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
|
||||||
|
|
||||||
|
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||||
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=None,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
# # quant_out, rms_norm_residual
|
||||||
|
return allreduce[4], allreduce[2]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (without residual)
|
||||||
|
+ static nvfp4 quant with fused flashinfer implementation.
|
||||||
|
Applies to allreduce + rmsnorm + quant before attn
|
||||||
|
in the first Transformer block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.empty([1, 16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
rmsnorm_result = torch.empty([1, 16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
quant_result = torch.empty((16, 8),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
input_global_scale = torch.empty([1, 1],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||||
|
output_scale = torch.empty([128, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
return [
|
||||||
|
input, rmsnorm_result, quant_result, weight,
|
||||||
|
input_global_scale, output_scale
|
||||||
|
]
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
rmsnorm_result: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
input_global_scale: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||||
|
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
|
||||||
|
result=rmsnorm_result,
|
||||||
|
input=all_reduce,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
|
||||||
|
quant_out_tuple = auto_functionalized(
|
||||||
|
STATIC_FP4_QUANT_OP,
|
||||||
|
output=quant_result,
|
||||||
|
input=rmsnorm_out_tuple[1],
|
||||||
|
output_scale=output_scale,
|
||||||
|
input_scale=input_global_scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output, output_scale
|
||||||
|
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
|
||||||
|
|
||||||
|
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
quant_result: torch.Tensor, weight: torch.Tensor,
|
||||||
|
input_global_scale: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor):
|
||||||
|
residual = torch.zeros_like(input)
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=result_rms,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=output_scale,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=input_global_scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output, output_scale
|
||||||
|
return allreduce[4], allreduce[1], allreduce[5]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||||
|
"""
|
||||||
|
This pattern replaces the allreduce + rms norm (with residual)
|
||||||
|
+ static nvfp4 quant with fused flashinfer implementation.
|
||||||
|
Applies to o_proj + rmsnorm after attn + quant and
|
||||||
|
mlp + rmsnorm + quant before attn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams):
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def get_inputs():
|
||||||
|
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
residual = torch.empty([16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
weight = torch.empty([16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
quant_result = torch.empty((16, 8),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
input_global_scale = torch.empty([1, 1],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
output_scale = torch.empty([128, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
return [
|
||||||
|
quant_result,
|
||||||
|
residual,
|
||||||
|
input,
|
||||||
|
output_scale,
|
||||||
|
weight,
|
||||||
|
input_global_scale,
|
||||||
|
]
|
||||||
|
|
||||||
|
def pattern(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||||
|
input: torch.Tensor, output_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor, input_global_scale: torch.Tensor):
|
||||||
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
|
|
||||||
|
fused_add_rmsnorm_out_tuple = \
|
||||||
|
auto_functionalized(
|
||||||
|
RMS_ADD_OP,
|
||||||
|
input=allreduce_output,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
quant_out_tuple = auto_functionalized(
|
||||||
|
STATIC_FP4_QUANT_OP,
|
||||||
|
output=quant_result,
|
||||||
|
input=fused_add_rmsnorm_out_tuple[1],
|
||||||
|
output_scale=output_scale,
|
||||||
|
input_scale=input_global_scale)
|
||||||
|
|
||||||
|
# quant_out, allreduce_output, output_scale
|
||||||
|
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[
|
||||||
|
2], quant_out_tuple[2]
|
||||||
|
|
||||||
|
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
|
||||||
|
input: torch.Tensor, output_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
input_global_scale: torch.Tensor):
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=None,
|
||||||
|
quant_out=quant_result,
|
||||||
|
scale_out=output_scale,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.
|
||||||
|
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
|
||||||
|
scale_factor=input_global_scale,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
# quant_out, rms_norm_residual, output_scale
|
||||||
|
return allreduce[4], allreduce[2], allreduce[5]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AllReduceFusionPass(VllmInductorPass):
|
class AllReduceFusionPass(VllmInductorPass):
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
self.tp_size,
|
self.tp_size,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
max_num_token = min(
|
||||||
|
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
|
||||||
|
(self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
|
||||||
|
config.compilation_config.pass_config.
|
||||||
|
fi_allreduce_fusion_max_token_num)
|
||||||
self.ipc_handles, workspace_tensor = (
|
self.ipc_handles, workspace_tensor = (
|
||||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||||
tp_rank=rank,
|
tp_rank=rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
max_token_num=config.compilation_config.pass_config.
|
max_token_num=max_num_token,
|
||||||
fi_allreduce_fusion_max_token_num,
|
|
||||||
hidden_dim=self.hidden_dim,
|
hidden_dim=self.hidden_dim,
|
||||||
group=self.group,
|
group=self.group,
|
||||||
use_fp32_lamport=use_fp32_lamport,
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
use_fp32_lamport=use_fp32_lamport,
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
max_token_num=config.compilation_config.pass_config.
|
max_token_num=max_num_token,
|
||||||
fi_allreduce_fusion_max_token_num,
|
# fuse rms norm static fp8 quant fused op
|
||||||
)
|
# in fallback path, when we don't use flashinfer
|
||||||
|
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
||||||
|
|
||||||
for epsilon in [1e-5, 1e-6]:
|
for epsilon in [1e-5, 1e-6]:
|
||||||
AllReduceRMSNORMPattern(
|
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
if current_platform.has_device_capability(100):
|
||||||
|
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
AllReduceRMSNormPattern(
|
||||||
epsilon,
|
epsilon,
|
||||||
self.model_dtype,
|
self.model_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
self.allreduce_params,
|
self.allreduce_params,
|
||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
|
|
||||||
|
# WARNING: This is a hack to clear the pattern matcher cache
|
||||||
|
# and allow multiple values of epsilon.
|
||||||
|
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||||
|
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
if flashinfer_comm is not None:
|
if flashinfer_comm is not None:
|
||||||
flashinfer_comm.trtllm_destroy_ipc_workspace(
|
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
||||||
self.ipc_handles, self.group)
|
self.ipc_handles, self.group)
|
||||||
|
|||||||
@ -4062,7 +4062,7 @@ class PassConfig:
|
|||||||
"""Whether to enable async TP."""
|
"""Whether to enable async TP."""
|
||||||
enable_fi_allreduce_fusion: bool = False
|
enable_fi_allreduce_fusion: bool = False
|
||||||
"""Whether to enable flashinfer allreduce fusion."""
|
"""Whether to enable flashinfer allreduce fusion."""
|
||||||
fi_allreduce_fusion_max_token_num: int = 1024
|
fi_allreduce_fusion_max_token_num: int = 16384
|
||||||
"""Max number of tokens to used in flashinfer allreduce fusion."""
|
"""Max number of tokens to used in flashinfer allreduce fusion."""
|
||||||
|
|
||||||
# TODO(luka) better pass enabling system.
|
# TODO(luka) better pass enabling system.
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
|
|||||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||||
from .pythonic_tool_parser import PythonicToolParser
|
from .pythonic_tool_parser import PythonicToolParser
|
||||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||||
|
from .step3_tool_parser import Step3ToolParser
|
||||||
from .xlam_tool_parser import xLAMToolParser
|
from .xlam_tool_parser import xLAMToolParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -40,4 +41,5 @@ __all__ = [
|
|||||||
"HunyuanA13BToolParser",
|
"HunyuanA13BToolParser",
|
||||||
"Glm4MoeModelToolParser",
|
"Glm4MoeModelToolParser",
|
||||||
"Qwen3CoderToolParser",
|
"Qwen3CoderToolParser",
|
||||||
|
"Step3ToolParser",
|
||||||
]
|
]
|
||||||
|
|||||||
296
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
Normal file
296
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
|
DeltaToolCall,
|
||||||
|
ExtractedToolCallInformation,
|
||||||
|
FunctionCall, ToolCall)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module(["step3"])
|
||||||
|
class Step3ToolParser(ToolParser):
|
||||||
|
"""
|
||||||
|
Tool parser for a model that uses a specific XML-like format for tool calls.
|
||||||
|
This version uses a robust, stateful, cursor-based streaming parser and
|
||||||
|
consolidates tool arguments into a single message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
|
||||||
|
TOOL_CALLS_END = "<|tool_calls_end|>"
|
||||||
|
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
|
||||||
|
TOOL_CALL_END = "<|tool_call_end|>"
|
||||||
|
TOOL_SEP = "<|tool_sep|>"
|
||||||
|
SPECIAL_TOKENS = [
|
||||||
|
TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
self.position = 0
|
||||||
|
# Explicit state flags for robust streaming
|
||||||
|
self.tool_block_started = False
|
||||||
|
self.tool_block_finished = False
|
||||||
|
|
||||||
|
def adjust_request(
|
||||||
|
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||||
|
if request.tools and request.tool_choice != 'none':
|
||||||
|
request.skip_special_tokens = False
|
||||||
|
return request
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_steptml_invoke(
|
||||||
|
action_text: str
|
||||||
|
) -> tuple[Optional[str], Optional[dict[str, str]]]:
|
||||||
|
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">',
|
||||||
|
action_text)
|
||||||
|
if not func_name_match:
|
||||||
|
return None, None
|
||||||
|
func_name = func_name_match.group(1)
|
||||||
|
|
||||||
|
params: dict[str, str] = {}
|
||||||
|
param_matches = re.findall(
|
||||||
|
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
|
||||||
|
action_text)
|
||||||
|
for name, value in param_matches:
|
||||||
|
params[name] = value.strip()
|
||||||
|
return func_name, params
|
||||||
|
|
||||||
|
def _cast_arguments(
|
||||||
|
self,
|
||||||
|
func_name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
for tool in request.tools or []:
|
||||||
|
if tool.function.name == func_name:
|
||||||
|
schema = tool.function.parameters or {}
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
for key, value in params.items():
|
||||||
|
if not isinstance(value, str):
|
||||||
|
continue
|
||||||
|
prop = properties.get(key, {})
|
||||||
|
typ = prop.get("type")
|
||||||
|
if typ == "string":
|
||||||
|
params[key] = value.strip()
|
||||||
|
elif typ == "integer":
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
params[key] = int(value)
|
||||||
|
elif typ == "number":
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
params[key] = float(value)
|
||||||
|
elif typ == "boolean":
|
||||||
|
lower_val = value.lower()
|
||||||
|
params[key] = lower_val == "true" if lower_val in (
|
||||||
|
"true", "false") else value
|
||||||
|
elif typ == "null":
|
||||||
|
params[key] = None if value.lower(
|
||||||
|
) == "null" else value
|
||||||
|
break
|
||||||
|
return params
|
||||||
|
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
|
||||||
|
# The main loop processes the stream from the last known position.
|
||||||
|
while True:
|
||||||
|
if self.position >= len(current_text):
|
||||||
|
return None # We've processed the entire stream.
|
||||||
|
|
||||||
|
unprocessed_text = current_text[self.position:]
|
||||||
|
|
||||||
|
# STATE: After all tools are done, all subsequent text is content.
|
||||||
|
if self.tool_block_finished:
|
||||||
|
self.position = len(current_text)
|
||||||
|
return DeltaMessage(content=unprocessed_text)
|
||||||
|
|
||||||
|
# STATE: Before the tool block has started.
|
||||||
|
if not self.tool_block_started:
|
||||||
|
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
|
||||||
|
self.position += len(self.TOOL_CALLS_BEGIN)
|
||||||
|
self.tool_block_started = True
|
||||||
|
continue # Token consumed, re-loop.
|
||||||
|
|
||||||
|
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
|
||||||
|
if start_pos == -1:
|
||||||
|
if self.TOOL_CALLS_BEGIN.startswith(
|
||||||
|
unprocessed_text.strip()) and unprocessed_text:
|
||||||
|
return None # It's a prefix, wait.
|
||||||
|
self.position = len(current_text)
|
||||||
|
return DeltaMessage(content=unprocessed_text)
|
||||||
|
else:
|
||||||
|
content = unprocessed_text[:start_pos]
|
||||||
|
self.position += len(content)
|
||||||
|
return DeltaMessage(content=content)
|
||||||
|
|
||||||
|
# STATE: Inside the main tool block.
|
||||||
|
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
|
||||||
|
unprocessed_text = unprocessed_text.lstrip()
|
||||||
|
self.position += offset
|
||||||
|
|
||||||
|
if unprocessed_text.startswith(self.TOOL_CALLS_END):
|
||||||
|
self.position += len(self.TOOL_CALLS_END)
|
||||||
|
self.tool_block_finished = True
|
||||||
|
self.current_tool_id = -1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if we are between tool calls.
|
||||||
|
tool_finished = (
|
||||||
|
self.current_tool_id != -1 and
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id].get("finished"))
|
||||||
|
if self.current_tool_id == -1 or tool_finished:
|
||||||
|
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
|
||||||
|
self.position += len(self.TOOL_CALL_BEGIN)
|
||||||
|
if self.current_tool_id == -1:
|
||||||
|
self.current_tool_id = 0
|
||||||
|
else:
|
||||||
|
self.current_tool_id += 1
|
||||||
|
self.current_tool_name_sent = False
|
||||||
|
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||||
|
self.prev_tool_call_arr.append({})
|
||||||
|
self.prev_tool_call_arr[
|
||||||
|
self.current_tool_id]["finished"] = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# STATE: Parsing an active tool call.
|
||||||
|
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
|
||||||
|
self.current_tool_id].get("finished", False):
|
||||||
|
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
|
||||||
|
if end_tool_pos == -1:
|
||||||
|
tool_body = unprocessed_text
|
||||||
|
else:
|
||||||
|
tool_body = unprocessed_text[:end_tool_pos]
|
||||||
|
|
||||||
|
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
|
||||||
|
tool_body):
|
||||||
|
return None
|
||||||
|
|
||||||
|
function_name, arguments = self._parse_steptml_invoke(
|
||||||
|
tool_body)
|
||||||
|
if not function_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_call_arr = {
|
||||||
|
"name": function_name,
|
||||||
|
"parameters": arguments or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send the function name as soon as it's parsed.
|
||||||
|
if not self.current_tool_name_sent:
|
||||||
|
self.current_tool_name_sent = True
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id].update(
|
||||||
|
tool_call_arr)
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
type="function",
|
||||||
|
id=f"chatcmpl-tool-{random_uuid()}",
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=function_name))
|
||||||
|
])
|
||||||
|
|
||||||
|
# Update our internal state with the latest parsed arguments.
|
||||||
|
self.prev_tool_call_arr[
|
||||||
|
self.current_tool_id].update( # noqa: E501
|
||||||
|
tool_call_arr)
|
||||||
|
|
||||||
|
# Only send arguments when the tool call is complete.
|
||||||
|
if end_tool_pos != -1:
|
||||||
|
self.position += end_tool_pos + len(self.TOOL_CALL_END)
|
||||||
|
self.prev_tool_call_arr[
|
||||||
|
self.current_tool_id]["finished"] = True
|
||||||
|
|
||||||
|
final_args = self._cast_arguments(
|
||||||
|
function_name,
|
||||||
|
tool_call_arr.get("parameters", {}), # type: ignore
|
||||||
|
request)
|
||||||
|
if final_args:
|
||||||
|
final_args_json = json.dumps(final_args,
|
||||||
|
ensure_ascii=False)
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=final_args_json))
|
||||||
|
])
|
||||||
|
|
||||||
|
# If tool is not finished, return None to wait for more tokens.
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
if self.TOOL_CALLS_BEGIN not in model_output:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
|
||||||
|
if self.TOOL_CALLS_END not in rest:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
|
||||||
|
content = (pre_text + post_text).strip()
|
||||||
|
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
|
||||||
|
|
||||||
|
for part in call_parts:
|
||||||
|
if not part or self.TOOL_CALL_END not in part:
|
||||||
|
continue
|
||||||
|
|
||||||
|
call_content = part.split(self.TOOL_CALL_END, 1)[0]
|
||||||
|
if self.TOOL_SEP not in call_content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
|
||||||
|
if type_part.strip() != "function":
|
||||||
|
continue
|
||||||
|
|
||||||
|
function_name, params_dict = self._parse_steptml_invoke(
|
||||||
|
invoke_part)
|
||||||
|
|
||||||
|
if function_name and params_dict is not None:
|
||||||
|
params_dict = self._cast_arguments(function_name, params_dict,
|
||||||
|
request)
|
||||||
|
params_str = json.dumps(params_dict, ensure_ascii=False)
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(function=FunctionCall(name=function_name,
|
||||||
|
arguments=params_str)))
|
||||||
|
if tool_calls:
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=True,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=content if content else None)
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
|||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||||
|
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
|||||||
if (("mlp.experts." in name) and name not in params_dict):
|
if (("mlp.experts." in name) and name not in params_dict):
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
# QKV fusion is optional, fall back to normal
|
||||||
|
# weight loading if it's not enabled
|
||||||
|
if ((param_name == "fused_qkv_a_proj")
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.layer_idx = extract_layer_index(prefix)
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
|
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
|
|||||||
@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
||||||
Llama4ForCausalLM)
|
Llama4ForCausalLM)
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
|
|||||||
self.norm = RMSNorm(self.config.hidden_size,
|
self.norm = RMSNorm(self.config.hidden_size,
|
||||||
eps=self.config.rms_norm_eps)
|
eps=self.config.rms_norm_eps)
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
input_embeds = self.embed_tokens(input_ids)
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
hidden_states = self.fc(
|
hidden_states = self.fc(
|
||||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
torch.cat((inputs_embeds, hidden_states), dim=-1))
|
||||||
residual = None
|
residual = None
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
return self.model(input_ids, positions, hidden_states)
|
return self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> None:
|
torch.Tensor]]) -> None:
|
||||||
@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
|||||||
model_weights[name] = loaded_weight
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
loader.load_weights(model_weights.items())
|
loader.load_weights(model_weights.items())
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
multimodal_embeddings,
|
||||||
|
self.config.image_token_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||||
|
)
|
||||||
return self.model(input_ids, positions, hidden_states)
|
return self.model(input_ids, positions, hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||||
|
)
|
||||||
return self.model(input_ids, positions, hidden_states)
|
return self.model(input_ids, positions, hidden_states)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
|||||||
@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
||||||
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
||||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
|
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
|
||||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||||
@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||||
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||||
|
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
|
||||||
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
||||||
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
||||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||||
|
|||||||
521
vllm/model_executor/models/step3_text.py
Normal file
521
vllm/model_executor/models/step3_text.py
Normal file
@ -0,0 +1,521 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Inference-only Jurassic model."""
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.attention import Attention
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
|
from vllm.distributed import (get_pp_group,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsPP
|
||||||
|
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: ModelConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
if self.tp_size > config.moe_num_experts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
|
f"the number of experts {config.moe_num_experts}.")
|
||||||
|
|
||||||
|
self.experts = FusedMoE(num_experts=config.moe_num_experts,
|
||||||
|
top_k=config.moe_top_k,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.moe_intermediate_size,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=config.norm_expert_weight,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.experts")
|
||||||
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
|
config.moe_num_experts,
|
||||||
|
bias=False,
|
||||||
|
quant_config=None,
|
||||||
|
prefix=f"{prefix}.gate")
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_dim = hidden_states.shape[-1]
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Step3TextMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size, [intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj")
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj")
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
gate_up, _ = self.gate_up_proj(hidden_states)
|
||||||
|
intermediate_act = self.act_fn(gate_up)
|
||||||
|
output, _ = self.down_proj(intermediate_act)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Step3TextAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
rope_theta: int,
|
||||||
|
share_q_dim: Optional[int] = None,
|
||||||
|
rope_scaling: Optional[dict[str, Any]] = None,
|
||||||
|
max_position_embedding: int = 8192,
|
||||||
|
head_dim: int = 256,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
|
||||||
|
if num_kv_heads != 1:
|
||||||
|
raise ValueError(f"Step3TextAttention num_kv_heads must be 1, "
|
||||||
|
f"but got {num_kv_heads}.")
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.q_size = share_q_dim if share_q_dim else self.head_dim
|
||||||
|
|
||||||
|
self.qkv_proj = ReplicatedLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.q_size + self.kv_size * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
|
||||||
|
self.wq = ColumnParallelLinear(
|
||||||
|
self.q_size,
|
||||||
|
self.head_dim * self.total_num_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.wq",
|
||||||
|
)
|
||||||
|
self.rotary_emb = get_rope(self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embedding,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
scaling = self.head_dim**-0.5
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
scaling,
|
||||||
|
self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
|
def forward(self, positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q = self.inter_norm(q)
|
||||||
|
q = self.wq(q)[0]
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
residual, _ = self.o_proj(attn_output)
|
||||||
|
return residual
|
||||||
|
|
||||||
|
|
||||||
|
class Step3TextDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: ModelConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> None:
|
||||||
|
super().__init__()
|
||||||
|
config = config.hf_config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
self.self_attn = Step3TextAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=1,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
norm_eps=config.rms_norm_eps,
|
||||||
|
max_position_embedding=config.max_position_embedding,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
share_q_dim=config.share_q_dim,
|
||||||
|
rope_theta=config.rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
prefix=f"{prefix}.self_attn")
|
||||||
|
|
||||||
|
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
|
||||||
|
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
||||||
|
if moe_layers_enum is not None:
|
||||||
|
moe_layers_idx = [
|
||||||
|
int(i) for i in moe_layers_enum.strip().split(',')
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Default to 1dense.
|
||||||
|
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
||||||
|
|
||||||
|
if layer_idx in moe_layers_idx:
|
||||||
|
self.moe = FusedMoEBlock(config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.moe")
|
||||||
|
self.share_expert = Step3TextMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.share_expert_dim,
|
||||||
|
hidden_act="silu",
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.share_expert")
|
||||||
|
self.use_moe = True
|
||||||
|
else:
|
||||||
|
self.mlp = Step3TextMLP(hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act="silu",
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp")
|
||||||
|
self.use_moe = False
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, positions: torch.Tensor, hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor]
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
share_output = self.share_expert(hidden_states)
|
||||||
|
moe_output = self.moe(hidden_states)
|
||||||
|
hidden_states = share_output + moe_output
|
||||||
|
else:
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class Step3TextModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||||
|
and get_pp_group().is_last_rank):
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: Step3TextDecoderLayer(config=vllm_config.
|
||||||
|
model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||||
|
config.hidden_size))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual,
|
||||||
|
})
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
self.config = config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
|
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None):
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
qkv_params_mapping = [
|
||||||
|
# (param_name, shard_name, relative_start_idx, relative_end_idx)
|
||||||
|
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim /
|
||||||
|
(self.config.share_q_dim + self.config.head_dim * 2)),
|
||||||
|
(".qkv_proj", ".k_proj", self.config.share_q_dim /
|
||||||
|
(self.config.share_q_dim + self.config.head_dim * 2),
|
||||||
|
(self.config.share_q_dim + self.config.head_dim) /
|
||||||
|
(self.config.share_q_dim + self.config.head_dim * 2)),
|
||||||
|
(".qkv_proj", ".v_proj",
|
||||||
|
(self.config.share_q_dim + self.config.head_dim) /
|
||||||
|
(self.config.share_q_dim + self.config.head_dim * 2),
|
||||||
|
(self.config.share_q_dim + self.config.head_dim * 2) /
|
||||||
|
(self.config.share_q_dim + self.config.head_dim * 2)),
|
||||||
|
]
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
|
expert_params_mapping = [
|
||||||
|
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
|
||||||
|
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
|
||||||
|
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
|
||||||
|
]
|
||||||
|
|
||||||
|
disable_moe_stacked_params = [
|
||||||
|
data[1] for data in expert_params_mapping
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
if any(disable_moe_stacked_param in name
|
||||||
|
for disable_moe_stacked_param in
|
||||||
|
disable_moe_stacked_params):
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
loaded_params.add(name)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
for expert_id in range(loaded_weight.shape[0]):
|
||||||
|
loaded_weight_expert = loaded_weight[expert_id]
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight_expert,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
loaded_params.add(name)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for (param_name, weight_name, start_idx,
|
||||||
|
end_idx) in qkv_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
dim = param.shape[param.output_dim]
|
||||||
|
begin_idx = int(start_idx * dim)
|
||||||
|
end_idx = int(end_idx * dim)
|
||||||
|
param_slice = param.narrow(param.output_dim, begin_idx,
|
||||||
|
end_idx - begin_idx)
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
1052
vllm/model_executor/models/step3_vl.py
Normal file
1052
vllm/model_executor/models/step3_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser
|
|||||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||||
from .mistral_reasoning_parser import MistralReasoningParser
|
from .mistral_reasoning_parser import MistralReasoningParser
|
||||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||||
|
from .step3_reasoning_parser import Step3ReasoningParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ReasoningParser",
|
"ReasoningParser",
|
||||||
@ -18,4 +19,5 @@ __all__ = [
|
|||||||
"Qwen3ReasoningParser",
|
"Qwen3ReasoningParser",
|
||||||
"Glm4MoeModelReasoningParser",
|
"Glm4MoeModelReasoningParser",
|
||||||
"MistralReasoningParser",
|
"MistralReasoningParser",
|
||||||
|
"Step3ReasoningParser",
|
||||||
]
|
]
|
||||||
|
|||||||
109
vllm/reasoning/step3_reasoning_parser.py
Normal file
109
vllm/reasoning/step3_reasoning_parser.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ReasoningParserManager.register_module("step3")
|
||||||
|
class Step3ReasoningParser(ReasoningParser):
|
||||||
|
"""
|
||||||
|
Reasoning parser for Step3 model.
|
||||||
|
|
||||||
|
The Step3 model uses </think> token to denote the end of reasoning
|
||||||
|
text. This parser extracts all content before </think> as reasoning content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
self.think_end_token = "</think>"
|
||||||
|
|
||||||
|
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
|
||||||
|
re.DOTALL)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"The model tokenizer must be passed to the ReasoningParser "
|
||||||
|
"constructor during construction.")
|
||||||
|
|
||||||
|
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||||
|
if self.think_end_token_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Step3 reasoning parser could not locate think end "
|
||||||
|
"token in the tokenizer!")
|
||||||
|
|
||||||
|
def extract_reasoning_content_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
"""
|
||||||
|
Extract reasoning content from a delta message.
|
||||||
|
Handles streaming output where previous + delta = current.
|
||||||
|
Uses token IDs for faster processing.
|
||||||
|
For text "abc</think>xyz":
|
||||||
|
- 'abc' goes to reasoning_content
|
||||||
|
- 'xyz' goes to content
|
||||||
|
"""
|
||||||
|
# Skip single special token
|
||||||
|
if len(delta_token_ids
|
||||||
|
) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.think_end_token_id in delta_token_ids:
|
||||||
|
# </think> in delta, extract reasoning content and remaining content
|
||||||
|
end_index = delta_text.find(self.think_end_token)
|
||||||
|
reasoning_content = delta_text[:end_index]
|
||||||
|
content = delta_text[end_index + len(self.think_end_token):]
|
||||||
|
return DeltaMessage(reasoning_content=reasoning_content,
|
||||||
|
content=content if content else None)
|
||||||
|
elif self.think_end_token_id in previous_token_ids:
|
||||||
|
# </think> already seen in previous text, everything is content
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
# No </think> seen yet, everything is reasoning
|
||||||
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
def extract_reasoning_content(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
|
|
||||||
|
# Check if the model output contains the </think> token
|
||||||
|
if self.think_end_token not in model_output:
|
||||||
|
# If no </think> token, everything is reasoning content
|
||||||
|
return model_output, None
|
||||||
|
else:
|
||||||
|
# Find the first occurrence of </think>
|
||||||
|
end_index = model_output.find(self.think_end_token)
|
||||||
|
reasoning_content = model_output[:end_index]
|
||||||
|
|
||||||
|
# Content after </think> token
|
||||||
|
content = model_output[end_index + len(self.think_end_token):]
|
||||||
|
|
||||||
|
if len(content) == 0:
|
||||||
|
content = None
|
||||||
|
|
||||||
|
return reasoning_content, content
|
||||||
|
|
||||||
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
|
return self.think_end_token_id in input_ids
|
||||||
|
|
||||||
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
|
if self.think_end_token_id not in input_ids[:-1]:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||||
@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
|
|||||||
MllamaConfig, MLPSpeculatorConfig,
|
MllamaConfig, MLPSpeculatorConfig,
|
||||||
Nemotron_Nano_VL_Config,
|
Nemotron_Nano_VL_Config,
|
||||||
NemotronConfig, NVLM_D_Config,
|
NemotronConfig, NVLM_D_Config,
|
||||||
RWConfig, UltravoxConfig)
|
RWConfig, Step3TextConfig,
|
||||||
|
Step3VLConfig, UltravoxConfig)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
|||||||
"nemotron": NemotronConfig,
|
"nemotron": NemotronConfig,
|
||||||
"NVLM_D": NVLM_D_Config,
|
"NVLM_D": NVLM_D_Config,
|
||||||
"ultravox": UltravoxConfig,
|
"ultravox": UltravoxConfig,
|
||||||
|
"step3_vl": Step3VLConfig,
|
||||||
|
"step3_text": Step3TextConfig,
|
||||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
|||||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||||
|
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
|
||||||
|
Step3VisionEncoderConfig,
|
||||||
|
Step3VLConfig)
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -42,4 +45,7 @@ __all__ = [
|
|||||||
"Nemotron_Nano_VL_Config",
|
"Nemotron_Nano_VL_Config",
|
||||||
"NVLM_D_Config",
|
"NVLM_D_Config",
|
||||||
"UltravoxConfig",
|
"UltravoxConfig",
|
||||||
|
"Step3VLConfig",
|
||||||
|
"Step3VisionEncoderConfig",
|
||||||
|
"Step3TextConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
123
vllm/transformers_utils/configs/step3_vl.py
Normal file
123
vllm/transformers_utils/configs/step3_vl.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Step3VisionEncoderConfig(PretrainedConfig):
|
||||||
|
model_type = "step3_vision_encoder"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=1792,
|
||||||
|
intermediate_size=3072,
|
||||||
|
output_hidden_size=4096,
|
||||||
|
num_hidden_layers=63,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=728,
|
||||||
|
patch_size=14,
|
||||||
|
hidden_act="quick_gelu",
|
||||||
|
layer_norm_eps=1e-5,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.output_hidden_size = output_hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Step3TextConfig(PretrainedConfig):
|
||||||
|
model_type = "step3_text"
|
||||||
|
architectures = ["Step3TextForCausalLM"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 7168,
|
||||||
|
intermediate_size: int = 18432,
|
||||||
|
num_attention_heads: int = 64,
|
||||||
|
num_attention_groups: int = 1,
|
||||||
|
num_hidden_layers: int = 61,
|
||||||
|
max_seq_len: int = 65536,
|
||||||
|
vocab_size: int = 128815,
|
||||||
|
rms_norm_eps: float = 1e-5,
|
||||||
|
moe_intermediate_size: int = 5120,
|
||||||
|
moe_num_experts: int = 48,
|
||||||
|
moe_top_k: int = 3,
|
||||||
|
rope_theta: float = 500000,
|
||||||
|
rope_scaling: Optional[dict[str, Any]] = None,
|
||||||
|
max_position_embedding: int = 65536,
|
||||||
|
share_expert_dim: int = 5120,
|
||||||
|
share_q_dim: int = 2048,
|
||||||
|
head_dim: int = 256,
|
||||||
|
norm_expert_weight: bool = False,
|
||||||
|
moe_layers_enum: tuple[int,
|
||||||
|
...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
||||||
|
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
|
||||||
|
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
|
||||||
|
35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
||||||
|
45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
||||||
|
55, 56, 57, 58, 59),
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_attention_groups = num_attention_groups
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.moe_num_experts = moe_num_experts
|
||||||
|
self.moe_top_k = moe_top_k
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.max_position_embedding = max_position_embedding
|
||||||
|
self.share_expert_dim = share_expert_dim
|
||||||
|
self.share_q_dim = share_q_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.norm_expert_weight = norm_expert_weight
|
||||||
|
self.moe_layers_enum = moe_layers_enum
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Step3VLConfig(PretrainedConfig):
|
||||||
|
model_type = "step3_vl"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
|
||||||
|
text_config: Optional[Union[dict, Step3TextConfig]] = None,
|
||||||
|
understand_projector_stride: int = 1,
|
||||||
|
projector_bias: bool = True,
|
||||||
|
image_token_id: int = 128001,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
if vision_config is None:
|
||||||
|
vision_config = Step3VisionEncoderConfig()
|
||||||
|
elif isinstance(vision_config, dict):
|
||||||
|
vision_config = Step3VisionEncoderConfig(**vision_config)
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
if text_config is None:
|
||||||
|
text_config = Step3TextConfig()
|
||||||
|
elif isinstance(text_config, dict):
|
||||||
|
text_config = Step3TextConfig(**text_config)
|
||||||
|
self.text_config = text_config
|
||||||
|
|
||||||
|
self.understand_projector_stride = understand_projector_stride
|
||||||
|
self.projector_bias = projector_bias
|
||||||
|
self.hidden_size = text_config.hidden_size
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -51,6 +53,9 @@ class EagleProposer:
|
|||||||
# hidden size (e.g., Llama 3.3 70B).
|
# hidden size (e.g., Llama 3.3 70B).
|
||||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||||
|
|
||||||
|
self.is_multimodal_model = vllm_config.model_config \
|
||||||
|
.is_multimodal_model
|
||||||
|
|
||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE and
|
== CompilationLevel.PIECEWISE and
|
||||||
not self.vllm_config.model_config.enforce_eager)
|
not self.vllm_config.model_config.enforce_eager)
|
||||||
@ -76,6 +81,11 @@ class EagleProposer:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
self.inputs_embeds = torch.zeros(
|
||||||
|
(self.max_num_tokens, self.hidden_size),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
@ -88,6 +98,7 @@ class EagleProposer:
|
|||||||
next_token_ids: torch.Tensor,
|
next_token_ids: torch.Tensor,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens = target_token_ids.shape[0]
|
num_tokens = target_token_ids.shape[0]
|
||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
@ -128,14 +139,27 @@ class EagleProposer:
|
|||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.positions[:num_tokens] = target_positions
|
self.positions[:num_tokens] = target_positions
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
input_ids = self.input_ids[:num_tokens]
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
|
input_ids,
|
||||||
|
multimodal_embeddings=mm_embeds or None,
|
||||||
|
)
|
||||||
|
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens):
|
num_tokens=num_input_tokens):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
self.input_ids[:num_input_tokens],
|
input_ids=input_ids,
|
||||||
self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
self.hidden_states[:num_input_tokens],
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
if self.method == "deepseek_mtp":
|
if self.method == "deepseek_mtp":
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
@ -218,15 +242,24 @@ class EagleProposer:
|
|||||||
self.input_ids[:batch_size] = input_ids
|
self.input_ids[:batch_size] = input_ids
|
||||||
self.positions[:batch_size] = clamped_positions
|
self.positions[:batch_size] = clamped_positions
|
||||||
self.hidden_states[:batch_size] = hidden_states
|
self.hidden_states[:batch_size] = hidden_states
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
|
self.inputs_embeds[:batch_size] = inputs_embeds
|
||||||
|
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
input_ids = self.input_ids[:input_batch_size]
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=input_batch_size):
|
num_tokens=input_batch_size):
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
self.input_ids[:input_batch_size],
|
input_ids=input_ids,
|
||||||
self.positions[:input_batch_size],
|
positions=self.positions[:input_batch_size],
|
||||||
self.hidden_states[:input_batch_size],
|
hidden_states=self.hidden_states[:input_batch_size],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states[:batch_size]
|
hidden_states = hidden_states[:batch_size]
|
||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||||
@ -391,10 +424,18 @@ class EagleProposer:
|
|||||||
) -> None:
|
) -> None:
|
||||||
with set_forward_context(None, self.vllm_config,
|
with set_forward_context(None, self.vllm_config,
|
||||||
num_tokens=num_tokens):
|
num_tokens=num_tokens):
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
input_ids = None
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
|
else:
|
||||||
|
input_ids = self.input_ids[:num_tokens]
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
self.model(
|
self.model(
|
||||||
self.input_ids[:num_tokens],
|
input_ids=input_ids,
|
||||||
self.positions[:num_tokens],
|
positions=self.positions[:num_tokens],
|
||||||
self.hidden_states[:num_tokens],
|
hidden_states=self.hidden_states[:num_tokens],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_same_kv_cache_group(self,
|
def validate_same_kv_cache_group(self,
|
||||||
|
|||||||
@ -1314,13 +1314,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
def _gather_mm_embeddings(
|
def _gather_mm_embeddings(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
|
shift_computed_tokens: int = 0,
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
mm_embeds: list[torch.Tensor] = []
|
mm_embeds: list[torch.Tensor] = []
|
||||||
for req_id in self.input_batch.req_ids:
|
for req_id in self.input_batch.req_ids:
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
req_id]
|
req_id]
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
num_computed_tokens = req_state.num_computed_tokens
|
num_computed_tokens = \
|
||||||
|
req_state.num_computed_tokens + shift_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info.offset
|
||||||
@ -2298,6 +2300,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
else:
|
else:
|
||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
|
mm_embeds = None
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
mm_embeds = self._gather_mm_embeddings(scheduler_output,
|
||||||
|
shift_computed_tokens=1)
|
||||||
|
|
||||||
draft_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
@ -2305,6 +2312,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
next_token_ids=next_token_ids,
|
next_token_ids=next_token_ids,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
mm_embeds=mm_embeds,
|
||||||
)
|
)
|
||||||
spec_token_ids = draft_token_ids.tolist()
|
spec_token_ids = draft_token_ids.tolist()
|
||||||
return spec_token_ids
|
return spec_token_ids
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user