From 0258b7a94b08321ca01cf170f867b67c1920af87 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 10 Apr 2024 02:39:56 -0600 Subject: [PATCH 01/50] [Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876) Signed-off-by: Travis Johnson --- tests/samplers/test_sampler.py | 116 +++++++++++++++++++++----- vllm/model_executor/layers/sampler.py | 19 ++++- 2 files changed, 112 insertions(+), 23 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1626b72282072..26e2d29ffd04c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,3 +1,4 @@ +import itertools import random from typing import List, Optional, Tuple from unittest.mock import patch @@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, - stop_token_ids=None): + *, + stop_token_ids: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, max_tokens=9999, # keep higher than max of min_tokens stop_token_ids=stop_token_ids, + # requesting prompt_logprobs changes the structure of `logits` + prompt_logprobs=prompt_logprobs, ) sampling_params.eos_token_id = eos_token_id return sampling_params @@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): expected_penalization = [] sequence_metadata_list = [] + # 20% chance to generate seq group metadata list with all prompts + is_prompt = random.random() < 0.2 while batch_size > 0: - # 20% chance to generate prompt seq group with single sequence - is_prompt = random.random() < 0.2 num_seqs = 1 if is_prompt else random.randint(1, batch_size) eos_token_id = random.randint(0, VOCAB_SIZE - 1) @@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): seq_group_penalization = [] for _ in range(num_seqs): num_input = random.randint(1, 100) - num_generated = random.randint(1, 100) if not is_prompt else 0 + num_generated = 0 if is_prompt else random.randint(1, 100) seq_data[next(seq_id_counter)] = create_sequence_data( num_input=num_input, num_generated=num_generated) seq_group_penalization.append(num_generated < min_tokens) @@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ] } + prompt_with_penalization_and_prompt_logprobs = { + "expected_penalization": [False, False, True], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=3), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + ] + } + stop_penalizing_after_min_tokens = { "expected_penalization": [False], "seq_group_metadata_list": [ @@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): } stop_token_ids = [42, 99, 42, 0] # intentional duplication - simple_combination = { - "expected_penalization": [True, False, False], + prompt_combination = { + "expected_penalization": [False, True, False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_2", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=2), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + SequenceGroupMetadata( + request_id="test_3", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params( + 0, stop_token_ids=stop_token_ids), + block_tables={}, + ) + ] + } + + stop_token_ids = [1, 999, 37, 37] # intentional duplication + decode_combination = { + "expected_penalization": [True, False, False, True, False], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", @@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ), SequenceGroupMetadata( request_id="test_2", - is_prompt=True, + is_prompt=False, seq_data={ - next(seq_id_counter): create_sequence_data(), + next(seq_id_counter): + create_sequence_data(num_generated=20), + next(seq_id_counter): + create_sequence_data(num_generated=1), + next(seq_id_counter): + create_sequence_data(num_generated=10), }, sampling_params=create_sampling_params( - 0, stop_token_ids=stop_token_ids), + 10, prompt_logprobs=5, stop_token_ids=stop_token_ids), block_tables={}, - ) + ), ] } @@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): test_cases = [ prompt_without_penalization, prompt_with_penalization, + prompt_with_penalization_and_prompt_logprobs, stop_penalizing_after_min_tokens, - simple_combination, + prompt_combination, + decode_combination, ] else: test_cases = [generate_test_case()] @@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def run_test_case(*, expected_penalization=None, seq_group_metadata_list=None): - assert expected_penalization, "Invalid test case" - assert seq_group_metadata_list, "Invalid test case" + assert expected_penalization, \ + "Invalid test case, need expected_penalization" + assert seq_group_metadata_list, \ + "Invalid test case, need seq_group_metadata_list" batch_size = 0 prompt_lens = [] - sampling_params_per_seq = [] + sampling_params_per_row = [] for sgm in seq_group_metadata_list: - num_seqs = len(sgm.seq_data) - batch_size += num_seqs sampling_params = sgm.sampling_params - for seq_id in sgm.seq_data: - prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) - sampling_params_per_seq.append(sampling_params) + + num_rows = len(sgm.seq_data) + if sgm.is_prompt: + # a prompt seq_group has only one sequence + seq_data = next(iter(sgm.seq_data.values())) + prompt_len = seq_data.get_prompt_len() + prompt_lens.append(prompt_len) + + if sgm.sampling_params.prompt_logprobs: + # with prompt_logprobs each token in the prompt has a row in + # logits + num_rows = prompt_len + + batch_size += num_rows + sampling_params_per_row.extend( + itertools.repeat(sampling_params, num_rows)) + + assert len( + expected_penalization + ) == batch_size, \ + ("Invalid test case, expected_penalization does not match computed" + "batch size") _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, - prompt_lens=prompt_lens, - subquery_lens=prompt_lens) + prompt_lens=prompt_lens if prompt_lens else None, + subquery_lens=prompt_lens if prompt_lens else None) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for logits_idx, (should_penalize, sampling_params) in enumerate( - zip(expected_penalization, sampling_params_per_seq)): + zip(expected_penalization, sampling_params_per_row)): tokens_to_check = [sampling_params.eos_token_id] if sampling_params.stop_token_ids: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb1480de03e3a..03bf38caebe0e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -27,6 +27,12 @@ class Sampler(nn.Module): 6. Sample the next tokens. Here, each sequence group within the batch can have different sampling parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. """ def forward( @@ -106,7 +112,16 @@ def _apply_min_tokens_penalty( # list of indices in logits that will be set to -inf logits_to_penalize = [] start_idx = 0 - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + + # handle prompt_logprobs by skipping rows in logits added for the prompt + # tokens (prompt logprobs are not penalized) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + start_idx += sampling_metadata.prompt_lens[i] - 1 + min_tokens = sampling_params.min_tokens if min_tokens > 0: seqs_to_penalize = [] @@ -132,6 +147,8 @@ def _apply_min_tokens_penalty( # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) logits[tuple(zip(*logits_to_penalize))] = -float("inf") + # verifies that no rows in logits were missed unexpectedly + assert start_idx == logits.shape[0] return logits From bd3c144e0b8e82c9b3c5c40c6d557fe8665de5a3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 Apr 2024 07:37:17 -0700 Subject: [PATCH 02/50] [Bugfix][ROCm] Add numba to Dockerfile.rocm (#3962) --- Dockerfile.rocm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 10b8bf1e7fabd..b1c5fac9d78ef 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -91,7 +91,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ COPY ./ /app/vllm -RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --upgrade pip numba RUN python3 -m pip install xformers==0.0.23 --no-deps RUN cd /app \ From 8b317c6dd09ce566f4b4abeb446585ac75262cce Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Wed, 10 Apr 2024 10:12:00 -0500 Subject: [PATCH 03/50] [Model][AMD] ROCm support for 256 head dims for Gemma (#3972) --- vllm/attention/ops/triton_flash_attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index b86e845020b07..87cf30cbef79a 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -677,8 +677,7 @@ def check_args( assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 + assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 @@ -729,7 +728,7 @@ class _attention(torch.autograd.Function): o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128} + unpadded_head_dims = {32, 64, 128, 256} if head_size not in unpadded_head_dims: padded_d_model = None for i in unpadded_head_dims: From e35397468f36a857b8d2b7d92a472265e1c500cc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 10:03:02 -0700 Subject: [PATCH 04/50] [Doc] Add doc to state our model support policy (#3948) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- docs/source/models/supported_models.rst | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e7bfdcb65316e..c09b0ff250437 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -168,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub `_ and `test_big_models.py `_ for the models that have passed this test. +2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. +3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. From e4c4072c94b346053768691451566c56664e26a7 Mon Sep 17 00:00:00 2001 From: Daniel E Marasco Date: Wed, 10 Apr 2024 13:15:51 -0400 Subject: [PATCH 05/50] [Bugfix] Remove key sorting for `guided_json` parameter in OpenAi compatible Server (#3945) --- vllm/model_executor/guided_decoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index e56f74c7794fb..8e710f1ac2b53 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -91,7 +91,7 @@ def _get_guide_and_mode( json = request.guided_json if isinstance(json, dict): # turn dict into hashable string - json = json_dumps(json, sort_keys=True) + json = json_dumps(json) elif isinstance(json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same From 92cd2e2f21e8ec65b2cb635a9f15de38157a1359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=CE=B1n=C3=A7ois?= Date: Wed, 10 Apr 2024 20:05:52 +0200 Subject: [PATCH 06/50] [Doc] Fix getting stared to use publicly available model (#3963) --- docs/source/serving/openai_compatible_server.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 032fe5d03bd52..388b5daa79a92 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat You can start the server using Python, or using [Docker](deploying_with_docker.rst): ```bash -python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123 +python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123 ``` To call the server, you can use the official OpenAI Python client library, or any other HTTP client. @@ -16,9 +16,8 @@ client = OpenAI( ) completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="mistralai/Mistral-7B-Instruct-v0.2", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"} ] ) @@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly ```python completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="mistralai/Mistral-7B-Instruct-v0.2", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], extra_body={ @@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input. -An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12) +An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format) Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat From 934d3662f716d60abfb04cf9fdd6d20f6e75f140 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 10 Apr 2024 16:28:25 -0600 Subject: [PATCH 07/50] [Bugfix] handle hf_config with architectures == None (#3982) Signed-off-by: Travis Johnson Co-authored-by: Simon Mo --- vllm/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 753fc33e9b717..bca250e922288 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -158,7 +158,9 @@ class ModelConfig: # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) - if "MixtralForCausalLM" in architectures and load_format == "pt": + # architectures can be None instead of [] + if architectures and "MixtralForCausalLM" in architectures \ + and load_format == "pt": raise ValueError( "Currently, the 'pt' format is not supported for Mixtral. " "Please use the 'safetensors' format instead. ") From 63e7176f265be43dcc425f5ab4ab45c90234f5c3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 15:33:30 -0700 Subject: [PATCH 08/50] [Core][Refactor] move parallel_utils into vllm/distributed (#3950) [WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950) --- tests/conftest.py | 3 +-- tests/distributed/test_comm_ops.py | 6 +++--- tests/distributed/test_custom_all_reduce.py | 13 ++++++------- tests/distributed/test_pynccl.py | 4 ++-- tests/lora/conftest.py | 3 +-- vllm/distributed/__init__.py | 3 +++ .../communication_op.py | 14 ++++++++------ .../device_communicators}/__init__.py | 0 .../device_communicators}/custom_all_reduce.py | 5 +++-- .../device_communicators}/pynccl.py | 0 .../device_communicators}/pynccl_utils.py | 4 ++-- .../parallel_state.py | 4 ++-- .../parallel_utils => distributed}/utils.py | 0 vllm/lora/layers.py | 13 ++++++------- vllm/model_executor/layers/activation.py | 5 ++--- vllm/model_executor/layers/linear.py | 11 +++++------ vllm/model_executor/layers/logits_processor.py | 3 +-- .../layers/vocab_parallel_embedding.py | 8 +++----- vllm/model_executor/models/baichuan.py | 4 ++-- vllm/model_executor/models/bloom.py | 4 ++-- vllm/model_executor/models/chatglm.py | 3 +-- vllm/model_executor/models/commandr.py | 4 ++-- vllm/model_executor/models/dbrx.py | 7 +++---- vllm/model_executor/models/deepseek.py | 7 +++---- vllm/model_executor/models/falcon.py | 7 +++---- vllm/model_executor/models/gemma.py | 3 +-- vllm/model_executor/models/gpt2.py | 3 +-- vllm/model_executor/models/gpt_bigcode.py | 3 +-- vllm/model_executor/models/gpt_j.py | 3 +-- vllm/model_executor/models/gpt_neox.py | 3 +-- vllm/model_executor/models/internlm2.py | 3 +-- vllm/model_executor/models/jais.py | 4 ++-- vllm/model_executor/models/llama.py | 4 ++-- vllm/model_executor/models/minicpm.py | 7 +++---- vllm/model_executor/models/mixtral.py | 7 +++---- vllm/model_executor/models/mixtral_quant.py | 7 +++---- vllm/model_executor/models/mpt.py | 4 ++-- vllm/model_executor/models/olmo.py | 3 +-- vllm/model_executor/models/opt.py | 3 +-- vllm/model_executor/models/orion.py | 3 +-- vllm/model_executor/models/phi.py | 3 +-- vllm/model_executor/models/qwen.py | 3 +-- vllm/model_executor/models/qwen2.py | 3 +-- vllm/model_executor/models/qwen2_moe.py | 7 +++---- vllm/model_executor/models/stablelm.py | 3 +-- vllm/model_executor/models/starcoder2.py | 3 +-- vllm/model_executor/models/xverse.py | 3 +-- vllm/model_executor/parallel_utils/README.md | 1 - vllm/test_utils.py | 4 ++-- vllm/worker/cpu_worker.py | 7 +++---- vllm/worker/model_runner.py | 8 +++----- vllm/worker/worker.py | 12 ++++++------ 52 files changed, 111 insertions(+), 141 deletions(-) create mode 100644 vllm/distributed/__init__.py rename vllm/{model_executor/parallel_utils => distributed}/communication_op.py (94%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/__init__.py (100%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/custom_all_reduce.py (98%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl.py (100%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl_utils.py (91%) rename vllm/{model_executor/parallel_utils => distributed}/parallel_state.py (98%) rename vllm/{model_executor/parallel_utils => distributed}/utils.py (100%) delete mode 100644 vllm/model_executor/parallel_utils/README.md diff --git a/tests/conftest.py b/tests/conftest.py index e00f3eb871e37..a7e8963af0eda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel) +from vllm.distributed import destroy_model_parallel from vllm.sequence import MultiModalData from vllm.transformers_utils.tokenizer import get_tokenizer diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index d1811cb694db6..aa9e0537c6910 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,9 +8,9 @@ import pytest import ray import torch -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import (broadcast_tensor_dict, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 1e6e7f89a528c..3b1cd1773af19 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -6,9 +6,8 @@ import ray import torch import torch.distributed as dist -from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators import custom_all_reduce from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) - custom_ar.init_custom_ar() + custom_all_reduce.init_custom_all_reduce() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_ar.capture(): + with custom_all_reduce.capture(): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port): distributed_init_port) sz = 1024 - custom_ar.init_custom_ar() - fa = custom_ar.get_handle() + custom_all_reduce.init_custom_all_reduce() + fa = custom_all_reduce.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) assert torch.allclose(out, inp * world_size) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 29782045130a6..b50eed1c8c722 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,8 +4,8 @@ import os import pytest import torch -from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetUniqueId) +from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, + ncclGetUniqueId) def distributed_run(fn, world_size): diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index acb5fa91e2012..207c635e2dc86 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download import vllm from vllm.config import LoRAConfig +from vllm.distributed import destroy_model_parallel, initialize_model_parallel from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) @@ -19,8 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel, initialize_model_parallel) def cleanup(): diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py new file mode 100644 index 0000000000000..db325cfabf55e --- /dev/null +++ b/vllm/distributed/__init__.py @@ -0,0 +1,3 @@ +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/distributed/communication_op.py similarity index 94% rename from vllm/model_executor/parallel_utils/communication_op.py rename to vllm/distributed/communication_op.py index 9cbb40708dd5b..cf15db099b304 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -4,12 +4,10 @@ from typing import Any, Dict, List, Optional, Union import torch from torch.distributed import ProcessGroup -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.custom_all_reduce import ( - custom_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) +from .parallel_state import (get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pynccl_enabled_for_all_reduce) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ + from vllm.distributed.device_communicators import pynccl_utils + from vllm.distributed.device_communicators.custom_all_reduce import ( + custom_all_reduce) + # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ diff --git a/vllm/model_executor/parallel_utils/__init__.py b/vllm/distributed/device_communicators/__init__.py similarity index 100% rename from vllm/model_executor/parallel_utils/__init__.py rename to vllm/distributed/device_communicators/__init__.py diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py similarity index 98% rename from vllm/model_executor/parallel_utils/custom_all_reduce.py rename to vllm/distributed/device_communicators/custom_all_reduce.py index bf8ee07070c8a..84238d2e46076 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -5,8 +5,6 @@ import torch import torch.distributed as dist from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) try: import pynvml @@ -25,6 +23,9 @@ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] def init_custom_ar() -> None: + from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + global _CA_HANDLE if _CA_HANDLE is not None: return diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/distributed/device_communicators/pynccl.py similarity index 100% rename from vllm/model_executor/parallel_utils/pynccl.py rename to vllm/distributed/device_communicators/pynccl.py diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py similarity index 91% rename from vllm/model_executor/parallel_utils/pynccl_utils.py rename to vllm/distributed/device_communicators/pynccl_utils.py index a099777aa0005..aeb73015733d1 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -9,8 +9,8 @@ from vllm.logger import init_logger logger = init_logger(__name__) try: - from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetVersion) + from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, + ncclGetVersion) except Exception as e: # in non-NVIDIA environments, we can't import the nccl module # e.g. when running on machines with AMD GPUs diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/distributed/parallel_state.py similarity index 98% rename from vllm/model_executor/parallel_utils/parallel_state.py rename to vllm/distributed/parallel_state.py index 3bbfa1bd5443a..4bb77146295af 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,8 +8,6 @@ from typing import Optional import torch -from vllm.model_executor.parallel_utils import pynccl_utils - # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -266,6 +264,7 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None + from vllm.distributed.device_communicators import pynccl_utils # Destroy the pynccl states if any. pynccl_utils.destroy_process_group() @@ -279,6 +278,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False @contextlib.contextmanager def with_pynccl_for_all_reduce(): + from vllm.distributed.device_communicators import pynccl_utils """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/distributed/utils.py similarity index 100% rename from vllm/model_executor/parallel_utils/utils.py rename to vllm/distributed/utils.py diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0505014753951..dd33868f76302 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -10,6 +10,12 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather) from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -18,13 +24,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, - tensor_model_parallel_gather) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - split_tensor_along_last_dim) if TYPE_CHECKING: pass diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f569a5a49cbdf..6786c48e0caba 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -7,10 +7,9 @@ import torch.nn as nn import torch.nn.functional as F from vllm._C import ops +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide from vllm.model_executor.utils import set_weight_attrs diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2d..8f42b3e8a4abe 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,13 +5,12 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index ec531f79ced52..e556e31f99378 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -4,8 +4,7 @@ from typing import Optional import torch import torch.nn as nn -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_gather) +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 73bbfac33ed13..088c0849243c0 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -4,11 +4,9 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index fa5a27b5a6974..30588aecdebe9 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -27,6 +27,8 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -38,8 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index a9ff909090586..40966ab33631a 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,8 @@ from torch import nn from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 4008896e48dd1..7b46ba306619a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -10,6 +10,7 @@ from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -21,8 +22,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 29ba3844eb11d..aa27f0a96c745 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,6 +29,8 @@ from torch.nn.parameter import Parameter from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -39,8 +41,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 14c0fece69214..49eb7f1b2c185 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -5,6 +5,9 @@ import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, @@ -15,10 +18,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2a2182ff4ebad..c7dd11d07e6da 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -28,6 +28,9 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -41,10 +44,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 77c19b227d213..4f1ebcd5fb43c 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,9 @@ from torch.nn import LayerNorm from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -37,10 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 08609532b8b3e..fc1fc35570368 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -23,6 +23,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -35,8 +36,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f816a9996be5..43f0d47fcb122 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from torch import nn from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 07c647c2e1c41..cec2d771adfa8 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from torch import nn from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -34,8 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 94048efe48420..5660097652748 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from torch import nn from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index a5b5d717d9846..2f9e2171cf114 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from torch import nn from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index bdb48bf21042e..6e9cbd3f9f43f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,6 +6,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -17,8 +18,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 12fc9dbd50732..a041b0c9a0452 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,6 +26,8 @@ import torch from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -34,8 +36,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 72fe21df67d8a..c86e292e7df1a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,6 +29,8 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -40,8 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 99d1b4eb97bb8..49eda9c9a8112 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -29,6 +29,9 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -42,10 +45,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 429bc8109b9f8..ff552a9d86536 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,6 +29,9 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -40,10 +43,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 75f86bc134ee3..1f0c0e912beea 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,6 +30,9 @@ from torch import nn from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, @@ -40,10 +43,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index a39f94359a948..af4cdce29d085 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -16,8 +18,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 611a48a9aad2b..3513c72879102 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -44,6 +44,7 @@ from hf_olmo import OLMoConfig from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -55,8 +56,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c1ae1b2ae0f03..3a640850662c0 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from torch import nn from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -34,8 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index ee910563b20df..c606ac027e9d9 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -21,8 +22,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 40e068acaba7d..e91624da90955 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,6 +42,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -52,8 +53,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a63b9c8d63d13..6213a2ded65ab 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,6 +11,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -22,8 +23,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8c92cd773f6b9..796e30e633e85 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,6 +30,7 @@ from transformers import Qwen2Config from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -41,8 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6b4a74198fd52..f920b4f5a40c7 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,6 +30,9 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -43,10 +46,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index b83637fd50dc7..651598b770f13 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -26,6 +26,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -36,8 +37,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 50d23e0a3b6ef..76e8e48673413 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from torch import nn from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -35,8 +36,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 83d2ddb2bcf35..7e9ce9e5c8e15 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -28,6 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -39,8 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/parallel_utils/README.md b/vllm/model_executor/parallel_utils/README.md deleted file mode 100644 index b25e3afddad9c..0000000000000 --- a/vllm/model_executor/parallel_utils/README.md +++ /dev/null @@ -1 +0,0 @@ -The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference. \ No newline at end of file diff --git a/vllm/test_utils.py b/vllm/test_utils.py index bc220d3b8a430..0cf23e4bb7e75 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -1,7 +1,7 @@ import ray -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.utils import get_open_port diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 42f0828b826e2..751384eb72af3 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -7,13 +7,12 @@ import torch.distributed from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.model_runner import ModelRunner diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e7f20475ab1a7..1de4748b7bcc9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,17 +9,15 @@ import torch.nn as nn from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce +from vllm.distributed.device_communicators import (custom_all_reduce, + pynccl_utils) from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - with_pynccl_for_all_reduce) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 19de33089b2db..3f0b2fd83f3e5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,14 +8,14 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.device_communicators import pynccl_utils +from vllm.distributed.device_communicators.custom_all_reduce import ( + init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner From 67b4221a61ace91a79aff507df0a95a01978300e Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 11 Apr 2024 09:56:48 +0900 Subject: [PATCH 09/50] [Core][5/N] Fully working chunked prefill e2e (#3884) --- .buildkite/test-pipeline.yaml | 2 + benchmarks/benchmark_latency.py | 3 +- benchmarks/benchmark_throughput.py | 62 ++-- .../basic_correctness/test_chunked_prefill.py | 70 ++++ tests/core/test_chunked_prefill_scheduler.py | 16 +- .../test_basic_distributed_correctness.py | 7 +- .../test_chunked_prefill_distributed.py | 66 ++++ tests/entrypoints/test_openai_server.py | 2 +- tests/models/test_models.py | 2 +- tests/worker/test_model_runner.py | 189 ++++++++-- vllm/attention/__init__.py | 4 +- vllm/attention/backends/abstract.py | 42 ++- vllm/attention/backends/flash_attn.py | 85 +++-- vllm/attention/backends/rocm_flash_attn.py | 97 ++++-- vllm/attention/backends/torch_sdpa.py | 67 ++-- vllm/attention/backends/xformers.py | 138 ++++---- vllm/attention/layer.py | 5 +- vllm/attention/ops/paged_attn.py | 6 - vllm/config.py | 13 +- vllm/core/scheduler.py | 15 +- vllm/distributed/communication_op.py | 10 +- vllm/engine/arg_utils.py | 5 +- vllm/engine/llm_engine.py | 5 +- vllm/lora/layers.py | 5 +- vllm/sequence.py | 3 +- vllm/worker/model_runner.py | 323 +++++++++++++----- 26 files changed, 927 insertions(+), 315 deletions(-) create mode 100644 tests/basic_correctness/test_chunked_prefill.py create mode 100644 tests/distributed/test_chunked_prefill_distributed.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27e44463a30a6..695290ed74ab5 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -29,6 +29,8 @@ steps: - pytest -v -s test_pynccl.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 91510dafc57a5..aadbc441713fc 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -177,8 +177,7 @@ if __name__ == '__main__': help='block size of key/value cache') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, + action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e71338273d1e5..6df1e1d628e6c 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -74,25 +74,31 @@ def run_vllm( quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, ) -> float: from vllm import LLM, SamplingParams - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir) + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -213,15 +219,15 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, - args.gpu_memory_utilization, args.download_dir) + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -335,6 +341,14 @@ if __name__ == "__main__": "--enable-prefix-caching", action='store_true', help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') parser.add_argument('--download-dir', type=str, default=None, diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py new file mode 100644 index 0000000000000..9ff07b3c09020 --- /dev/null +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -0,0 +1,70 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +It tests chunked prefill. Chunked prefill can be enabled by +enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, +prefill requests are chunked. + +Run `pytest tests/models/test_chunked_prefill.py`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + tensor_parallel_size: int, +) -> None: + if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 + and not enforce_eager): + pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " + "for high TP to save testing time.") + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + print(vllm_outputs[0]) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 05e62ced5898f..cce396bf4953c 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -104,10 +104,10 @@ def test_chunk(): # One chunked prefill, and one decoding. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 + # The first one is prefill. Scheduler guarantees ordering. + assert seq_group_meta[0].token_chunk_size == 56 # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 + assert seq_group_meta[1].token_chunk_size == 1 assert out.num_prefill_groups == 1 assert out.num_batched_tokens == 57 @@ -157,12 +157,12 @@ def test_complex(): # Decoding & chunked prefill & first chunk of 3rd request is scheduled. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(get_sequence_groups(out)) == 3 - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 - # The second one is a chunked prefill. + # The first one is the first chunked prefill. + assert seq_group_meta[0].token_chunk_size == 7 + # The second one is the second new chunked prefill. assert seq_group_meta[1].token_chunk_size == 56 - # The third one is also chunked. - assert seq_group_meta[2].token_chunk_size == 7 + # The last one is decode. + assert seq_group_meta[2].token_chunk_size == 1 # Two of them are in chunked prefill. assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 64 diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1eba14d7a6422..77aa90b12bf8f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -33,11 +33,16 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py new file mode 100644 index 0000000000000..737b1f3169519 --- /dev/null +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -0,0 +1,66 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. + +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_chunked_prefill_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_chunked_prefill_distributed.py +``` +""" +import os + +import pytest +import torch + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 442f8bdf3b4ba..6f2086c4dd269 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -141,7 +141,7 @@ def server(zephyr_lora_files): "--max-cpu-loras", "2", "--max-num-seqs", - "128" + "128", ]) ray.get(server_runner.ready.remote()) yield server_runner diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 53a80d4619646..cfe2539e3a052 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -12,7 +12,7 @@ MODELS = [ "gpt2", "bigcode/tiny_starcoder_py", "EleutherAI/pythia-70m", - "bigscience/bloom-560m", + "bigscience/bloom-560m", # Testing alibi slopes. "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t", # "allenai/OLMo-1B", # Broken diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5b6f001f62fa7..dcaae4af4a6f8 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,14 +1,18 @@ import pytest import torch -from vllm.config import ModelConfig +from vllm.config import ModelConfig, SchedulerConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - model_runner = ModelRunner(None, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(None, None, scheduler_config, None, None) model_runner.set_block_size(16) prompt_lens = [] @@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size): prompt_len - 1) selected_token_start_idx += prompt_len (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device @@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.prompt_lens_tensor, torch.tensor(prompt_lens, device=device)) assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.num_prompt_tokens == sum(prompt_lens) - assert attn_metadata.num_generation_tokens == 0 assert attn_metadata.max_prompt_len == max(prompt_lens) # Test subquery start locs. @@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - torch.testing.assert_close(input_tokens, input_positions) + assert input_tokens == input_positions actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size): revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(model_config, None, scheduler_config, None, + None) model_runner.set_block_size(16) prompt_lens = [] @@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _ = ( + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False assert attn_metadata.prompt_lens is None - assert attn_metadata.num_prompt_tokens == 0 - assert attn_metadata.num_generation_tokens == expected_bs assert attn_metadata.max_prompt_len is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None @@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size): model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (expected_bs, ) - assert input_positions.shape == (expected_bs, ) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_tokens) == expected_bs + assert len(input_positions) == expected_bs + assert input_tokens == input_positions # Verify Sampling expected_selected_token_indices = [] @@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +def test_empty_seq_group(): + """Verify prepare prompt and decode returns empty output.""" + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + seq_group_metadata_list = [] + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( + model_runner._prepare_decode(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_prompt_lens) == 0 + + +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): + + def get_world_size(group=None): + return 1 + + def mock_get_process_group_ranks(group=None): + return [0] + + monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) + monkeypatch.setattr(torch.distributed, "get_process_group_ranks", + mock_get_process_group_ranks) + + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=enforce_eager, + ) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=True) + model_runner = ModelRunner(model_config, + None, + scheduler_config, + None, + None, + is_driver_worker=True) + model_runner.set_block_size(16) + + # Add prefill requests. + prompt_lens = [] + seq_group_metadata_list = [] + prefill_metadata_list = [] + decode_metadata_list = [] + block_tables = {0: [1]} + prefill_batch_size = batch_size // 2 + decode_batch_size = batch_size - prefill_batch_size + for i in range(prefill_batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = SequenceData(list(range(prompt_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + prefill_metadata_list.append(seq_group_metadata) + + # Add decode requests + for i in range(prefill_batch_size, batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(prompt_len)) + seq_data = SequenceData(prompt_toks) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + decode_metadata_list.append(seq_group_metadata) + + (input_tokens, input_positions, attn_metadata, _, _, _, + _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + + prefill_meta_actual = attn_metadata.prefill_metadata + decode_meta_actual = attn_metadata.decode_metadata + + assert len(attn_metadata.slot_mapping) == len(input_tokens) + assert len(input_positions) == len(input_tokens) + assert attn_metadata.kv_cache_dtype == "auto" + assert attn_metadata.num_prefills == prefill_batch_size + if enforce_eager: + assert attn_metadata.num_decode_tokens == decode_batch_size + else: + assert attn_metadata.num_decode_tokens == _get_graph_batch_size( + decode_batch_size) + assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + + # Verify attn metadata is consistent. We don't need to test individual + # values here because they are tested above. + prefill_meta = model_runner._prepare_prompt( + prefill_metadata_list).attn_metadata + decode_meta = model_runner._prepare_decode( + decode_metadata_list).attn_metadata + + for attr_expected, attr_actual in zip(vars(prefill_meta), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip(vars(decode_meta), + vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 9acb82c0df2c2..7636b34a16fed 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,4 +9,5 @@ __all__ = [ "AttentionMetadata", "Attention", "get_attn_backend", + "AttentionMetadataPerStage", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a03cf2dd7a6fa..7a4ccecf702f4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import torch @@ -47,7 +47,8 @@ class AttentionBackend(ABC): @dataclass -class AttentionMetadata: +class AttentionMetadataPerStage: + """Attention metadata for a specific stage. I.e., prefill or decode.""" def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" @@ -59,6 +60,41 @@ class AttentionMetadata: } +T = TypeVar("T", bound=AttentionMetadataPerStage) + + +@dataclass +class AttentionMetadata(Generic[T]): + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # The attention metadata for prefill requests in a batch. + # None if there's no prefill requests in a batch. + prefill_metadata: Optional[T] + # The attention metadata for decode requests in a batch. + # None if there's no decode requests in a batch. + decode_metadata: Optional[T] + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # The kv cache's data type. + kv_cache_dtype: str + + def __post_init__(self): + if self.num_prefill_tokens > 0: + assert self.num_prefills > 0 + assert self.prefill_metadata is not None + if self.num_decode_tokens > 0: + assert self.decode_metadata is not None + + class AttentionImpl(ABC): @abstractmethod @@ -80,7 +116,7 @@ class AttentionImpl(ABC): key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4e0d9d1418b32..12e8c4404b94e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,7 +11,8 @@ import torch from flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend): @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - output = flash_attn_varlen_func( + out = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6019d917b4494..e55435cd2c947 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend): @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, + attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl): kv_scale, ) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl): key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) if self.use_naive_attn: - output = self.attn_fuc( + out = self.attn_fuc( query, key, value, - attn_metadata.prompt_lens, + prefill_meta.prompt_lens, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output, _ = self.attn_func( + out, _ = self.attn_func( query, key, value, None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, True, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output = self.attn_func( + out = self.attn_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, ) - + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9706e1910cb79..63904ea929870 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,8 @@ import torch from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -49,17 +50,14 @@ class TorchSDPABackend(AttentionBackend): @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] prompt_lens_tensor: Optional[torch.Tensor] - num_prompt_tokens: int - num_generation_tokens: int max_subquery_len: Optional[int] = None max_prompt_len: Optional[int] = None @@ -113,7 +111,7 @@ class TorchSDPABackendImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, + attn_metadata: AttentionMetadata[TorchSDPAMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -142,36 +140,51 @@ class TorchSDPABackendImpl(AttentionImpl): attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: + if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + prefill_meta.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + prefill_meta.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) - attn_metadata.attn_bias = att_masks + att_masks = [None] * len(prefill_meta.prompt_lens) + prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): + out = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(prefill_meta.prompt_lens, + prefill_meta.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -181,28 +194,32 @@ class TorchSDPABackendImpl(AttentionImpl): dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out + out[start:end, :, :] = sub_out start = end + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + out = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) + assert out.shape == output[num_prefill_tokens:].shape + output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 05b68bba5e6eb..b745a04a143b4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias, LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend): @dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens --------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: XFormersMetadata, + attn_metadata: AttentionMetadata[XFormersMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl): attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. # block tables are empty if the prompt does not have a cached # prefix. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - output = self._run_memory_efficient_xformers_forward( - query, key, value, attn_metadata) + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta) + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + out = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl): """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. @@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl): # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. if self.alibi_slopes is None: + # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) @@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl): attn_bias=attn_metadata.attn_bias[0], p=0.0, scale=self.scale) - - return out.view_as(query) + return out.view_as(original_query) # Attention with alibi slopes. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - output = torch.empty_like(query) + output = torch.empty_like(original_query) start = 0 for i, prompt_len in enumerate(attn_metadata.prompt_lens): end = start + prompt_len @@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl): p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.squeeze(0)) + output[start:end].copy_(out.view_as(original_query[start:end])) start += prompt_len return output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9856654fc5f94..fc65ae108dbb1 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,8 @@ from typing import List, Optional import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import (AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend @@ -41,7 +42,7 @@ class Attention(nn.Module): key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 256bffdf032eb..2d918491d6576 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,11 +13,6 @@ _PARTITION_SIZE = 512 @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The length of context (tokens stored in KV cache) per # sequence. WARNING: When it is a prefill request, it doesn't include new # tokens. When it is for decoding, it includes a new token. @@ -31,7 +26,6 @@ class PagedAttentionMetadata: # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] - kv_cache_dtype: str class PagedAttention: diff --git a/vllm/config.py b/vllm/config.py index bca250e922288..4102edbe01d35 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -565,9 +565,16 @@ class SchedulerConfig: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: - # If max_model_len is too short, use 2048 as the default value for - # higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + # For chunked prefill, choose the well-tuned batch size. + self.max_num_batched_tokens = 768 + else: + # If max_model_len is too short, use 2048 as the default value + # for higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ae53f9374960..2942eab735a92 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -140,7 +140,11 @@ class SchedulerOutputs: @property def lora_requests(self) -> Set[LoRARequest]: - return {g.seq_group.lora_request for g in self.scheduled_seq_groups} + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } @dataclass @@ -826,13 +830,12 @@ class Scheduler: # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) - return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.decode_seq_groups + running_scheduled.prefill_seq_groups + - swapped_in.decode_seq_groups + - swapped_in.prefill_seq_groups), + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), @@ -907,7 +910,7 @@ class Scheduler: # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = i < scheduler_outputs.num_prefill_groups + is_prompt = seq_group.is_prefill() seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index cf15db099b304..1004d626b6a4b 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -173,10 +173,18 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list([metadata_list], src=src, group=group) + async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src, group=group) + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) + for async_handle in async_handles: + async_handle.wait() + else: recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4b573992c06c..daefddc01b431 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -386,9 +386,8 @@ class EngineArgs: 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, - help='If True, the prefill requests can be chunked based on the ' + action='store_true', + help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c639af696544..ddfdda898a5c6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -633,7 +633,10 @@ class LLMEngine: seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - self._process_sequence_group_outputs(seq_group, outputs) + # If uncomputed tokens > 0, it means prefill is chunked. + # We don't need to process outputs in that case. + if seq_group.get_num_uncomputed_tokens() == 0: + self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index dd33868f76302..84a94091486d7 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + embedding_len = self.indices_len[3] + indices = self.embeddings_indices[1][:embedding_len].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + indices = self.embeddings_indices[0][:embedding_len].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 576bbe8c4f6c4..77029908c2218 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -500,7 +500,8 @@ class SequenceGroup: def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 for seq in self.get_seqs(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1de4748b7bcc9..47ad8f0c9b78b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,12 +1,14 @@ import contextlib import time -from typing import Dict, List, Optional, Set, Tuple +from enum import IntEnum +from typing import Dict, List, NamedTuple, Optional, Set, Tuple import numpy as np import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, + get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ ] +class PreparePromptMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadataPerStage] + prompt_lens: List[int] + subquery_lens: List[int] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + multi_modal_input: Optional[torch.Tensor] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PreparePromptMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + prompt_lens=[], + subquery_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + ) + + +class PrepareDecodeMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadata] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PrepareDecodeMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + ) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + class ModelRunner: def __init__( @@ -152,10 +214,7 @@ class ModelRunner: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest], - torch.Tensor]: - assert len(seq_group_metadata_list) > 0 + ) -> PreparePromptMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -169,6 +228,9 @@ class ModelRunner: prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -178,7 +240,8 @@ class ModelRunner: computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled - and computed_block_nums is not None): + and not (computed_block_nums is None + or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") @@ -190,13 +253,8 @@ class ModelRunner: # it contains output tokens. prefill_end = min(seq_data.get_len(), computed_len + token_chunk_size) - # TODO(sang): Rename it after chunked prefill is introduced. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = len(prompt_tokens) - # Right now, the prefill_end is always same as the length of - # sequence. However, once chunked prefill is introduced, this - # assumption can be changed. - assert prefill_end == seq_data.get_len() + prompt_len = prefill_end prompt_lens.append(prompt_len) # NOTE: This only works for oooooooxxx style attention. @@ -206,6 +264,14 @@ class ModelRunner: computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) else: prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this @@ -267,20 +333,8 @@ class ModelRunner: max_subquery_len = max(subquery_lens) max_prompt_len = max(prompt_lens) - num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - lora_index_mapping = lora_index_mapping - context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -332,11 +386,8 @@ class ModelRunner: attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - slot_mapping=slot_mapping, prompt_lens=prompt_lens, prompt_lens_tensor=prompt_lens_tensor, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, max_prompt_len=max_prompt_len, @@ -345,18 +396,25 @@ class ModelRunner: context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input) + + return PreparePromptMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + prompt_lens=prompt_lens, + subquery_lens=subquery_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping, + ) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], Set[LoRARequest]]: - assert len(seq_group_metadata_list) > 0 + ) -> PrepareDecodeMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -366,6 +424,9 @@ class ModelRunner: lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + if len(seq_group_metadata_list) == 0: + return PrepareDecodeMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 @@ -424,15 +485,6 @@ class ModelRunner: lora_index_mapping.append(0) batch_size = graph_batch_size - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -440,9 +492,9 @@ class ModelRunner: if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] + assert context_lens.shape[0] == len(input_tokens) + assert context_lens.shape[0] == len(input_positions) + assert context_lens.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -464,11 +516,8 @@ class ModelRunner: attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping, prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, max_prompt_len=None, @@ -477,10 +526,16 @@ class ModelRunner: context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + return PrepareDecodeMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + ) def _prepare_sample( self, @@ -586,26 +641,66 @@ class ModelRunner: ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] - subquery_lens = None - multi_modal_input = None + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + if self.lora_config: lora_mapping = LoRAMapping( lora_index_mapping, @@ -615,6 +710,16 @@ class ModelRunner: lora_mapping = None # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -623,19 +728,49 @@ class ModelRunner: "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, } - metadata_dict.update(attn_metadata.asdict_zerocopy()) + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) + + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == BatchType.MIXED: + assert decode_attn_metadata is not None + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + # Create an attention metadata. + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -646,6 +781,23 @@ class ModelRunner: perform_sampling=False, ) + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. + if batch_type == BatchType.MIXED: + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -663,8 +815,10 @@ class ModelRunner: if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) - # Execute the model. - if attn_metadata.use_cuda_graph: + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -842,13 +996,10 @@ class ModelRunner: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( + decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping[:batch_size], prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, max_prompt_len=None, @@ -857,6 +1008,14 @@ class ModelRunner: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, + ) + attn_metadata = AttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + prefill_metadata=None, + decode_metadata=decode_metadata, kv_cache_dtype=self.kv_cache_dtype, ) @@ -950,8 +1109,8 @@ class CUDAGraphRunner: "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.context_lens, - "block_tables": attn_metadata.block_tables, + "context_lens": attn_metadata.decode_metadata.context_lens, + "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -972,10 +1131,10 @@ class CUDAGraphRunner: self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, - non_blocking=True) + self.input_buffers["context_lens"].copy_( + attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() From caada5e50aa16cd5f59bd7889128a83588ca1f99 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 18:48:26 -0700 Subject: [PATCH 10/50] [Core][Model] torch.compile for layernorm in commandr (#3985) [Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985) --- vllm/model_executor/models/commandr.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index aa27f0a96c745..aa9b28b676e0b 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -48,6 +48,18 @@ from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.sequence import SamplerOutput +@torch.compile +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + class LayerNorm(nn.Module): def __init__(self, param_shape=None, eps=1e-5): @@ -57,14 +69,9 @@ class LayerNorm(nn.Module): set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - - mean) * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight.to(torch.float32) * hidden_states - return hidden_states.to(input_dtype), residuals + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() From e42df7227d18e2b96785f8ee52053663ade05b63 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 11 Apr 2024 12:09:50 +0900 Subject: [PATCH 11/50] [Test] Add xformer and flash attn tests (#3961) Co-authored-by: Simon Mo --- tests/basic_correctness/test_basic_correctness.py | 6 ++++++ vllm/attention/selector.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1d..bd4c7ea3301be 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,8 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import pytest +from vllm.attention.selector import VLLM_ATTENTION_BACKEND + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -14,6 +16,7 @@ MODELS = [ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -22,7 +25,10 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + attn_backend: str, + monkeypatch, ) -> None: + monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4c699aed48d49..554e802cd5513 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,5 @@ import enum +import os from functools import lru_cache from typing import Type @@ -10,6 +11,8 @@ from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "Cannot use FlashAttention backend because the flash_attn package " "is not found. Please install it for better performance.") return _Backend.XFORMERS + + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var is not None: + return _Backend[backend_by_env_var] + + # Default case. return _Backend.FLASH_ATTN From e9da5a40c63ce7f8a85438d3c7d919b46e7939f5 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 11 Apr 2024 03:26:07 +0000 Subject: [PATCH 12/50] [Misc] Add indirection layer for custom ops (#3913) --- .../kernels/benchmark_paged_attention.py | 2 +- tests/kernels/test_attention.py | 6 +- tests/kernels/test_cache.py | 25 ++- vllm/_custom_ops.py | 193 ++++++++++++++++++ vllm/attention/ops/paged_attn.py | 10 +- vllm/model_executor/layers/activation.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- .../model_executor/layers/quantization/awq.py | 2 +- .../layers/quantization/gptq.py | 2 +- .../layers/quantization/marlin.py | 2 +- .../layers/quantization/squeezellm.py | 2 +- .../model_executor/layers/rotary_embedding.py | 2 +- vllm/utils.py | 4 +- 14 files changed, 224 insertions(+), 32 deletions(-) create mode 100644 vllm/_custom_ops.py diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index f71d1fcaaef50..5c3650fa72d17 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,7 +5,7 @@ from typing import Optional import torch -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 03ea72924921e..9b1f3e30b6dca 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -237,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4141aacafd0b2..d1051fd7e2f4d 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -4,7 +4,7 @@ from typing import Tuple import pytest import torch -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -80,7 +80,7 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + ops.copy_blocks(key_caches, value_caches, block_mapping) # Run the reference implementation. for src, dsts in block_mapping.items(): @@ -145,9 +145,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(key_cache, cloned_key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(value_cache, cloned_value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -156,14 +156,14 @@ def test_reshape_and_cache( kv_scale = 1.0 # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, kv_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(key_cache, result_key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -251,9 +251,8 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) for src, dst in block_mapping.items(): assert torch.allclose(src_key_caches_clone[src].cpu(), @@ -291,9 +290,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - cache_ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache, cache_fp8) converted_cache = torch.empty_like(cache) - cache_ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(cache_fp8, converted_cache) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 0000000000000..a0837a20875fe --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,193 @@ +from typing import Dict, Optional + +import torch + +try: + from vllm._C import cache_ops as vllm_cache_ops + from vllm._C import ops as vllm_ops +except ImportError: + pass + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.silu_and_mul(out, x) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_and_mul(out, x) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_tanh_and_mul(out, x) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_fast(out, x) + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_new(out, x) + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, + context_lens, block_size, max_context_len, + alibi_slopes, kv_cache_dtype, kv_scale) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale) + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + vllm_ops.rms_norm(out, input, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, + thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + + +# squeezellm +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, + lookup_table: torch.Tensor) -> None: + vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# moe +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) + + +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, + block_mapping: torch.Tensor) -> None: + vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: Dict[int, int]) -> None: + vllm_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: + vllm_cache_ops.convert_fp8(output, input) + + +#TODO: cuda_utils, custom_ar diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2d918491d6576..cd0690a4ba957 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple import torch -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -69,7 +69,7 @@ class PagedAttention: kv_cache_dtype: str, kv_scale: float, ) -> None: - cache_ops.reshape_and_cache( + ops.reshape_and_cache( key, value, key_cache, @@ -199,11 +199,11 @@ class PagedAttention: ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -212,4 +212,4 @@ class PagedAttention: ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 6786c48e0caba..baf1d4f266181 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ec09f0cd4c28..377b6588dbf47 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,7 +8,7 @@ import torch import triton import triton.language as tl -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.utils import is_hip diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index cb3cee2bad5ad..a6619714b8aab 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops class RMSNorm(nn.Module): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf50..daea5ac73e429 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed811..757ab1af8392e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf4..a6482c059cc41 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1f..bb295df2acc3f 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d80e73bbe39e9..eb8d5f6dfb2a9 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/utils.py b/vllm/utils.py index 8ba03333d3b6c..8ab8927512cc9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -279,10 +279,10 @@ def _generate_random_fp8( #-----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm._C import cache_ops + from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor_tmp, tensor) del tensor_tmp From f3d0bf7589d6e63a691dcbb9d1db538c184fde29 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 20:33:02 -0700 Subject: [PATCH 13/50] [Doc][Installation] delete python setup.py develop (#3989) --- docs/source/getting_started/installation.rst | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 5dfb32080f97a..e7826114ffa9d 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -85,13 +85,3 @@ You can also build and install vLLM from source: $ nvcc --version # verify that nvcc is in your PATH $ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME - -.. note:: - If you are developing the C++ backend of vLLM, consider building vLLM with - - .. code-block:: console - - $ python setup.py develop - - since it will give you incremental builds. The downside is that this method - is `deprecated by setuptools `_. From c1dc547129f5faaa2ca5ba557145b8ec8838693c Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 11 Apr 2024 07:50:00 -0700 Subject: [PATCH 14/50] [Kernel] Fused MoE Config for Mixtral 8x22 (#4002) --- ...048,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++ ...048,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ ...096,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++ ...096,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ 4 files changed, 584 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000000000..0bb423b28f5ab --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..26bcbf26970c7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000000000..dbc624731f5cb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..32c0c9da471cb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} From 08ccee1e830d39ecdb3c6cf382c843dbf5ae830e Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Thu, 11 Apr 2024 23:59:26 +0800 Subject: [PATCH 15/50] punica fix-bgmv-kernel-640 (#4007) --- csrc/punica/bgmv/bgmv_config.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 2219d960ae62f..1084a0f20df6b 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128) \ f(in_T, out_T, W_T, narrow, 256) \ f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 640) \ f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ From 8afca50889bad6ad987c523c48c31fc52fcb72e4 Mon Sep 17 00:00:00 2001 From: bigPYJ1151 Date: Fri, 12 Apr 2024 02:56:49 +0800 Subject: [PATCH 16/50] [Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better maintenance (#3824) --- vllm/attention/backends/torch_sdpa.py | 72 ++--- vllm/executor/cpu_executor.py | 10 + vllm/utils.py | 1 - vllm/worker/cpu_model_runner.py | 408 ++++++++++++++++++++++++++ vllm/worker/cpu_worker.py | 13 +- 5 files changed, 443 insertions(+), 61 deletions(-) create mode 100644 vllm/worker/cpu_model_runner.py diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 63904ea929870..d21b54b16db4b 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -50,20 +50,15 @@ class TorchSDPABackend(AttentionBackend): @dataclass -class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, + AttentionMetadataPerStage): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool + slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] - prompt_lens_tensor: Optional[torch.Tensor] - - max_subquery_len: Optional[int] = None - max_prompt_len: Optional[int] = None - subquery_start_loc: Optional[torch.Tensor] = None - seq_start_loc: Optional[torch.Tensor] = None - use_cuda_graph: bool = False def __post_init__(self): # Set during the execution of the first attention op. @@ -111,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[TorchSDPAMetadata], + attn_metadata: TorchSDPAMetadata, kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -140,51 +135,36 @@ class TorchSDPABackendImpl(AttentionImpl): attn_metadata.kv_cache_dtype, kv_scale) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - if (kv_cache is None or prefill_meta.block_tables.numel() == 0): + if attn_metadata.is_prompt: + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if prefill_meta.attn_bias is None: + if attn_metadata.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - prefill_meta.prompt_lens) # type: ignore + attn_metadata.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - prefill_meta.prompt_lens, self.sliding_window, + attn_metadata.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(prefill_meta.prompt_lens) - prefill_meta.attn_bias = att_masks + att_masks = [None] * len(attn_metadata.prompt_lens) + attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - out = torch.empty((num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(prefill_meta.prompt_lens, - prefill_meta.attn_bias): + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(attn_metadata.prompt_lens, + attn_metadata.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -194,32 +174,28 @@ class TorchSDPABackendImpl(AttentionImpl): dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - out[start:end, :, :] = sub_out + output[start:end, :, :] = sub_out start = end - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - if decode_meta := attn_metadata.decode_metadata: + else: # Decoding run. - out = PagedAttention.forward_decode( - decode_query, + output = PagedAttention.forward_decode( + query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) - assert out.shape == output[num_prefill_tokens:].shape - output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) @@ -241,7 +217,7 @@ def _make_alibi_bias( bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] - bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, prompt_len, prompt_len), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 2bf97338da0ed..eda4e8989c163 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -25,6 +25,7 @@ class CPUExecutor(ExecutorBase): assert lora_config is None, "cpu backend doesn't support LoRA" model_config = _verify_and_get_model_config(model_config) cache_config = _verify_and_get_cache_config(cache_config) + scheduler_config = _verify_and_get_scheduler_config(scheduler_config) self.model_config = model_config self.cache_config = cache_config @@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: return config +def _verify_and_get_scheduler_config( + config: SchedulerConfig) -> SchedulerConfig: + if config.chunked_prefill_enabled: + logger.warning("Chunked prefill is not supported on CPU, disable it.") + config.chunked_prefill_enabled = False + + return config + + def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: _GB = 1 << 30 if config.enable_prefix_caching: diff --git a/vllm/utils.py b/vllm/utils.py index 8ab8927512cc9..fdb0a3768ab0d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool: print_warning_once("Pin memory is not supported on Neuron.") return False elif is_cpu(): - print_warning_once("Pin memory is not supported on CPU.") return False return True diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 0000000000000..49e1ad5709f5d --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,408 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad, maybe_expand_dim + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +class CPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + + self.model = None + self.block_size = None # Set after initial profiling. + + self.kv_cache_dtype = kv_cache_dtype + + self.attn_backend = get_attn_backend( + self.model_config.dtype if model_config is not None else None) + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + prompt_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + prompt_len = len(prompt_tokens) + + prompt_lens.append(prompt_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, prompt_len))) + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + + for i in range(computed_len, prompt_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + prompt_lens=prompt_lens, + num_prefills=len(prompt_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + prefill_metadata=None, + decode_metadata=None, + max_context_len=None, + context_lens=None, + block_tables=torch.tensor([]), + slot_mapping=slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + prompt_lens, + ) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_context_len = max(context_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + max_context_len=max_context_len, + num_prefills=0, + prefill_metadata=None, + decode_metadata=None, + context_lens=context_lens, + block_tables=block_tables, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + generators: List[torch.Generator] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + subquery_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += subquery_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) + categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + subquery_len - 1)) + selected_token_indices.append(selected_token_start_idx + + subquery_len - 1) + selected_token_start_idx += subquery_len + + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=self.device).manual_seed(sampling_params.seed) + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) + categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs + + if sampling_params.seed is not None: + generators.append(seq_group_metadata.state.generator) + + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long) + + categorized_sample_indices = { + t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + generators=generators, + ) + return sampling_metadata + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, + SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, + prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens) + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + } + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + generators=None, + perform_sampling=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not sampling_metadata.perform_sampling: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 751384eb72af3..3989207e8dd83 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -12,25 +12,14 @@ from vllm.distributed import (broadcast_tensor_dict, init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.model_loader import get_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.model_runner import ModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) -class CPUModelRunner(ModelRunner): - - def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - class CPUCacheEngine: """Manages the KV cache for CPU backend. From a10d3056da644c31e4ebf95a2b6ad65a626a7350 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 13:35:51 -0700 Subject: [PATCH 17/50] [Core] Set `linear_weights` directly on the layer (#3977) --- csrc/quantization/gptq/q_gemm.cu | 2 +- tests/kernels/test_moe.py | 2 +- vllm/lora/layers.py | 12 +-- vllm/model_executor/layers/linear.py | 77 ++++++++++--------- .../model_executor/layers/quantization/awq.py | 29 +++---- .../layers/quantization/gptq.py | 47 ++++++----- .../layers/quantization/marlin.py | 23 +++--- .../layers/quantization/squeezellm.py | 24 +++--- 8 files changed, 114 insertions(+), 102 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 655158e38f557..cc56649917a8a 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -2067,7 +2067,7 @@ void gptq_shuffle const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index affbbfb4aa94e..046f11d957bdd 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype): ).cuda() # Load the weights - vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 84a94091486d7..a8ec4dcfd6137 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -368,7 +368,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -402,10 +402,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): if self.base_layer.skip_bias_add else None) return output, output_bias - @property - def linear_weights(self): - return self.base_layer.linear_weights - @classmethod def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, @@ -505,7 +501,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -746,7 +742,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -838,7 +834,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def apply_weights(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x) + self.base_layer, x) _apply_lora( x, self.lora_a_stacked, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 8f42b3e8a4abe..3ca870742efc5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -28,19 +28,24 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - """Create weights for a linear layer.""" + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + + The weights will be set as attributes of the layer.""" raise NotImplementedError @abstractmethod def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - """Apply the weights to the input tensor.""" + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - weight = weights["weight"] + weight = layer.weight if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias @@ -111,12 +118,9 @@ class ReplicatedLinear(torch.nn.Module): if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) + self.linear_method.create_weights(self, self.input_size, + self.output_size, self.input_size, + self.output_size, self.params_dtype) if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -126,7 +130,7 @@ class ReplicatedLinear(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self.linear_weights, x, bias) + output = self.linear_method.apply_weights(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -177,13 +181,13 @@ class ColumnParallelLinear(torch.nn.Module): if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size, + self.output_size_per_partition, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -211,8 +215,7 @@ class ColumnParallelLinear(torch.nn.Module): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_, bias) + output_parallel = self.linear_method.apply_weights(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -523,13 +526,13 @@ class RowParallelLinear(torch.nn.Module): if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size_per_partition, + self.output_size, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -569,7 +572,7 @@ class RowParallelLinear(torch.nn.Module): # Matrix multiply. output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_parallel) + self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index daea5ac73e429..98651aed8be0e 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -136,19 +137,21 @@ class AWQLinearMethod(LinearMethodBase): "input_dim": 0, "output_dim": 1, }) - return { - "qweight": qweight, - "qzeros": qzeros, - "scales": scales, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - scales = weights["scales"] - qzeros = weights["qzeros"] + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) @@ -163,5 +166,5 @@ class AWQLinearMethod(LinearMethodBase): out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 757ab1af8392e..f370b94a210ee 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,12 +89,14 @@ class GPTQLinearMethod(LinearMethodBase): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( @@ -179,37 +181,40 @@ class GPTQLinearMethod(LinearMethodBase): "input_dim": scale_and_zero_input_dim, "output_dim": 1, }) - return { - "qweight": qweight, - "g_idx": g_idx, - "qzeros": qzeros, - "scales": scales, - "exllama_state": exllama_state, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("g_idx", g_idx) + set_weight_attrs(g_idx, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + layer.exllama_state = exllama_state def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] + qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: - weights["g_idx"] = torch.argsort(weights["g_idx"]).to( - torch.int) + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + layer.g_idx.data = torch.empty((0, ), + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - output = ops.gptq_gemm(reshaped_x, weights["qweight"], - weights["qzeros"], weights["scales"], - weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY, + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: - output = output + bias + output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index a6482c059cc41..bf0500f1155a1 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if params_dtype != torch.float16: @@ -187,21 +189,22 @@ class MarlinLinearMethod(LinearMethodBase): dtype=torch.int), requires_grad=False) - return { - "B": qweight, - "s": scales, - "workspace": workspace, - } + layer.register_parameter("B", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) def apply_weights( self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weights["B"] - scales = weights["s"] - workspace = weights["workspace"] + qweight = layer.B + scales = layer.s + workspace = layer.workspace x_2d = x.view(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index bb295df2acc3f..661ff9c55d0d1 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -103,17 +104,18 @@ class SqueezeLLMLinearMethod(LinearMethodBase): set_weight_attrs(lookup_table, { "output_dim": 0, }) - return { - "qweight": qweight, - "lookup_table": lookup_table, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("lookup_table", lookup_table) + set_weight_attrs(lookup_table, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - lookup_table = weights["lookup_table"] + qweight = layer.qweight + lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): @@ -126,5 +128,5 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) From 559eb852f83fe7867390dd2986b4f93a6572cf10 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 11 Apr 2024 14:00:48 -0700 Subject: [PATCH 18/50] [Core] init_distributed_environment align with init_process_group(#4014) [Core][Distributed] make init_distributed_environment compatible with init_process_group (#4014) --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4bb77146295af..9fceffe7cb88b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,9 +39,9 @@ _PIPELINE_GLOBAL_RANKS = None def init_distributed_environment( - world_size: int, - rank: int, - distributed_init_method: Optional[str] = None, + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", ): From 95e7d4a97cd64f8c6dc226ec0bbceebef6458701 Mon Sep 17 00:00:00 2001 From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:15:50 -0700 Subject: [PATCH 19/50] Fix echo/logprob OpenAI completion bug (#3441) Co-authored-by: Dylan Hawk --- tests/entrypoints/test_openai_server.py | 31 ++++++++++++ vllm/entrypoints/openai/serving_chat.py | 9 ++-- vllm/entrypoints/openai/serving_completion.py | 15 ++++-- vllm/entrypoints/openai/serving_engine.py | 47 +++++++++++-------- 4 files changed, 73 insertions(+), 29 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 6f2086c4dd269..7940430b8b654 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -742,5 +742,36 @@ number: "1" | "2" assert content.strip() == ground_truth +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, + model_name: str): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=1) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert (completion.choices[0].text is not None + and re.search(r"^" + prompt_text, completion.choices[0].text)) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + assert len(logprobs.tokens) > 5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0980c3d3cb614..a03c5dc88108f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing): request_id = f"cmpl-{random_uuid()}" try: - token_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) + # Tokenize/detokenize depending on prompt format (string/token list) + prompt_ids, prompt_text = self._validate_prompt_and_tokenize( + request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) guided_decode_logits_processor = ( @@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing): except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids, + result_generator = self.engine.generate(prompt_text, sampling_params, + request_id, prompt_ids, lora_request) # Streaming response if request.stream: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 06e7a9225fefb..c1f1744a118bd 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing): for i, prompt in enumerate(prompts): if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt_ids=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) else: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) + prompt_ids, prompt_text = prompt_formats generators.append( - self.engine.generate(prompt, + self.engine.generate(prompt_text, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids, + prompt_token_ids=prompt_ids, lora_request=lora_request)) except ValueError as e: # TODO: Use a vllm-specific Validation Error @@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing): output_text = prompt_text elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids - top_logprobs = prompt_logprobs + output.logprobs + top_logprobs = (prompt_logprobs + output.logprobs + if request.logprobs else None) output_text = prompt_text + output.text else: token_ids = output.token_ids @@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing): output_text = output.text if request.logprobs is not None: + assert top_logprobs is not None, ( + "top_logprobs must be provided when logprobs " + "is requested") logprobs = self._create_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8f69388c0251e..77a568b564039 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import asyncio import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from pydantic import conint @@ -99,27 +99,32 @@ class OpenAIServing: last_token_len = 0 if num_output_top_logprobs: logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id].logprob + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(None) + logprobs.top_logprobs.append(None) else: - token_logprob = None - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) + token_logprob = step_top_logprobs[token_id].logprob + token = step_top_logprobs[token_id].decoded_token + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + p.decoded_token: p.logprob + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) last_token_len = len(token) - - if num_output_top_logprobs: - logprobs.top_logprobs.append({ - p.decoded_token: p.logprob - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) return logprobs def create_error_response( @@ -164,12 +169,12 @@ class OpenAIServing: raise ValueError("The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( - self, - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None - ) -> List[int]: + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None + ) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): @@ -187,6 +192,8 @@ class OpenAIServing: else: input_ids = prompt_ids + input_text = prompt if prompt is not None else self.tokenizer.decode( + prompt_ids) token_num = len(input_ids) if request.max_tokens is None: @@ -201,4 +208,4 @@ class OpenAIServing: f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.", ) else: - return input_ids + return input_ids, input_text From 1e96c3341a4e055ae392085fecc7a672295b71c2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 15:18:57 -0700 Subject: [PATCH 20/50] Add extra punica sizes to support bigger vocabs (#4015) --- csrc/punica/bgmv/bgmv_config.h | 12 +++++- csrc/punica/punica_ops.cc | 14 +++--- tests/lora/test_layers.py | 78 +++++++++++++++++++--------------- tests/lora/test_punica.py | 49 +++++++++++++++++++-- vllm/lora/layers.py | 4 +- 5 files changed, 109 insertions(+), 48 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 1084a0f20df6b..9b76b98ab3322 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + f(in_T, out_T, W_T, narrow, 64000) \ + f(in_T, out_T, W_T, narrow, 64256) \ + f(in_T, out_T, W_T, narrow, 64512) \ + f(in_T, out_T, W_T, narrow, 102400) \ + f(in_T, out_T, W_T, narrow, 102656) \ + f(in_T, out_T, W_T, narrow, 102912) \ + f(in_T, out_T, W_T, narrow, 128000) \ + f(in_T, out_T, W_T, narrow, 128256) \ + f(in_T, out_T, W_T, narrow, 128512) \ +// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA +// and vllm/tests/lora/test_punica.py // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index 28739be14b862..7ebfd851c4feb 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, } } -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { - return (uint32_t(a) << 16) | uint32_t(b); +inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { + return (uint64_t(a) << 32) | uint64_t(b); } #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") @@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { template inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, const int64_t *lora_indices, - uint16_t in_features, uint16_t out_features, + uint32_t in_features, uint32_t out_features, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { - switch (pack_u16(in_features, out_features)) { + switch (pack_u32(in_features, out_features)) { #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u16(feat_in, feat_out): \ + case pack_u32(feat_in, feat_out): \ bgmv_kernel(Y, X, W, lora_indices, y_offset, \ full_y_size, batch_size, num_layers, \ layer_idx, scale); \ @@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: @@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 71ce6f1764832..e9e0c8554c1ef 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -170,7 +170,8 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) @@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info) lora_result = lora_embedding(torch.cat(inputs)) @@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None: # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding_data = torch.rand_like(embedding.weight.data) embedding.weight.data = embedding_data - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( - 512 + lora_config.lora_extra_vocab_size * max_loras, + vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=512) - expanded_embedding.weight.data[:512, :] = embedding_data + org_num_embeddings=vocab_size) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place lora_embedding = VocabParallelEmbeddingWithLoRA( @@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, 512 + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size)), generate_embeddings_tensor=256, ) @@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): embedding_id = lora_id - 1 - input_[-1] = 512 + (embedding_id * embeddings_tensor_len) - original_input_[-1] = 512 - input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = 512 + embeddings_tensor_len - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ( + (embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) - expanded_embedding.weight[512:512 + + expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) @@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) original_inputs = deepcopy(inputs) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(original_inputs)) @@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_lm_head_logits_processor(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def _pretest(): - linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, - 1024, 32000) + linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + 1024, vocab_size) linear.weight.data = torch.rand_like(linear.weight.data) - linear.weight.data[:, 32000:] = 0 + linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - 32000 + lora_config.lora_extra_vocab_size, 32000) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device) lora_logits_processor.create_lora_weights(max_loras, lora_config) @@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size, ) lora_logits_processor.set_mapping(*mapping_info, ) @@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - logits_processor.org_vocab_size = (32000 + + logits_processor.org_vocab_size = (vocab_size + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: result = logits_processor._get_logits(hidden_states=input_, embedding=linear.weight, embedding_bias=None) - result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = 32000 + logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds @@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size) lora_logits_processor.set_mapping(*mapping_info, ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, - embedding_bias=None)[:, :32000] + embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 2736a1c7ade27..cab8b44ccd2df 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -43,10 +43,51 @@ def _lora_ref_impl( H1 = H2 = [ - 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456, - 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216, - 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512, - 32768, 33024 + 128, + 256, + 512, + 1024, + 1152, + 1280, + 1536, + 2048, + 2304, + 2560, + 2752, + 3072, + 3456, + 3584, + 4096, + 4608, + 5120, + 5504, + 5632, + 6144, + 6848, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 22016, + 24576, + 27392, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 49152, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, ] SEED = [0xabcdabcd987] CUDA_DEVICES = [ diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a8ec4dcfd6137..5456b5613c47a 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -935,9 +935,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): model_config: Optional[PretrainedConfig] = None, ) -> None: # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - if 32000 < self.base_layer.vocab_size > 33024: + if 32000 < self.base_layer.vocab_size > 128512: raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 33024") + "32000 >= vocab_size <= 128512") self.lora_a_stacked = torch.zeros( ( max_loras, From e46a60aa4c90cf3dfd9b90782f2eeabbda935eef Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 11 Apr 2024 23:34:12 +0100 Subject: [PATCH 21/50] [BugFix] Fix handling of stop strings and stop token ids (#3672) --- tests/conftest.py | 2 +- .../{samplers => engine}/test_stop_reason.py | 2 +- tests/engine/test_stop_strings.py | 111 ++++++++++++++++++ vllm/engine/llm_engine.py | 106 +++++++++++------ vllm/outputs.py | 4 +- vllm/sampling_params.py | 9 ++ vllm/sequence.py | 6 + vllm/transformers_utils/detokenizer.py | 7 +- 8 files changed, 206 insertions(+), 41 deletions(-) rename tests/{samplers => engine}/test_stop_reason.py (97%) create mode 100644 tests/engine/test_stop_strings.py diff --git a/tests/conftest.py b/tests/conftest.py index a7e8963af0eda..5c50fc2d1bab6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -401,7 +401,7 @@ class VllmRunner: cleanup() -@pytest.fixture +@pytest.fixture(scope="session") def vllm_runner(): return VllmRunner diff --git a/tests/samplers/test_stop_reason.py b/tests/engine/test_stop_reason.py similarity index 97% rename from tests/samplers/test_stop_reason.py rename to tests/engine/test_stop_reason.py index b242c405a4fb6..b2f521a8ae4ce 100644 --- a/tests/samplers/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -3,7 +3,7 @@ 2. One of the provided stop tokens 3. The EOS token -Run `pytest tests/samplers/test_stop_reason.py`. +Run `pytest tests/engine/test_stop_reason.py`. """ import pytest diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py new file mode 100644 index 0000000000000..6b747beb4b543 --- /dev/null +++ b/tests/engine/test_stop_strings.py @@ -0,0 +1,111 @@ +from typing import Any, List, Optional + +import pytest + +from vllm import CompletionOutput, LLMEngine, SamplingParams + +MODEL = "meta-llama/llama-2-7b-hf" +MAX_TOKENS = 200 + + +@pytest.fixture(scope="session") +def vllm_model(vllm_runner): + return vllm_runner(MODEL) + + +@pytest.mark.skip_global_cleanup +def test_stop_basic(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".") + + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".") + + +@pytest.mark.skip_global_cleanup +def test_stop_multi_tokens(vllm_model): + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization. We are a ", + expected_reason="group of peo") + + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=True, + expected_output= + "VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo") + + +@pytest.mark.skip_global_cleanup +def test_stop_partial_token(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani") + + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani") + + +@pytest.mark.skip_global_cleanup +def test_stop_token_id(vllm_model): + # token id 13013 => " organization" + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013) + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013) + + +def _test_stopping(llm_engine: LLMEngine, + expected_output: str, + expected_reason: Any, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_in_output: bool = False) -> None: + llm_engine.add_request( + "id", "A story about vLLM:\n", + SamplingParams( + temperature=0.0, + max_tokens=MAX_TOKENS, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_in_output, + ), None) + + output: Optional[CompletionOutput] = None + output_text = "" + stop_reason = None + while llm_engine.has_unfinished_requests(): + (request_output, ) = llm_engine.step() + (output, ) = request_output.outputs + + # Ensure we don't backtrack + assert output.text.startswith(output_text) + output_text = output.text + stop_reason = output.stop_reason + + assert output is not None + assert output_text == expected_output + assert stop_reason == expected_reason diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ddfdda898a5c6..a91629a630591 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -501,9 +501,11 @@ class LLMEngine: for seq, _ in child_seqs: if seq_group.sampling_params.detokenize: - self.detokenizer.decode_sequence_inplace( + new_char_count = self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self._check_stop(seq, seq_group.sampling_params) + else: + new_char_count = 0 + self._check_stop(seq, new_char_count, seq_group.sampling_params) # Non-beam search case if not seq_group.sampling_params.use_beam_search: @@ -798,9 +800,45 @@ class LLMEngine: time_e2e_requests=time_e2e_requests, ) - def _check_stop(self, seq: Sequence, + def _check_stop(self, seq: Sequence, new_char_count: int, sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + # Check if the sequence has reached max_model_len. if seq.get_len() > self.scheduler_config.max_model_len: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED @@ -811,43 +849,37 @@ class LLMEngine: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: - return + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/outputs.py b/vllm/outputs.py index 61fe20bfc2744..d01be0eb0efd2 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -112,8 +112,10 @@ class RequestOutput: # always has the logprobs of the sampled tokens even if the # logprobs are not requested. include_logprobs = seq_group.sampling_params.logprobs is not None + text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), seq.output_text, + CompletionOutput(seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4fdc3c6dedaef..0b9787608798c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -166,6 +166,13 @@ class SamplingParams: self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self.truncate_prompt_tokens = truncate_prompt_tokens + # Number of characters to hold back for stop string evaluation + # until sequence is finished. + if self.stop and not include_stop_str_in_output: + self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 + else: + self.output_text_buffer_length = 0 + self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -226,6 +233,8 @@ class SamplingParams: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + if any(not stop_str for stop_str in self.stop): + raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " diff --git a/vllm/sequence.py b/vllm/sequence.py index 77029908c2218..cdb6cce6f0255 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -235,6 +235,12 @@ class Sequence: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + def get_output_text_to_return(self, buffer_length: int): + # We return the full output text if the sequence is finished. + truncate = buffer_length and not self.is_finished() + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 486c1938e1e10..005932f1e3df4 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -87,12 +87,15 @@ class Detokenizer: prev_tokens.extend(next_iter_tokens) def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> None: + prms: SamplingParams) -> int: """Decodes the new token for a sequence. In-place operation. Args: seq: The sequence to decode. prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. """ all_input_ids = seq.get_token_ids() token_id_generated_this_iteration = all_input_ids[-1] @@ -151,6 +154,8 @@ class Detokenizer: seq.read_offset = read_offset seq.output_text += new_decoded_token_text + return len(new_decoded_token_text) + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], From c2b4a1bce9a7707179cdfab2fb498c20b2b221e6 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:17:21 -0700 Subject: [PATCH 22/50] [Doc] Add typing hints / mypy types cleanup (#3816) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- benchmarks/backend_request_func.py | 62 ++++++++++--------- docs/source/conf.py | 3 +- setup.py | 5 +- vllm/core/block/interfaces.py | 31 ++++++---- vllm/engine/metrics.py | 10 ++- vllm/logger.py | 8 ++- .../model_executor/layers/rotary_embedding.py | 15 ++--- vllm/transformers_utils/config.py | 4 +- vllm/transformers_utils/configs/dbrx.py | 2 +- .../transformers_utils/tokenizers/baichuan.py | 10 +-- vllm/utils.py | 4 +- 11 files changed, 90 insertions(+), 64 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index ad428bd1c3644..bab570252c929 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -27,8 +27,8 @@ class RequestFuncInput: class RequestFuncOutput: generated_text: str = "" success: bool = False - latency: float = 0 - ttft: float = 0 # Time to first token + latency: float = 0.0 + ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies prompt_len: int = 0 @@ -58,23 +58,24 @@ async def async_request_tgi( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -119,23 +120,24 @@ async def async_request_trt_llm( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -151,7 +153,7 @@ async def async_request_trt_llm( output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -195,7 +197,7 @@ async def async_request_deepspeed_mii( output.generated_text = parsed_resp["text"][0] output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -234,19 +236,20 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -255,7 +258,7 @@ async def async_request_openai_completions( if data["choices"][0]["text"]: timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -315,19 +318,20 @@ async def async_request_openai_chat_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -337,7 +341,7 @@ async def async_request_openai_chat_completions( delta = data["choices"][0]["delta"] if delta.get("content", None): # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -354,7 +358,7 @@ async def async_request_openai_chat_completions( output.success = True output.latency = latency else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False diff --git a/docs/source/conf.py b/docs/source/conf.py index 44cda7c99cdd5..7a8c365ffb3bb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,6 +12,7 @@ import logging import sys +from typing import List from sphinx.ext import autodoc @@ -45,7 +46,7 @@ templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +exclude_patterns: List[str] = [] # Exclude the prompt "$" when copying code copybutton_prompt_text = r"\$ " diff --git a/setup.py b/setup.py index 98c92f9196e7e..9f0814e9f3bff 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import re import subprocess import sys from shutil import which -from typing import List +from typing import Dict, List import torch from packaging.version import Version, parse @@ -52,7 +52,7 @@ class CMakeExtension(Extension): class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. - did_config = {} + did_config: Dict[str, bool] = {} # # Determine number of compilation jobs and optionally nvcc compile threads. @@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version: Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ + assert CUDA_HOME is not None, "CUDA_HOME is not set" nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) output = nvcc_output.split() diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 9f466566f096b..fbceacf0ec417 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,5 +1,5 @@ -from abc import ABC, abstractmethod, abstractproperty -from typing import Dict, List, Optional, Protocol +from abc import ABC, abstractmethod +from typing import Dict, FrozenSet, List, Optional, Protocol from vllm.utils import Device @@ -10,23 +10,28 @@ class Block(ABC): def append_token_ids(self, token_ids: List[int]) -> None: pass - @abstractproperty + @property + @abstractmethod def block_id(self) -> Optional[int]: pass - @abstractproperty + @property + @abstractmethod def token_ids(self) -> List[int]: pass - @abstractproperty + @property + @abstractmethod def num_empty_slots(self) -> int: pass - @abstractproperty + @property + @abstractmethod def is_full(self) -> bool: pass - @abstractproperty + @property + @abstractmethod def prev_block(self) -> Optional["Block"]: pass @@ -47,12 +52,13 @@ class Block(ABC): class BlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: pass @abstractmethod def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + token_ids: List[int], device: Device) -> Block: pass @abstractmethod @@ -64,11 +70,12 @@ class BlockAllocator(ABC): pass @abstractmethod - def get_num_free_blocks(self) -> int: + def get_num_free_blocks(self, device: Device) -> int: pass - @abstractproperty - def all_block_ids(self) -> frozenset[int]: + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: pass @abstractmethod diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 905db52a1912b..02560907a1282 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Protocol import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -119,6 +119,12 @@ class Stats: time_e2e_requests: List[float] +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> Dict[str, str]: + ... + + class StatLogger: """StatLogger is used LLMEngine to log to Promethus and Stdout.""" @@ -135,7 +141,7 @@ class StatLogger: self.labels = labels self.metrics = Metrics(labelnames=list(labels.keys())) - def info(self, type: str, obj: object) -> None: + def info(self, type: str, obj: SupportsMetricsInfo) -> None: if type == "cache_config": self.metrics.info_cache_config.info(obj.metrics_info()) diff --git a/vllm/logger.py b/vllm/logger.py index e5e46f5cce3fe..af9575085ef37 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -4,6 +4,7 @@ import logging import os import sys +from typing import Optional VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) @@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter): _root_logger = logging.getLogger("vllm") -_default_handler = None +_default_handler: Optional[logging.Handler] = None def _setup_logger(): @@ -55,7 +56,12 @@ def init_logger(name: str): # Use the same settings as above for root logger logger = logging.getLogger(name) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) + if VLLM_CONFIGURE_LOGGING: + if _default_handler is None: + raise ValueError( + "_default_handler is not set up. This should never happen!" + " Please open an issue on Github.") logger.addHandler(_default_handler) logger.propagate = False return logger diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index eb8d5f6dfb2a9..6519781c8a8eb 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int, # Find dim range bounds based on rotations -def _yarn_find_correction_range(low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> int: +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: low = math.floor( _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( @@ -293,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): *, extrapolation_factor: float = 1, attn_factor: float = 1, - beta_fast: float = 32, - beta_slow: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8a6ba6c5b396c..ce7a30dce72fa 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,10 +1,10 @@ -from typing import Optional +from typing import Dict, Optional from transformers import AutoConfig, PretrainedConfig from vllm.transformers_utils.configs import * -_CONFIG_REGISTRY = { +_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, "mpt": MPTConfig, diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py index 3a19af7129e73..1d2724f22abd6 100644 --- a/vllm/transformers_utils/configs/dbrx.py +++ b/vllm/transformers_utils/configs/dbrx.py @@ -12,7 +12,7 @@ from transformers.utils import logging logger = logging.get_logger(__name__) -DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore class DbrxAttentionConfig(PretrainedConfig): diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py index 02045bdcb2ccf..79894035cb1f1 100644 --- a/vllm/transformers_utils/tokenizers/baichuan.py +++ b/vllm/transformers_utils/tokenizers/baichuan.py @@ -16,11 +16,11 @@ logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} -PRETRAINED_VOCAB_FILES_MAP = { +PRETRAINED_VOCAB_FILES_MAP = { # type: ignore "vocab_file": {}, "tokenizer_file": {}, } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore class BaichuanTokenizer(PreTrainedTokenizer): @@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer): `Tuple(str)`: Paths to the files saved. """ if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) " - "should be a directory") - return + raise ValueError(f"Vocabulary path ({save_directory}) " + "should be a directory") + out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + diff --git a/vllm/utils.py b/vllm/utils.py index fdb0a3768ab0d..669b65891d0db 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -294,7 +294,7 @@ def create_kv_caches_with_random( head_size: int, cache_dtype: Optional[Union[str, torch.dtype]], model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = 0, + seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) @@ -400,7 +400,7 @@ class CudaMemoryProfiler: gc.collect() -def str_to_int_tuple(s: str) -> Tuple[int]: +def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" try: return tuple(map(int, s.split(","))) From 1096717ae9e0b414ad625c1a12354dd1d949ffb1 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Fri, 12 Apr 2024 12:02:44 +0800 Subject: [PATCH 23/50] [Core] Support LoRA on quantized models (#4012) --- tests/lora/conftest.py | 5 + tests/lora/test_quant_model.py | 179 +++++++++++++++++++++++++++++++++ vllm/config.py | 9 +- vllm/lora/layers.py | 67 +++++++----- 4 files changed, 234 insertions(+), 26 deletions(-) create mode 100644 tests/lora/test_quant_model.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 207c635e2dc86..1127cc33183c9 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -143,6 +143,11 @@ def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") +@pytest.fixture(scope="session") +def tinyllama_lora_files(): + return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py new file mode 100644 index 0000000000000..3d86a4366aa57 --- /dev/null +++ b/tests/lora/test_quant_model.py @@ -0,0 +1,179 @@ +# Adapted from +# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py +from dataclasses import dataclass +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +from .conftest import cleanup + + +@dataclass +class ModelWithQuantization: + model_path: str + quantization: str + + +MODELS: List[ModelWithQuantization] = [ + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + quantization="AWQ"), + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), +] + + +def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): + raw_prompts = [ + "Give me an orange-ish brown color", + "Give me a neon pink color", + ] + + def format_prompt_tuples(prompt): + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + prompts = [format_prompt_tuples(p) for p in raw_prompts] + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=max_tokens, + stop=["<|im_end|>"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", [1]) +def test_quant_model_lora(tinyllama_lora_files, model, tp_size): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < tp_size: + # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_model_len=400, + tensor_parallel_size=tp_size, + quantization=model.quantization, + trust_remote_code=True) + + if model.quantization is None: + expected_no_lora_output = [ + "Here are some examples of orange-brown colors", + "I'm sorry, I don't have" + ] + expected_lora_output = [ + "#ff8050", + "#ff8080", + ] + elif model.quantization == "AWQ": + expected_no_lora_output = [ + "I'm sorry, I don't understand", + "I'm sorry, I don't understand", + ] + expected_lora_output = [ + "#f07700: A v", + "#f00000: A v", + ] + elif model.quantization == "GPTQ": + expected_no_lora_output = [ + "I'm sorry, I don't have", + "I'm sorry, I don't have", + ] + expected_lora_output = [ + "#f08800: This is", + "#f07788 \n#", + ] + + def expect_match(output, expected_output): + # HACK: GPTQ lora outputs are just incredibly unstable. + # Assert that the outputs changed. + if (model.quantization == "GPTQ" + and expected_output is expected_lora_output): + assert output != expected_no_lora_output + for i, o in enumerate(output): + assert o.startswith( + '#'), f"Expected example {i} to start with # but got {o}" + return + assert output == expected_output + + max_tokens = 10 + + print("lora adapter created") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 1") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=1, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("no lora") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 2") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=2, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("removing lora") + + del llm + cleanup() + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.skip("Requires multiple GPUs") +def test_quant_model_tp_equality(tinyllama_lora_files, model): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 2: + # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") + + llm_tp1 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + quantization=model.quantization, + trust_remote_code=True) + output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) + + del llm_tp1 + cleanup() + + llm_tp2 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + quantization=model.quantization) + output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) + + del llm_tp2 + cleanup() + + assert output_tp1 == output_tp2 diff --git a/vllm/config.py b/vllm/config.py index 4102edbe01d35..da7eb2810ff05 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -822,9 +822,12 @@ class LoRAConfig: self.lora_dtype = model_config.dtype elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - if model_config.quantization is not None: - raise ValueError( - "LoRA is not supported with quantized models yet.") + if model_config.quantization and model_config.quantization not in [ + "awq", "gptq" + ]: + # TODO support marlin and squeezellm + logger.warning(f"{model_config.quantization} quantization is not " + "tested with LoRA yet.") def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5456b5613c47a..4b9653de73a88 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -29,6 +29,19 @@ if TYPE_CHECKING: pass +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + if hasattr(base_layer, "weight"): + return base_layer.weight.device + if hasattr(base_layer, "linear_weights") and isinstance( + base_layer.linear_weights, dict): + values = list(base_layer.linear_weights.values()) + if len(values) and isinstance(values[0], torch.Tensor): + return values[0].device + raise ValueError(f"Unsupported base layer: {base_layer}") + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -302,6 +315,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): super().__init__() self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size_per_partition + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -312,17 +328,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None @@ -442,18 +458,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0] // 2, + self.output_size // 2, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.indices: Optional[torch.Tensor] = None @@ -619,25 +635,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) self.lora_b_stacked = ( @@ -647,7 +663,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.q_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -655,7 +671,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -663,7 +679,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) @@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() self.base_layer = base_layer + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -777,20 +796,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None @@ -809,7 +828,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): self.reset_lora(index) if self.base_layer.tp_size > 1: tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.base_layer.weight.shape[1] + shard_size = self.input_size start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size lora_a = lora_a[start_idx:end_idx, :] @@ -884,7 +903,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @property def weight(self): - return self.base_layer.weight + + return self.base_layer.weight if hasattr( + self.base_layer, "weight") else self.base_layer.qweight @classmethod def can_replace_layer(cls, source_layer: nn.Module, From 7fd3949a0b1c6cd0dcd7066aca48d9d589f2f68e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 13:30:54 +0800 Subject: [PATCH 24/50] [Frontend][Core] Move `merge_async_iterators` to utils (#4026) --- vllm/entrypoints/openai/serving_completion.py | 38 +----------------- vllm/utils.py | 40 ++++++++++++++++++- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c1f1744a118bd..e24aa2489a80f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,4 +1,3 @@ -import asyncio import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, Tuple) @@ -17,7 +16,7 @@ from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput -from vllm.utils import random_uuid +from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: return prompt_is_tokens, prompts -def merge_async_iterators(*iterators): - """Merge multiple asynchronous iterators into a single iterator. - - This method handle the case where some iterators finish before others. - When it yields, it yields a tuple (i, item) where i is the index of the - iterator that yields the item. - """ - queue = asyncio.Queue() - - finished = [False] * len(iterators) - - async def producer(i, iterator): - try: - async for item in iterator: - await queue.put((i, item)) - except Exception as e: - await queue.put(e) - finished[i] = True - - _tasks = [ - asyncio.create_task(producer(i, iterator)) - for i, iterator in enumerate(iterators) - ] - - async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item - await asyncio.gather(*_tasks) - - return consumer() - - class OpenAIServingCompletion(OpenAIServing): def __init__(self, diff --git a/vllm/utils.py b/vllm/utils.py index 669b65891d0db..0967dfc969c8a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -9,8 +9,8 @@ import warnings from collections import OrderedDict, defaultdict from functools import lru_cache, partial from platform import uname -from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, - Optional, Tuple, TypeVar, Union) +from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, + Hashable, List, Optional, Tuple, TypeVar, Union) import psutil import torch @@ -181,6 +181,42 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: return _async_wrapper +def merge_async_iterators( + *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i: int, iterator: AsyncIterator[T]): + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + await asyncio.gather(*_tasks) + + return consumer() + + def get_ip() -> str: host_ip = os.environ.get("HOST_IP") if host_ip: From 36729bac1303b655b816b77f45b17237bfafd692 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 Apr 2024 01:56:57 +0900 Subject: [PATCH 25/50] [Test] Test multiple attn backend for chunked prefill. (#4023) --- .buildkite/test-pipeline.yaml | 8 +++++++- .../test_basic_correctness.py | 6 ------ .../basic_correctness/test_chunked_prefill.py | 4 ---- vllm/attention/backends/rocm_flash_attn.py | 18 ++++++------------ 4 files changed, 13 insertions(+), 23 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 695290ed74ab5..8d7d6304cf12e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -12,7 +12,13 @@ steps: command: pytest -v -s async_engine - label: Basic Correctness Test - command: pytest -v -s basic_correctness + commands: + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test command: pytest -v -s core diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index bd4c7ea3301be..97cff623c5e1d 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,8 +4,6 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import pytest -from vllm.attention.selector import VLLM_ATTENTION_BACKEND - MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -16,7 +14,6 @@ MODELS = [ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -25,10 +22,7 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, - attn_backend: str, - monkeypatch, ) -> None: - monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 9ff07b3c09020..d83416eb51b43 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -33,10 +33,6 @@ def test_models( enforce_eager: bool, tensor_parallel_size: int, ) -> None: - if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 - and not enforce_eager): - pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " - "for high TP to save testing time.") max_num_seqs = min(chunked_prefill_token_size, 256) enable_chunked_prefill = False max_num_batched_tokens = None diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e55435cd2c947..c42660fb8f74f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -162,7 +162,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): # AMD Radeon 7900 series (gfx1100) currently does not support # xFormers nor FlashAttention. As a temporary workaround, we use # naive PyTorch implementation of attention. - self.attn_fuc = _naive_attention() + self.attn_fuc = _naive_attention logger.debug("Using naive attention in ROCmBackend") elif self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 @@ -334,26 +334,21 @@ def _naive_attention( prompt_lens: List[int], scale: float, ) -> torch.Tensor: - num_tokens = query.shape[0] output = torch.empty_like(query) start = 0 for _, prompt_len in enumerate(prompt_lens): end = start + prompt_len out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], + query[start:end], + key[start:end], + value[start:end], scale, ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) start += prompt_len - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, -1) + return output def _naive_masked_attention( @@ -362,14 +357,13 @@ def _naive_masked_attention( value: torch.Tensor, scale: float, ) -> torch.Tensor: - seq_len, _, _ = query.shape + seq_len, head_size, head_dim = query.shape attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=query.dtype, device=query.device), diagonal=1) attn_mask = attn_mask * torch.finfo(query.dtype).min - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) From 96b6a6d790115d04bb87d410f3bdd5d7d85b43f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 12:35:44 -0700 Subject: [PATCH 26/50] [Bugfix] fix type hint for py 3.8 (#4036) --- vllm/executor/executor_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c18edd75d7a4d..55bccfa8e3ca9 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -31,7 +31,7 @@ class ExecutorBase(ABC): raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. From d4ec9ffb9574988132d927fd1615180522877262 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 12 Apr 2024 13:56:04 -0700 Subject: [PATCH 27/50] [Misc] Fix typo in scheduler.py (#4022) --- vllm/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 2942eab735a92..e44f983e15374 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -674,7 +674,7 @@ class Scheduler: def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - The current policy is designed to opimimize the throughput. First, + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. From 09473ee41c0a22c4d18936ea7eb2328071c19308 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 Apr 2024 06:35:50 +0900 Subject: [PATCH 28/50] [mypy] Add mypy type annotation part 1 (#4006) --- .github/workflows/mypy.yaml | 50 ++++++++++++++++++++++++++ format.sh | 22 +++++++++--- pyproject.toml | 5 ++- requirements-common.txt | 3 +- requirements-dev.txt | 2 +- vllm/config.py | 9 +++-- vllm/core/block_manager_v1.py | 12 ++++--- vllm/core/block_manager_v2.py | 4 ++- vllm/core/interfaces.py | 4 ++- vllm/core/scheduler.py | 25 +++++++------ vllm/distributed/communication_op.py | 10 +++--- vllm/engine/ray_utils.py | 18 ++++++---- vllm/entrypoints/api_server.py | 1 + vllm/entrypoints/llm.py | 8 +++-- vllm/executor/cpu_executor.py | 4 +-- vllm/executor/gpu_executor.py | 4 +-- vllm/executor/neuron_executor.py | 4 +-- vllm/executor/ray_gpu_executor.py | 11 +++--- vllm/sampling_params.py | 5 +-- vllm/sequence.py | 8 ++--- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/detokenizer.py | 7 ++-- vllm/transformers_utils/tokenizer.py | 4 +-- vllm/usage/usage_lib.py | 8 ++--- vllm/utils.py | 12 ++++--- 25 files changed, 171 insertions(+), 72 deletions(-) create mode 100644 .github/workflows/mypy.yaml diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml new file mode 100644 index 0000000000000..fbe0f816fd4af --- /dev/null +++ b/.github/workflows/mypy.yaml @@ -0,0 +1,50 @@ +name: mypy + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + ruff: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy==1.9.0 + pip install types-setuptools + pip install types-PyYAML + pip install types-requests + pip install types-setuptools + - name: Mypy + run: | + mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + + # TODO(sang): Follow up + # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + diff --git a/format.sh b/format.sh index deb57b2b049d1..1c195b899c742 100755 --- a/format.sh +++ b/format.sh @@ -93,9 +93,23 @@ fi echo 'vLLM yapf: Done' # Run mypy -# TODO(zhuohan): Enable mypy -# echo 'vLLM mypy:' -# mypy +echo 'vLLM mypy:' +mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + +# TODO(sang): Follow up +# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + CODESPELL_EXCLUDES=( '--skip' '*docs/source/_build/**' @@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then exit 1 fi - - diff --git a/pyproject.toml b/pyproject.toml index 2a00d6796ee02..b870a4b85897b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,13 @@ ignore = [ python_version = "3.8" ignore_missing_imports = true + check_untyped_defs = true files = "vllm" # TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" +exclude = [ + "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", +] [tool.codespell] diff --git a/requirements-common.txt b/requirements-common.txt index ff053388a23e1..c96f9c9937fb0 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,4 +11,5 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -outlines == 0.0.34 # Requires torch >= 2.1.0 \ No newline at end of file +outlines == 0.0.34 # Requires torch >= 2.1.0 +typing_extensions \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 75d22bbdb2a1b..96dfda6faf00f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ codespell==2.2.6 isort==5.13.2 # type checking -mypy==0.991 +mypy==1.9.0 types-PyYAML types-requests types-setuptools diff --git a/vllm/config.py b/vllm/config.py index da7eb2810ff05..bbda4ecf3cc56 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2,7 +2,7 @@ import enum import json import os from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch from packaging.version import Version @@ -141,7 +141,7 @@ class ModelConfig: supported_load_format = [ "auto", "pt", "safetensors", "npcache", "dummy" ] - rocm_not_supported_load_format = [] + rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " @@ -679,6 +679,9 @@ class SpeculativeConfig: "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + assert (speculative_model is not None + and num_speculative_tokens is not None) + # TODO: The user should be able to specify revision/quantization/max # model len for the draft model. It is not currently supported. draft_revision = None @@ -993,7 +996,7 @@ def _get_and_verify_max_len( derived_max_model_len *= scaling_factor if max_model_len is None: - max_model_len = derived_max_model_len + max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e7e3b4dc1e9b4..e391a3b1e5a33 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,5 +1,6 @@ """A block manager that manages token blocks.""" from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from itertools import count, takewhile from os.path import commonprefix from typing import Dict, List, Optional, Set @@ -231,10 +232,10 @@ class BlockSpaceManagerV1(BlockSpaceManager): if self.enable_caching: logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) else: self.gpu_allocator = UncachedBlockAllocator( Device.GPU, block_size, num_gpu_blocks) @@ -588,7 +589,8 @@ class BlockSpaceManagerV1(BlockSpaceManager): for b in takewhile(lambda b: b.computed, block_table[:-1]) ] - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Return the block ids that are common for a given sequence group. Used in prefill (can skip prefill of some blocks). diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 813e71ad883b2..19f0cf415eb34 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" +from collections.abc import Sequence as GenericSequence from typing import Dict, List, Optional from vllm.core.block.block_table import BlockTable @@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager): # as computed. self.block_allocator.mark_blocks_as_computed() - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Determine which blocks for which we skip prefill. With prefix caching we can skip prefill for previously-generated blocks. diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 711536bcc97be..c1f68a2e891bf 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,5 +1,6 @@ import enum from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from typing import Dict, List from vllm.sequence import Sequence, SequenceGroup @@ -103,7 +104,8 @@ class BlockSpaceManager(ABC): pass @abstractmethod - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e44f983e15374..18ddcd1d6d466 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -42,8 +42,8 @@ class SchedulingBudget: """ token_budget: int max_num_seqs: int - _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) - _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) + _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set) _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 @@ -133,7 +133,7 @@ class SchedulerOutputs: return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) - def _sort_by_lora_ids(self) -> bool: + def _sort_by_lora_ids(self): self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @@ -337,7 +337,8 @@ class Scheduler: self.free_seq(seq) def has_unfinished_seqs(self) -> bool: - return self.waiting or self.running or self.swapped + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) @@ -404,7 +405,7 @@ class Scheduler: budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.pop(seq_group.lora_int_id) + curr_loras.remove(seq_group.lora_int_id) if running_queue: # Preempt the lowest-priority sequence groups. @@ -496,7 +497,7 @@ class Scheduler: now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) - leftover_swapped = deque() + leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] @@ -507,7 +508,9 @@ class Scheduler: lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) and len(curr_loras) >= self.lora_config.max_loras): # We don't have a space for another LoRA, so # we ignore this request for now. @@ -593,7 +596,7 @@ class Scheduler: # Copy the queue so that the input queue is not modified. waiting_queue = deque([s for s in waiting_queue]) - leftover_waiting_sequences = deque() + leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] @@ -635,6 +638,8 @@ class Scheduler: lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None if (self.lora_enabled and lora_int_id > 0 and lora_int_id not in curr_loras and len(curr_loras) >= self.lora_config.max_loras): @@ -780,7 +785,7 @@ class Scheduler: token_budget=self.scheduler_config.max_num_batched_tokens, max_num_seqs=self.scheduler_config.max_num_seqs, ) - curr_loras = set() + curr_loras: Set[int] = set() remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) @@ -1087,7 +1092,7 @@ class Scheduler: def _get_num_new_tokens(self, seq_group: SequenceGroup, status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> Tuple[int, bool]: + budget: SchedulingBudget) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 1004d626b6a4b..a3e93691a1e8e 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import ProcessGroup @@ -144,7 +144,7 @@ def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, -) -> Dict[Any, Union[torch.Tensor, Any]]: +) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary.""" group = group or torch.distributed.group.WORLD ranks = torch.distributed.get_process_group_ranks(group) @@ -157,10 +157,10 @@ def broadcast_tensor_dict( rank = torch.distributed.get_rank() if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): assert value.is_cuda, ( @@ -190,10 +190,10 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list(recv_metadata_list, src=src, group=group) - metadata_list = recv_metadata_list[0] + assert recv_metadata_list[0] is not None tensor_dict = {} async_handles = [] - for key, value in metadata_list: + for key, value in recv_metadata_list[0]: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 70d5c9b1fae05..04d4ed83976d0 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,9 +1,10 @@ import pickle -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.utils import get_ip, is_hip, set_cuda_visible_devices +from vllm.worker.worker import Worker logger = init_logger(__name__) @@ -18,15 +19,20 @@ try: if init_cached_hf_modules: from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() - self.worker = None + self._worker: Optional[Worker] = None # Since the compiled DAG runs a main execution # in a different thread that calls cuda.set_device. # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False - def init_worker(self, worker_init_fn): - self.worker = worker_init_fn() + def init_worker(self, worker_init_fn: Callable[[], Worker]): + self._worker = worker_init_fn() + + @property + def worker(self) -> Worker: + assert self._worker is not None + return self._worker def __getattr__(self, name): return getattr(self.worker, name) @@ -70,8 +76,8 @@ except ImportError as e: logger.warning(f"Failed to import Ray with {e!r}. " "For distributed inference, please install Ray with " "`pip install ray`.") - ray = None - RayWorkerVllm = None + ray = None # type: ignore + RayWorkerVllm = None # type: ignore def initialize_ray_cluster( diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 2a47eae112c12..587142adb9c6b 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -47,6 +47,7 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5777e8179a1c1..63ff0b30da552 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -170,8 +170,12 @@ class LLM: multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + if prompts is not None: + num_requests = len(prompts) + else: + assert prompt_token_ids is not None + num_requests = len(prompt_token_ids) + for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index eda4e8989c163..33e67d8b3eec2 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch @@ -61,7 +61,7 @@ class CPUExecutor(ExecutorBase): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 80ca5cb7367c5..f20221a0b941a 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -66,7 +66,7 @@ class GPUExecutor(ExecutorBase): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 57436a85cfa27..ee8e87432fa67 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -47,7 +47,7 @@ class NeuronExecutor(ExecutorBase): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6c0ccd7e64c90..b937693c92257 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,7 +3,7 @@ import copy import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -197,7 +197,7 @@ class RayGPUExecutor(ExecutorBase): max_parallel_loading_workers, ) - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. This invokes `determine_num_available_blocks` on each worker and takes @@ -205,7 +205,7 @@ class RayGPUExecutor(ExecutorBase): compatible with all workers. Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] + - Tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers("determine_num_available_blocks", ) @@ -276,7 +276,7 @@ class RayGPUExecutor(ExecutorBase): self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, @@ -291,6 +291,7 @@ class RayGPUExecutor(ExecutorBase): if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. + assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. @@ -369,7 +370,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Any: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 0b9787608798c..53a38b25bfdac 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -5,7 +5,8 @@ from functools import cached_property from typing import Callable, List, Optional, Union import torch -from pydantic import conint +from pydantic import Field +from typing_extensions import Annotated _SAMPLING_EPS = 1e-5 @@ -127,7 +128,7 @@ class SamplingParams: skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n diff --git a/vllm/sequence.py b/vllm/sequence.py index cdb6cce6f0255..dcde81df19923 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -171,10 +171,10 @@ class SequenceData: return self.prompt_token_ids[-1] return self.output_token_ids[-1] - def get_prompt_token_ids(self) -> int: + def get_prompt_token_ids(self) -> List[int]: return self.prompt_token_ids - def get_output_token_ids(self) -> int: + def get_output_token_ids(self) -> List[int]: return self.output_token_ids @property @@ -370,7 +370,7 @@ class SequenceGroupState: """Mutable state tied to a specific sequence group""" # torch.Generator used in seeded sampling - generator: Optional = None + generator: Optional = None # type: ignore class MultiModalData: @@ -599,7 +599,7 @@ class SequenceGroupMetadata: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> int: + def token_chunk_size(self) -> Optional[int]: """Return the number of tokens to be processed (chunk size).""" return self._token_chunk_size diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ce7a30dce72fa..1756c91a612f0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,7 +2,8 @@ from typing import Dict, Optional from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import * +from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + JAISConfig, MPTConfig, RWConfig) _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 005932f1e3df4..f064c26c3f40c 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] + sub_texts: List[str] = [] + current_sub_text: List[str] = [] all_special_tokens = set(tokenizer.all_special_tokens) for token in output_tokens: if skip_special_tokens and token in all_special_tokens: @@ -263,6 +263,7 @@ def detokenize_incrementally( tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. if new_token_id >= len(tokenizer): @@ -271,6 +272,8 @@ def detokenize_incrementally( # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e216a99af91f9..5d3d5801c960d 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import * +from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async logger = init_logger(__name__) @@ -28,7 +28,7 @@ def get_cached_tokenizer( tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_len = len(tokenizer) - class CachedTokenizer(tokenizer.__class__): + class CachedTokenizer(tokenizer.__class__): # type: ignore @property def all_special_ids(self): diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 658fe5c98f5ee..b2672f7f1da61 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -7,7 +7,7 @@ import time from enum import Enum from pathlib import Path from threading import Thread -from typing import Dict, Optional +from typing import Any, Dict, Optional from uuid import uuid4 import cpuinfo @@ -124,7 +124,7 @@ class UsageMessage: def report_usage(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any] = None) -> None: + extra_kvs: Optional[Dict[str, Any]] = None) -> None: t = Thread(target=self._report_usage_worker, args=(model_architecture, usage_context, extra_kvs or {}), daemon=True) @@ -132,13 +132,13 @@ class UsageMessage: def _report_usage_worker(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continous_usage() def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: # Platform information if torch.cuda.is_available(): device_property = torch.cuda.get_device_properties(0) diff --git a/vllm/utils.py b/vllm/utils.py index 0967dfc969c8a..4c0dc9ca729a9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -60,7 +60,7 @@ class LRUCache(Generic[T]): def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> T: + def __getitem__(self, key: Hashable) -> Optional[T]: return self.get(key) def __setitem__(self, key: Hashable, value: T) -> None: @@ -76,7 +76,7 @@ class LRUCache(Generic[T]): key: Hashable, default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: - value = self.cache[key] + value: Optional[T] = self.cache[key] self.cache.move_to_end(key) else: value = default_value @@ -87,7 +87,7 @@ class LRUCache(Generic[T]): self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: T): + def _on_remove(self, key: Hashable, value: Optional[T]): pass def remove_oldest(self): @@ -100,9 +100,11 @@ class LRUCache(Generic[T]): while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: + def pop(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: run_on_remove = key in self.cache - value = self.cache.pop(key, default_value) + value: Optional[T] = self.cache.pop(key, default_value) if run_on_remove: self._on_remove(key, value) return value From fbb9d9eef48a29e0ea821bbf399e4bf9a08d6ac1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 16:40:39 -0700 Subject: [PATCH 29/50] [Core] fix custom allreduce default value (#4040) --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63ff0b30da552..9e08c253dc539 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -86,7 +86,7 @@ class LLM: swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = True, + disable_custom_all_reduce: bool = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: From d04973ad5446fe05c06035f6b2d99402fc3ac7bf Mon Sep 17 00:00:00 2001 From: Bellk17 Date: Fri, 12 Apr 2024 16:41:26 -0700 Subject: [PATCH 30/50] Fix triton compilation issue (#3984) Co-authored-by: Woosuk Kwon --- vllm/attention/ops/triton_flash_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 87cf30cbef79a..e160411859f0b 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -415,7 +415,11 @@ def attn_fwd( return is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + if is_mqa: # noqa: SIM108 + off_h_k = off_h_q % hk + else: + off_h_k = off_h_q + n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k From b8aacac31a4e2e03381fdaef6f1e4bbb895f3b64 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Sat, 13 Apr 2024 07:56:37 +0800 Subject: [PATCH 31/50] [Bugfix] Fix LoRA bug (#4032) --- vllm/lora/layers.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 4b9653de73a88..aac86351b15e1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -32,14 +32,17 @@ if TYPE_CHECKING: def _get_lora_device(base_layer: nn.Module) -> torch.device: # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear if hasattr(base_layer, "weight"): return base_layer.weight.device - if hasattr(base_layer, "linear_weights") and isinstance( - base_layer.linear_weights, dict): - values = list(base_layer.linear_weights.values()) - if len(values) and isinstance(values[0], torch.Tensor): - return values[0].device - raise ValueError(f"Unsupported base layer: {base_layer}") + # GPTQ/AWQ/SqueezeLLM + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") def _apply_lora( From 546e7211684a28bbe53088961b4cf5123e235760 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 18:43:37 -0700 Subject: [PATCH 32/50] [CI/Test] expand ruff and yapf for all supported python version (#4037) --- .github/workflows/mypy.yaml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/yapf.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index fbe0f816fd4af..6db0bb7645ecd 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index e8060e369a889..e71033f828006 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index b163c960db555..04f307bcf8b0e 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From 5c2e66e4871917c5d59cc4a8b89ef53e690e9bd9 Mon Sep 17 00:00:00 2001 From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Date: Fri, 12 Apr 2024 21:07:04 -0700 Subject: [PATCH 33/50] [Bugfix] More type hint fixes for py 3.8 (#4039) --- vllm/executor/executor_base.py | 2 +- vllm/worker/cpu_worker.py | 4 ++-- vllm/worker/neuron_worker.py | 4 ++-- vllm/worker/worker_base.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 55bccfa8e3ca9..bbfbfc689c99f 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -39,7 +39,7 @@ class ExecutorBase(ABC): ExecutorBase may require modification of the result, e.g. to ensure the selected cache sizes are compatible with all workers. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3989207e8dd83..41341b063bed7 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch import torch.distributed @@ -157,7 +157,7 @@ class CPUWorker(LoraNotSupportedWorkerBase): def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of blocks available for the KV cache. This determines how many KV blocks can fit into the configured CPU diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 6136d50d0c068..2f22f82c045db 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Optional +from typing import List, Optional, Tuple import torch import torch.distributed @@ -40,7 +40,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase): def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. Swapping is not yet supported, so always return num_cpu_blocks=0. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e3027c406ffeb..d8c9febb11584 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Dict, List, Tuple from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -18,14 +18,14 @@ class WorkerBase(ABC): raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. The implementation may run profiling or other heuristics to determine the size of caches. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. From 98afde19fc273b1e6a695990b93ec07157b856f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Apr 2024 07:12:53 -0700 Subject: [PATCH 34/50] [Core][Distributed] improve logging for init dist (#4042) --- vllm/distributed/parallel_state.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9fceffe7cb88b..1258bf58cb453 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,6 +8,10 @@ from typing import Optional import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -45,6 +49,8 @@ def init_distributed_environment( local_rank: int = -1, backend: str = "nccl", ): + logger.debug(f"{world_size=} {rank=} {local_rank=} " + f"{distributed_init_method=} {backend=}") if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " From ec8e3c695f2dce080bde569746180300e91084a3 Mon Sep 17 00:00:00 2001 From: zspo Date: Sat, 13 Apr 2024 22:52:36 +0800 Subject: [PATCH 35/50] [Bugfix] fix_log_time_in_metrics (#4050) --- vllm/engine/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 02560907a1282..04e27e69ce0f3 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -130,7 +130,7 @@ class StatLogger: def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: # Metadata for logging locally. - self.last_local_log = time.monotonic() + self.last_local_log = time.time() self.local_interval = local_interval # Tracked stats over current local logging interval. From 0a430b4ae2763c2f161e3bfb1529acf4685f7caa Mon Sep 17 00:00:00 2001 From: zspo Date: Sat, 13 Apr 2024 22:54:03 +0800 Subject: [PATCH 36/50] [Bugfix] fix_small_bug_in_neuron_executor (#4051) --- vllm/executor/neuron_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index ee8e87432fa67..d45f18e466256 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -25,6 +25,7 @@ class NeuronExecutor(ExecutorBase): speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config + self.cache_config = cache_config assert lora_config is None, "LoRA is not supported for Neuron backend." self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -43,6 +44,7 @@ class NeuronExecutor(ExecutorBase): self.parallel_config, self.scheduler_config, self.device_config, + self.cache_config, ) self.driver_worker.init_device() self.driver_worker.load_model() From 989ae2538df211ca3a31f77ac8e106c5c97c6e53 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Sat, 13 Apr 2024 22:55:05 +0800 Subject: [PATCH 37/50] [Kernel] Add punica dimension for Baichuan-13B (#4053) --- csrc/punica/bgmv/bgmv_config.h | 1 + tests/lora/test_baichuan.py | 2 +- tests/lora/test_punica.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 9b76b98ab3322..d2906914f927e 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -47,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 2178266d2e0c8..5ab863eea94b3 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files): @pytest.mark.skip("Requires multiple GPUs") -def test_llama_tensor_parallel_equality(baichuan_lora_files): +def test_baichuan_tensor_parallel_equality(baichuan_lora_files): # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < 4: # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index cab8b44ccd2df..8b174f01d87d4 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -72,6 +72,7 @@ H1 = H2 = [ 11008, 13824, 14336, + 15360, 22016, 24576, 27392, From 711a000255eac3e034f0b73aa5cc62b45201a571 Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Sat, 13 Apr 2024 20:13:01 -0400 Subject: [PATCH 38/50] [Frontend] [Core] feat: Add model loading using `tensorizer` (#3476) --- .buildkite/test-pipeline.yaml | 3 + docs/source/conf.py | 1 + docs/source/models/engine_args.rst | 3 +- examples/tensorize_vllm_model.py | 254 ++++++++++++++ requirements-cpu.txt | 2 +- requirements-dev.txt | 1 + setup.py | 3 + tests/tensorizer/__init__.py | 0 .../tensorize_vllm_model_for_testing.py | 245 ++++++++++++++ tests/tensorizer/test_tensorizer.py | 302 +++++++++++++++++ vllm/config.py | 74 +++- vllm/engine/arg_utils.py | 45 ++- vllm/engine/llm_engine.py | 8 +- vllm/executor/gpu_executor.py | 23 +- vllm/executor/ray_gpu_executor.py | 6 +- vllm/model_executor/model_loader.py | 61 +++- vllm/model_executor/tensorizer_loader.py | 319 ++++++++++++++++++ vllm/model_executor/weight_utils.py | 34 +- vllm/worker/model_runner.py | 9 +- vllm/worker/worker.py | 9 +- 20 files changed, 1351 insertions(+), 51 deletions(-) create mode 100644 examples/tensorize_vllm_model.py create mode 100644 tests/tensorizer/__init__.py create mode 100644 tests/tensorizer/tensorize_vllm_model_for_testing.py create mode 100644 tests/tensorizer/test_tensorizer.py create mode 100644 vllm/model_executor/tensorizer_loader.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8d7d6304cf12e..aa4582bbda0c7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -91,6 +91,9 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 +- label: Tensorizer Test + command: apt-get install curl libsodium23 && pytest -v -s tensorizer + - label: Metrics Test command: pytest -v -s metrics diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a8c365ffb3bb..19cc8557a7541 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -83,6 +83,7 @@ autodoc_mock_imports = [ "vllm._C", "numpy", "tqdm", + "tensorizer", ] for mock_target in autodoc_mock_imports: diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d8a7ac72e0175..886a806934c04 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM: Directory to download and load the weights, default to the default cache dir of huggingface. -.. option:: --load-format {auto,pt,safetensors,npcache,dummy} +.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer} The format of the model weights to load. @@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM: * "safetensors" will load the weights in the safetensors format. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "dummy" will initialize the weights with random values, mainly for profiling. + * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`. .. option:: --dtype {auto,half,float16,bfloat16,float,float32} diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py new file mode 100644 index 0000000000000..3c20a38c7f726 --- /dev/null +++ b/examples/tensorize_vllm_model.py @@ -0,0 +1,254 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True, + help="The directory to serialize the model to. " + "This can be a local directory or S3 URI. The path to where the " + "tensors are saved is a combination of the supplied `dir` and model " + "reference ID. For instance, if `dir` is the serialized directory, " + "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " + "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " + "where `suffix` is given by `--suffix` or a random UUID if not " + "provided.") + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 36d20bc9473ea..5779b38b24e69 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. +triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 96dfda6faf00f..1317e51b2dd11 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,6 +14,7 @@ types-setuptools # testing pytest +tensorizer==2.9.0a0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 9f0814e9f3bff..813321efe796d 100644 --- a/setup.py +++ b/setup.py @@ -405,6 +405,9 @@ setup( python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, + extras_require={ + "optional": ["tensorizer==2.9.0a1"], + }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer/tensorize_vllm_model_for_testing.py new file mode 100644 index 0000000000000..d0be08329fd64 --- /dev/null +++ b/tests/tensorizer/tensorize_vllm_model_for_testing.py @@ -0,0 +1,245 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True) + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer/test_tensorizer.py new file mode 100644 index 0000000000000..2ab893e95da9c --- /dev/null +++ b/tests/tensorizer/test_tensorizer.py @@ -0,0 +1,302 @@ +import gc +import subprocess +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from tests.entrypoints.test_openai_server import ServerRunner +from vllm import SamplingParams +from vllm.config import TensorizerConfig +from vllm.model_executor.tensorizer_loader import ( + EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer, + load_with_tensorizer, open_stream) + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + +model_ref = "facebook/opt-125m" + + +def is_curl_installed(): + try: + subprocess.check_call(['curl', '--version']) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +@pytest.fixture(autouse=True) +def tensorizer_config(): + config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True) + return config + + +@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent') +def test_load_with_tensorizer(mock_agent, tensorizer_config): + mock_linear_method = MagicMock() + mock_agent_instance = mock_agent.return_value + mock_agent_instance.deserialize.return_value = MagicMock() + + result = load_with_tensorizer(tensorizer_config, + linear_method=mock_linear_method) + + mock_agent.assert_called_once_with(tensorizer_config, + linear_method=mock_linear_method) + mock_agent_instance.deserialize.assert_called_once() + assert result == mock_agent_instance.deserialize.return_value + + +def test_is_vllm_model_with_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = True + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is True + + +def test_is_vllm_model_without_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = False + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is False + + +def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_can_deserialize_s3(vllm_runner): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + loaded_hf_model = vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + ) + + deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) + + assert deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_deserialized_encrypted_vllm_model_has_same_outputs( + vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + key_path = tmp_path / (model_ref + ".key") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + with open_stream(key_path, "wb+") as stream: + stream.write(encryption_params.key) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + encryption_keyfile=key_path, + num_readers=1, + vllm_tensorized=True) + + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, + tmp_path): + hf_model = hf_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + max_tokens = 50 + outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(hf_model.model) + del hf_model + gc.collect() + torch.cuda.empty_cache() + loaded_hf_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False) + + deserialized_outputs = loaded_hf_model.generate_greedy( + prompts, max_tokens=max_tokens) + + assert outputs == deserialized_outputs + + +def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): + from huggingface_hub import snapshot_download + + from examples.multilora_inference import (create_test_prompts, + process_requests) + + model_ref = "meta-llama/Llama-2-7b-hf" + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + + # Serialize model before deserializing and binding LoRA adapters + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner( + model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=True, + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=50, + max_model_len=1000, + ) + process_requests(loaded_vllm_model.model.llm_engine, test_prompts) + + assert loaded_vllm_model + + +def test_load_without_tensorizer_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, tensorizer_uri="test") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorize_vllm_model(tmp_path): + # Test serialize command + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + # Test deserialize command + deserialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors", + path_to_tensors + ] + result = subprocess.run(deserialize_args, capture_output=True, text=True) + assert result.returncode == 0, (f"Deserialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_openai_apiserver_with_tensorizer(tmp_path): + ## Serialize model + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + ## Start OpenAI API server + openai_args = [ + "--model", model_ref, "--dtype", "float16", "--load-format", + "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized", + "--port", "8000" + ] + + server = ServerRunner.remote(openai_args) + + print("Server ready.") + assert server.ready.remote() + + +def test_raise_value_error_on_invalid_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, + load_format="safetensors", + tensorizer_uri="test") + + +def test_tensorizer_with_tp(vllm_runner): + with pytest.raises(ValueError): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + tensor_parallel_size=2, + ) + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorizer_warn_quant(tmp_path): + model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--quantization", "gptq", "--tensorizer-uri", "test", + "serialize", "--serialized-directory", tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + assert 'PerformanceWarning' in result.stderr diff --git a/vllm/config.py b/vllm/config.py index bbda4ecf3cc56..dce2944b2ee8a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,8 @@ import enum +import io import json import os +import typing from dataclasses import dataclass, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union @@ -16,6 +18,8 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup + from vllm.model_executor.tensorizer_loader import TensorizerArgs + logger = init_logger(__name__) _GB = 1 << 30 @@ -139,13 +143,14 @@ class ModelConfig: def _verify_load_format(self) -> None: load_format = self.load_format.lower() supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" + "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer" ] rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or " + "'dummy'.") if is_hip() and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ f for f in supported_load_format @@ -882,6 +887,65 @@ class VisionLanguageConfig: f"{[x.name for x in cls.ImageInputType]}.") from e +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[torch.nn.Module] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Union[str, torch.dtype] = None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + from vllm.model_executor.tensorizer_loader import TensorizerArgs + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if (parallel_config.tensor_parallel_size > 1 + and self.tensorizer_uri is not None): + raise ValueError( + "Loading to multiple GPUs is not currently supported with " + "vLLM-serialized models. Please set tensor_parallel_size=1." + " or use a non-vLLM-serialized model, such as a " + "serialized Hugging Face `PretrainedModel`.") + + def verify_with_model_config(self, model_config) -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + from vllm.model_executor.tensorizer_loader import ( + tensorizer_warning) + tensorizer_warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + if (model_config.load_format != "tensorizer" + and self.tensorizer_uri is not None): + raise ValueError( + "A tensorizer uri was passed for tensorizer loading, but the " + f"load format was set to {model_config.load_format}. " + "Please set the load format to 'tensorizer' to use " + f"tensorizer args.") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, @@ -1029,6 +1093,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1036,6 +1101,11 @@ class EngineConfig: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) + self.tensorizer_config.verify_with_model_config(self.model_config) + if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index daefddc01b431..831a03be65f61 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,12 +1,15 @@ import argparse import dataclasses +import io +import os from dataclasses import dataclass -from typing import Optional +from typing import BinaryIO, Optional, Union from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig) + SpeculativeConfig, TensorizerConfig, + TokenizerPoolConfig, VisionLanguageConfig) +from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -58,12 +61,22 @@ class EngineArgs: num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 + # Tensorizer configuration parameters + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, + bytes, os.PathLike, int] = None + vllm_tensorized: bool = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + # Related to Vision-language models such as llava image_input_type: Optional[str] = None image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None - scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -135,7 +148,9 @@ class EngineArgs: '--load-format', type=str, default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + ], help='The format of the model weights to load. ' '"auto" will try to load the weights in the safetensors format ' 'and fall back to the pytorch bin format if safetensors format ' @@ -145,7 +160,10 @@ class EngineArgs: '"npcache" will load the weights in pytorch format and store ' 'a numpy cache to speed up the loading. ' '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') + 'which is mainly for profiling.' + '"tensorizer" will load the weights using tensorizer from CoreWeave' + 'which assumes tensorizer_uri is set to the location of the ' + 'serialized weights.') parser.add_argument( '--dtype', type=str, @@ -403,6 +421,7 @@ class EngineArgs: default=None, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding') + parser = TensorizerArgs.add_cli_args(parser) return parser @classmethod @@ -465,6 +484,17 @@ class EngineArgs: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + tensorizer_config = TensorizerConfig( + tensorizer_uri=self.tensorizer_uri, + vllm_tensorized=self.vllm_tensorized, + verify_hash=self.verify_hash, + num_readers=self.num_readers, + encryption_keyfile=self.encryption_keyfile, + s3_access_key_id=self.s3_access_key_id, + s3_secret_access_key=self.s3_secret_access_key, + s3_endpoint=self.s3_endpoint, + ) + if self.image_input_type: if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): @@ -488,7 +518,8 @@ class EngineArgs: device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, - speculative_config=speculative_config) + speculative_config=speculative_config, + tensorizer_config=tensorizer_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a91629a630591..8c37c5a9d6ee9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer import vllm from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +74,7 @@ class LLMEngine: lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -110,6 +111,7 @@ class LLMEngine: self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config self.log_stats = log_stats self._init_tokenizer() @@ -125,6 +127,7 @@ class LLMEngine: lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + tensorizer_config=tensorizer_config, ) self._initialize_kv_caches() @@ -264,6 +267,9 @@ class LLMEngine: def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index f20221a0b941a..30577ecf62faa 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,17 +15,14 @@ logger = init_logger(__name__) class GPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig]) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config @@ -33,6 +30,7 @@ class GPUExecutor(ExecutorBase): self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for GPU backend" @@ -61,6 +59,7 @@ class GPUExecutor(ExecutorBase): distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) self.driver_worker.init_device() diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b937693c92257..28dc3e0db312a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -42,6 +42,7 @@ class RayGPUExecutor(ExecutorBase): lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -50,6 +51,7 @@ class RayGPUExecutor(ExecutorBase): self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for RayGPU backend." @@ -171,6 +173,7 @@ class RayGPUExecutor(ExecutorBase): distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, + tensorizer_config=self.tensorizer_config, )) # Initialize the driver worker with the Worker class. @@ -187,6 +190,7 @@ class RayGPUExecutor(ExecutorBase): distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2745dbd89ab0f..c70ca48bca70a 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -3,11 +3,14 @@ import contextlib from typing import Tuple, Type import torch -import torch.nn as nn +from torch import nn from vllm.config import DeviceConfig, ModelConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llava import LlavaForConditionalGeneration +from vllm.model_executor.tensorizer_loader import ( + ParameterizedLoadFormat, is_vllm_serialized_tensorizer, + load_with_tensorizer) from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: lora_config = kwargs.get("lora_config", None) vision_language_config = kwargs.get("vision_language_config", None) + tensorizer_config = kwargs.get("tensorizer_config", None) model_class = _get_model_architecture(model_config)[0] # Get the (maybe quantized) linear method. @@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. + extra_kwargs = {} + if hasattr(model_class, "supported_lora_modules"): + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + elif model_class in _VISION_MODEL_CLASSES: + extra_kwargs["vision_language_config"] = vision_language_config + with torch.device(device_config.device): - if hasattr(model_class, "supported_lora_modules"): - model = model_class(model_config.hf_config, linear_method, - lora_config) - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - else: - if model_class not in _VISION_MODEL_CLASSES: - model = model_class(model_config.hf_config, linear_method) - else: - model = model_class(model_config.hf_config, - vision_language_config, linear_method) + if (model_config.load_format == "tensorizer" + and is_vllm_serialized_tensorizer(tensorizer_config)): + extra_kwargs["linear_method"] = linear_method + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + model = model_class(config=model_config.hf_config, + linear_method=linear_method, + **extra_kwargs) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) else: # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + if model_config.load_format == "tensorizer": + # Provide a dynamic load format for `model.load_weights` + # to retain tensorizer args from CLI. + model_config.load_format = ParameterizedLoadFormat( + model_config.load_format) + model_config.load_format.params = ( + tensorizer_config._construct_tensorizer_args()) + + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) return model.eval() diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py new file mode 100644 index 0000000000000..ed3ad9e2ffa15 --- /dev/null +++ b/vllm/model_executor/tensorizer_loader.py @@ -0,0 +1,319 @@ +import argparse +import dataclasses +import io +import os +import time +import typing +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from vllm.config import TensorizerConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + +tensorizer_load_fail = False + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) +except ImportError: + tensorizer_load_fail = True + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor' +] + +logger = init_logger(__name__) + + +def load_with_tensorizer(tensorizer_config: TensorizerConfig, + **extra_kwargs) -> nn.Module: + tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + return tensorizer.deserialize() + + +def tensorizer_warning(message: str): + return warnings.warn(message, category=PerformanceWarning, stacklevel=2) + + +def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: + if tensorizer_config is None: + return False + return tensorizer_config.vllm_tensorized + + +class ParameterizedLoadFormat(str): + __slots__ = "params" + + +class PerformanceWarning(UserWarning): + + def __str__(self): + return (f"{super().__str__()}" + " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS" + " environment variable to hide this)") + + +if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower() + not in ("", "0", "n", "no", "off", "disable")): + warnings.simplefilter("ignore", category=PerformanceWarning) + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is 1. This greatly increases + performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = (self.s3_access_key_id + or os.environ.get("S3_ACCESS_KEY_ID")) or None + self.s3_secret_access_key = ( + self.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY")) or None + self.s3_endpoint = (self.s3_endpoint + or os.environ.get("S3_ENDPOINT_URL")) or None + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + # Omitting self.dtype and self.device as this behaves weirdly + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Tensorizer CLI arguments""" + + # Create the argument group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + '--load-format=tensorizer')) + + group.add_argument( + "--tensorizer-uri", + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=1, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file.") + group.add_argument( + "--s3-access-key-id", + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + group.add_argument( + "--vllm-tensorized", + action="store_true", + help="If enabled, indicates that the serialized model is a vLLM " + "model. This is used to determine the behavior of the " + "TensorDeserializer when loading tensors from a " + "serialized model.") + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +class TensorizerAgent: + """ + A class for performing tensorizer deserializations specifically for + vLLM models using plaid_mode. Uses TensorizerArgs to configure the + behavior of the TensorDeserializer when loading tensors from a serialized + model. For deserializations of HuggingFace models, TensorDeserializer is + instead used as an iterator directly in the func hf_model_weights_iterator + in vllm/model_executor/weight_utils.py + """ + + def __init__(self, tensorizer_config: TensorizerConfig, + linear_method: LinearMethodBase, **extra_kwargs): + self.tensorizer_config = tensorizer_config + self.tensorizer_args = ( + self.tensorizer_config._construct_tensorizer_args()) + self.extra_kwargs = extra_kwargs + if extra_kwargs.get("linear_method", None) is not None: + self.linear_method = extra_kwargs["linear_method"] + else: + self.linear_method = linear_method + self.model = self._init_model() + + if tensorizer_load_fail: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`.") + + def _init_model(self): + model_args = self.tensorizer_config.hf_config + model_args.torch_dtype = self.tensorizer_config.dtype + with no_init_or_tensor(): + return self.tensorizer_config.model_class( + config=model_args, + linear_method=self.linear_method, + **self.extra_kwargs) + + def _resize_lora_embeddings(self): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in self.model.modules(): + if (isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < + child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + def _check_tensors_on_meta_device(self): + for tensor in self.model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + def deserialize(self): + """ + Deserialize the model using the TensorDeserializer. This method is + specifically for vLLM models using tensorizer's plaid_mode. + + The deserializer makes use of tensorizer_args.stream_params + to configure the behavior of the stream when loading tensors from a + serialized model. The deserializer_params are used to configure the + behavior of the TensorDeserializer when loading tensors themselves. + Documentation on these params can be found in TensorizerArgs + + Returns: + nn.Module: The deserialized model. + """ + before_mem = get_mem_usage() + # Lazy load the tensors from S3 into the model. + start = time.perf_counter() + with open_stream( + self.tensorizer_args.tensorizer_uri, + mode="rb", + **self.tensorizer_args.stream_params, + ) as stream, TensorDeserializer( + stream, + dtype=self.tensorizer_config.dtype, + **self.tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(self.model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info(f"Deserialized {total_bytes_str} in " + f"{end - start:0.2f}s, {per_second}/s") + logger.info(f"Memory usage before: {before_mem}") + logger.info(f"Memory usage after: {after_mem}") + + self._check_tensors_on_meta_device() + self._resize_lora_embeddings() + return self.model.eval() diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 0961478930d74..08425604f0511 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -5,7 +5,7 @@ import hashlib import json import os from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union import filelock import huggingface_hub.constants @@ -161,7 +161,8 @@ def prepare_hf_model_weights( revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) \ + and load_format != "tensorizer" use_safetensors = False # Some quantized models use .pt files for storing the weights. if load_format == "auto": @@ -173,13 +174,15 @@ def prepare_hf_model_weights( allow_patterns = ["*.pt"] elif load_format == "npcache": allow_patterns = ["*.bin"] + elif load_format == "tensorizer": + allow_patterns = ["*.tensors"] else: raise ValueError(f"Unknown load_format: {load_format}") if fall_back_to_pt: allow_patterns += ["*.pt"] - if not is_local: + if not is_local and load_format != "tensorizer": # Before we download we look at that is available: fs = HfFileSystem() file_list = fs.ls(model_name_or_path, detail=False, revision=revision) @@ -224,6 +227,9 @@ def prepare_hf_model_weights( if not any(f.endswith(x) for x in blacklist) ] + if load_format == "tensorizer": + return hf_folder, hf_weights_files, use_safetensors + if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") @@ -234,7 +240,7 @@ def prepare_hf_model_weights( def hf_model_weights_iterator( model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto", + load_format: Union[Tuple, str] = "auto", revision: Optional[str] = None, fall_back_to_pt: Optional[bool] = True, ) -> Iterator[Tuple[str, torch.Tensor]]: @@ -277,6 +283,26 @@ def hf_model_weights_iterator( with open(param_path, "rb") as f: param = np.load(f) yield name, torch.from_numpy(param) + elif load_format == "tensorizer": + from vllm.model_executor.tensorizer_loader import (TensorDeserializer, + open_stream, + tensorizer_warning) + tensorizer_args = load_format.params + tensorizer_warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state elif use_safetensors: for st_file in hf_weights_files: with safe_open(st_file, framework="pt") as f: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 47ad8f0c9b78b..7dbe14ead0976 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,8 @@ import torch.nn as nn from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) + SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -111,11 +112,13 @@ class ModelRunner: kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -158,7 +161,9 @@ class ModelRunner: lora_config=self.lora_config, vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + tensorizer_config=self.tensorizer_config, + ) self.model_memory_usage = m.consumed_memory logger.info(f"Loading model weights took " diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3f0b2fd83f3e5..82491c6df6616 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,8 @@ import torch import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -42,6 +43,7 @@ class Worker(WorkerBase): distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -53,6 +55,7 @@ class Worker(WorkerBase): self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -70,7 +73,9 @@ class Worker(WorkerBase): lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + tensorizer_config=tensorizer_config, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine = None From 2cd6b4f3625466eb5849bcfd7a6fb316735adab8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Apr 2024 23:40:21 -0700 Subject: [PATCH 39/50] [Core] avoid too many cuda context by caching p2p test (#4021) --- .../device_communicators/custom_all_reduce.py | 53 +++++------ vllm/distributed/parallel_state.py | 9 ++ vllm/distributed/utils.py | 87 ++++++++++++++++++- 3 files changed, 116 insertions(+), 33 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 84238d2e46076..f83caef879da3 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -42,12 +42,17 @@ def init_custom_ar() -> None: " disable_custom_all_reduce=True explicitly.", world_size, str(_SUPPORTED_WORLD_SIZES)) return - if not _can_p2p(rank, world_size): + num_dev = torch.cuda.device_count() + # note: num dev can be larger than world_size if we're only using + # first few GPUs + if num_dev < world_size: logger.warn( - "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability or P2P test failed. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return + "Cannot test GPU P2P because not all GPUs are visible to the " + "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" + " is set.") + return False + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported full_nvlink = _is_full_nvlink(rank, world_size) if world_size > 2 and not full_nvlink: logger.warn( @@ -55,6 +60,15 @@ def init_custom_ar() -> None: " than two PCIe-only GPUs. To silence this warning, specify" " disable_custom_all_reduce=True explicitly.") return + # test P2P capability + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warn( + "Custom allreduce is disabled because your platform lacks GPU P2P" + " capability or P2P test failed. To silence this warning, specify" + " disable_custom_all_reduce=True explicitly.") + return _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) @@ -143,40 +157,15 @@ def _is_full_nvlink(rank, world_size): def _can_p2p(rank: int, world_size: int) -> bool: - num_dev = torch.cuda.device_count() - # note: num dev can be larger than world_size if we're only using - # first few GPUs - if num_dev < world_size: - logger.warn( - "Cannot test GPU P2P because not all GPUs are visible to the " - "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" - " is set.") - return False + from vllm.distributed.utils import gpu_p2p_access_check for i in range(world_size): if i == rank: continue - if not torch.cuda.can_device_access_peer(rank, i): - return False - # on some platforms, P2P support might be buggy and we need - # additional checks. See also: - # https://github.com/vllm-project/vllm/issues/2728 - if not _can_actually_p2p(rank, i): + if not gpu_p2p_access_check(rank, i): return False return True -# code partly borrowed from -# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 -# License: MIT -def _can_actually_p2p(idx_a, idx_b): - dev_i = f"cuda:{idx_a}" - dev_j = f"cuda:{idx_b}" - a = torch.randn(5, device=dev_i) + 123.0 - b = a.to(dev_j) - c = b.to(dev_i) - return torch.all(a == c) - - class CustomAllreduce: # max_size: max supported allreduce size diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1258bf58cb453..e2473736375e0 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -41,6 +41,13 @@ _CPU_WORLD_GROUP = None # source rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None +_LOCAL_RANK = -1 + + +def get_local_rank(): + global _LOCAL_RANK + return _LOCAL_RANK + def init_distributed_environment( world_size: int = -1, @@ -66,6 +73,8 @@ def init_distributed_environment( ranks = list(range(torch.distributed.get_world_size())) _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, backend="gloo") + global _LOCAL_RANK + _LOCAL_RANK = local_rank def initialize_model_parallel( diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 0cd420c8e11b5..e0a871ebe1756 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -2,9 +2,18 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import Sequence +import json +import os +from typing import Dict, Optional, Sequence import torch +import torch.distributed as dist + +from vllm.logger import init_logger + +from .parallel_state import get_cpu_world_group, get_local_rank + +logger = init_logger(__name__) def ensure_divisibility(numerator, denominator): @@ -46,3 +55,79 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list + + +# code partly borrowed from +# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 +# License: MIT +def _can_actually_p2p(idx_a, idx_b): + dev_i = f"cuda:{idx_a}" + dev_j = f"cuda:{idx_b}" + a = torch.randn(5, device=dev_i) + 123.0 + b = a.to(dev_j) + c = b.to(dev_i) + return torch.all(a == c).cpu().item() + + +# why do we need this cache? +# 1. we can have runtime checks for P2P access, where every process checks +# P2P access to all other GPUs. Unfortunately, the test might cost many +# (world_size * world_size) cuda context, and reduce the memory available +# for the model. see https://github.com/vllm-project/vllm/issues/3821 +# 2. alternatively, we can have a p2p map that is generated by the master +# process and broadcasted to all other processes. This still requires +# #world_size of cuda context, belonging to the master process, on each GPU. +# 3. we can have a cache file, that records the p2p access status. The first +# time the master process checks the p2p access, it will generate the cache +# file, at the cost of #world_size of cuda context. Later on, all processes +# can read the cache file to check the p2p access status without any cost of +# additional cuda context. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(i: int, j: int) -> bool: + """Check if GPU i can access GPU j.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{i}->{j}"] + + is_distributed = dist.is_initialized() + + num_dev = torch.cuda.device_count() + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + path = os.path.expanduser( + f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if (not is_distributed or get_local_rank() == 0) \ + and (not os.path.exists(path)): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info(f"generating GPU P2P access cache for in {path}") + cache = {} + for _i in range(num_dev): + for _j in range(num_dev): + # on some platforms, P2P support might be buggy and we need + # additional checks. See also: + # https://github.com/vllm-project/vllm/issues/2728 + cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer( + _i, _j) and _can_actually_p2p(_i, _j) + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + cpu_world_group = get_cpu_world_group() + dist.barrier(cpu_world_group) + logger.info(f"reading GPU P2P access cache from {path}") + with open(path, "r") as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{i}->{j}"] From 563c54f760f870ae44c7662c8a9ec3a223a3c4c4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 14 Apr 2024 22:12:42 +0100 Subject: [PATCH 40/50] [BugFix] Fix tensorizer extra in setup.py (#4072) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 813321efe796d..19a9150ad2e64 100644 --- a/setup.py +++ b/setup.py @@ -406,7 +406,7 @@ setup( install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "optional": ["tensorizer==2.9.0a1"], + "tensorizer": ["tensorizer==2.9.0a1"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, From aceb17cf2d629175a484c3d9df355f44bd334cb3 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sun, 14 Apr 2024 14:35:55 -0700 Subject: [PATCH 41/50] [Docs] document that mixtral 8x22b is supported (#4073) --- README.md | 2 +- docs/source/models/supported_models.rst | 36 ++++++++++++------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index d53227b82d87a..8434c11883341 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) +- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c09b0ff250437..5e5ce871f61dd 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -30,23 +30,23 @@ Alongside each architecture, we include some popular models that use it. * - :code:`CohereForCausalLM` - Command-R - :code:`CohereForAI/c4ai-command-r-v01`, etc. - - + - * - :code:`DbrxForCausalLM` - DBRX - :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc. - - + - * - :code:`DeciLMForCausalLM` - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - - + - * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - - + - * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - - + - * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -54,19 +54,19 @@ Alongside each architecture, we include some popular models that use it. * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. - - + - * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. - - + - * - :code:`GPTJForCausalLM` - GPT-J - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. - - + - * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - - + - * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. @@ -93,32 +93,32 @@ Alongside each architecture, we include some popular models that use it. - ✅︎ * - :code:`MixtralForCausalLM` - Mixtral-8x7B, Mixtral-8x7B-Instruct - - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc. + - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc. - ✅︎ * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - - + - * - :code:`OLMoForCausalLM` - OLMo - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. - - + - * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - - + - * - :code:`OrionForCausalLM` - Orion - :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc. - - + - * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - - + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - - + - * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. @@ -126,11 +126,11 @@ Alongside each architecture, we include some popular models that use it. * - :code:`Qwen2MoeForCausalLM` - Qwen2MoE - :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. - - + - * - :code:`StableLmForCausalLM` - StableLM - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - - + - If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model. From 8db1bf32f8924403c6a845b5ce71ba0f14beb038 Mon Sep 17 00:00:00 2001 From: Roy Date: Mon, 15 Apr 2024 08:43:54 +0800 Subject: [PATCH 42/50] [Misc] Upgrade triton to 2.2.0 (#4061) --- requirements-cpu.txt | 2 +- requirements-cuda.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 5779b38b24e69..e911ad03295f0 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file +triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6ee75e8139c04..c6d2cd46aee54 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,3 @@ pynvml == 11.5.0 vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 -triton >= 2.1.0 From e11e2007368b22fce05b9ecdf00dd48eda471f9e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 14 Apr 2024 21:50:08 -0700 Subject: [PATCH 43/50] [Bugfix] Fix filelock version requirement (#4075) --- requirements-common.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index c96f9c9937fb0..90a3bc8abc1db 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer outlines == 0.0.34 # Requires torch >= 2.1.0 -typing_extensions \ No newline at end of file +typing_extensions +filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 From 0003e9154bf1091d0de7ca7a6c7f1253df1eca5b Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Mon, 15 Apr 2024 23:35:55 +0800 Subject: [PATCH 44/50] [Misc][Minor] Fix CPU block num log in CPUExecutor. (#4088) --- vllm/executor/cpu_executor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 33e67d8b3eec2..e63a88be7868f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -74,7 +74,10 @@ class CPUExecutor(ExecutorBase): # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. - logger.info(f"# CPU blocks: {num_cpu_blocks}") + # NOTE: `cpu block` for CPU backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info(f"# CPU blocks: {num_gpu_blocks}") self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, From eb46fbfda25348422918c4a876e17aef05fc5e34 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 15 Apr 2024 13:05:09 -0700 Subject: [PATCH 45/50] [Core] Simplifications to executor classes (#4071) --- vllm/executor/cpu_executor.py | 31 +++++++++------------------- vllm/executor/executor_base.py | 27 +++++++++++++++++------- vllm/executor/gpu_executor.py | 32 ++++------------------------- vllm/executor/neuron_executor.py | 29 ++++++-------------------- vllm/executor/ray_gpu_executor.py | 34 ++++--------------------------- 5 files changed, 44 insertions(+), 109 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e63a88be7868f..f562e4e0ae3de 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,10 +1,9 @@ import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -16,23 +15,13 @@ logger = init_logger(__name__) class CPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: - assert device_config.device_type == "cpu" - assert lora_config is None, "cpu backend doesn't support LoRA" - model_config = _verify_and_get_model_config(model_config) - cache_config = _verify_and_get_cache_config(cache_config) - scheduler_config = _verify_and_get_scheduler_config(scheduler_config) - - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config + def _init_executor(self) -> None: + assert self.device_config.device_type == "cpu" + assert self.lora_config is None, "cpu backend doesn't support LoRA" + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.scheduler_config = _verify_and_get_scheduler_config( + self.scheduler_config) # Instantiate the worker and load the model to CPU. self._init_worker() @@ -99,7 +88,7 @@ class CPUExecutor(ExecutorBase): def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbfbfc689c99f..bbb6ec80f7b7e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -16,7 +16,6 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ - @abstractmethod def __init__( self, model_config: ModelConfig, @@ -27,8 +26,23 @@ class ExecutorBase(ABC): lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: - raise NotImplementedError + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config + + self._init_executor() + + @abstractmethod + def _init_executor(self) -> None: + pass @abstractmethod def determine_num_available_blocks(self) -> Tuple[int, int]: @@ -71,7 +85,7 @@ class ExecutorBase(ABC): raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @abstractmethod @@ -94,8 +108,7 @@ class ExecutorAsyncBase(ExecutorBase): """Executes one model step on the given sequences.""" raise NotImplementedError - @abstractmethod async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" - raise NotImplementedError + self.check_health() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 30577ecf62faa..bae509f48025b 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,24 +12,8 @@ logger = init_logger(__name__) class GPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig]) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for GPU backend" # Instantiate the worker and load the model to GPU. @@ -103,7 +84,7 @@ class GPUExecutor(ExecutorBase): assert lora_id > 0, "lora_id must be greater than 0." return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: @@ -127,8 +108,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy) return output - - async def check_health_async(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index d45f18e466256..273b17a927efd 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -13,24 +10,10 @@ logger = init_logger(__name__) class NeuronExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - assert lora_config is None, "LoRA is not supported for Neuron backend." - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Neuron backend." + assert (not self.speculative_config ), "Speculative decoding not yet supported for Neuron backend." # Instantiate the worker and load the model to the device. @@ -80,7 +63,7 @@ class NeuronExecutor(ExecutorBase): def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 28dc3e0db312a..5db2f3f652532 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,11 +3,8 @@ import copy import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -32,27 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) class RayGPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." assert self.parallel_config.worker_use_ray @@ -273,7 +251,7 @@ class RayGPUExecutor(ExecutorBase): lora_id=lora_id, ) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self._run_workers("list_loras") def _run_workers( @@ -416,7 +394,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): # Only the driver worker returns the sampling results. output = all_outputs[0] return output - - async def check_health_async(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() From d619ae2d19c41d9aa8f68fa0e5e32cc410dc2522 Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Mon, 15 Apr 2024 16:28:25 -0400 Subject: [PATCH 46/50] [Doc] Add better clarity for tensorizer usage (#4090) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- docs/source/models/engine_args.rst | 2 +- examples/tensorize_vllm_model.py | 60 +++++++++++++++++------- vllm/model_executor/tensorizer_loader.py | 6 +-- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index 886a806934c04..235cb4e128c99 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -45,7 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM: * "safetensors" will load the weights in the safetensors format. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "dummy" will initialize the weights with random values, mainly for profiling. - * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`. + * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_ See `examples/tensorize_vllm_model.py `_ to serialize a vLLM model, and for more information. .. option:: --dtype {auto,half,float16,bfloat16,float,float32} diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index 3c20a38c7f726..8cf8be09d0b9c 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -23,14 +23,16 @@ from vllm.model_executor.tensorizer_loader import TensorizerArgs # yapf: disable """ tensorize_vllm_model.py is a script that can be used to serialize and -deserialize vLLM models. These models can be loaded using tensorizer directly -to the GPU extremely quickly. Tensor encryption and decryption is also -supported, although libsodium must be installed to use it. Install -vllm with tensorizer support using `pip install vllm[tensorizer]`. +deserialize vLLM models. These models can be loaded using tensorizer +to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint, +or locally. Tensor encryption and decryption is also supported, although +libsodium must be installed to use it. Install vllm with tensorizer support +using `pip install vllm[tensorizer]`. -To serialize a model, you can run something like this: +To serialize a model, install vLLM from source, then run something +like this from the root level of this repository: -python tensorize_vllm_model.py \ +python -m examples.tensorize_vllm_model \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ serialize \ @@ -38,31 +40,57 @@ python tensorize_vllm_model.py \ --suffix vllm Which downloads the model from HuggingFace, loads it into vLLM, serializes it, -and saves it to your S3 bucket. A local directory can also be used. +and saves it to your S3 bucket. A local directory can also be used. This +assumes your S3 credentials are specified as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. +To provide S3 credentials directly, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this +script. You can also encrypt the model weights with a randomly-generated key by providing a `--keyfile` argument. -To deserialize a model, you can run something like this: +To deserialize a model, you can run something like this from the root +level of this repository: -python tensorize_vllm_model.py \ +python -m examples.tensorize_vllm_model \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors Which downloads the model tensors from your S3 bucket and deserializes them. -To provide S3 credentials, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, -the OpenAI entrypoint, as arguments for LLM(), or as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. - You can also provide a `--keyfile` argument to decrypt the model weights if they were serialized with encryption. -For more information on the available arguments, run -`python tensorize_vllm_model.py --help`. +For more information on the available arguments for serializing, run +`python -m examples.tensorize_vllm_model serialize --help`. + +Or for deserializing: + +`python -m examples.tensorize_vllm_model deserialize --help`. + +Once a model is serialized, it can be used to load the model when running the +OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing +the `--tensorizer-uri` CLI argument that is functionally the same as the +`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to +signify that the model to be deserialized is a vLLM model, rather than a +HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer +in the same inference server, albeit without the speed optimizations. To +deserialize an encrypted file, the `--encryption-keyfile` argument can be used +to provide the path to the keyfile used to encrypt the model weights. For +information on all the arguments that can be used to configure tensorizer's +deserialization, check out the tensorizer options argument group in the +`vllm/entrypoints/openai/api_server.py` script with `--help`. + +Tensorizer can also be invoked with the `LLM` class directly to load models: + + llm = LLM(model="facebook/opt-125m", + load_format="tensorizer", + tensorizer_uri=path_to_opt_tensors, + num_readers=3, + vllm_tensorized=True) """ diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py index ed3ad9e2ffa15..8550cc97aefe8 100644 --- a/vllm/model_executor/tensorizer_loader.py +++ b/vllm/model_executor/tensorizer_loader.py @@ -126,7 +126,6 @@ class TensorizerArgs: "s3_endpoint": self.s3_endpoint, } - # Omitting self.dtype and self.device as this behaves weirdly self.deserializer_params = { "verify_hash": self.verify_hash, "encryption": self.encryption_keyfile, @@ -145,7 +144,7 @@ class TensorizerArgs: parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Tensorizer CLI arguments""" - # Create the argument group + # Tensorizer options arg group group = parser.add_argument_group( 'tensorizer options', description=('Options for configuring the behavior of the' @@ -205,9 +204,7 @@ class TensorizerArgs: @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": - # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] - # Set the attributes from the parsed arguments. tensorizer_args = cls(**{ attr: getattr(args, attr) for attr in attrs if hasattr(args, attr) @@ -291,7 +288,6 @@ class TensorizerAgent: nn.Module: The deserialized model. """ before_mem = get_mem_usage() - # Lazy load the tensors from S3 into the model. start = time.perf_counter() with open_stream( self.tensorizer_args.tensorizer_uri, From 4695397dcfef693a0a10f1eb8bf77ea905c54829 Mon Sep 17 00:00:00 2001 From: Ricky Xu Date: Mon, 15 Apr 2024 14:24:45 -0700 Subject: [PATCH 47/50] [Bugfix] Fix ray workers profiling with nsight (#4095) --- vllm/executor/ray_gpu_executor.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5db2f3f652532..7aca5e36107aa 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -48,6 +48,21 @@ class RayGPUExecutor(ExecutorBase): if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: @@ -63,6 +78,10 @@ class RayGPUExecutor(ExecutorBase): # The remaining workers are the actual ray actors. self.workers: List[RayWorkerVllm] = [] + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + # Create the workers. driver_ip = get_ip() for bundle_id, bundle in enumerate(placement_group.bundle_specs): From 37e84a403d6d11b670a42e84153204cd8b76b849 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 Apr 2024 06:47:31 +0900 Subject: [PATCH 48/50] [Typing] Fix Sequence type GenericAlias only available after Python 3.9. (#4092) --- vllm/core/block_manager_v1.py | 5 +++-- vllm/core/block_manager_v2.py | 2 +- vllm/core/interfaces.py | 2 +- vllm/utils.py | 7 ++++--- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e391a3b1e5a33..be093922b84f2 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,9 +1,10 @@ """A block manager that manages token blocks.""" from abc import ABC, abstractmethod -from collections.abc import Sequence as GenericSequence from itertools import count, takewhile from os.path import commonprefix -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Set from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 19f0cf415eb34..6339a6baf4161 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,6 +1,6 @@ """A block manager that manages token blocks.""" -from collections.abc import Sequence as GenericSequence from typing import Dict, List, Optional +from typing import Sequence as GenericSequence from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index c1f68a2e891bf..56c2c5995c38b 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,7 +1,7 @@ import enum from abc import ABC, abstractmethod -from collections.abc import Sequence as GenericSequence from typing import Dict, List +from typing import Sequence as GenericSequence from vllm.sequence import Sequence, SequenceGroup diff --git a/vllm/utils.py b/vllm/utils.py index 4c0dc9ca729a9..aad62516ad1b9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -6,11 +6,12 @@ import socket import subprocess import uuid import warnings -from collections import OrderedDict, defaultdict +from collections import defaultdict from functools import lru_cache, partial from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, - Hashable, List, Optional, Tuple, TypeVar, Union) + Hashable, List, Optional, OrderedDict, Tuple, TypeVar, + Union) import psutil import torch @@ -51,7 +52,7 @@ class Counter: class LRUCache(Generic[T]): def __init__(self, capacity: int): - self.cache = OrderedDict[Hashable, T]() + self.cache: OrderedDict[Hashable, T] = OrderedDict() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: From 4e7ee664e201442e24e2298a36a5264b98691626 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 Apr 2024 14:24:53 +0900 Subject: [PATCH 49/50] [Core] Fix engine-use-ray broken (#4105) --- tests/async_engine/test_api_server.py | 17 +++++++++++++---- vllm/engine/async_llm_engine.py | 7 +++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 248bfbc8ab5c0..7f57d5cf9b182 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(tokenizer_pool_size: int): +def api_server(tokenizer_pool_size: int, engine_use_ray: bool, + worker_use_ray: bool): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() - uvicorn_process = subprocess.Popen([ + commands = [ sys.executable, "-u", str(script_path), "--model", "facebook/opt-125m", "--host", "127.0.0.1", "--tokenizer-pool-size", str(tokenizer_pool_size) - ]) + ] + if engine_use_ray: + commands.append("--engine-use-ray") + if worker_use_ray: + commands.append("--worker-use-ray") + uvicorn_process = subprocess.Popen(commands) yield uvicorn_process.terminate() @pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) -def test_api_server(api_server, tokenizer_pool_size: int): +@pytest.mark.parametrize("worker_use_ray", [False, True]) +@pytest.mark.parametrize("engine_use_ray", [False, True]) +def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool, + engine_use_ray: bool): """ Run the API server and test it. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f610495135121..1dbf58904541c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -333,8 +333,7 @@ class AsyncLLMEngine: if engine_config.device_config.device_type == "neuron": raise NotImplementedError("Neuron is not supported for " "async engine yet.") - elif (engine_config.parallel_config.worker_use_ray - or engine_args.engine_use_ray): + elif engine_config.parallel_config.worker_use_ray: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync @@ -410,8 +409,8 @@ class AsyncLLMEngine: else: # FIXME(woosuk): This is a bit hacky. Be careful when changing the # order of the arguments. - cache_config = args[1] - parallel_config = args[2] + cache_config = kwargs["cache_config"] + parallel_config = kwargs["parallel_config"] if parallel_config.tensor_parallel_size == 1: num_gpus = cache_config.gpu_memory_utilization else: From 05434764cd99990035779cf9a4ed86623b528825 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Tue, 16 Apr 2024 08:54:57 +0300 Subject: [PATCH 50/50] LM Format Enforcer Guided Decoding Support (#3868) Co-authored-by: Simon Mo --- requirements-common.txt | 1 + tests/entrypoints/test_guided_processors.py | 42 +++++++- tests/entrypoints/test_openai_server.py | 69 ++++++++---- vllm/config.py | 26 ++++- vllm/engine/arg_utils.py | 18 +++- vllm/engine/llm_engine.py | 10 +- vllm/entrypoints/openai/protocol.py | 12 +++ vllm/entrypoints/openai/serving_chat.py | 6 +- vllm/entrypoints/openai/serving_completion.py | 6 +- .../guided_decoding/__init__.py | 25 +++++ .../lm_format_enforcer_decoding.py | 69 ++++++++++++ .../outlines_decoding.py} | 7 +- .../outlines_logits_processors.py} | 100 +++++++++--------- 13 files changed, 304 insertions(+), 87 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/__init__.py create mode 100644 vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py rename vllm/model_executor/{guided_decoding.py => guided_decoding/outlines_decoding.py} (93%) rename vllm/model_executor/{guided_logits_processors.py => guided_decoding/outlines_logits_processors.py} (70%) diff --git a/requirements-common.txt b/requirements-common.txt index 90a3bc8abc1db..c1614d2537b25 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,6 +11,7 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer +lm-format-enforcer == 0.9.3 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 5622744566bcc..30f0ad5d8272f 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -1,11 +1,14 @@ # This unit test should be moved to a new # tests/test_guided_decoding directory. - +import pytest import torch from transformers import AutoTokenizer -from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.entrypoints.openai.protocol import CompletionRequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + JSONLogitsProcessor, RegexLogitsProcessor) TEST_SCHEMA = { "type": "object", @@ -73,3 +76,36 @@ def test_guided_logits_processors(): json_LP(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +async def test_guided_logits_processor_black_box(backend: str): + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + regex_request = CompletionRequest(model='test', + prompt=token_ids, + guided_regex=TEST_REGEX) + regex_lp = await get_guided_decoding_logits_processor( + backend, regex_request, tokenizer) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + json_request = CompletionRequest(model='test', + prompt=token_ids, + guided_json=TEST_SCHEMA) + json_lp = await get_guided_decoding_logits_processor( + backend, json_request, tokenizer) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 7940430b8b654..14e6ee0ffe9d9 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text -async def test_guided_json_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example JSON for an employee profile " @@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): n=3, temperature=1.0, max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) -async def test_guided_json_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) @@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): assert json1["age"] != json2["age"] -async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", n=3, temperature=1.0, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None -async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(TEST_REGEX, ip1) is not None @@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(TEST_REGEX, ip2) is not None assert ip1 != ip2 -async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt="The best language for type-safe systems programming is ", n=2, temperature=1.0, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 2 @@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): assert completion.choices[i].text in TEST_CHOICE -async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice1 = chat_completion.choices[0].message.content assert choice1 in TEST_CHOICE @@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content assert choice2 in TEST_CHOICE assert choice1 != choice2 -async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42)) + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) messages = [{ "role": "system", diff --git a/vllm/config.py b/vllm/config.py index dce2944b2ee8a..bf31b03b7c6c4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -66,8 +66,8 @@ class ModelConfig: weights. If None, we assume the model weights are not quantized. quantization_param_path: Path to JSON file containing scaling factors. Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. @@ -422,7 +422,7 @@ class CacheConfig: @dataclass class TokenizerPoolConfig: """Configuration for the tokenizer pool. - + Args: pool_size: Number of tokenizer workers in the pool. pool_type: Type of the pool. @@ -446,9 +446,9 @@ class TokenizerPoolConfig: tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. - + If tokenizer_pool_size is 0, return None. - + Args: tokenizer_pool_size: Number of tokenizer workers in the pool. tokenizer_pool_type: Type of the pool. @@ -1079,6 +1079,21 @@ def _get_and_verify_max_len( return int(max_model_len) +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine""" + + # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' + guided_decoding_backend: str = 'outlines' + + def __post_init__(self): + valid_guided_backends = ['outlines', 'lm-format-enforcer'] + backend = self.guided_decoding_backend + if backend not in valid_guided_backends: + raise ValueError(f"Invalid guided_decoding_backend '{backend}," + f"must be one of {valid_guided_backends}") + + @dataclass(frozen=True) class EngineConfig: """Dataclass which contains all engine-related configuration. This @@ -1093,6 +1108,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + decoding_config: Optional[DecodingConfig] tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 831a03be65f61..3de74b0ac28b9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,9 +5,9 @@ import os from dataclasses import dataclass from typing import BinaryIO, Optional, Union -from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig, TensorizerConfig, TokenizerPoolConfig, VisionLanguageConfig) from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -80,6 +80,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False + guided_decoding_backend: str = 'outlines' # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None @@ -200,6 +201,13 @@ class EngineArgs: default=EngineArgs.max_model_len, help='model context length. If unspecified, ' 'will be automatically derived from the model.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='outlines', + choices=['outlines', 'lm-format-enforcer'], + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc)') # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', @@ -511,6 +519,9 @@ class EngineArgs: else: vision_language_config = None + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend) + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -519,6 +530,7 @@ class EngineArgs: lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + decoding_config=decoding_config, tensorizer_config=tensorizer_config) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c37c5a9d6ee9..f06c1d18ace4b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,9 +4,10 @@ from typing import Iterable, List, Optional, Tuple, Type, Union from transformers import PreTrainedTokenizer import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +75,7 @@ class LLMEngine: lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, @@ -100,6 +102,7 @@ class LLMEngine: f"kv_cache_dtype={cache_config.cache_dtype}, " f"quantization_param_path={model_config.quantization_param_path}, " f"device_config={device_config.device}, " + f"decoding_config={decoding_config!r}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -111,6 +114,7 @@ class LLMEngine: self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.decoding_config = decoding_config or DecodingConfig() self.tensorizer_config = tensorizer_config self.log_stats = log_stats diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f94d22d279cc4..cf779d44c816b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) # doc: end-chat-completion-extra-params @@ -265,6 +271,12 @@ class CompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) # doc: end-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a03c5dc88108f..c9ed4a9de20f4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -68,9 +68,13 @@ class OpenAIServingChat(OpenAIServing): request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e24aa2489a80f..a71f2d6a4426a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -88,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing): try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 0000000000000..0558d6c95d97b --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,25 @@ +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( + get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_guided_decoding_logits_processor( + guided_decoding_backend: str, request: Union[CompletionRequest, + ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + if guided_decoding_backend == 'outlines': + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + return await get_lm_format_enforcer_guided_decoding_logits_processor( + request, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 0000000000000..0d74a5f8e81ff --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,69 @@ +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_lm_format_enforcer_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if request.guided_json: + schema = _normalize_json_schema_object(request.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif request.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in request.guided_choice]) + elif request.guided_regex: + character_level_parser = RegexParser(request.guided_regex) + elif request.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + elif (request.response_format is not None + and request.response_format.type == "json_object"): + character_level_parser = JsonSchemaParser( + None) # None means any json object + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + if isinstance(schema, BaseModel): + return schema.model_json_schema() + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py similarity index 93% rename from vllm/model_executor/guided_decoding.py rename to vllm/model_executor/guided_decoding/outlines_decoding.py index 8e710f1ac2b53..bd4564a36e1ed 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor, - JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) class GuidedDecodingMode(Enum): @@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value global_thread_pool = None # used for generating logits processor fsm -async def get_guided_decoding_logits_processor( +async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: """ diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py similarity index 70% rename from vllm/model_executor/guided_logits_processors.py rename to vllm/model_executor/guided_decoding/outlines_logits_processors.py index 035fe00037328..28041695546dc 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -13,9 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import math from collections import defaultdict +from functools import lru_cache from typing import Callable, DefaultDict, Dict, List, Optional, Union import torch @@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase class BaseLogitsProcessor: - def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. The decoder of outlines, returns a list whereas - the decode of vLLM returns an str. To sync the vLLM decoder with - outlines internal api, the decoder should be adapted. In addition - we need to handle the missing spaces to Llama's tokenizer to be - able to compile FSMs for this model. - - """ - if getattr(tokenizer, "_outlines_adapted", False): - return tokenizer - - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def change_decoder( - decoder: Callable[[List[int]], str] - ) -> Callable[[List[int]], List[str]]: - """Sync vLLM's decoder with the outlines by returning list.""" - - def new_decoder(inp_tokens: List[int]) -> List[str]: - return [decoder(inp_tokens)] - - return new_decoder - - tokenizer.convert_token_to_string = convert_token_to_string - tokenizer.decode = change_decoder(tokenizer.decode) - setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 - - return tokenizer - def init_state(self): """Initialize the FSM states.""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) @@ -78,7 +36,6 @@ class BaseLogitsProcessor: def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" - seq_id = hash(tuple(input_ids)) if len(input_ids) == 0: @@ -96,7 +53,6 @@ class BaseLogitsProcessor: device=scores.device) mask[allowed_tokens] = 0 scores.add_(mask) - return scores @@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm @@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) self.fsm = fsm + + +@lru_cache +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], + str]) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer