From 4e0db57fff89cc968794650c9b9caf4ccc51b399 Mon Sep 17 00:00:00 2001 From: QiliangCui Date: Wed, 25 Jun 2025 13:48:17 -0700 Subject: [PATCH 001/175] Fix the path to the testing script. (#20082) Signed-off-by: Qiliang Cui --- .buildkite/scripts/tpu/docker_run_bm.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/scripts/tpu/docker_run_bm.sh b/.buildkite/scripts/tpu/docker_run_bm.sh index 6705da03e3d76..715afce5f71ab 100755 --- a/.buildkite/scripts/tpu/docker_run_bm.sh +++ b/.buildkite/scripts/tpu/docker_run_bm.sh @@ -68,7 +68,7 @@ docker run \ echo "run script..." echo -docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh" +docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/tpu/run_bm.sh" echo "copy result back..." VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt From 9f0608fc166ba0173dac4a470753464b969c7043 Mon Sep 17 00:00:00 2001 From: zhrrr <43847754+izhuhaoran@users.noreply.github.com> Date: Thu, 26 Jun 2025 05:03:17 +0800 Subject: [PATCH 002/175] [Bugfix] default set cuda_graph_sizes to max_num_seqs for v1 engine (#20062) Signed-off-by: izhuhaoran --- vllm/config.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 96ea47a0dce38..e90ad5e9c8b65 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2042,11 +2042,12 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) - """Cuda graph capture sizes, default is 512. - 1. if one value is provided, then the capture list would follow the + cuda_graph_sizes: list[int] = field(default_factory=list) + """Cuda graph capture sizes + 1. if none provided, then default set to [max_num_seqs] + 2. if one value is provided, then the capture list would follow the pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] - 2. more than one value (e.g. 1 2 128) is provided, then the capture list + 3. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" delay_factor: float = 0.0 @@ -2211,6 +2212,10 @@ class SchedulerConfig: self.max_num_partial_prefills, self.max_long_partial_prefills, self.long_prefill_token_threshold) + # If cuda_graph_sizes is not specified, default set to [max_num_seqs]. + if not self.cuda_graph_sizes: + self.cuda_graph_sizes = [self.max_num_seqs] + @model_validator(mode='after') def _verify_args(self) -> Self: if (self.max_num_batched_tokens < self.max_model_len From 2cc206997012057152f194c0f25e19e3ab3297ea Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 25 Jun 2025 14:24:10 -0700 Subject: [PATCH 003/175] [TPU][Bugfix] fix kv cache padding (#20048) Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 8 +------- vllm/v1/worker/tpu_worker.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 1069578cfd292..e0aeea439794a 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -48,13 +48,7 @@ class PallasAttentionBackend(AttentionBackend): ) -> tuple[int, ...]: padded_head_size = cdiv( head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - num_blocks = num_blocks * head_size // padded_head_size - if padded_head_size != head_size: - logger.warning_once( - "head size is padded to %d, and num_blocks is adjusted to %d" - " accordingly", padded_head_size, num_blocks) - head_size = padded_head_size - return (num_blocks, block_size, num_kv_heads * 2, head_size) + return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) @staticmethod def swap_blocks( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 87af8e476707c..a64ce881fe318 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -18,7 +18,8 @@ from vllm.distributed import (ensure_model_parallel_initialized, from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) @@ -221,7 +222,17 @@ class TPUWorker: usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - + head_size = self.model_config.get_head_size() + if head_size > 0: + padded_head_size = cdiv( + head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + if padded_head_size != head_size: + logger.warning_once("head size is padded to %d", + padded_head_size) + # We adjust the usable memory size for the KV cache to prevent OOM + # errors, even after padding the head_size. + tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size // + padded_head_size) return int(tpu_kv_cache_bytes) def execute_model( From 55c65ab495f5d270f65f89dcc737e9694b278002 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 25 Jun 2025 15:19:44 -0700 Subject: [PATCH 004/175] [P/D] Avoid stranding blocks in P when aborted in D's waiting queue (#19223) Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a962a9241d73e..92a9184d318c7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -298,8 +298,21 @@ class NixlConnectorScheduler: logger.debug( "NIXLConnector request_finished, request_status=%s, " "kv_transfer_params=%s", request.status, params) + if not params: + return False, None - if (params is None or not params.get("do_remote_decode") + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if (not params.get("do_remote_decode") or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): return False, None From 2d7620c3ebb3a3e0e600dd2781d7e5dfbd1c2382 Mon Sep 17 00:00:00 2001 From: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com> Date: Wed, 25 Jun 2025 15:51:02 -0700 Subject: [PATCH 005/175] [TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (#19919) Signed-off-by: Chenyaaang --- tests/v1/tpu/worker/test_tpu_model_runner.py | 14 ++ vllm/envs.py | 3 + vllm/platforms/tpu.py | 10 - vllm/v1/attention/backends/pallas.py | 5 + vllm/v1/worker/tpu_model_runner.py | 228 +++++++++++++------ 5 files changed, 184 insertions(+), 76 deletions(-) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index d22ddf5c7e581..25839d0897a4c 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + + +def test_most_model_len(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048") + vllm_config = get_vllm_config() + vllm_config.model_config.max_model_len = 32000 + vllm_config.scheduler_config.max_num_seqs = 1200 + model_runner = get_model_runner(vllm_config) + + # verify model runner will adjust num_reqs to avoid SMEM OOM. + assert model_runner.num_reqs_most_model_len == 1200 + # num_page_per_req = 32k // 128 + # num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524 + assert model_runner.num_reqs_max_model_len == 524 diff --git a/vllm/envs.py b/vllm/envs.py index 43fc0f5a36e83..c9c81603a75a8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -119,6 +119,7 @@ if TYPE_CHECKING: VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 + VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 @@ -833,6 +834,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TPU_BUCKET_PADDING_GAP": lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, + "VLLM_TPU_MOST_MODEL_LEN": + lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)), # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 07e52017f5a53..0387e348965d7 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -122,16 +122,6 @@ class TpuPlatform(Platform): PallasAttentionBackend) cache_config.block_size = PallasAttentionBackend.get_page_size( vllm_config) # type: ignore[assignment] - min_page_size = PallasAttentionBackend.get_min_page_size( - vllm_config) - if min_page_size > cache_config.block_size: - logger.warning( - "Increase the page size from %s to %s to make sure there's" - "no SMEM OOM", - cache_config.block_size, - min_page_size, - ) - cache_config.block_size = min_page_size # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index e0aeea439794a..ff2862edaa01b 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -71,6 +71,11 @@ class PallasAttentionBackend(AttentionBackend): min_page_size = 1 << (min_page_size - 1).bit_length() return min_page_size + @staticmethod + def get_max_num_seqs(model_len: int, page_size: int) -> int: + num_page_per_req = cdiv(model_len, page_size) + return 1024 * 1024 // 2 // num_page_per_req // 4 + # TPU has limited SREGs (scalar registers), if page_size is too small, we # can spill SREGs easily which leads to bad performance. The strategy we # apply here is trying to split max-model-len to 16 pages which make the diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 774caa1a3d98f..2d80bac3c9546 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -37,8 +37,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, + LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache @@ -150,7 +150,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len + self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.num_blocks_per_most_len_req = cdiv( + self.most_model_len, + self.block_size) if self.most_model_len is not None else None # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) @@ -220,12 +224,19 @@ class TPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32, device="cpu") self.positions_np = self.positions_cpu.numpy() - self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), dtype=torch.int32, device="cpu") - + # adjust num_reqs to avoid SMEM OOM. + self.num_reqs_most_model_len = min( + PallasAttentionBackend.get_max_num_seqs(self.most_model_len, + self.block_size), + self.max_num_reqs) if self.most_model_len is not None else None + self.num_reqs_max_model_len = min( + PallasAttentionBackend.get_max_num_seqs(self.max_model_len, + self.block_size), + self.max_num_reqs) self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, dtype=torch.int32, device="cpu", @@ -515,25 +526,50 @@ class TPUModelRunner(LoRAModelRunnerMixin): return kv_cache_spec - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 + def _prepare_inputs(self, scheduler_output: "SchedulerOutput", + start_index: int): + assert scheduler_output.total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + assert start_index < num_reqs # Get the number of scheduled tokens for each request. + use_max_model_len = self.most_model_len is None num_scheduled_tokens_per_req = [] max_num_scheduled_tokens_all_reqs = 0 - for req_id in self.input_batch.req_ids[:num_reqs]: + end_index = start_index + + # Use either most_model_len or max_model_len depending on request size. + for i in range(start_index, num_reqs): + req_id = self.input_batch.req_ids[i] assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if not use_max_model_len and num_tokens > self.most_model_len: + use_max_model_len = True num_scheduled_tokens_per_req.append(num_tokens) - max_num_scheduled_tokens_all_reqs = max( - max_num_scheduled_tokens_all_reqs, num_tokens) + if use_max_model_len: + if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len: + num_scheduled_tokens_per_req = \ + num_scheduled_tokens_per_req[:self.num_reqs_max_model_len] + end_index = start_index + self.num_reqs_max_model_len + else: + end_index = num_reqs + else: + if len(num_scheduled_tokens_per_req + ) > self.num_reqs_most_model_len: + num_scheduled_tokens_per_req = \ + num_scheduled_tokens_per_req[:self.num_reqs_most_model_len] + end_index = start_index + self.num_reqs_most_model_len + else: + end_index = num_reqs + max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req) num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, dtype=np.int32) + total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req) assert max_num_scheduled_tokens_all_reqs > 0 + num_reqs = len(num_scheduled_tokens_per_req) + # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # For each scheduled token, what are the corresponding req index. @@ -615,13 +651,29 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.input_batch.block_table[0]. slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( self.device)) - block_tables = self.block_table_cpu[:self.max_num_reqs] - block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) + if use_max_model_len: + block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, : + self.max_num_blocks_per_req] + block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( + self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) + query_start_loc = self.query_start_loc_cpu[:self. + num_reqs_max_model_len + + 1].to(self.device) + seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to( + self.device) + else: + block_tables = self.block_table_cpu[:self. + num_reqs_most_model_len, :self. + num_blocks_per_most_len_req] + block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = ( + self.input_batch.block_table[0].get_cpu_tensor() + [:num_reqs, :self.num_blocks_per_most_len_req]) + query_start_loc = self.query_start_loc_cpu[:self. + num_reqs_most_model_len + + 1].to(self.device) + seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to( + self.device) block_tables = block_tables.to(self.device) - query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( - self.device) - seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) if self.lora_config is not None: # We need to respect padding when activating LoRA adapters @@ -672,7 +724,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): layer_name: attn_metadata for layer_name in layer_names } - return per_layer_attn_metadata, logits_indices, padded_num_reqs + return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ + num_reqs, end_index def _scatter_placeholders( self, @@ -847,52 +900,84 @@ class TPUModelRunner(LoRAModelRunnerMixin): else: mm_embeds = [] xm.mark_step() - # Prepare inputs - attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( - scheduler_output) - input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embeds) - xm.mark_step() - num_reqs = self.input_batch.num_reqs - # Run the decoder - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=scheduler_output.total_num_scheduled_tokens): - hidden_states = self.model( - input_ids=input_ids, - positions=self.position_ids, - inputs_embeds=inputs_embeds, - ) - hidden_states = self.select_hidden_states(hidden_states, - logits_indices) - logits = self.compute_logits(hidden_states) - tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, padded_num_reqs, self.device) - if scheduler_output.grammar_bitmask is not None: - require_struct_decoding, grammar_bitmask_padded, arange = \ - self.prepare_structured_decoding_input(logits, scheduler_output) - logits = self.structured_decode(require_struct_decoding, - grammar_bitmask_padded, logits, - arange) - selected_token_ids = self.sample_from_logits_func( - logits, tpu_sampling_metadata) - # NOTE (NickLucche) Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. We can't enforce it due - # to recompilations outside torch.compiled code, so just make sure - # `sample_from_logits` does not modify the logits in-place. - logprobs = self.gather_logprobs(logits, selected_token_ids) \ - if tpu_sampling_metadata.logprobs else None + # Prepare inputs, the requests might be splitted into multiple + # executions, combine the result of each execution. + start_index = 0 + combined_selected_tokens: list[torch.Tensor] = [] + combined_logprobs: list[LogprobsLists] = [] + while start_index < self.input_batch.num_reqs: + attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ + end_index = self._prepare_inputs(scheduler_output, start_index) + input_ids, inputs_embeds = self._get_model_inputs( + self.input_ids, mm_embeds) + xm.mark_step() + # Run the decoder + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens): + hidden_states = self.model( + input_ids=input_ids, + positions=self.position_ids, + inputs_embeds=inputs_embeds, + ) + hidden_states = self.select_hidden_states(hidden_states, + logits_indices) + logits = self.compute_logits(hidden_states) + tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ + from_input_batch(self.input_batch, padded_num_reqs, self.device) + if scheduler_output.grammar_bitmask is not None: + require_struct_decoding, grammar_bitmask_padded, arange = \ + self.prepare_structured_decoding_input(logits, + scheduler_output) + logits = self.structured_decode(require_struct_decoding, + grammar_bitmask_padded, logits, + arange) + selected_token_ids = self.sample_from_logits_func( + logits, tpu_sampling_metadata) + # NOTE (NickLucche) Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. We can't enforce it + # due to recompilations outside torch.compiled code, so just make + # sure `sample_from_logits` does not modify the logits in-place. + logprobs = self.gather_logprobs(logits, selected_token_ids) \ + if tpu_sampling_metadata.logprobs else None - # Remove padding on cpu and keep dynamic op outside of xla graph. - selected_token_ids = selected_token_ids.cpu()[:num_reqs] - logprobs_lists = logprobs.tolists() \ - if tpu_sampling_metadata.logprobs else None + # Remove padding on cpu and keep dynamic op outside of xla graph. + selected_token_ids = selected_token_ids.cpu()[:num_reqs] + + combined_selected_tokens.append(selected_token_ids) + if tpu_sampling_metadata.logprobs: + combined_logprobs.append(logprobs.tolists()) + + start_index = end_index + + selected_token_ids = torch.cat(combined_selected_tokens, dim=0) + if tpu_sampling_metadata.logprobs: + + def concat_lists(input_lists): + result = [] + for input_list in input_lists: + result.extend(input_list) + return result + + logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists( + [lp.logprob_token_ids for lp in combined_logprobs]), + logprobs=concat_lists([ + lp.logprobs + for lp in combined_logprobs + ]), + sampled_token_ranks=concat_lists([ + lp.sampled_token_ranks + for lp in combined_logprobs + ])) + else: + logprobs_lists = None # Update the cache state concurrently. Code above will not block until # we use `selected_token_ids`. Add mark_step if post-processing changes request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] discard_sampled_tokens_req_indices = [] + num_reqs = self.input_batch.num_reqs for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] @@ -1020,7 +1105,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.sampler = TPUSampler() @torch.no_grad() - def _dummy_run(self, num_tokens: int) -> None: + def _dummy_run(self, num_tokens: int, num_reqs: int, + num_blocks: int) -> None: if self.is_multimodal_model: input_ids = None inputs_embeds = torch.zeros((num_tokens, self.hidden_size), @@ -1030,20 +1116,19 @@ class TPUModelRunner(LoRAModelRunnerMixin): input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device) inputs_embeds = None - actual_num_reqs = min(num_tokens, self.max_num_reqs) + actual_num_reqs = min(num_tokens, num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) slot_mapping = torch.zeros(num_tokens, dtype=torch.int64).to(self.device) - block_tables = torch.zeros( - (self.max_num_reqs, self.block_table_cpu.shape[1]), - dtype=torch.int32).to(self.device) - query_lens = [1] * self.max_num_reqs + block_tables = torch.zeros((num_reqs, num_blocks), + dtype=torch.int32).to(self.device) + query_lens = [1] * num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device) - context_lens = torch.ones((self.max_num_reqs, ), + context_lens = torch.ones((num_reqs, ), dtype=torch.int32).to(self.device) num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device) @@ -1061,6 +1146,9 @@ class TPUModelRunner(LoRAModelRunnerMixin): torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1)) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() @@ -1152,7 +1240,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, self.num_reqs_max_model_len, + self.max_num_blocks_per_req) + if self.most_model_len is not None: + self._dummy_run(num_tokens, self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req) xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in %.2f [secs].", end - start) @@ -1341,7 +1433,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, self.num_reqs_max_model_len, + self.max_num_blocks_per_req) + if self.most_model_len is not None: + self._dummy_run(num_tokens, self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req) xm.mark_step() xm.wait_device_ops() From 296ce95d8e72f4c6680bda539058f48dbe0f340a Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 26 Jun 2025 08:23:56 +0900 Subject: [PATCH 006/175] [CI] Add SM120 to the Dockerfile (#19794) Signed-off-by: mgoin --- docker/Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index cf9c245a95174..8d4375470adf9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -77,7 +77,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # can be useful for both `dev` and `test` # explicitly set the list to avoid issues with torch 2.2 # see https://github.com/pytorch/pytorch/pull/123243 -ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX' +ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} # Override the arch list for flash-attn to reduce the binary size ARG vllm_fa_cmake_gpu_arches='80-real;90-real' @@ -244,7 +244,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # If we need to build FlashInfer wheel before its release: # $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+ -# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a' +# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' # $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive # $ cd flashinfer # $ git checkout v0.2.6.post1 @@ -261,7 +261,7 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ if [[ "$CUDA_VERSION" == 12.8* ]]; then \ uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl; \ else \ - export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a' && \ + export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \ git clone https://github.com/flashinfer-ai/flashinfer.git --single-branch --branch v0.2.6.post1 --recursive && \ # Needed to build AOT kernels (cd flashinfer && \ From 754b00edb3fd2642da08c40363a07f1d60a54977 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 26 Jun 2025 10:01:17 +0900 Subject: [PATCH 007/175] [Bugfix] Fix Mistral tool-parser regex for nested JSON (#20093) Signed-off-by: mgoin --- .../language/generation/test_mistral.py | 51 +++++++++++++++++++ .../tool_parsers/mistral_tool_parser.py | 4 +- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index bdd857ff50620..c70698ede37a5 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -10,6 +10,7 @@ import pytest from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall, MistralToolParser) from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.transformers_utils.tokenizer import MistralTokenizer from ...utils import check_logprobs_close @@ -318,3 +319,53 @@ def test_mistral_guided_decoding( schema=SAMPLE_JSON_SCHEMA) except jsonschema.exceptions.ValidationError: pytest.fail("Generated response is not valid with JSON schema") + + +def test_mistral_function_call_nested_json(): + """Ensure that the function-name regex captures the entire outer-most + JSON block, including nested braces.""" + + # Create a minimal stub tokenizer that provides the few attributes the + # parser accesses (`version` and `get_vocab`). + class _StubMistralTokenizer(MistralTokenizer): + version = 11 # Satisfy the version check + + def __init__(self): + pass + + @staticmethod + def get_vocab(): + # Provide the special TOOL_CALLS token expected by the parser. + return {"[TOOL_CALLS]": 0} + + tokenizer = _StubMistralTokenizer() + parser = MistralToolParser(tokenizer) + + # Craft a model output featuring nested JSON inside the arguments. + args_dict = { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + "sub_dict": { + "foo": "bar", + "inner": { + "x": 1, + "y": 2 + } + }, + } + + model_output = ( + f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}") + + parsed = parser.extract_tool_calls(model_output, None) + + # Assertions: the tool call is detected and the full nested JSON is parsed + # without truncation. + assert parsed.tools_called + + assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id) + assert parsed.tool_calls[0].function.name == "get_current_weather" + assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict + # No additional content outside the tool call should be returned. + assert parsed.content is None diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index ab1cfd4b6eabe..c0691f122904e 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -77,8 +77,8 @@ class MistralToolParser(ToolParser): self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) if _is_fn_name_regex_support(self.model_tokenizer): - self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})', - re.DOTALL) + self.fn_name_regex = re.compile( + r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) else: self.fn_name_regex = None From 2582683566ed676a811f4311f1048f0b323676b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Thu, 26 Jun 2025 05:04:39 +0200 Subject: [PATCH 008/175] [PD] Skip `tp_size` exchange with rank0 (#19413) Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 29 ++++- .../kv_connector/v1/nixl_connector.py | 109 ++++++++---------- 2 files changed, 72 insertions(+), 66 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index ab9729aae2e9f..e30a250449aaa 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -7,6 +7,8 @@ from collections import defaultdict from typing import Optional from unittest.mock import patch +import pytest + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) @@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: + def _nixl_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by @@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - tp_size=1, block_len=self.block_len, attn_backend_name=self.backend_name, - )) + ), + remote_tp_size=remote_tp_size) return {0: remote_agent_name} @@ -233,6 +236,8 @@ class TestNixlHandshake: "localhost", "remote_port": 1234, + "remote_tp_size": + 1, }) connector.bind_connector_metadata(metadata) @@ -259,13 +264,23 @@ class TestNixlHandshake: @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper) + @pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [ + (1, 1), + (2, 1), + (4, 2), + (4, 4), + ]) def test_async_load_kv( - self, - # dist_init is a fixture that initializes the distributed environment. - dist_init): + self, + # Fixture that initializes the distributed environment. + dist_init, + # Simulate consumer-producer TP sizes. + decode_tp_size, + prefill_tp_size): """Test that NixlConnector's start_load_kv should be non-blocking.""" vllm_config = create_vllm_config() + vllm_config.parallel_config.tensor_parallel_size = decode_tp_size # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) @@ -280,6 +295,7 @@ class TestNixlHandshake: FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_host": "localhost", "remote_port": 1234, + "remote_tp_size": prefill_tp_size, }) connector.bind_connector_metadata(metadata) @@ -329,6 +345,7 @@ class TestNixlHandshake: FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_host": "localhost", "remote_port": 1234, + "remote_tp_size": 1, }) connector.bind_connector_metadata(metadata) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 92a9184d318c7..7a077dce7706c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -62,7 +62,6 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int - tp_size: int block_len: int attn_backend_name: str @@ -73,7 +72,8 @@ class ReqMeta: remote_block_ids: list[int] remote_host: str remote_port: int - remote_engine_id: EngineId + remote_engine_id: str + tp_size: int class NixlConnectorMetadata(KVConnectorMetadata): @@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata): remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], + # P workers don't need to receive tp_size from proxy here. + tp_size=kv_transfer_params.get("tp_size", 1), ) @@ -330,7 +332,7 @@ class NixlConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, - ) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size) class NixlConnectorWorker: @@ -473,7 +475,8 @@ class NixlConnectorWorker: "Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: + def _nixl_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() @@ -482,7 +485,7 @@ class NixlConnectorWorker: # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: + def handshake(path: str, rank: int) -> str: # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: sock.send(GET_META_MSG) @@ -492,33 +495,25 @@ class NixlConnectorWorker: got_metadata_time = time.perf_counter() # Register Remote agent. - remote_agent_name = self.add_remote_agent(metadata, rank) + remote_agent_name = self.add_remote_agent( + metadata, rank, remote_tp_size) setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time) - return metadata, remote_agent_name + return remote_agent_name - # Handshake with remote agent-rank0 first to get the tp_size of remote - path = make_zmq_path("tcp", host, port) - logger.debug("Querying master rank metadata on path: %s", path) - rank_to_agent_name: dict[int, str] = {} - metadata, rank_to_agent_name[0] = handshake(path, 0) - - # Handshake only with the other TP remote the current local rank will + # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio - if p_remote_rank > 0: - path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug("Querying metadata on path: %s at remote rank %s", - path, p_remote_rank) - _, rank_to_agent_name[p_remote_rank] = handshake( - path, p_remote_rank) - - return rank_to_agent_name + path = make_zmq_path("tcp", host, port + p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", path, + p_remote_rank) + # Remote rank -> agent name. + return {p_remote_rank: handshake(path, p_remote_rank)} def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -645,7 +640,6 @@ class NixlConnectorWorker: agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - tp_size=self.world_size, block_len=self.block_len, attn_backend_name=self.backend_name) ready_event = threading.Event() @@ -659,7 +653,8 @@ class NixlConnectorWorker: def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, - remote_tp_rank: int = 0) -> str: + remote_tp_rank: int = 0, + remote_tp_size: int = 1) -> str: """ Add the remote NIXL agent and prepare the descriptors for reading cache blocks from remote. @@ -704,9 +699,9 @@ class NixlConnectorWorker: return self._remote_agents[engine_id][remote_tp_rank] if engine_id in self._tp_size: - assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + assert self._tp_size[engine_id] == remote_tp_size else: - self._tp_size[engine_id] = nixl_agent_meta.tp_size + self._tp_size[engine_id] = remote_tp_size # We may eventually enable this after asserting equality in cache # layout and close outputs. assert nixl_agent_meta.attn_backend_name == self.backend_name @@ -756,33 +751,31 @@ class NixlConnectorWorker: # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - p_remote_tp_rank = self.tp_rank // tp_ratio # Only register the remote's descriptors if current rank pulls from it. - if p_remote_tp_rank == remote_tp_rank: - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ - if not (self.use_mla or is_kv_replicated) else 0 - # Register all remote blocks, but only the corresponding kv heads. - for base_addr in nixl_agent_meta.kv_caches_base_addr: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset - # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) - logger.debug( - "Created %s blocks for dst engine %s with remote rank %s and " - "local rank %s", len(blocks_data), engine_id, remote_tp_rank, - self.tp_rank) + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.tp_rank % tp_ratio * self.block_len \ + if not (self.use_mla or is_kv_replicated) else 0 + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, remote_tp_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " + "local rank %s", len(blocks_data), engine_id, remote_tp_rank, + self.tp_rank) - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - remote_agent_name, descs) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + remote_agent_name, descs) return remote_agent_name @@ -917,7 +910,7 @@ class NixlConnectorWorker: if fut is None: fut = self._handshake_initiation_executor.submit( self._nixl_handshake, meta.remote_host, - meta.remote_port) + meta.remote_port, meta.tp_size) self._handshake_futures[remote_engine_id] = fut def done_callback(f: Future[dict[int, str]], @@ -957,13 +950,9 @@ class NixlConnectorWorker: remote_block_ids=meta.remote_block_ids, ) - def _read_blocks( - self, - local_block_ids: list[int], - remote_block_ids: list[int], - dst_engine_id: str, - request_id: str, - ): + def _read_blocks(self, local_block_ids: list[int], + remote_block_ids: list[int], dst_engine_id: str, + request_id: str): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the From 9502c38138a03669c4d54225336553db70ad799d Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 26 Jun 2025 01:06:27 -0400 Subject: [PATCH 009/175] [Benchmark][Bug] Fix multiple bugs in bench and add args to spec_decode offline (#20083) --- benchmarks/benchmark_dataset.py | 3 ++- examples/offline_inference/spec_decode.py | 20 +++++++++++++------- vllm/benchmarks/datasets.py | 12 ++++++++---- vllm/benchmarks/serve.py | 6 ++++++ 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 8671719bce72f..55c0cf851264f 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -349,8 +349,9 @@ class RandomDataset(BenchmarkDataset): # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] # To avoid uncontrolled change of the prompt length, # the encoded sequence is truncated before being decode again. + total_input_len = prefix_len + int(input_lens[i]) re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ - : input_lens[i] + :total_input_len ] prompt = tokenizer.decode(re_encoded_sequence) total_input_len = len(re_encoded_sequence) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index eece8beced510..6fa68d2ecee1d 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -39,6 +39,9 @@ def parse_args(): parser.add_argument("--top-k", type=int, default=-1) parser.add_argument("--print-output", action="store_true") parser.add_argument("--output-len", type=int, default=256) + parser.add_argument("--model-dir", type=str, default=None) + parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--max-model-len", type=int, default=2048) return parser.parse_args() @@ -46,9 +49,10 @@ def main(): args = parse_args() args.endpoint_type = "openai-chat" - model_dir = "meta-llama/Llama-3.1-8B-Instruct" + model_dir = args.model_dir + if args.model_dir is None: + model_dir = "meta-llama/Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_dir) - max_model_len = 2048 prompts = get_samples(args, tokenizer) # add_special_tokens is False to avoid adding bos twice when using chat templates @@ -57,16 +61,18 @@ def main(): ] if args.method == "eagle" or args.method == "eagle3": - if args.method == "eagle": + eagle_dir = args.eagle_dir + if args.method == "eagle" and eagle_dir is None: eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - elif args.method == "eagle3": + + elif args.method == "eagle3" and eagle_dir is None: eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" speculative_config = { "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": max_model_len, + "max_model_len": args.max_model_len, } elif args.method == "ngram": speculative_config = { @@ -74,7 +80,7 @@ def main(): "num_speculative_tokens": args.num_spec_tokens, "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, - "max_model_len": max_model_len, + "max_model_len": args.max_model_len, } else: raise ValueError(f"unknown method: {args.method}") @@ -86,7 +92,7 @@ def main(): enable_chunked_prefill=args.enable_chunked_prefill, max_num_batched_tokens=args.max_num_batched_tokens, enforce_eager=args.enforce_eager, - max_model_len=max_model_len, + max_model_len=args.max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config=speculative_config, diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 3efbe5695711f..b3688d2340e44 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -320,6 +320,8 @@ class RandomDataset(BenchmarkDataset): **kwargs, ) -> None: super().__init__(**kwargs) + random.seed(self.random_seed) + np.random.seed(self.random_seed) def sample( self, @@ -376,10 +378,11 @@ class RandomDataset(BenchmarkDataset): # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] # To avoid uncontrolled change of the prompt length, # the encoded sequence is truncated before being decode again. - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:input_lens[i]] - prompt = tokenizer.decode(re_encoded_sequence) total_input_len = prefix_len + int(input_lens[i]) + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) requests.append( SampleRequest( prompt=prompt, @@ -692,7 +695,8 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: dataset_path=args.dataset_path). sample(tokenizer=tokenizer, num_requests=args.num_prompts), "random": - lambda: RandomDataset(dataset_path=args.dataset_path).sample( + lambda: RandomDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 302f655f424a3..419284cca042e 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -631,6 +631,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The label (prefix) of the benchmark results. If not specified, " "the endpoint type will be used as the label.", ) + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) parser.add_argument( "--base-url", type=str, From 65397e40f58ff5657d9e8bbd860ed9d3fdf734a0 Mon Sep 17 00:00:00 2001 From: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Date: Thu, 26 Jun 2025 00:01:57 -0700 Subject: [PATCH 010/175] [Bugfix] Allow `CUDA_VISIBLE_DEVICES=''` in `Platform.device_id_to_physical_device_id` (#18979) Signed-off-by: Seiji Eicher --- tests/config/test_config_generation.py | 38 ++++++++++++ tests/v1/engine/test_engine_core_client.py | 71 ++++++++++++++++++++++ vllm/platforms/interface.py | 15 ++--- 3 files changed, 114 insertions(+), 10 deletions(-) create mode 100644 tests/config/test_config_generation.py diff --git a/tests/config/test_config_generation.py b/tests/config/test_config_generation.py new file mode 100644 index 0000000000000..024e81fccc5f1 --- /dev/null +++ b/tests/config/test_config_generation.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.quantization.quark.utils import deep_compare + + +def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch): + """Test that configs created with normal (untouched) CUDA_VISIBLE_DEVICES + and CUDA_VISIBLE_DEVICES="" are equivalent. This ensures consistent + behavior regardless of whether GPU visibility is disabled via empty string + or left in its normal state. + """ + + def create_config(): + engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", + trust_remote_code=True) + return engine_args.create_engine_config() + + # Create config with CUDA_VISIBLE_DEVICES set normally + normal_config = create_config() + + # Create config with CUDA_VISIBLE_DEVICES="" + with monkeypatch.context() as m: + m.setenv("CUDA_VISIBLE_DEVICES", "") + empty_config = create_config() + + normal_config_dict = vars(normal_config) + empty_config_dict = vars(empty_config) + + # Remove instance_id before comparison as it's expected to be different + normal_config_dict.pop("instance_id", None) + empty_config_dict.pop("instance_id", None) + + assert deep_compare(normal_config_dict, empty_config_dict), ( + "Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\"" + " should be equivalent") diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 16c36cd5c6b9b..d5ff78c1449a6 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -8,8 +8,10 @@ import time import uuid from threading import Thread from typing import Optional +from unittest.mock import MagicMock import pytest +import torch from transformers import AutoTokenizer from tests.utils import multi_gpu_test @@ -517,3 +519,72 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): ) assert "Engine core initialization failed" in str(e_info.value) + + +@create_new_process_for_each_test() +def test_engine_core_proc_instantiation_cuda_empty( + monkeypatch: pytest.MonkeyPatch): + """ + Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES + is empty. This ensures the engine frontend does not need access to GPUs. + """ + + from vllm.v1.engine.core import EngineCoreProc + from vllm.v1.executor.abstract import Executor + + # Create a simple mock executor instead of a complex custom class + mock_executor_class = MagicMock(spec=Executor) + + def create_mock_executor(vllm_config): + mock_executor = MagicMock() + + # Only implement the methods that are actually called during init + from vllm.v1.kv_cache_interface import FullAttentionSpec + mock_spec = FullAttentionSpec(block_size=16, + num_kv_heads=1, + head_size=64, + dtype=torch.float16, + use_mla=False) + + mock_executor.get_kv_cache_specs.return_value = [{ + "default": mock_spec + }] + mock_executor.determine_available_memory.return_value = [ + 1024 * 1024 * 1024 + ] + mock_executor.initialize_from_config.return_value = None + mock_executor.max_concurrent_batches = 1 + + return mock_executor + + mock_executor_class.side_effect = create_mock_executor + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices + + from vllm.v1.utils import EngineZmqAddresses + + def mock_startup_handshake(self, handshake_socket, on_head_node, + parallel_config): + return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"], + outputs=["tcp://127.0.0.1:5556"], + coordinator_input=None, + coordinator_output=None) + + # Background processes are not important here + m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake) + + vllm_config = EngineArgs( + model="deepseek-ai/DeepSeek-V2-Lite", + trust_remote_code=True).create_engine_config() + engine_core_proc = EngineCoreProc( + vllm_config=vllm_config, + on_head_node=True, + handshake_address="tcp://127.0.0.1:12345", + executor_class=mock_executor_class, + log_stats=False, + engine_index=0, + ) + + engine_core_proc.shutdown() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f962fafabf502..0f08bf986333b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -173,17 +173,12 @@ class Platform: @classmethod def device_id_to_physical_device_id(cls, device_id: int): - if cls.device_control_env_var in os.environ: + # Treat empty device control env var as unset. This is a valid + # configuration in Ray setups where the engine is launched in + # a CPU-only placement group located on a GPU node. + if cls.device_control_env_var in os.environ and os.environ[ + cls.device_control_env_var] != "": device_ids = os.environ[cls.device_control_env_var].split(",") - if device_ids == [""]: - msg = (f"{cls.device_control_env_var} is set to empty string, " - "which means current platform support is disabled. If " - "you are using ray, please unset the environment " - f"variable `{cls.device_control_env_var}` inside the " - "worker/actor. Check " - "https://github.com/vllm-project/vllm/issues/8402 for " - "more information.") - raise RuntimeError(msg) physical_device_id = device_ids[device_id] return int(physical_device_id) else: From 1d7c29f5fecab930fbb28bf59f1bc4510abe335b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 26 Jun 2025 15:47:06 +0800 Subject: [PATCH 011/175] [Doc] Update docs for New Model Implementation (#20115) Signed-off-by: DarkLight1337 --- docs/.nav.yml | 7 ++++++- docs/contributing/model/README.md | 24 +++++++++++++----------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/docs/.nav.yml b/docs/.nav.yml index a9c594c291777..e679807f75346 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -48,7 +48,12 @@ nav: - General: - glob: contributing/* flatten_single_child_sections: true - - Model Implementation: contributing/model + - Model Implementation: + - contributing/model/README.md + - contributing/model/basic.md + - contributing/model/registration.md + - contributing/model/tests.md + - contributing/model/multimodal.md - Design Documents: - V0: design - V1: design/v1 diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index b7727f02c11bf..82541924bc028 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -1,21 +1,23 @@ --- -title: Adding a New Model +title: Summary --- [](){ #new-model } -This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM. +!!! important + Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve ` works first! -Contents: +vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features][compatibility-matrix] to optimize their performance. -- [Basic](basic.md) -- [Registration](registration.md) -- [Tests](tests.md) -- [Multimodal](multimodal.md) +The complexity of integrating a model into vLLM depends heavily on the model's architecture. +The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. +However, this can be more complex for models that include new operators (e.g., a new attention mechanism). -!!! note - The complexity of adding a new model depends heavily on the model's architecture. - The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. - However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex. +Read through these pages for a step-by-step guide: + +- [Implementing a Basic Model](basic.md) +- [Registering a Model to vLLM](registration.md) +- [Writing Unit Tests](tests.md) +- [Multi-Modal Support](multimodal.md) !!! tip If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues) From d188913d99bbdfc699bc4f7c2c23187f3745f94b Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 26 Jun 2025 05:16:10 -0400 Subject: [PATCH 012/175] [Refactor] Remove unused library (#20099) Signed-off-by: yewentao256 --- vllm/_custom_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8ebe694eefd0e..d5a41284385e6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -5,7 +5,6 @@ import contextlib from typing import TYPE_CHECKING, Optional, Union import torch -import torch.library import vllm.envs as envs from vllm.logger import init_logger From 0567c8249fdbff59a05f000cb326aed7cf5c8567 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 26 Jun 2025 18:34:47 +0800 Subject: [PATCH 013/175] [CPU] Fix torch version in x86 CPU backend (#19258) Signed-off-by: jiang1.li --- csrc/cpu/torch_bindings.cpp | 13 +++++--- docker/Dockerfile.cpu | 33 +++++++++++-------- requirements/cpu-build.txt | 12 +++++++ requirements/cpu.txt | 5 +-- .../multimodal/generation/test_common.py | 2 ++ .../generation/vlm_utils/builders.py | 3 ++ vllm/model_executor/layers/fused_moe/layer.py | 2 ++ .../layers/quantization/ipex_quant.py | 2 +- 8 files changed, 52 insertions(+), 20 deletions(-) create mode 100644 requirements/cpu-build.txt diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 447e826bc1c09..60304d229a8f5 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -131,16 +131,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization #ifdef __AVX512F__ + at::Tag stride_tag = at::Tag::needs_fixed_stride_order; // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); + "Tensor? azp) -> ()", + {stride_tag}); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); + "Tensor!? azp) -> ()", + {stride_tag}); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); // W8A8 GEMM, supporting symmetric per-tensor or per-row/column @@ -148,7 +151,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cutlass_scaled_mm(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); + " Tensor b_scales, Tensor? bias) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. @@ -156,7 +160,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_scaled_mm_azp(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); + " Tensor? azp, Tensor? bias) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #elif defined(__powerpc64__) // Compute int8 quantized tensor for given scaling factor. diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 3e9fa0e7af2dc..13bd03c5696ab 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -66,7 +66,7 @@ ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} WORKDIR /workspace/vllm RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \ + --mount=type=bind,src=requirements/cpu-build.txt,target=requirements/build.txt \ uv pip install -r requirements/build.txt COPY . . @@ -79,6 +79,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel +######################### TEST DEPS ######################### +FROM base AS vllm-test-deps + +WORKDIR /workspace/vllm + +RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ + cp requirements/test.in requirements/cpu-test.in && \ + sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ + sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \ + sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ + sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ + uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install -r requirements/cpu-test.txt + ######################### DEV IMAGE ######################### FROM vllm-build AS vllm-dev @@ -97,28 +113,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py develop +COPY --from=vllm-test-deps /workspace/vllm/requirements/cpu-test.txt requirements/test.txt + RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ - cp requirements/test.in requirements/test-cpu.in && \ - sed -i '/mamba_ssm/d' requirements/test-cpu.in && \ - uv pip compile requirements/test-cpu.in -o requirements/test.txt && \ uv pip install -r requirements/dev.txt && \ pre-commit install --hook-type pre-commit --hook-type commit-msg ENTRYPOINT ["bash"] ######################### TEST IMAGE ######################### -FROM base AS vllm-test +FROM vllm-test-deps AS vllm-test WORKDIR /workspace/ -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ - cp requirements/test.in requirements/test-cpu.in && \ - sed -i '/mamba_ssm/d' requirements/test-cpu.in && \ - uv pip compile requirements/test-cpu.in -o requirements/cpu-test.txt && \ - uv pip install -r requirements/cpu-test.txt - RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \ uv pip install dist/*.whl diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt new file mode 100644 index 0000000000000..37f072202bd71 --- /dev/null +++ b/requirements/cpu-build.txt @@ -0,0 +1,12 @@ +# Temporarily used for x86 CPU backend to avoid performance regression of torch>2.6.0+cpu, +# see https://github.com/pytorch/pytorch/pull/151218 +cmake>=3.26.1 +ninja +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 +setuptools-scm>=8 +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0+cpu +wheel +jinja2>=3.1.6 +regex diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 8742898cff00f..df3a3393563a0 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9' packaging>=24.2 setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu -torch==2.7.0+cpu; platform_machine == "x86_64" +torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 torch==2.7.0; platform_system == "Darwin" torch==2.7.0; platform_machine == "ppc64le" or platform_machine == "aarch64" @@ -23,6 +23,7 @@ datasets # for benchmark scripts # Intel Extension for PyTorch, only for x86_64 CPUs intel-openmp==2024.2.1; platform_machine == "x86_64" -intel_extension_for_pytorch==2.7.0; platform_machine == "x86_64" +intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 py-libnuma; platform_system != "Darwin" psutil; platform_system != "Darwin" +triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile. diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 496850b19af4f..9d63339737ce6 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -107,6 +107,8 @@ VLM_TEST_SETTINGS = { ), limit_mm_per_prompt={"image": 4}, )], + # TODO: Revert to "auto" when CPU backend can use torch > 2.6 + dtype="bfloat16" if current_platform.is_cpu() else "auto", marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "paligemma": VLMTestInfo( diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 7d20dd66089bb..03c08240d6a81 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -203,6 +203,9 @@ def build_embedding_inputs_from_test_info( images = [asset.pil_image for asset in image_assets] embeds = test_info.convert_assets_to_embeddings(image_assets) + if test_info.dtype != "auto": + dtype = getattr(torch, test_info.dtype) # type: ignore + embeds = [e.to(dtype=dtype) for e in embeds] assert len(images) == len(model_prompts) inputs = build_single_image_inputs(images, model_prompts, size_wrapper) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c1bae033c2b4b..133881fd04990 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -54,6 +54,8 @@ else: if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) +elif current_platform.is_cpu(): + pass else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 31ad96eccaf3e..428e9b882bca7 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.platforms import current_platform -MIN_IPEX_VERSION = "2.7.0" +MIN_IPEX_VERSION = "2.6.0" class IPEXConfig(QuantizationConfig): From 167aca45cbbfd8c56d700dfc9a6a5a3482a5bd74 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Thu, 26 Jun 2025 18:35:16 +0800 Subject: [PATCH 014/175] [Misc] Use collapsible blocks for benchmark examples. (#20017) Signed-off-by: reidliu41 Co-authored-by: reidliu41 --- benchmarks/README.md | 94 ++++++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 2714b8b49821c..fb8690d42db98 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -4,7 +4,7 @@ This README guides you through running benchmark tests with the extensive datasets supported on vLLM. It’s a living document, updated as new features and datasets become available. -## Dataset Overview +**Dataset Overview** @@ -82,7 +82,10 @@ become available. **Note**: HuggingFace dataset's `dataset-name` should be set to `hf` --- -## Example - Online Benchmark +
+🚀 Example - Online Benchmark + +
First start serving your model @@ -130,7 +133,8 @@ P99 ITL (ms): 8.39 ================================================== ``` -### Custom Dataset +**Custom Dataset** + If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl ``` @@ -162,7 +166,7 @@ python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detaile You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. -### VisionArena Benchmark for Vision Language Models +**VisionArena Benchmark for Vision Language Models** ```bash # need a model with vision capability here @@ -180,7 +184,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 1000 ``` -### InstructCoder Benchmark with Speculative Decoding +**InstructCoder Benchmark with Speculative Decoding** ``` bash VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ @@ -197,7 +201,7 @@ python3 benchmarks/benchmark_serving.py \ --num-prompts 2048 ``` -### Other HuggingFaceDataset Examples +**Other HuggingFaceDataset Examples** ```bash vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests @@ -251,7 +255,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 80 ``` -### Running With Sampling Parameters +**Running With Sampling Parameters** When using OpenAI-compatible backends such as `vllm`, optional sampling parameters can be specified. Example client command: @@ -269,7 +273,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 10 ``` -### Running With Ramp-Up Request Rate +**Running With Ramp-Up Request Rate** The benchmark tool also supports ramping up the request rate over the duration of the benchmark run. This can be useful for stress testing the @@ -284,8 +288,12 @@ The following arguments can be used to control the ramp-up: - `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. - `--ramp-up-end-rps`: The request rate at the end of the benchmark. ---- -## Example - Offline Throughput Benchmark +
+ +
+📈 Example - Offline Throughput Benchmark + +
```bash python3 vllm/benchmarks/benchmark_throughput.py \ @@ -303,7 +311,7 @@ Total num prompt tokens: 5014 Total num output tokens: 1500 ``` -### VisionArena Benchmark for Vision Language Models +**VisionArena Benchmark for Vision Language Models** ``` bash python3 vllm/benchmarks/benchmark_throughput.py \ @@ -323,7 +331,7 @@ Total num prompt tokens: 14527 Total num output tokens: 1280 ``` -### InstructCoder Benchmark with Speculative Decoding +**InstructCoder Benchmark with Speculative Decoding** ``` bash VLLM_WORKER_MULTIPROC_METHOD=spawn \ @@ -347,7 +355,7 @@ Total num prompt tokens: 261136 Total num output tokens: 204800 ``` -### Other HuggingFaceDataset Examples +**Other HuggingFaceDataset Examples** **`lmms-lab/LLaVA-OneVision-Data`** @@ -386,7 +394,7 @@ python3 benchmarks/benchmark_throughput.py \ --num-prompts 10 ``` -### Benchmark with LoRA Adapters +**Benchmark with LoRA Adapters** ``` bash # download dataset @@ -403,18 +411,22 @@ python3 vllm/benchmarks/benchmark_throughput.py \ --lora-path yard1/llama-2-7b-sql-lora-test ``` ---- -## Example - Structured Output Benchmark +
+ +
+🛠️ Example - Structured Output Benchmark + +
Benchmark the performance of structured output generation (JSON, grammar, regex). -### Server Setup +**Server Setup** ```bash vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests ``` -### JSON Schema Benchmark +**JSON Schema Benchmark** ```bash python3 benchmarks/benchmark_serving_structured_output.py \ @@ -426,7 +438,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \ --num-prompts 1000 ``` -### Grammar-based Generation Benchmark +**Grammar-based Generation Benchmark** ```bash python3 benchmarks/benchmark_serving_structured_output.py \ @@ -438,7 +450,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \ --num-prompts 1000 ``` -### Regex-based Generation Benchmark +**Regex-based Generation Benchmark** ```bash python3 benchmarks/benchmark_serving_structured_output.py \ @@ -449,7 +461,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \ --num-prompts 1000 ``` -### Choice-based Generation Benchmark +**Choice-based Generation Benchmark** ```bash python3 benchmarks/benchmark_serving_structured_output.py \ @@ -460,7 +472,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \ --num-prompts 1000 ``` -### XGrammar Benchmark Dataset +**XGrammar Benchmark Dataset** ```bash python3 benchmarks/benchmark_serving_structured_output.py \ @@ -471,12 +483,16 @@ python3 benchmarks/benchmark_serving_structured_output.py \ --num-prompts 1000 ``` ---- -## Example - Long Document QA Throughput Benchmark +
+ +
+📚 Example - Long Document QA Benchmark + +
Benchmark the performance of long document question-answering with prefix caching. -### Basic Long Document QA Test +**Basic Long Document QA Test** ```bash python3 benchmarks/benchmark_long_document_qa_throughput.py \ @@ -488,7 +504,7 @@ python3 benchmarks/benchmark_long_document_qa_throughput.py \ --repeat-count 5 ``` -### Different Repeat Modes +**Different Repeat Modes** ```bash # Random mode (default) - shuffle prompts randomly @@ -519,12 +535,16 @@ python3 benchmarks/benchmark_long_document_qa_throughput.py \ --repeat-mode interleave ``` ---- -## Example - Prefix Caching Benchmark +
+ +
+🗂️ Example - Prefix Caching Benchmark + +
Benchmark the efficiency of automatic prefix caching. -### Fixed Prompt with Prefix Caching +**Fixed Prompt with Prefix Caching** ```bash python3 benchmarks/benchmark_prefix_caching.py \ @@ -535,7 +555,7 @@ python3 benchmarks/benchmark_prefix_caching.py \ --input-length-range 128:256 ``` -### ShareGPT Dataset with Prefix Caching +**ShareGPT Dataset with Prefix Caching** ```bash # download dataset @@ -550,12 +570,16 @@ python3 benchmarks/benchmark_prefix_caching.py \ --input-length-range 128:256 ``` ---- -## Example - Request Prioritization Benchmark +
+ +
+⚡ Example - Request Prioritization Benchmark + +
Benchmark the performance of request prioritization in vLLM. -### Basic Prioritization Test +**Basic Prioritization Test** ```bash python3 benchmarks/benchmark_prioritization.py \ @@ -566,7 +590,7 @@ python3 benchmarks/benchmark_prioritization.py \ --scheduling-policy priority ``` -### Multiple Sequences per Prompt +**Multiple Sequences per Prompt** ```bash python3 benchmarks/benchmark_prioritization.py \ @@ -577,3 +601,5 @@ python3 benchmarks/benchmark_prioritization.py \ --scheduling-policy priority \ --n 2 ``` + +
From 84c260caeb88d25840ec0653c0b978a46eae6a84 Mon Sep 17 00:00:00 2001 From: Michael Yao Date: Thu, 26 Jun 2025 18:41:51 +0800 Subject: [PATCH 015/175] [Docs] Improve frameworks/helm.md (#20113) Signed-off-by: windsonsea --- docs/deployment/frameworks/helm.md | 120 +++++++++++++++-------------- 1 file changed, 64 insertions(+), 56 deletions(-) diff --git a/docs/deployment/frameworks/helm.md b/docs/deployment/frameworks/helm.md index cff8af2c09d29..d929665e8a3df 100644 --- a/docs/deployment/frameworks/helm.md +++ b/docs/deployment/frameworks/helm.md @@ -5,9 +5,9 @@ title: Helm A Helm chart to deploy vLLM for Kubernetes -Helm is a package manager for Kubernetes. It will help you to deploy vLLM on k8s and automate the deployment of vLLM Kubernetes applications. With Helm, you can deploy the same framework architecture with different configurations to multiple namespaces by overriding variable values. +Helm is a package manager for Kubernetes. It helps automate the deployment of vLLM applications on Kubernetes. With Helm, you can deploy the same framework architecture with different configurations to multiple namespaces by overriding variable values. -This guide will walk you through the process of deploying vLLM with Helm, including the necessary prerequisites, steps for helm installation and documentation on architecture and values file. +This guide will walk you through the process of deploying vLLM with Helm, including the necessary prerequisites, steps for Helm installation and documentation on architecture and values file. ## Prerequisites @@ -16,17 +16,23 @@ Before you begin, ensure that you have the following: - A running Kubernetes cluster - NVIDIA Kubernetes Device Plugin (`k8s-device-plugin`): This can be found at [https://github.com/NVIDIA/k8s-device-plugin](https://github.com/NVIDIA/k8s-device-plugin) - Available GPU resources in your cluster -- S3 with the model which will be deployed +- An S3 with the model which will be deployed ## Installing the chart To install the chart with the release name `test-vllm`: ```bash -helm upgrade --install --create-namespace --namespace=ns-vllm test-vllm . -f values.yaml --set secrets.s3endpoint=$ACCESS_POINT --set secrets.s3bucketname=$BUCKET --set secrets.s3accesskeyid=$ACCESS_KEY --set secrets.s3accesskey=$SECRET_KEY +helm upgrade --install --create-namespace \ + --namespace=ns-vllm test-vllm . \ + -f values.yaml \ + --set secrets.s3endpoint=$ACCESS_POINT \ + --set secrets.s3bucketname=$BUCKET \ + --set secrets.s3accesskeyid=$ACCESS_KEY \ + --set secrets.s3accesskey=$SECRET_KEY ``` -## Uninstalling the Chart +## Uninstalling the chart To uninstall the `test-vllm` deployment: @@ -39,57 +45,59 @@ chart **including persistent volumes** and deletes the release. ## Architecture -![](../../assets/deployment/architecture_helm_deployment.png) +![helm deployment architecture](../../assets/deployment/architecture_helm_deployment.png) ## Values -| Key | Type | Default | Description | -|--------------------------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------| -| autoscaling | object | {"enabled":false,"maxReplicas":100,"minReplicas":1,"targetCPUUtilizationPercentage":80} | Autoscaling configuration | -| autoscaling.enabled | bool | false | Enable autoscaling | -| autoscaling.maxReplicas | int | 100 | Maximum replicas | -| autoscaling.minReplicas | int | 1 | Minimum replicas | -| autoscaling.targetCPUUtilizationPercentage | int | 80 | Target CPU utilization for autoscaling | -| configs | object | {} | Configmap | -| containerPort | int | 8000 | Container port | -| customObjects | list | [] | Custom Objects configuration | -| deploymentStrategy | object | {} | Deployment strategy configuration | -| externalConfigs | list | [] | External configuration | -| extraContainers | list | [] | Additional containers configuration | -| extraInit | object | {"pvcStorage":"1Gi","s3modelpath":"relative_s3_model_path/opt-125m", "awsEc2MetadataDisabled": true} | Additional configuration for the init container | -| extraInit.pvcStorage | string | "50Gi" | Storage size of the s3 | -| extraInit.s3modelpath | string | "relative_s3_model_path/opt-125m" | Path of the model on the s3 which hosts model weights and config files | -| extraInit.awsEc2MetadataDisabled | boolean | true | Disables the use of the Amazon EC2 instance metadata service | -| extraPorts | list | [] | Additional ports configuration | -| gpuModels | list | ["TYPE_GPU_USED"] | Type of gpu used | -| image | object | {"command":["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"],"repository":"vllm/vllm-openai","tag":"latest"} | Image configuration | -| image.command | list | ["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"] | Container launch command | -| image.repository | string | "vllm/vllm-openai" | Image repository | -| image.tag | string | "latest" | Image tag | -| livenessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":15,"periodSeconds":10} | Liveness probe configuration | -| livenessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive | -| livenessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the Kubelet http request on the server | -| livenessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server | -| livenessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening | -| livenessProbe.initialDelaySeconds | int | 15 | Number of seconds after the container has started before liveness probe is initiated | -| livenessProbe.periodSeconds | int | 10 | How often (in seconds) to perform the liveness probe | -| maxUnavailablePodDisruptionBudget | string | "" | Disruption Budget Configuration | -| readinessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":5,"periodSeconds":5} | Readiness probe configuration | -| readinessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready | -| readinessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the Kubelet http request on the server | -| readinessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server | -| readinessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening | -| readinessProbe.initialDelaySeconds | int | 5 | Number of seconds after the container has started before readiness probe is initiated | -| readinessProbe.periodSeconds | int | 5 | How often (in seconds) to perform the readiness probe | -| replicaCount | int | 1 | Number of replicas | -| resources | object | {"limits":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1},"requests":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1}} | Resource configuration | -| resources.limits."nvidia.com/gpu" | int | 1 | Number of gpus used | -| resources.limits.cpu | int | 4 | Number of CPUs | -| resources.limits.memory | string | "16Gi" | CPU memory configuration | -| resources.requests."nvidia.com/gpu" | int | 1 | Number of gpus used | -| resources.requests.cpu | int | 4 | Number of CPUs | -| resources.requests.memory | string | "16Gi" | CPU memory configuration | -| secrets | object | {} | Secrets configuration | -| serviceName | string | Service name | | -| servicePort | int | 80 | Service port | -| labels.environment | string | test | Environment name | +The following table describes configurable parameters of the chart in `values.yaml`: + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| autoscaling | object | {"enabled":false,"maxReplicas":100,"minReplicas":1,"targetCPUUtilizationPercentage":80} | Autoscaling configuration | +| autoscaling.enabled | bool | false | Enable autoscaling | +| autoscaling.maxReplicas | int | 100 | Maximum replicas | +| autoscaling.minReplicas | int | 1 | Minimum replicas | +| autoscaling.targetCPUUtilizationPercentage | int | 80 | Target CPU utilization for autoscaling | +| configs | object | {} | Configmap | +| containerPort | int | 8000 | Container port | +| customObjects | list | [] | Custom Objects configuration | +| deploymentStrategy | object | {} | Deployment strategy configuration | +| externalConfigs | list | [] | External configuration | +| extraContainers | list | [] | Additional containers configuration | +| extraInit | object | {"pvcStorage":"1Gi","s3modelpath":"relative_s3_model_path/opt-125m", "awsEc2MetadataDisabled": true} | Additional configuration for the init container | +| extraInit.pvcStorage | string | "1Gi" | Storage size of the s3 | +| extraInit.s3modelpath | string | "relative_s3_model_path/opt-125m" | Path of the model on the s3 which hosts model weights and config files | +| extraInit.awsEc2MetadataDisabled | boolean | true | Disables the use of the Amazon EC2 instance metadata service | +| extraPorts | list | [] | Additional ports configuration | +| gpuModels | list | ["TYPE_GPU_USED"] | Type of gpu used | +| image | object | {"command":["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"],"repository":"vllm/vllm-openai","tag":"latest"} | Image configuration | +| image.command | list | ["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"] | Container launch command | +| image.repository | string | "vllm/vllm-openai" | Image repository | +| image.tag | string | "latest" | Image tag | +| livenessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":15,"periodSeconds":10} | Liveness probe configuration | +| livenessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive | +| livenessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the kubelet http request on the server | +| livenessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server | +| livenessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening | +| livenessProbe.initialDelaySeconds | int | 15 | Number of seconds after the container has started before liveness probe is initiated | +| livenessProbe.periodSeconds | int | 10 | How often (in seconds) to perform the liveness probe | +| maxUnavailablePodDisruptionBudget | string | "" | Disruption Budget Configuration | +| readinessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":5,"periodSeconds":5} | Readiness probe configuration | +| readinessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready | +| readinessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the kubelet http request on the server | +| readinessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server | +| readinessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening | +| readinessProbe.initialDelaySeconds | int | 5 | Number of seconds after the container has started before readiness probe is initiated | +| readinessProbe.periodSeconds | int | 5 | How often (in seconds) to perform the readiness probe | +| replicaCount | int | 1 | Number of replicas | +| resources | object | {"limits":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1},"requests":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1}} | Resource configuration | +| resources.limits."nvidia.com/gpu" | int | 1 | Number of GPUs used | +| resources.limits.cpu | int | 4 | Number of CPUs | +| resources.limits.memory | string | "16Gi" | CPU memory configuration | +| resources.requests."nvidia.com/gpu" | int | 1 | Number of GPUs used | +| resources.requests.cpu | int | 4 | Number of CPUs | +| resources.requests.memory | string | "16Gi" | CPU memory configuration | +| secrets | object | {} | Secrets configuration | +| serviceName | string | "" | Service name | +| servicePort | int | 80 | Service port | +| labels.environment | string | test | Environment name | From 27c065df50407f6b801d0053378c442ccea37d39 Mon Sep 17 00:00:00 2001 From: TJian Date: Thu, 26 Jun 2025 05:42:31 -0700 Subject: [PATCH 016/175] [Bugfix][V1][ROCm] Fix AITER Flash Attention Backend (Fix API Break and Local Attention Logic: affecting Llama4) (#19904) Signed-off-by: tjtanaa --- vllm/attention/layer.py | 14 ++++-- vllm/v1/attention/backends/rocm_aiter_fa.py | 55 ++++++++++++++------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f7d230c5d7d6f..0c79aaf135518 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -306,12 +306,16 @@ class MultiHeadAttention(nn.Module): block_size=16, is_attention_free=False) backend = backend_name_to_enum(attn_backend.get_name()) - if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: - backend = _Backend.XFORMERS + if current_platform.is_rocm(): + # currently, only torch_sdpa is supported on rocm + self.attn_backend = _Backend.TORCH_SDPA + else: + if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + backend = _Backend.XFORMERS - self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 - } else _Backend.TORCH_SDPA + self.attn_backend = backend if backend in { + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + } else _Backend.TORCH_SDPA def forward( self, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index e011e95efd41b..dc8ff22613061 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -243,8 +243,8 @@ class AiterFlashAttentionMetadataBuilder: self.runner.device, non_blocking=True) local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( self.runner.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max() - local_max_seq_len = virt_k_seqlens_np.max() + local_max_query_len = int(seqlens_q_local_np.max()) + local_max_seq_len = int(virt_k_seqlens_np.max()) local_scheduler_metadata = schedule( batch_size=local_query_start_loc.shape[0] - 1, cu_query_lens=local_query_start_loc, @@ -253,6 +253,17 @@ class AiterFlashAttentionMetadataBuilder: max_seq_len=local_max_seq_len, causal=True) + local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1, + dtype=torch.int32, + device=self.runner.device) + local_cu_seq_lens[1:] = torch.cumsum( + torch.from_numpy(virt_k_seqlens_np).to( + device=self.runner.device, + dtype=torch.int32, + non_blocking=True), + dim=0) + + local_attn_metadata = \ AiterFlashAttentionMetadata.LocalAttentionMetadata( local_query_start_loc=local_query_start_loc, @@ -260,6 +271,7 @@ class AiterFlashAttentionMetadataBuilder: local_block_table=virt_block_table_tensor, local_max_query_len=local_max_query_len, local_max_seq_len=local_max_seq_len, + local_cu_seq_lens=local_cu_seq_lens, local_scheduler_metadata=local_scheduler_metadata, ) @@ -368,6 +380,7 @@ class AiterFlashAttentionMetadata: local_block_table: torch.Tensor local_max_query_len: int local_max_seq_len: int + local_cu_seq_lens: torch.Tensor local_scheduler_metadata: Optional[torch.Tensor] local_attn_metadata: Optional[LocalAttentionMetadata] = None @@ -387,6 +400,7 @@ class AiterFlashAttentionImpl(AttentionImpl): blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, ) -> None: if blocksparse_params is not None: @@ -408,6 +422,7 @@ class AiterFlashAttentionImpl(AttentionImpl): # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0. self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -478,22 +493,25 @@ class AiterFlashAttentionImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - # Reshape the input keys and values and store them in the cache. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] and - # value[:num_actual_tokens] because the reshape_and_cache_flash op uses - # the slot_mapping's shape to determine the number of actual tokens. key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(torch.float8_e4m3fnuz) @@ -541,7 +559,8 @@ class AiterFlashAttentionImpl(AttentionImpl): alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, - cu_seqlens_k=cu_seq_lens, + cu_seqlens_k=(cu_seq_lens if not use_local_attn else + local_metadata.local_cu_seq_lens), ) _, num_heads, head_size = query.shape From 1f5d178e9cc02a49e9d734420b0c0afaff2fd7af Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 26 Jun 2025 23:32:22 +0900 Subject: [PATCH 017/175] Revert "[Bugfix] default set cuda_graph_sizes to max_num_seqs for v1 engine" (#20128) --- vllm/config.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e90ad5e9c8b65..96ea47a0dce38 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2042,12 +2042,11 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - cuda_graph_sizes: list[int] = field(default_factory=list) - """Cuda graph capture sizes - 1. if none provided, then default set to [max_num_seqs] - 2. if one value is provided, then the capture list would follow the + cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) + """Cuda graph capture sizes, default is 512. + 1. if one value is provided, then the capture list would follow the pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] - 3. more than one value (e.g. 1 2 128) is provided, then the capture list + 2. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" delay_factor: float = 0.0 @@ -2212,10 +2211,6 @@ class SchedulerConfig: self.max_num_partial_prefills, self.max_long_partial_prefills, self.long_prefill_token_threshold) - # If cuda_graph_sizes is not specified, default set to [max_num_seqs]. - if not self.cuda_graph_sizes: - self.cuda_graph_sizes = [self.max_num_seqs] - @model_validator(mode='after') def _verify_args(self) -> Self: if (self.max_num_batched_tokens < self.max_model_len From c894c5dc1ffadee8979f3a051bfccea0441ae09a Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 26 Jun 2025 10:33:13 -0400 Subject: [PATCH 018/175] [Bug Fix] Fix address/port already in use error for deep_ep test (#20094) Signed-off-by: yewentao256 --- tests/kernels/moe/deepep_utils.py | 5 ++++- vllm/model_executor/layers/fused_moe/utils.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/deepep_utils.py b/tests/kernels/moe/deepep_utils.py index 117f1babdf62a..e4cd8386e1020 100644 --- a/tests/kernels/moe/deepep_utils.py +++ b/tests/kernels/moe/deepep_utils.py @@ -4,6 +4,7 @@ DeepEP test utilities """ import dataclasses import importlib +import os import traceback from typing import Callable, Optional @@ -13,6 +14,8 @@ from torch.multiprocessing import ( spawn) # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec +from vllm.model_executor.layers.fused_moe.utils import find_free_port + has_deep_ep = importlib.util.find_spec("deep_ep") is not None if has_deep_ep: from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 @@ -92,7 +95,7 @@ def parallel_launch( world_size, world_size, 0, - "tcp://localhost:29500", + f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{find_free_port()}", worker, ) + args, nprocs=world_size, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 692482c2ea692..8f3191db680fd 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import socket +from contextlib import closing from math import prod from typing import Optional @@ -96,3 +98,10 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] \ No newline at end of file From 0907d507bf389b908a267155de4162d725ae1c54 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 26 Jun 2025 22:34:17 +0800 Subject: [PATCH 019/175] [Doc] Automatically signed-off by PyCharm (#20120) Signed-off-by: wang.yuqi --- docs/contributing/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/contributing/README.md b/docs/contributing/README.md index c0c338b426951..d472366c43b5a 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -151,6 +151,11 @@ the terms of the DCO. Using `-s` with `git commit` will automatically add this header. +!!! tip + If you develop using PyCharm, there is a `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window. + Opening it will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`. + This ensures that all your commits are automatically signed-off by PyCharm. + ### PR Title and Classification Only specific types of PRs will be reviewed. The PR title is prefixed From 6393b039865b35c79c5c397e5dca0218d3c26622 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 26 Jun 2025 23:18:36 +0800 Subject: [PATCH 020/175] [Doc] Auto sign-off for VSCode (#20132) Signed-off-by: DarkLight1337 --- docs/contributing/README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/contributing/README.md b/docs/contributing/README.md index d472366c43b5a..83525436be139 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -152,9 +152,12 @@ the terms of the DCO. Using `-s` with `git commit` will automatically add this header. !!! tip - If you develop using PyCharm, there is a `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window. - Opening it will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`. - This ensures that all your commits are automatically signed-off by PyCharm. + You can enable automatic sign-off via your IDE: + + - **PyCharm**: Click on the `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window. + It will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`. + - **VSCode**: Open the [Settings editor](https://code.visualstudio.com/docs/configure/settings) + and enable the `Git: Always Sign Off` (`git.alwaysSignOff`) field. ### PR Title and Classification From 34878a0b481bbbb65bf17923b1eae5ebbb56f896 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 26 Jun 2025 23:18:49 +0800 Subject: [PATCH 021/175] [Doc] Rename page titles (#20130) Signed-off-by: DarkLight1337 --- docs/contributing/incremental_build.md | 2 +- docs/contributing/model/README.md | 6 +++--- docs/contributing/model/basic.md | 2 +- docs/contributing/model/registration.md | 2 +- docs/contributing/model/tests.md | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md index 8efa34825ecaf..14c3aaead51ed 100644 --- a/docs/contributing/incremental_build.md +++ b/docs/contributing/incremental_build.md @@ -1,4 +1,4 @@ -# Incremental Compilation Workflow for vLLM Development +# Incremental Compilation Workflow When working on vLLM's C++/CUDA kernels located in the `csrc/` directory, recompiling the entire project with `uv pip install -e .` for every change can be time-consuming. An incremental compilation workflow using CMake allows for faster iteration by only recompiling the necessary components after an initial setup. This guide details how to set up and use such a workflow, which complements your editable Python installation. diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index 82541924bc028..63abb7991050d 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -14,9 +14,9 @@ However, this can be more complex for models that include new operators (e.g., a Read through these pages for a step-by-step guide: -- [Implementing a Basic Model](basic.md) -- [Registering a Model to vLLM](registration.md) -- [Writing Unit Tests](tests.md) +- [Basic Model](basic.md) +- [Registering a Model](registration.md) +- [Unit Testing](tests.md) - [Multi-Modal Support](multimodal.md) !!! tip diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index 644d21482ef6f..d552cd06be204 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -1,5 +1,5 @@ --- -title: Implementing a Basic Model +title: Basic Model --- [](){ #new-model-basic } diff --git a/docs/contributing/model/registration.md b/docs/contributing/model/registration.md index a6dc1e32dfb95..758caa72cd4a0 100644 --- a/docs/contributing/model/registration.md +++ b/docs/contributing/model/registration.md @@ -1,5 +1,5 @@ --- -title: Registering a Model to vLLM +title: Registering a Model --- [](){ #new-model-registration } diff --git a/docs/contributing/model/tests.md b/docs/contributing/model/tests.md index a8cb457453b91..c7bcc02a8b809 100644 --- a/docs/contributing/model/tests.md +++ b/docs/contributing/model/tests.md @@ -1,5 +1,5 @@ --- -title: Writing Unit Tests +title: Unit Testing --- [](){ #new-model-tests } From 0bceac9810a5f51b06bf3e4cace182b639326ed2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 26 Jun 2025 11:19:46 -0400 Subject: [PATCH 022/175] Spam folks if config.py changes (#20131) Signed-off-by: Tyler Michael Smith --- .github/CODEOWNERS | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e98ccd035ee90..da7f89747a16d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -18,6 +18,10 @@ /vllm/entrypoints @aarnphm CMakeLists.txt @tlrmchlsmth +# Any change to the VllmConfig changes can have a large user-facing impact, +# so spam a lot of people +/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor + # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat /vllm/v1/structured_output @mgoin @russellb @aarnphm From b69781f107b7ad847a351f584178cfafbee2b32a Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 27 Jun 2025 00:27:18 +0800 Subject: [PATCH 023/175] [Hardware][Intel GPU] Add v1 Intel GPU support with Flash attention backend. (#19560) Signed-off-by: Kunshang Ji --- .../scripts/hardware_ci/run-xpu-test.sh | 1 + docker/Dockerfile.xpu | 1 + requirements/xpu.txt | 1 + vllm/_ipex_ops.py | 105 +++++++++++ vllm/attention/utils/fa_utils.py | 15 +- vllm/executor/ray_distributed_executor.py | 2 +- vllm/platforms/xpu.py | 102 +++++++---- vllm/v1/attention/backends/flash_attn.py | 12 +- vllm/v1/worker/xpu_model_runner.py | 32 ++++ vllm/v1/worker/xpu_worker.py | 164 ++++++++++++++++++ 10 files changed, 393 insertions(+), 42 deletions(-) create mode 100644 vllm/v1/worker/xpu_model_runner.py create mode 100644 vllm/v1/worker/xpu_worker.py diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index f54010c4231f9..827649bfcf548 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -28,4 +28,5 @@ docker run \ sh -c ' VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager ' diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 681102b9d18be..466ba98333635 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi ENV VLLM_TARGET_DEVICE=xpu +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 3cb6a4a8addac..0d95dc57152de 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0 wheel jinja2>=3.1.6 datasets # for benchmark scripts +numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding torch==2.7.0+xpu torchaudio diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index ae63e06030dd1..2be02411ec05e 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -228,6 +228,111 @@ class ipex_ops: ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping) + @staticmethod + def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, + k_scale_float: float = 1.0, + v_scale_float: float = 1.0, + ) -> None: + assert kv_cache_dtype == "auto" + # TODO: support FP8 kv cache. + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def flash_attn_varlen_func( + out: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + seqused_k: torch.Tensor, # we don't support this in ipex kernel + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + causal: bool, + block_table: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + window_size: Optional[list[int]] = None, + softcap: Optional[float] = 0.0, + cu_seqlens_k: Optional[torch.Tensor] = None, + # The following parameters are not used in ipex kernel currently, + # we keep API compatible to CUDA's. + scheduler_metadata=None, + fa_version: int = 2, + q_descale=None, + k_descale=None, + v_descale=None, + ): + if cu_seqlens_k is None: + # cu_seqlens_k is not used in ipex kernel. + cu_seqlens_k = torch.cumsum(seqused_k, dim=0) + cu_seqlens_k = torch.cat([ + torch.tensor([0], device=seqused_k.device, dtype=torch.int32), + cu_seqlens_k + ]).to(torch.int32) + + real_window_size: tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + q.contiguous(), + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + block_table, + alibi_slopes, + softcap=softcap, + window_size_left=real_window_size[0], + window_size_right=real_window_size[1], + k_scale=1.0, + v_scale=1.0, + ) + + @staticmethod + def get_scheduler_metadata( + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + ) -> None: + logger.warning_once( + "get_scheduler_metadata is not implemented for ipex_ops, " + "returning None.") + return None + @staticmethod def copy_blocks(key_caches: list[torch.Tensor], value_caches: list[torch.Tensor], diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 69cde06fd72e9..36fd2d231bc5f 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -4,13 +4,27 @@ from typing import Optional from vllm import envs from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) +if current_platform.is_cuda(): + from vllm import _custom_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash + flash_attn_varlen_func = ops.flash_attn_varlen_func + get_scheduler_metadata = ops.get_scheduler_metadata + def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform + if current_platform.is_xpu(): + return 2 try: from vllm.vllm_flash_attn.flash_attn_interface import ( fa_version_unsupported_reason, is_fa_version_supported) @@ -50,6 +64,5 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: def flash_attn_supports_fp8() -> bool: - from vllm.platforms import current_platform return get_flash_attn_version() == 3 and \ current_platform.get_device_capability().major == 9 diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index a3f05ec5ea3f2..84e8ddd8e274d 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - if envs.VLLM_USE_V1: + if envs.VLLM_USE_V1 and not current_platform.is_xpu(): # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 73f6f3d417671..f361f5e2616ef 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,18 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import TYPE_CHECKING, Optional import torch +import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: + ModelConfig = None VllmConfig = None logger = init_logger(__name__) @@ -35,8 +38,13 @@ class XPUPlatform(Platform): use_mla: bool) -> str: if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) - logger.info("Using IPEX attention backend.") - return "vllm.attention.backends.ipex_attn.IpexAttnBackend" + use_v1 = envs.VLLM_USE_V1 + if use_v1: + logger.info("Using Flash Attention backend on V1 engine.") + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + else: + logger.info("Using IPEX attention backend.") + return "vllm.attention.backends.ipex_attn.IpexAttnBackend" @classmethod def get_device_capability( @@ -67,25 +75,27 @@ class XPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config + # in V1(or with ipex chunked prefill) block_size is 64 if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + if envs.VLLM_USE_V1: + cache_config.block_size = 64 + else: + cache_config.block_size = 16 - # check and update model config - model_config = vllm_config.model_config - if model_config.dtype == torch.bfloat16: - bf16_supported = cls.device_support_bf16() - if not bf16_supported: + # Instances created using VllmConfig() typically have model_config as + # None by default. The modification involves adding a check to prevent + # potential null exceptions check and update model config. + if vllm_config.model_config is not None: + model_config = vllm_config.model_config + if model_config.dtype == torch.bfloat16: + bf16_supported = cls.device_support_bf16() + if not bf16_supported: + model_config.dtype = torch.float16 + if not model_config.enforce_eager: logger.warning( - "bfloat16 is only supported on Intel Data Center GPU, " - "Intel Arc GPU is not supported yet. Your device is %s," - " which is not supported. will fallback to float16", - cls.get_device_name()) - model_config.dtype = torch.float16 - if not model_config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - model_config.enforce_eager = True + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True if vllm_config.speculative_config is not None: raise NotImplementedError( @@ -96,21 +106,27 @@ class XPUPlatform(Platform): # check and update parallel config parallel_config = vllm_config.parallel_config - if parallel_config.worker_cls == "auto": + if envs.VLLM_USE_V1: + parallel_config.worker_cls =\ + "vllm.v1.worker.xpu_worker.XPUWorker" + else: parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" if parallel_config.distributed_executor_backend is None: - parallel_config.distributed_executor_backend = "ray" + if parallel_config.world_size > 1: + parallel_config.distributed_executor_backend = "ray" + else: + parallel_config.distributed_executor_backend = "uni" elif parallel_config.distributed_executor_backend == "mp": # FIXME(kunshang): # spawn needs calling `if __name__ == '__main__':`` # fork is not supported for xpu start new process. - logger.error( - "Both start methods (spawn and fork) have issue " - "on XPU if you use mp backend, setting it to ray instead.") - parallel_config.distributed_executor_backend = "ray" - - elif parallel_config.distributed_executor_backend != "ray": + if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn": + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + logger.warning( + "Please use spawn as start method if you want to use mp.") + elif parallel_config.distributed_executor_backend != "ray" and \ + parallel_config.distributed_executor_backend != "uni": logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", @@ -142,15 +158,35 @@ class XPUPlatform(Platform): @classmethod def device_support_bf16(cls) -> bool: device_name = cls.get_device_name().lower() - if device_name.count("arc") > 0: + if cls.is_client_gpu_a770(): + logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," + " fallback to float16") return False - elif device_name.count("data center gpu") > 0: - return True else: - logger.warning("Unknown device name %s, always use float16", - device_name) - return False + logger.info( + "Device name %s supports bfloat16. Please file an issue " + "if you encounter any accuracy problems with bfloat16.", + device_name) + return True + + @classmethod + def is_data_center_gpu(cls) -> bool: + device_name = cls.get_device_name().lower() + return device_name.count("data center gpu") > 0 + + @classmethod + def is_client_gpu_a770(cls) -> bool: + device_name = cls.get_device_name().lower() + return device_name.count("a770") > 0 @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + return True + + @classmethod + def device_count(cls) -> int: + return torch.xpu.device_count() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ef65d2ea36e4f..42b5997f085b1 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) + flash_attn_varlen_func, + get_flash_attn_version, + get_scheduler_metadata, + reshape_and_cache_flash) from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, @@ -28,10 +30,6 @@ from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner -if current_platform.is_cuda(): - from vllm.vllm_flash_attn import (flash_attn_varlen_func, - get_scheduler_metadata) - logger = init_logger(__name__) @@ -443,7 +441,7 @@ class FlashAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( + reshape_and_cache_flash( key, value, key_cache, diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py new file mode 100644 index 0000000000000..55d116dcd4968 --- /dev/null +++ b/vllm/v1/worker/xpu_model_runner.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class XPUModelRunner(GPUModelRunner): + """A model runner for XPU devices.""" + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(vllm_config, device) + # FIXME: To be verified. + self.cascade_attn_enabled = False + + def _init_device_properties(self) -> None: + pass + + def _sync_device(self) -> None: + torch.xpu.synchronize() diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py new file mode 100644 index 0000000000000..d9ea03986566b --- /dev/null +++ b/vllm/v1/worker/xpu_worker.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.v1.worker.gpu_worker import (Worker, + init_worker_distributed_environment) +from vllm.v1.worker.xpu_model_runner import XPUModelRunner + +logger = init_logger(__name__) + + +class XPUWorker(Worker): + """A XPU worker class.""" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + super().__init__(vllm_config, local_rank, rank, + distributed_init_method, is_driver_worker) + device_config = self.device_config + assert device_config.device_type == "xpu" + assert current_platform.is_xpu() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + # we provide this function due to `torch.xpu.mem_get_info()` doesn't + # return correct free_gpu_memory on intel client GPU. We need to + # calculate/estiamte it. + def xpu_get_mem_info(self): + if current_platform.is_data_center_gpu(): + return torch.xpu.mem_get_info() + else: + _, total_gpu_memory = torch.xpu.mem_get_info() + # FIXME: memory_allocated() doesn't count non-torch allocations, + # and we don't have any API to get it. so we mark it as 128MB. + used_memory = torch.xpu.memory_allocated() + non_torch_allocations = 128 * 1024 * 1024 + free_gpu_memory = total_gpu_memory - (used_memory + + non_torch_allocations) + return free_gpu_memory, total_gpu_memory + + @torch.inference_mode() + def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.xpu.empty_cache() + torch.xpu.reset_peak_memory_stats() + + free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() + current_allocated_bytes = torch.xpu.memory_allocated() + msg = ("Before memory profiling run, " + f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " + f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + logger.info(msg) + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + free_gpu_memory, _ = self.xpu_get_mem_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + assert self.init_gpu_memory > free_gpu_memory, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + # Get the peak memory allocation recorded by torch + peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] + + torch.xpu.empty_cache() + torch_allocated_bytes = torch.xpu.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = self.xpu_get_mem_info( + )[1] - self.xpu_get_mem_info()[0] + + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + msg = ("After memory profiling run, " + f"peak memory usage is {peak_memory / 1024**2:.2f} MB," + f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " + f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + logger.info(msg) + + return int(available_kv_cache_memory) + + def init_device(self): + if self.device_config.device.type == "xpu" and current_platform.is_xpu( + ): + self.device = torch.device(f"xpu:{self.local_rank}") + torch.xpu.set_device(self.device) + torch.xpu.empty_cache() + self.init_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd") + ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(self.parallel_config.world_size)) + os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE + os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) + dist_backend = "ccl" + + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank, dist_backend) + + # global all_reduce needed for overall oneccl warm up + torch.distributed.all_reduce(torch.zeros(1).xpu()) + + # Set random seed. + set_random_seed(self.model_config.seed) + + # Construct the model runner + self.model_runner = XPUModelRunner( # type: ignore + self.vllm_config, self.device) From 04e1642e3251fc575d104c84782fafea348cfbaf Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Thu, 26 Jun 2025 10:01:37 -0700 Subject: [PATCH 024/175] [TPU] add kv cache update kernel (#19928) Signed-off-by: Chengji Yao --- .../scripts/hardware_ci/run-tpu-v1-test.sh | 2 + tests/v1/tpu/test_kv_cache_update_kernel.py | 71 ++++++++++ tests/v1/tpu/test_pallas.py | 3 +- vllm/attention/ops/pallas_kv_cache_update.py | 117 ++++++++++++++++ vllm/v1/attention/backends/pallas.py | 55 +++++++- vllm/v1/worker/tpu_model_runner.py | 132 +++++++++++++----- 6 files changed, 342 insertions(+), 38 deletions(-) create mode 100644 tests/v1/tpu/test_kv_cache_update_kernel.py create mode 100644 vllm/attention/ops/pallas_kv_cache_update.py diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index a2a5c2a02cbb9..90cad506ab1e9 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" run_and_track_test 15 "test_spmd_model_weight_loading.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" +run_and_track_test 16 "test_kv_cache_update_kernel.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py new file mode 100644 index 0000000000000..63a1f6777e4df --- /dev/null +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest +import torch +import torch_xla + +import vllm.v1.attention.backends.pallas # noqa: F401 +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a test for TPU only") +@pytest.mark.parametrize("page_size", [32, 33]) +@pytest.mark.parametrize("combined_kv_head_num", [2, 16]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("num_slices_per_block", [4, 8]) +def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, + head_dim: int, num_slices_per_block: int): + page_num = 1000 + padded_num_tokens = 128 + kv_cache_cpu = torch.zeros( + (page_num * page_size, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) + new_kv_cpu = torch.randn( + (padded_num_tokens, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + new_kv_xla = new_kv_cpu.to(torch_xla.device()) + slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], + dtype=np.int32) + kv_cache_start_indices = np.array([ + page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, + page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 + ], + dtype=np.int32) + new_kv_cache_indices = np.concatenate( + [np.array([0], dtype=np.int32), + np.cumsum(slice_lens[:-1])]) + slot_mapping = np.stack( + [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) + padded_size = (slot_mapping.shape[0] + num_slices_per_block - + 1) // num_slices_per_block * num_slices_per_block + slot_mapping = np.pad(slot_mapping, + [[0, padded_size - slot_mapping.shape[0]], [0, 0]], + constant_values=0) + slot_mapping = np.transpose(slot_mapping) + slot_mapping_cpu = torch.tensor(slot_mapping, + device="cpu", + dtype=torch.int32) + slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) + torch_xla.sync() + + torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) + new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( + new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, + num_slices_per_block) + kv_cache_xla.copy_(new_kv_cache_xla) + torch_xla.sync() + + for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, + slice_lens): + kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] + + assert torch.allclose(kv_cache_xla.cpu(), + kv_cache_cpu, + atol=1e-4, + rtol=1e-4) diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index 3a9d80847a16b..e279edfffbc72 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -47,7 +47,7 @@ def test_ragged_paged_attention(): key = torch.zeros(num_tokens, num_kv_heads * head_size) value = torch.zeros(num_tokens, num_kv_heads * head_size) kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size) - slot_mapping = torch.zeros(num_tokens, dtype=torch.int64) + slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64) max_num_reqs = 8 max_num_blocks_per_req = 8 block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), @@ -65,6 +65,7 @@ def test_ragged_paged_attention(): context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + num_slices_per_kv_cache_update_block=8, ) with patch("torch.ops.xla.ragged_paged_attention" diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py new file mode 100644 index 0000000000000..1a92b10e4f9c7 --- /dev/null +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def _kv_cache_update_kernel( + # Prefetch + slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start, + # slice_len) + # Input + new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] + kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, + # head_dim] + # Output + _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + # Scratch + scratch, # [num_slices_per_block, page_size, num_combined_kv_heads, + # head_dim] + sem, +): + async_copies = [] + block_idx = pl.program_id(0) + num_slices_per_block = scratch.shape[0] + + # Copy from new_kv_hbm_ref to scratch + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + new_kv_start = slices_ref[1, offset_i] + length = slices_ref[2, offset_i] + async_copy = pltpu.make_async_copy( + new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], + scratch.at[i, pl.ds(0, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + + for async_copy in async_copies: + async_copy.wait() + + # Copy from scratch to kv_cache_hbm_ref + async_copies.clear() + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + kv_cache_start = slices_ref[0, offset_i] + length = slices_ref[2, offset_i] + async_copy = pltpu.make_async_copy( + scratch.at[i, pl.ds(0, length), ...], + kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + for async_copy in async_copies: + async_copy.wait() + + +@functools.partial( + jax.jit, + static_argnames=["page_size", "num_slices_per_block"], +) +def kv_cache_update( + new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] + slices: jax. + Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) + kv_cache: jax. + Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + *, + page_size: int = 32, + num_slices_per_block: int = 8, +): + assert slices.shape[1] % num_slices_per_block == 0 + _, num_combined_kv_heads, head_dim = new_kv.shape + assert kv_cache.shape[1] == num_combined_kv_heads + assert kv_cache.shape[2] == head_dim + assert head_dim % 128 == 0 + # TODO: Add dynamic check to make sure that the all the slice lengths are + # smaller or equal to page_size + + in_specs = [ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + + out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] + out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] + + scalar_prefetches = [slices] + scratch = pltpu.VMEM( + (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), + new_kv.dtype, + ) + + scratch_shapes = [ + scratch, + pltpu.SemaphoreType.DMA, + ] + + kernel = pl.pallas_call( + _kv_cache_update_kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=(slices.shape[1] // num_slices_per_block, ), + scratch_shapes=scratch_shapes, + ), + out_shape=out_shape, + input_output_aliases={len(scalar_prefetches) + 1: 0}, + ) + + return kernel(*scalar_prefetches, new_kv, kv_cache)[0] diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index ff2862edaa01b..49f0772c62d13 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,8 +5,12 @@ from dataclasses import dataclass from typing import Any, Optional import torch -# Required to register custom ops. +import torch_xla.core.xla_builder as xb import torch_xla.experimental.custom_kernel # noqa: F401 +# Required to register custom ops. +from torch.library import impl +from torch_xla._internal.jax_workarounds import requires_jax +from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -107,6 +111,7 @@ class PallasMetadata: context_lens: torch.Tensor query_start_loc: torch.Tensor num_seqs: torch.Tensor + num_slices_per_kv_cache_update_block: int class PallasAttentionBackendImpl(AttentionImpl): @@ -212,7 +217,9 @@ class PallasAttentionBackendImpl(AttentionImpl): # Write input keys and values to the KV cache. # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping - write_to_kv_cache(key, value, kv_cache, slot_mapping) + write_to_kv_cache( + key, value, kv_cache, slot_mapping, + attn_metadata.num_slices_per_kv_cache_update_block) output = torch.ops.xla.ragged_paged_attention( query, @@ -244,6 +251,7 @@ def write_to_kv_cache( value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, + num_slices_per_kv_cache_update_block: int, ) -> None: """ Write the key and values to the KV cache. @@ -251,9 +259,9 @@ def write_to_kv_cache( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] - + num_slices_per_kv_cache_update_block: int """ - _, _, num_combined_kv_heads, head_size = kv_cache.shape + _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, @@ -262,4 +270,41 @@ def write_to_kv_cache( torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) kv_cache = kv_cache.flatten(0, 1) - kv_cache.index_copy_(0, slot_mapping, kv) + new_kv_cache = torch.ops.xla.kv_cache_update_op( + kv, slot_mapping, kv_cache, page_size, + num_slices_per_kv_cache_update_block) + # NOTE: the in-place copy will be optimized away by XLA compiler. + kv_cache.copy_(new_kv_cache) + + +@requires_jax +def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + num_slices_per_block: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { + "page_size": page_size, + "num_slices_per_block": num_slices_per_block + }) + return new_kv_cache + + +XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " + "int page_size, int num_slices_per_block) -> Tensor", ) + + +@impl(XLA_LIB, "kv_cache_update_op", "XLA") +def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + num_slices_per_block: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + page_size, num_slices_per_block) + return new_kv_cache + + +@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") +def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + num_slices_per_block: int) -> torch.Tensor: + return kv_cache diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2d80bac3c9546..bc334419c4cec 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -53,12 +53,11 @@ if TYPE_CHECKING: logger = init_logger(__name__) -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 +# Block size used for kv cache updating kernel +NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 ######################################################### @@ -526,6 +525,69 @@ class TPUModelRunner(LoRAModelRunnerMixin): return kv_cache_spec + def _get_slot_mapping_metadata(self, num_reqs, + num_scheduled_tokens_per_req): + """ + Computes metadata for mapping slots to blocks in the key-value (KV) + cache for a batch of requests. + + This function determines, for each request in the batch, how the + scheduled tokens are distributed across memory blocks, and generates + metadata needed to map slices of tokens to their corresponding positions + in the KV cache. + + Args: + num_reqs (int): Number of requests in the current batch. + num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens + to be scheduled for each request. + + Returns: + np.ndarray: A 2D array of shape (total_block_len, 3), where each row + contains: + - kv_cache_start_index (int): The starting index in the KV cache + for the corresponding slice. + - new_kv_start_index (int): The starting index in the new KV + cache for the corresponding slice. + - slice_len (int): The length of the slice. + """ + slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] + slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ + num_scheduled_tokens_per_req + local_block_start_idx = slices_start // self.block_size + local_block_end_idx = (slices_end - 1) // self.block_size + no_repeat_req_indices = self.arange_np[:num_reqs] + global_block_start_idx = ( + no_repeat_req_indices * self.max_num_blocks_per_req + + local_block_start_idx) + block_lens = local_block_end_idx - local_block_start_idx + 1 + global_block_start_idx = np.repeat(global_block_start_idx, block_lens) + slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) + global_block_indices = global_block_start_idx + slice_arange + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() + total_block_len = np.sum(block_lens) + slot_mapping_slices = np.repeat(np.array([[0, self.block_size]], + dtype=np.int32), + total_block_len, + axis=0) + cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) + np.cumsum(block_lens, out=cu_block_lens[1:]) + for req_idx in range(num_reqs): + slot_mapping_slices[cu_block_lens[req_idx]][ + 0] = slices_start[req_idx] % self.block_size + slot_mapping_slices[ + cu_block_lens[req_idx + 1] - + 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1 + slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] + cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) + np.cumsum(slice_lens, out=cu_slices_lens[1:]) + kv_cache_start_indices = slot_mapping_slices[:, 0] + \ + (block_numbers * self.block_size) + new_kv_start_indices = cu_slices_lens[:-1] + slot_mapping_metadata = np.stack( + [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1) + return slot_mapping_metadata + def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): assert scheduler_output.total_num_scheduled_tokens > 0 @@ -603,26 +665,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.input_batch.block_table[0]. - slot_mapping_np[:total_num_scheduled_tokens]) - # Prepare the attention metadata. self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, @@ -645,12 +687,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.input_batch.block_table[0].slot_mapping_cpu[ - total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = ( - self.input_batch.block_table[0]. - slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( - self.device)) if use_max_model_len: block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, : self.max_num_blocks_per_req] @@ -675,6 +711,19 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.device) block_tables = block_tables.to(self.device) + slot_mapping_metadata = self._get_slot_mapping_metadata( + num_reqs, num_scheduled_tokens_per_req) + padded_num_slices = _get_padded_num_kv_cache_update_slices( + padded_total_num_scheduled_tokens, self.max_num_reqs, + self.block_size) + slot_mapping_metadata = np.pad( + slot_mapping_metadata, + [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], + constant_values=0) + slot_mapping_metadata = np.transpose(slot_mapping_metadata) + slot_mapping_metadata = torch.tensor(slot_mapping_metadata, + device=self.device) + if self.lora_config is not None: # We need to respect padding when activating LoRA adapters padded_num_scheduled_tokens_per_req = np.copy( @@ -687,13 +736,15 @@ class TPUModelRunner(LoRAModelRunnerMixin): padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_metadata, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + num_slices_per_kv_cache_update_block= + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -1119,8 +1170,10 @@ class TPUModelRunner(LoRAModelRunnerMixin): actual_num_reqs = min(num_tokens, num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) - slot_mapping = torch.zeros(num_tokens, - dtype=torch.int64).to(self.device) + padded_num_slices = _get_padded_num_kv_cache_update_slices( + num_tokens, self.max_num_reqs, self.block_size) + slot_mapping = torch.zeros((3, padded_num_slices), + dtype=torch.int32).to(self.device) block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to(self.device) query_lens = [1] * num_reqs @@ -1138,6 +1191,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + num_slices_per_kv_cache_update_block= + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, ) if self.is_multimodal_model: @@ -1742,6 +1797,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] +def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, + page_size: int) -> int: + """Calculates the padded number of KV cache update slices to avoid + recompilation.""" + padded_num_slices = 2 * max_num_reqs + num_tokens // page_size + padded_num_slices = min(padded_num_slices, num_tokens) + padded_num_slices = ( + padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1 + ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \ + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK + return padded_num_slices + + def replace_set_lora(model): def _tpu_set_lora( From 562308816ceabd8414f49ff2aa291480f69fa1a5 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 26 Jun 2025 18:19:32 -0400 Subject: [PATCH 025/175] [Refactor] Rename commnication utils (#20091) Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepep_deepgemm_moe.py | 2 +- tests/kernels/moe/test_deepep_moe.py | 2 +- tests/kernels/moe/test_pplx_cutlass_moe.py | 2 +- tests/kernels/moe/test_pplx_moe.py | 2 +- tests/kernels/moe/{deepep_utils.py => utils.py} | 0 5 files changed, 4 insertions(+), 4 deletions(-) rename tests/kernels/moe/{deepep_utils.py => utils.py} (100%) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 2d7cf39a8cca5..f580dee4c9285 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform -from .deepep_utils import ProcessGroupInfo, parallel_launch +from .utils import ProcessGroupInfo, parallel_launch has_deep_ep = importlib.util.find_spec("deep_ep") is not None diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 7e029ea950555..380eb43c42a40 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform -from .deepep_utils import ProcessGroupInfo, parallel_launch +from .utils import ProcessGroupInfo, parallel_launch has_deep_ep = importlib.util.find_spec("deep_ep") is not None diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 0caf14f040bbe..ee2bdc838b0d1 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform -from .deepep_utils import ProcessGroupInfo, parallel_launch +from .utils import ProcessGroupInfo, parallel_launch try: from pplx_kernels import AllToAll diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index c4ad3af6802d4..1da14eddff317 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform -from .deepep_utils import ProcessGroupInfo, parallel_launch +from .utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( not has_pplx, diff --git a/tests/kernels/moe/deepep_utils.py b/tests/kernels/moe/utils.py similarity index 100% rename from tests/kernels/moe/deepep_utils.py rename to tests/kernels/moe/utils.py From 07b8fae219b1fff51ef115c38c44b51395be5bb5 Mon Sep 17 00:00:00 2001 From: Kyle Yu <153807854+kyolebu@users.noreply.github.com> Date: Thu, 26 Jun 2025 18:22:12 -0400 Subject: [PATCH 026/175] [Doc] correct LoRA capitalization (#20135) Signed-off-by: kyolebu --- docs/README.md | 2 +- docs/models/supported_models.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 0c6aff5fa07c3..9fb3137b31928 100644 --- a/docs/README.md +++ b/docs/README.md @@ -40,7 +40,7 @@ vLLM is flexible and easy to use with: - OpenAI-compatible API server - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. - Prefix caching support -- Multi-lora support +- Multi-LoRA support For more information, check out the following: diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a435c59a3042b..04d9923f92105 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -427,7 +427,7 @@ Specified using `--task embed`. See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882). !!! note - `jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights. + `jinaai/jina-embeddings-v3` supports multiple tasks through LoRA, while vllm temporarily only supports text-matching tasks by merging LoRA weights. !!! note The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture. From e9fd658a736a4d30f7a367c317506c87ad7f5359 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 26 Jun 2025 15:30:21 -0700 Subject: [PATCH 027/175] [Feature] Expert Parallelism Load Balancer (EPLB) (#18343) Signed-off-by: Bowen Wang --- .buildkite/test-pipeline.yaml | 17 + tests/distributed/test_eplb_algo.py | 292 ++++++++++ tests/distributed/test_eplb_execute.py | 504 ++++++++++++++++++ tests/models/test_initialization.py | 12 +- vllm/config.py | 33 ++ vllm/distributed/eplb/__init__.py | 7 + vllm/distributed/eplb/eplb_state.py | 431 +++++++++++++++ vllm/distributed/eplb/rebalance_algo.py | 233 ++++++++ vllm/distributed/eplb/rebalance_execute.py | 306 +++++++++++ vllm/engine/arg_utils.py | 20 + vllm/model_executor/layers/fused_moe/layer.py | 264 ++++++++- .../layers/quantization/awq_marlin.py | 8 + .../compressed_tensors_moe.py | 42 ++ .../layers/quantization/experts_int8.py | 8 + .../model_executor/layers/quantization/fp8.py | 14 + .../layers/quantization/gguf.py | 8 + .../layers/quantization/gptq_marlin.py | 8 + .../layers/quantization/modelopt.py | 8 + .../layers/quantization/moe_wna16.py | 8 + .../layers/quantization/quark/quark_moe.py | 8 + vllm/model_executor/models/deepseek_v2.py | 127 ++++- vllm/model_executor/models/interfaces.py | 68 +++ vllm/v1/worker/gpu_model_runner.py | 65 ++- vllm/v1/worker/gpu_worker.py | 9 +- 24 files changed, 2446 insertions(+), 54 deletions(-) create mode 100644 tests/distributed/test_eplb_algo.py create mode 100644 tests/distributed/test_eplb_execute.py create mode 100644 vllm/distributed/eplb/__init__.py create mode 100644 vllm/distributed/eplb/eplb_state.py create mode 100644 vllm/distributed/eplb/rebalance_algo.py create mode 100644 vllm/distributed/eplb/rebalance_execute.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1536759c06bd2..26f70ad457b67 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -168,6 +168,23 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd +- label: EPLB Algorithm Test + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_algo.py + commands: + - pytest -v -s distributed/test_eplb_algo.py + +- label: EPLB Execution Test # 5min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_execute.py + commands: + - pytest -v -s distributed/test_eplb_execute.py + - label: Metrics, Tracing Test # 10min mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py new file mode 100644 index 0000000000000..e47ccba99c81d --- /dev/null +++ b/tests/distributed/test_eplb_algo.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.distributed.eplb.rebalance_algo import rebalance_experts + + +def test_basic_rebalance(): + """Test basic rebalancing functionality""" + # Example from https://github.com/deepseek-ai/eplb + weight = torch.tensor([ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ]) + + num_layers = weight.shape[0] + num_replicas = 16 + num_groups = 4 + num_nodes = 2 + num_gpus = 8 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify output shapes + assert phy2log.shape == ( + 2, + 16, + ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" + assert (log2phy.shape[0] == 2 + ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" + assert ( + log2phy.shape[1] == 12 + ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + assert logcnt.shape == ( + 2, + 12, + ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" + + # Verify physical to logical expert mapping range is correct + assert torch.all(phy2log >= 0) and torch.all( + phy2log < 12), "Physical to logical mapping should be in range [0, 12)" + + # Verify expert count reasonableness + assert torch.all( + logcnt >= 1), "Each logical expert should have at least 1 replica" + assert ( + torch.sum(logcnt, dim=1).sum() == num_replicas * + num_layers), f"Total replicas should be {num_replicas * num_layers}" + + # Verify expected output + expected_phy2log = torch.tensor([ + [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], + [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], + ]) + assert torch.all(phy2log == expected_phy2log) + + expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], + [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) + assert torch.all(logcnt == expected_logcnt) + + +def test_single_gpu_case(): + """Test single GPU case""" + weight = torch.tensor([[10, 20, 30, 40]]) + num_replicas = 4 + num_groups = 1 + num_nodes = 1 + num_gpus = 1 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 4) + assert log2phy.shape[0] == 1 + assert log2phy.shape[1] == 4 + assert logcnt.shape == (1, 4) + + # Verify all logical experts are mapped + assert set(phy2log[0].tolist()) == {0, 1, 2, 3} + + +def test_equal_weights(): + """Test case with equal weights""" + weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]]) + num_replicas = 8 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 8) + + # With equal weights, each expert should have exactly one replica + assert torch.all( + logcnt == 1 + ), "With equal weights and no replication, " \ + "each expert should have exactly 1 replica" + + +def test_extreme_weight_imbalance(): + """Test extreme weight imbalance case""" + weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]]) + num_replicas = 12 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 12) + assert logcnt.shape == (1, 8) + + # Expert with highest weight (index 0) should have more replicas + assert ( + logcnt[0, 0] + > logcnt[0, 1]), "Expert with highest weight should have more replicas" + + +def test_multiple_layers(): + """Test multiple layers case""" + weight = torch.tensor([ + [10, 20, 30, 40, 50, 60], # First layer + [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) + [25, 25, 25, 25, 25, 25], # Third layer (equal weights) + ]) + num_replicas = 8 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (3, 8) + assert logcnt.shape == (3, 6) + + # Verify expert allocation is reasonable for each layer + for layer in range(3): + assert torch.all(phy2log[layer] >= 0) and torch.all( + phy2log[layer] < 6 + ), f"Layer {layer} physical to logical mapping" \ + "should be in range [0, 6)" + assert (torch.sum(logcnt[layer]) == num_replicas + ), f"Layer {layer} total replicas should be {num_replicas}" + + +def test_parameter_validation(): + """Test parameter validation""" + weight = torch.tensor([[10, 20, 30, 40]]) + + # Test non-divisible case - this should handle normally without throwing + # errors because the function will fall back to global load balancing + # strategy + phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4) + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 4) + + # Test cases that will actually cause errors: + # num_physical_experts not divisible by num_gpus + with pytest.raises(AssertionError): + rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 + + +def test_small_scale_hierarchical(): + """Test small-scale hierarchical load balancing""" + weight = torch.tensor([ + [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts + ]) + num_replicas = 12 + num_groups = 4 # 4 groups, 2 experts each + num_nodes = 2 # 2 nodes + num_gpus = 4 # 4 GPUs + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify basic constraints + assert phy2log.shape == (1, 12) + assert logcnt.shape == (1, 8) + assert torch.sum(logcnt) == num_replicas + assert torch.all(logcnt >= 1) + + # Expert with highest weight should have more replicas + max_weight_expert = torch.argmax(weight[0]) + assert (logcnt[0, max_weight_expert] + >= 2), "Highest weight expert should have multiple replicas" + + +def test_global_load_balance_fallback(): + """Test global load balancing fallback case""" + # When num_groups % num_nodes != 0, should fall back to global load + # balancing + weight = torch.tensor([[10, 20, 30, 40, 50, 60]]) + num_replicas = 8 + num_groups = 3 # Cannot be divided evenly by num_nodes=2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Should work normally, just using global load balancing strategy + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 6) + assert torch.sum(logcnt) == num_replicas + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_compatibility(device): + """Test device compatibility""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + weight = torch.tensor([[10, 20, 30, 40]], device=device) + num_replicas = 6 + num_groups = 2 + num_nodes = 1 + num_gpus = 2 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Function will convert to CPU internally, but should handle different + # device inputs normally + assert phy2log.shape == (1, 6) + assert logcnt.shape == (1, 4) + + +def test_additional_cases(): + """Test more edge cases and different parameter combinations""" + + # Test case 1: Large-scale distributed setup + weight1 = torch.tensor( + [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) + phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) + + assert phy2log1.shape == (1, 24) + assert logcnt1.shape == (1, 16) + assert torch.sum(logcnt1) == 24 + + # Test case 2: Different weight distributions + weight2 = torch.tensor([ + [200, 150, 100, 50, 25, 12], # Decreasing weights + [12, 25, 50, 100, 150, 200], # Increasing weights + ]) + phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) + + assert phy2log2.shape == (2, 10) + assert logcnt2.shape == (2, 6) + + # Verify high-weight experts have more replicas + for layer in range(2): + max_weight_idx = torch.argmax(weight2[layer]) + assert logcnt2[layer, max_weight_idx] >= 2 + + +if __name__ == "__main__": + weight = torch.tensor([ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ]) + + num_replicas = 16 + num_groups = 4 + num_nodes = 2 + num_gpus = 8 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + print(phy2log) + + test_basic_rebalance() diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py new file mode 100644 index 0000000000000..de9ed1eabbac6 --- /dev/null +++ b/tests/distributed/test_eplb_execute.py @@ -0,0 +1,504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +import os +import random + +import pytest +import torch +import torch.distributed + +from vllm.distributed.eplb.rebalance_execute import ( + rearrange_expert_weights_inplace) +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + get_tp_group, + init_distributed_environment) +from vllm.utils import update_environment_variables + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes: list[multiprocessing.Process] = [] + for i in range(number_of_processes): + env: dict[str, str] = {} + env['RANK'] = str(i) + env['LOCAL_RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['LOCAL_WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function + def wrapped_fn(env): + update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + init_distributed_environment() + + # Ensure each worker process has the same random seed + random.seed(42) + torch.manual_seed(42) + + fn() + + return wrapped_fn + + +def create_expert_indices_with_redundancy( + num_layers: int, + num_logical_experts: int, + total_physical_experts: int, + redundancy_config: list[int], # redundancy for each logical expert +) -> torch.Tensor: + """ + Create expert indices with redundancy. + + Args: + num_layers: number of layers + num_logical_experts: number of logical experts + total_physical_experts: total number of physical experts + redundancy_config: redundancy for each logical expert + + Returns: + indices: Shape (num_layers, total_physical_experts) + """ + assert sum(redundancy_config) == total_physical_experts + assert len(redundancy_config) == num_logical_experts + + indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long) + + for layer in range(num_layers): + physical_pos = 0 + for logical_expert_id, redundancy in enumerate(redundancy_config): + for _ in range(redundancy): + indices[layer, physical_pos] = logical_expert_id + physical_pos += 1 + + # Shuffle the indices at dim 1 + for layer in range(num_layers): + indices[layer] = indices[layer][torch.randperm(indices.shape[1])] + + return indices + + +def create_expert_weights( + num_layers: int, + num_local_experts: int, + hidden_sizes: list[int], + rank: int, + device: torch.device, + physical_to_logical_mapping: torch.Tensor, +) -> list[list[torch.Tensor]]: + """ + Create fake expert weights tensor for testing. + + Use `arange` to generate predictable weights values, based on logical + expert ID. + All replicas of the same logical expert should have the same weights. + + Args: + physical_to_logical_mapping: Shape (num_layers, num_local_experts) + mapping[layer, physical_pos] = logical_expert_id + """ + expert_weights = [] + + for layer in range(num_layers): + layer_weights = [] + for weight_idx, hidden_size in enumerate(hidden_sizes): + weight_tensor = torch.zeros(num_local_experts, + hidden_size, + device=device, + dtype=torch.float32) + + for local_expert in range(num_local_experts): + # Get the logical expert ID for this physical expert + global_pos = rank * num_local_experts + local_expert + logical_expert_id = physical_to_logical_mapping[ + layer, global_pos].item() + + # Generate weights based on logical expert ID + # (so that all replicas of the same logical expert have the + # same weights) + base_value = (logical_expert_id * 1000 + layer * 100 + + weight_idx * 10) + weight_tensor[local_expert] = torch.arange(base_value, + base_value + + hidden_size, + device=device, + dtype=torch.float32) + + layer_weights.append(weight_tensor) + expert_weights.append(layer_weights) + + return expert_weights + + +def create_redundancy_config( + num_logical_experts: int, + num_physical_experts: int, +) -> list[int]: + """Create a redundancy configuration.""" + redundancy_config = [1] * num_logical_experts + remaining = num_physical_experts - num_logical_experts + # Randomly assign the remaining physical experts to the logical experts + for _ in range(remaining): + redundancy_config[random.choice(range(num_logical_experts))] += 1 + return redundancy_config + + +def verify_expert_weights_after_shuffle( + expert_weights: list[list[torch.Tensor]], + new_indices: torch.Tensor, + hidden_sizes: list[int], + ep_rank: int, + num_local_experts: int, +): + """Verify the weights after shuffling are correct.""" + num_layers = len(expert_weights) + + for layer in range(num_layers): + for weight_idx, hidden_size in enumerate(hidden_sizes): + weight_tensor = expert_weights[layer][weight_idx] + + for local_expert in range(num_local_experts): + # Calculate the global expert ID for this local expert + global_pos = ep_rank * num_local_experts + local_expert + expected_logical_expert = new_indices[layer, global_pos].item() + + # Check if the weights are correct + actual_weights = weight_tensor[local_expert] + expected_base = (expected_logical_expert * 1000 + layer * 100 + + weight_idx * 10) + expected_weights = torch.arange(expected_base, + expected_base + hidden_size, + device=actual_weights.device, + dtype=actual_weights.dtype) + + torch.testing.assert_close( + actual_weights, + expected_weights, + msg=f"Layer {layer}, weight {weight_idx}," + f"local expert {local_expert}: " + f"weights do not match. " + f"Expected logical expert {expected_logical_expert}") + + +def verify_redundant_experts_have_same_weights( + expert_weights: list[list[torch.Tensor]], + indices: torch.Tensor, + hidden_sizes: list[int], + world_size: int, + num_local_experts: int, +): + """ + Verify that all replicas of the same logical expert have the same weights. + """ + num_layers = len(expert_weights) + total_physical_experts = world_size * num_local_experts + + for layer in range(num_layers): + # Collect weights for all physical experts for each weight matrix + all_weights: list[torch.Tensor] = [] + + for weight_idx, hidden_size in enumerate(hidden_sizes): + # Create tensor to store all expert weights + # Shape: [total_physical_experts, hidden_size] + gathered_weights = torch.zeros( + total_physical_experts, + hidden_size, + device=expert_weights[layer][weight_idx].device, + dtype=expert_weights[layer][weight_idx].dtype) + + # Use all_gather to collect expert weights from current node + # expert_weights[layer][weight_idx] shape: + # [num_local_experts, hidden_size] + local_weights = expert_weights[layer][ + weight_idx] # [num_local_experts, hidden_size] + + # Split tensor along dim 0 into a list for all_gather + gathered_weights_list = torch.chunk(gathered_weights, + world_size, + dim=0) + + torch.distributed.all_gather( + # Output list: each element corresponds to one rank's weights + list(gathered_weights_list), + local_weights # Input: current rank's local weights + ) + + all_weights.append(gathered_weights) + + # Verify that all replicas of the same logical expert have the same + # weights + logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {} + + for physical_pos in range(total_physical_experts): + logical_expert_id = int(indices[layer, physical_pos].item()) + + if logical_expert_id not in logical_expert_weights: + # First time encountering this logical expert, save its weights + logical_expert_weights[logical_expert_id] = { + weight_idx: all_weights[weight_idx][physical_pos] + for weight_idx in range(len(hidden_sizes)) + } + else: + # Verify that current physical expert's weights match the + # previously saved logical expert weights + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + all_weights[weight_idx][physical_pos], + logical_expert_weights[logical_expert_id][weight_idx], + msg=f"Layer {layer}, weight {weight_idx}," + f"logical expert {logical_expert_id}: " + f"Physical expert {physical_pos} has different weights" + f"than expected") + + +@pytest.mark.parametrize( + "world_size,num_layers,num_local_experts,num_logical_experts", + [ + # 2 GPU, 2 experts per GPU + # 3 logical experts, 4 physical experts, 1 redundant experts + (2, 1, 2, 3), + # 2 GPU, 3 experts per GPU + # 4 logical experts, 6 physical experts, 2 redundant experts + (2, 2, 3, 4), + # 2 GPU, 8 experts per GPU + # 16 logical experts, 16 physical experts, 0 redundant experts + (2, 4, 8, 16), + # 4 GPU, 2 experts per GPU + # 6 logical experts, 8 physical experts, 2 redundant experts + (4, 1, 2, 6), + # 4 GPU, 2 experts per GPU + # 5 logical experts, 8 physical experts, 3 redundant experts + (4, 2, 2, 5), + # 4 GPU, 8 experts per GPU + # 16 logical experts, 32 physical experts, 16 redundant experts + (4, 8, 8, 16), + ]) +def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, + num_local_experts, + num_logical_experts): + """Test the functionality of rearranging expert weights with redundancy.""" + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + # Initialize model parallel (using tensor parallel as an entrypoint + # to expert parallel) + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + # Test parameters + total_physical_experts = world_size * num_local_experts + hidden_sizes = [32, 64] # Two different weight matrices + + # Create old expert indices (with redundancy) + redundancy_config = create_redundancy_config(num_logical_experts, + total_physical_experts) + + old_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + redundancy_config, + ) + + # Create new expert indices (with redundancy) + new_redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts) + new_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + new_redundancy_config, + ) + + # Create expert weights + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + old_indices) + + # Execute weight rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=False, + ) + + # Verify the rearrangement result + verify_expert_weights_after_shuffle( + expert_weights, + new_indices, + hidden_sizes, + ep_rank, + num_local_experts, + ) + + verify_redundant_experts_have_same_weights( + expert_weights, + new_indices, + hidden_sizes, + world_size, + num_local_experts, + ) + + distributed_run(worker_fn, world_size) + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_rearrange_expert_weights_no_change(world_size): + """ + Test that when the indices do not change, the weights should remain + unchanged. + """ + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + num_layers = 2 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 # Some redundancy + hidden_sizes = [32, 64] + + # Create redundancy configuration + redundancy_config = [2] * num_logical_experts + + # Same indices - no change + indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + redundancy_config) + + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + indices) + + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute rearrangement (should be no change) + rearrange_expert_weights_inplace( + indices, + indices, # Same indices + expert_weights, + ep_group, + is_profile=False) + + # Verify that the weights have not changed + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg=f"Layer {layer}, weight {weight_idx} should remain " + f"unchanged") + + distributed_run(worker_fn, world_size) + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_rearrange_expert_weights_profile_mode(world_size): + """Test profile mode (should not copy actual weights)""" + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + num_layers = 1 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 + hidden_sizes = [32] + + # Create different index distributions + old_redundancy = create_redundancy_config(num_logical_experts, + total_physical_experts) + new_redundancy = create_redundancy_config(num_logical_experts, + total_physical_experts) + + old_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + old_redundancy) + new_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + new_redundancy) + + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + old_indices) + + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute profile mode rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=True # Profile mode + ) + + # In profile mode, the weights should remain unchanged + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg="In profile mode, the weights should remain unchanged") + + distributed_run(worker_fn, world_size) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 54e8cd597bfc4..e56bc925c9c40 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -31,12 +31,20 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): text_config = hf_config.get_text_config() + # Ensure at least 2 expert per group + # Since `grouped_topk` assums top-2 + num_experts = getattr(text_config, 'n_group', 1) * 2 + text_config.update({ "num_layers": 1, "num_hidden_layers": 1, - "num_experts": 2, + "num_experts": num_experts, "num_experts_per_tok": 2, - "num_local_experts": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, }) if hasattr(hf_config, "vision_config"): diff --git a/vllm/config.py b/vllm/config.py index 96ea47a0dce38..856b361531168 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1775,6 +1775,25 @@ class ParallelConfig: """Backend to use for data parallel, either "mp" or "ray".""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" + enable_eplb: bool = False + """Enable expert parallelism load balancing for MoE layers.""" + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + eplb_window_size: int = 1000 + """Window size for expert load recording.""" + eplb_step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `eplb_window_size` steps will be used for rearranging experts. + """ + eplb_log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor @@ -1913,6 +1932,20 @@ class ParallelConfig: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") + if self.enable_eplb: + if not current_platform.is_cuda(): + raise ValueError( + "Expert parallelism load balancing is only supported on " + "CUDA devices now.") + if self.num_redundant_experts < 0: + raise ValueError( + "num_redundant_experts must be non-negative, but got " + f"{self.num_redundant_experts}.") + else: + if self.num_redundant_experts != 0: + raise ValueError( + "num_redundant_experts should be used with EPLB." + f"{self.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py new file mode 100644 index 0000000000000..c87b039afd73d --- /dev/null +++ b/vllm/distributed/eplb/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +''' +Expert parallelism load balancer (EPLB). +''' + +from .eplb_state import * +from .rebalance_algo import * diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py new file mode 100644 index 0000000000000..2185df865c1f6 --- /dev/null +++ b/vllm/distributed/eplb/eplb_state.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Expert parallelism load balancer (EPLB) metrics and states. + +# Glossary + +- **Logical Expert**: An expert that is part of the model's logical structure. + It holds a set of weights and is replicated across multiple physical + experts. +- **Redundant Expert**: To achieve load balancing, for some popular logical + experts, we create additional copies of the expert weights. During inference, + each of these copies can be routed to by the same set of tokens. +- **Physical Expert**: An expert that is instantiated on a specific device. + It is a replica of a logical expert and can be rearranged across devices. + I.e., one logical expert may have multiple sets of weights initialized on + different devices, and each of these sets is a physical expert. +- **Local Physical Expert**: A physical expert that is instantiated on the + current device. + +For example: DeepSeek-R1 has 256 logical experts, so each MoE layer +has 256 sets of linear layer weights in the model parameters. If we add 32 +redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in +total. And when deploying, we'll have 288 sets of linear layer weights for each +MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local +physical experts. +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass + +import torch +from torch.distributed import all_gather, all_reduce + +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_ep_group, get_node_count +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import MixtureOfExperts + +from .rebalance_algo import rebalance_experts +from .rebalance_execute import rearrange_expert_weights_inplace + +logger = init_logger(__name__) + + +@dataclass +class EplbState: + """EPLB metrics.""" + + physical_to_logical_map: torch.Tensor + """ + Mapping from physical experts to logical experts. + + Shape: (num_moe_layers, num_physical_experts) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[0, 1, 2, 3, 0, 1], + [0, 2, 0, 1, 0, 3]] + ``` + """ + logical_to_physical_map: torch.Tensor + """ + Mapping from logical experts to physical experts. + + This is a sparse matrix, where -1 indicates no mapping. + + Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[[0, 4, -1], + [1, 5, -1], + [2, -1, -1], + [3, -1, -1]], + [[0, 2, 4], + [3, -1, -1], + [1, -1, -1], + [5, -1, -1]]] + ``` + """ + logical_replica_count: torch.Tensor + """ + Number of replicas for each logical expert. + This is exactly the non-`-1` count in the `logical_to_physical_map`. + + Shape: (num_moe_layers, num_logical_experts) + + # Example + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the count could look like this: + + ``` + [[2, 2, 1, 1], + [3, 1, 1, 1]] + """ + + expert_load_pass: torch.Tensor + """ + Expert load during this forward pass. + We use the token count each expert processes as the load. + + Shape: (num_moe_layers, num_local_physical_experts) + """ + expert_load_window: torch.Tensor + """ + A sliding window of expert load. + + Shape: (window_size, num_moe_layers, num_local_physical_experts) + """ + expert_load_window_step: int = 0 + """ + Current step in the sliding window. + + Different from `expert_rearrangement_step`, each EP rank may have its own + `expert_load_window_step`. + """ + expert_load_window_size: int = 0 + """ + Size of the expert load sliding window. + This is a constant and is taken from the config. + """ + + expert_rearrangement_step: int = 0 + """ + Steps after last rearrangement. + Will trigger a rearrangement if it exceeds the threshold. + + NOTE: Keep in mind that all EP ranks need to have the same + `expert_rearrangement_step` value to ensure synchronization. + Otherwise, the rearrangement will hang at collective + communication calls. + """ + expert_rearrangement_step_interval: int = 0 + """ + Interval for expert rearrangement steps. + This is a constant and is taken from the config. + """ + + @staticmethod + def build_initial_global_physical_to_logical_map( + num_routed_experts: int, + num_redundant_experts: int, + ) -> Sequence[int]: + """ + Build an initial expert arrangement using the following structure: + [original routed experts, redundant experts] + + Returns: + physical_to_logical_map (Sequence[int]): A list of integers, + where each integer is the index of the logical expert + that the corresponding physical expert maps to. + """ + global_physical_to_logical_map = list(range(num_routed_experts)) + global_physical_to_logical_map += [ + i % num_routed_experts for i in range(num_redundant_experts) + ] + return global_physical_to_logical_map + + @classmethod + def build( + cls, + model: MixtureOfExperts, + device: torch.device, + parallel_config: ParallelConfig, + ) -> "EplbState": + """ + Build the initial EPLB state. + """ + physical_to_logical_map_list = ( + cls.build_initial_global_physical_to_logical_map( + model.num_routed_experts, + model.num_redundant_experts, + )) + physical_to_logical_map = torch.tensor( + physical_to_logical_map_list, + device=device, + ) + logical_to_physical_map = torch.full( + (model.num_logical_experts, model.num_redundant_experts + 1), + -1, + device=device, + ) + logical_replica_count = torch.zeros( + (model.num_logical_experts, ), + device=device, + dtype=torch.long, + ) + + for i in range(model.num_physical_experts): + logical_idx = physical_to_logical_map[i] + logical_to_physical_map[logical_idx, + logical_replica_count[logical_idx]] = i + logical_replica_count[logical_idx] += 1 + + # Duplicate initial mapping for all layers + physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + -1, + ).contiguous() + logical_replica_count = logical_replica_count.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + + expert_load_pass = torch.zeros( + (model.num_moe_layers, model.num_local_physical_experts), + dtype=torch.int32, + device=device, + ) + expert_load_window_size = parallel_config.eplb_window_size + expert_load_window = torch.zeros( + (expert_load_window_size, model.num_moe_layers, + model.num_local_physical_experts), + dtype=torch.int32, + device=device, + ) + + # Set the initial progress of rearrangement to 3/4 + eplb_step_interval = parallel_config.eplb_step_interval + expert_rearrangement_step = max( + 0, eplb_step_interval - eplb_step_interval // 4) + + model.set_eplb_state( + expert_load_pass, + logical_to_physical_map, + logical_replica_count, + ) + + return cls( + physical_to_logical_map, + logical_to_physical_map, + logical_replica_count, + expert_load_pass, + expert_load_window, + expert_load_window_size=expert_load_window_size, + expert_rearrangement_step=expert_rearrangement_step, + expert_rearrangement_step_interval=eplb_step_interval, + ) + + def step(self, + model: MixtureOfExperts, + is_dummy: bool = False, + is_profile: bool = False, + log_stats: bool = False) -> None: + """ + Step the EPLB state. + + Args: + model (MixtureOfExperts): The MoE model. + is_dummy (bool): If `True`, this is a dummy step and the load + metrics recorded in this forward pass will not count. Defaults + to `False`. + is_profile (bool): If `True`, perform a dummy rearrangement + with maximum communication cost. This is used in `profile_run` + to reserve enough memory for the communication buffer. + log_stats (bool): If `True`, log the expert load metrics. + + # Stats + The metrics are all summed up across layers. + - `avg_tokens`: The average load across ranks. + - `max_tokens`: The maximum load across ranks. + - `balancedness`: The ratio of average load to maximum load. + """ + + if is_profile: + self.rearrange(model, is_profile=True) + return + + if is_dummy: + # Do not record load metrics for dummy steps + self.expert_load_pass.zero_() + + if log_stats: + # `num_tokens`: (num_moe_layers,) + num_tokens = self.expert_load_pass.sum(dim=-1) + + # Collect load metrics from all ranks + ep_group = get_ep_group().device_group + num_tokens_list = [ + torch.empty_like(num_tokens) for _ in range(ep_group.size()) + ] + all_gather(num_tokens_list, num_tokens, group=ep_group) + # Stack to get (num_ranks, num_moe_layers) + num_tokens_per_rank = torch.stack(num_tokens_list).float() + + # Compute balancedness ratio: + # for each layer: + # (mean load across ranks) / (max load across ranks) + avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum( + dim=0) + + # Just to make type checker happy + tokens_tensors: list[float] = torch.stack( + [avg_tokens_tensor, max_tokens_tensor]).tolist() + avg_tokens, max_tokens = tokens_tensors + balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 + + if ep_group.rank() == 0: + logger.info( + "EPLB step: avg_tokens=%.2f, max_tokens=%d, " + "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + + # Update the expert load sliding window + if not is_dummy: + self.expert_load_window[self.expert_load_window_step] = ( + self.expert_load_pass.clone()) + self.expert_load_window_step += 1 + if self.expert_load_window_step >= self.expert_load_window_size: + self.expert_load_window_step = 0 + self.expert_load_pass.zero_() + + # Step the expert rearrangement step + # Note that even if this is a dummy step, we still increment the + # rearrangement step and perform rearrangement to ensure all ranks are + # performing collective communication. + self.expert_rearrangement_step += 1 + if (self.expert_rearrangement_step + >= self.expert_rearrangement_step_interval): + self.expert_rearrangement_step = 0 + self.rearrange(model) + + def rearrange(self, + model: MixtureOfExperts, + is_profile: bool = False) -> None: + """ + Rearrange the experts according to the current load. + """ + + ep_group = get_ep_group().device_group + ep_rank = ep_group.rank() + + time_start = None + is_main_rank = ep_rank == 0 + if is_main_rank: + torch.cuda.synchronize() + time_start = time.perf_counter() + logger.info("Rearranging experts %s...", + "(profile)" if is_profile else "") + + # This mapping is only used here, so we do not store it in the state + physical_expert_start = ep_rank * model.num_local_physical_experts + physical_expert_end = (physical_expert_start + + model.num_local_physical_experts) + # (num_moe_layers, num_local_physical_experts) + local_physical_to_logical_map = self.physical_to_logical_map[ + :, + physical_expert_start:physical_expert_end, + ] + + # Map the local physical expert load to global logical experts + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=local_physical_to_logical_map.unsqueeze(0).expand_as( + self.expert_load_window).long(), + src=self.expert_load_window, + ) + + # Perform all-reduce to get the expert load across all ranks + global_expert_load_window = logical_expert_load_window.sum(dim=0) + all_reduce(global_expert_load_window, group=ep_group) + + # TODO(bowen): Treat differently for prefill and decode nodes + num_replicas = model.num_physical_experts + num_groups = model.num_expert_groups + num_nodes = get_node_count() + num_gpus = ep_group.size() + + if num_gpus % num_nodes != 0: + logger.warning_once( + f"num_gpus % num_nodes != 0, " + "not using hierarchical rearrangement algorithm.\n" + f"{num_gpus=}, {num_nodes=}") + + # Get new expert mappings + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = (rebalance_experts( + global_expert_load_window, + num_replicas, + num_groups, + num_nodes, + num_gpus, + )) + + # Update expert weights + rearrange_expert_weights_inplace( + self.physical_to_logical_map, + new_physical_to_logical_map, + model.expert_weights, + ep_group, + is_profile, + ) + + if not is_profile: + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + self.logical_to_physical_map.copy_(new_logical_to_physical_map) + self.logical_replica_count.copy_(new_logical_replica_count) + + if is_main_rank: + assert time_start is not None + torch.cuda.synchronize() + time_end = time.perf_counter() + logger.info( + "Rearranged experts%sin %.2f seconds.", + " (profile) " if is_profile else " ", + time_end - time_start, + ) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py new file mode 100644 index 0000000000000..7ad6d566b55bb --- /dev/null +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Expert parallelism load balancer (EPLB) for vLLM. + +This module implements the core rearrangement algorithm. + +The rearrangement algorithm is adapted from +[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). + +Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example +on how the EPLB algorithm works. +""" + +import torch + + +def balanced_packing(weight: torch.Tensor, + num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly + n/m objects and the weights of all packs are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), + dtype=torch.int64, + device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, + fill_value=-1, + dtype=torch.int64, + device="cpu") + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min( + (i + for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts( + weight: torch.Tensor, + num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum + load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, + device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_hierarchical( + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange(perm.size(1), dtype=torch.int64, + device=perm.device).expand(perm.shape), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing( + tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * + group_size).unsqueeze(-1) + + torch.arange(group_size, + dtype=torch.int64, + device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, + num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all + logical experts + num_replicas: number of physical experts, must be a multiple of + `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of + each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica + indices for each expert + expert_count: [layers, num_logical_experts], number of physical + replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus) + num_redundant_experts = num_replicas - num_logical_experts + maxlogcnt = num_redundant_experts + 1 + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, + device=log2phy.device).expand(num_layers, -1), + ) + return phy2log, log2phy, logcnt + + +__all__ = ["rebalance_experts"] diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py new file mode 100644 index 0000000000000..cf173c734afd1 --- /dev/null +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +The actual execution of the rearrangement. + +This involves the exchange of expert weights between GPUs. +""" + +from collections.abc import Iterable, MutableSequence, Sequence +from functools import partial + +import torch +from torch.distributed import (P2POp, ProcessGroup, all_gather, + batch_isend_irecv, get_global_rank) + + +def idx_local_to_global( + local_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a local expert index to a global expert index. + """ + return ep_rank * local_cnt + local_idx + + +def idx_global_to_local( + global_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a global expert index to a local expert index. + """ + return global_idx - ep_rank * local_cnt + + +def global_idx_to_rank( + global_idx: int, + local_cnt: int, +) -> int: + """ + Convert a global expert index to a rank index. + """ + return global_idx // local_cnt + + +def get_ep_ranks_with_expert( + idx: int, + num_local_experts: int, + old_indices: Sequence[int], + new_indices: Sequence[int], +) -> tuple[MutableSequence[int], MutableSequence[int]]: + """ + Get the ranks of the experts that need to be exchanged. + + Args: + idx: The index of the expert. + num_local_experts: The number of local experts. + old_indices: The old indices of the experts. + new_indices: The new indices of the experts. + + Returns: + A tuple of two lists: + - The ranks of the experts that need to be sent. + - The ranks of the experts that need to be received. + """ + global2rank = partial( + global_idx_to_rank, + local_cnt=num_local_experts, + ) + + ranks_to_send: list[int] = [] + ranks_to_recv: list[int] = [] + + for i, e in enumerate(old_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_send or ranks_to_send[-1] != rank: + ranks_to_send.append(rank) + + for i, e in enumerate(new_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_recv or ranks_to_recv[-1] != rank: + ranks_to_recv.append(rank) + + # Remove those ranks that can get this expert locally. + ranks_to_send_set = set(ranks_to_send) + ranks_to_recv_actual = [ + rank for rank in ranks_to_recv if rank not in ranks_to_send_set + ] + + return ranks_to_send, ranks_to_recv_actual + + +def shuffle_layer( + num_local_experts: int, + ep_rank: int, + old_indices: Sequence[int], + new_indices: Sequence[int], + expert_weights: Iterable[torch.Tensor], + expert_weights_buffer: Sequence[torch.Tensor], + ep_group: ProcessGroup, +) -> None: + """ + Perform expert weights rearrangement of one layer. + """ + local2global = partial( + idx_local_to_global, + local_cnt=num_local_experts, + ep_rank=ep_rank, + ) + + # 0. Do nothing for experts that did not change. + is_unchanged = [ + old_indices[local2global(i)] == new_indices[local2global(i)] + for i in range(num_local_experts) + ] + + # 1. Perform weight copy inside the local rank. + is_received_locally = is_unchanged[:] + for src in range(num_local_experts): + src_global = local2global(src) + for dst in range(num_local_experts): + dst_global = local2global(dst) + if is_received_locally[dst]: + continue + if old_indices[src_global] == new_indices[dst_global]: + is_received_locally[dst] = True + for weight, buffer in zip(expert_weights, + expert_weights_buffer): + buffer[dst].copy_(weight[src]) + + p2p_ops: list[P2POp] = [] + + # 2. Initiate sending of weights. + experts_send_loc: dict[int, int] = {} + for src in range(num_local_experts): + expert = old_indices[local2global(src)] + if expert in experts_send_loc: + continue + experts_send_loc[expert] = src + + # We need to sort here to match send/recv + for expert, src in sorted(experts_send_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the ranks to send by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + sender_pos = ranks_to_send.index(ep_rank) + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + recv_ranks = ranks_to_recv[recv_begin:recv_end] + + # Tackle remainders + remainder_start = len(ranks_to_send) * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < len(ranks_to_recv): + recv_ranks.append(ranks_to_recv[recver_pos]) + + for dst in recv_ranks: + dst_global = get_global_rank(ep_group, dst) + p2p_ops += [ + P2POp( + torch.distributed.isend, + weight[src], + dst_global, + ) for weight in expert_weights + ] + + # 3. Initiate receiving of weights. + experts_recv_loc: dict[int, int] = {} + for dst in range(num_local_experts): + if is_received_locally[dst]: + continue + expert = new_indices[local2global(dst)] + if expert in experts_recv_loc: + continue + experts_recv_loc[expert] = dst + + # We need to sort here to match send/recv + for expert, dst in sorted(experts_recv_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the rank to recv by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + recver_pos = ranks_to_recv.index(ep_rank) + remainder_start = len(ranks_to_send) * num_dst_per_sender + if recver_pos < remainder_start: + src = ranks_to_send[recver_pos // num_dst_per_sender] + else: + src = ranks_to_send[recver_pos - remainder_start] + + src_global = get_global_rank(ep_group, src) + p2p_ops += [ + P2POp( + torch.distributed.irecv, + weight[dst], + src_global, + ) for weight in expert_weights_buffer + ] + + # 4. Execute the P2P operations. The real communication happens here. + if p2p_ops: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + # 5. Copy the weights from the buffer back to the original weights. + for dst in range(num_local_experts): + if is_unchanged[dst]: + continue + if is_received_locally[dst]: + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[dst]) + else: + expert = new_indices[local2global(dst)] + src = experts_recv_loc[expert] + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[src]) + + +def rearrange_expert_weights_inplace( + old_global_expert_indices: torch.Tensor, + new_global_expert_indices: torch.Tensor, + expert_weights: Sequence[Iterable[torch.Tensor]], + ep_group: ProcessGroup, + is_profile: bool = False, +) -> None: + """ + Rearranges the expert weights in place according to the new expert indices. + + The value of the indices arguments are logical indices of the experts, + while keys are physical. + + Args: + old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + expert_weights: A sequence of shape (num_moe_layers)(weight_count) + of tensors of shape (num_local_physical_experts, hidden_size_i). + For example, a linear layer may have up and down projection, + so weight_count = 2. Each weight's hidden size can be different. + ep_group: The device process group for expert parallelism. + is_profile (bool): If `True`, do not perform any actual weight copy. + This is used during profile run, where we only perform dummy + communications to reserve enough memory for the buffers. + """ + num_moe_layers, num_physical_experts = old_global_expert_indices.shape + assert len(expert_weights) == num_moe_layers + + num_local_physical_experts = next(iter(expert_weights[0])).shape[0] + assert new_global_expert_indices.shape == (num_moe_layers, + num_physical_experts) + + ep_rank = ep_group.rank() + ep_size = ep_group.size() + assert num_physical_experts == ep_size * num_local_physical_experts + + # A buffer to hold the expert weights in one layer during the exchange. + # NOTE: Currently we assume the same weights across different layers + # have the same shape. + expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] + + if is_profile: + # Maximum send size is to send all local experts to all ranks, + # So we use a dummy `all_gather` to reserve enough communication buffer + for weight, buffer in zip(expert_weights[0], expert_weights_buffer): + # A `/dev/null`-like buffer to avoid real memory allocation + dummy_recv_buffer = [buffer for _ in range(ep_size)] + # NOTE(bowen): Needed this barrier to avoid OOM during actual + # execution. I'm not very sure why this is needed + torch.distributed.barrier() + all_gather( + dummy_recv_buffer, + weight, + group=ep_group, + ) + return + + for layer in range(num_moe_layers): + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + shuffle_layer( + num_local_physical_experts, + ep_rank, + old_global_expert_indices[layer].tolist(), + new_global_expert_indices[layer].tolist(), + expert_weights[layer], + expert_weights_buffer, + ep_group, + ) + + +__all__ = ["rearrange_expert_weights_inplace"] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9d1008b6b350a..6c908f88b9a92 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -320,6 +320,11 @@ class EngineArgs: data_parallel_rpc_port: Optional[int] = None data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_eplb: bool = ParallelConfig.enable_eplb + num_redundant_experts: int = ParallelConfig.num_redundant_experts + eplb_window_size: int = ParallelConfig.eplb_window_size + eplb_step_interval: int = ParallelConfig.eplb_step_interval + eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -666,6 +671,16 @@ class EngineArgs: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--enable-eplb", + **parallel_kwargs["enable_eplb"]) + parallel_group.add_argument("--num-redundant-experts", + **parallel_kwargs["num_redundant_experts"]) + parallel_group.add_argument("--eplb-window-size", + **parallel_kwargs["eplb_window_size"]) + parallel_group.add_argument("--eplb-step-interval", + **parallel_kwargs["eplb_step_interval"]) + parallel_group.add_argument("--eplb-log-balancedness", + **parallel_kwargs["eplb_log_balancedness"]) parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1135,6 +1150,11 @@ class EngineArgs: data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.num_redundant_experts, + eplb_window_size=self.eplb_window_size, + eplb_step_interval=self.eplb_step_interval, + eplb_log_balancedness=self.eplb_log_balancedness, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 133881fd04990..6fe95d32a10e7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,9 +3,10 @@ import importlib from abc import abstractmethod +from collections.abc import Iterable from dataclasses import dataclass from enum import Enum -from typing import Callable, Optional, Union +from typing import Callable, Literal, Optional, Union, overload import torch import torch.nn.functional as F @@ -20,6 +21,7 @@ from vllm.distributed import (get_dp_group, get_ep_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -435,6 +437,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError @@ -574,7 +580,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + return self.forward( x=x, layer=layer, @@ -821,6 +835,7 @@ class FusedMoE(torch.nn.Module): reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. + enable_eplb: Whether to enable expert parallelism load balancer. """ def __init__( @@ -845,6 +860,8 @@ class FusedMoE(torch.nn.Module): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + num_redundant_experts: int = 0, ): super().__init__() if params_dtype is None: @@ -860,7 +877,7 @@ class FusedMoE(torch.nn.Module): get_dp_group().world_size), vllm_parallel_config=vllm_config.parallel_config)) - self.global_num_experts = num_experts + self.global_num_experts = num_experts + num_redundant_experts # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config @@ -869,8 +886,20 @@ class FusedMoE(torch.nn.Module): compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.enable_eplb = enable_eplb + self.expert_load_view: Optional[torch.Tensor] = None + self.logical_to_physical_map: Optional[torch.Tensor] = None + self.logical_replica_count: Optional[torch.Tensor] = None + # Determine expert maps if self.use_ep: + if self.enable_eplb: + assert self.global_num_experts % self.ep_size == 0, \ + "EPLB currently only supports even distribution of " \ + "experts across ranks." + else: + assert num_redundant_experts == 0, \ + "Redundant experts are only supported with EPLB." self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, @@ -937,6 +966,20 @@ class FusedMoE(torch.nn.Module): assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + if self.enable_eplb: + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8MoEMethod) + if not isinstance(quant_method, Fp8MoEMethod): + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError("EPLB is only supported for FP8 " + "quantization for now.") + moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, @@ -965,8 +1008,9 @@ class FusedMoE(torch.nn.Module): dtype=act_dtype, device=torch.cuda.current_device()) + # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts), + (envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts), dtype=act_dtype, device=torch.cuda.current_device()) @@ -1130,13 +1174,33 @@ class FusedMoE(torch.nn.Module): return expert_id return self.expert_map[expert_id].item() + @overload def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: + shard_id: str, expert_id: int, + return_success: Literal[False]) -> None: + ... + @overload + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int, + return_success: Literal[True]) -> bool: + ... + + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False) -> Optional[bool]: expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: - return + # Failed to load this param since it's not local to this rank + return False if return_success else None + # Hereafter, `expert_id` is local physical id + quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format @@ -1163,7 +1227,7 @@ class FusedMoE(torch.nn.Module): if is_gguf_weight_type: param.weight_type = loaded_weight.item() param.data.copy_(loaded_weight) - return + return True if return_success else None # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -1202,7 +1266,7 @@ class FusedMoE(torch.nn.Module): self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case g_idx if "g_idx" in weight_name: @@ -1211,7 +1275,7 @@ class FusedMoE(torch.nn.Module): loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None if "ModelOpt" in quant_method_name: if ('weight_scale_2' in weight_name @@ -1227,7 +1291,7 @@ class FusedMoE(torch.nn.Module): loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name @@ -1264,7 +1328,7 @@ class FusedMoE(torch.nn.Module): else: raise ValueError( f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") - return + return True if return_success else None # Case weight_shape if "weight_shape" in weight_name: @@ -1272,7 +1336,7 @@ class FusedMoE(torch.nn.Module): self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case model weights if "weight" in weight_name: @@ -1282,23 +1346,77 @@ class FusedMoE(torch.nn.Module): loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None + + return False if return_success else None + + def get_expert_weights(self) -> Iterable[torch.Tensor]: + weights = list(self.named_parameters()) + assert all(weight.is_contiguous() for _, weight in weights) + + # Filter out the non-expert weights. + # `e_score_correction_bias` is a bias for each logical expert, + # with shape (num_logical_experts,), not an expert weight. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + } + + return [ + weight.view(self.local_num_experts, -1) for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + ] + + def set_eplb_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + """ + Register the EPLB state in this layer. + + This is used later in forward pass, where we get the expert mapping + and record the load metrics in `expert_load_view`. + """ + self.expert_load_view = expert_load_view[moe_layer_idx] + self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.logical_replica_count = logical_replica_count[moe_layer_idx] @staticmethod - def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None): + def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None, + enable_eplb: bool = False, + expert_map: Optional[torch.Tensor] = None, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]): + The weights and *global physical* expert ids of the top-k experts. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - # DeekSeekv2 uses grouped_top_k + # DeepSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None @@ -1330,6 +1448,74 @@ class FusedMoE(torch.nn.Module): if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # TODO: maybe optimize this by using specified kernels, + # or compute pseudo-random indices by modulo + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + replica_indices = ( + torch.rand_like(topk_ids, dtype=torch.float) * + logical_replica_count[topk_ids_long]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_logical_experts,) + + # Mask out non-local experts + if expert_map is not None: + topk_ids_local = expert_map[topk_ids] + topk_ids_flatten = topk_ids_local.flatten() + else: + topk_ids_flatten = topk_ids.flatten() + + # Should be equivalent to: + # ``` + # topk_ids_masked = topk_ids_local[topk_ids_local >= 0] + # expert_load_view += topk_ids_masked.bincount( + # minlength=expert_load_view.shape[0]) + # ``` + # We use `scatter_add_` since `bincount` cannot be compiled + + # Performance optimization: + # `masked_fill` is significantly faster than `masked_select` + invalid_mask = topk_ids_flatten < 0 + # Replace invalid expert ids with 0 (just a dummy position) + # to avoid out-of-bounds errors in scatter_add_ + index = topk_ids_flatten.masked_fill_(invalid_mask, 0) + # `src` is the valid mask, which is 1 for valid and 0 for invalid + src = ~invalid_mask + + expert_load_view.scatter_add_(dim=0, + index=index.long(), + src=src.to(expert_load_view)) + + topk_ids = topk_ids.to(dtype=indices_type) + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: @@ -1410,6 +1596,10 @@ class FusedMoE(torch.nn.Module): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if not skip_result_store: @@ -1467,6 +1657,10 @@ class FusedMoE(torch.nn.Module): e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if do_naive_dispatch_combine: @@ -1481,16 +1675,30 @@ class FusedMoE(torch.nn.Module): @classmethod def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> list[tuple[str, str, int, str]]: + num_experts: int, + num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: + + num_physical_experts = num_experts + num_redundant_experts + + # In the returned mapping: + # - `expert_id` is the physical expert id + # - `weight_name` contains the weight name of the logical expert + # So that we should map the expert id to logical in `weight_name` + physical_to_logical_map = \ + EplbState.build_initial_global_physical_to_logical_map( + num_experts, num_redundant_experts) return [ # (param_name, weight_name, expert_id, shard_id) ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, weight_name in [ + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + expert_id, shard_id) for expert_id in range(num_physical_experts) + for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), ("w3", ckpt_up_proj_name), diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 56d803c6baf12..aff54bc495b2d 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -482,7 +482,15 @@ class AWQMoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `AWQMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f14131c5f05b3..7703b9e687c4a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -331,7 +331,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoEMethod` yet.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -593,7 +601,15 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -722,7 +738,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Int8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -1012,7 +1037,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsWNA16MarlinMoEMethod` yet.") + assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") assert not apply_router_weight_on_input, ( @@ -1228,7 +1262,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError("EPLB not supported for " + "`CompressedTensorsWNA16MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 01b0064f08058..47eca80609e0e 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -117,7 +117,15 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ExpertsInt8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b3042bfaed3d7..d2eda541f7a40 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -825,7 +825,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -839,6 +848,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) if self.rocm_aiter_moe_enabled: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9c8f74545d37d..86da04c39989b 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -520,7 +520,15 @@ class GGUFMoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GGUFMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e9b8dc3266b4a..48ab04c9ab37f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -635,7 +635,15 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GPTQMarlinMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3f79b203aa170..e35db5b31dba7 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -664,7 +664,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + if self.use_marlin: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 3aa23f0682576..c5055a02fa3d5 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -297,7 +297,15 @@ class MoeWNA16Method(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `MoeWNA16Method` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4c2da4c8b04ee..a040c430cbcaa 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -205,7 +205,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0f996d04e6e80..f712b626c74c3 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Any, Optional, Union import torch @@ -32,8 +33,10 @@ from transformers import PretrainedConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -51,7 +54,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import MixtureOfExperts, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -99,11 +102,17 @@ class DeepseekV2MoE(nn.Module): config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -120,6 +129,22 @@ class DeepseekV2MoE(nn.Module): else: self.gate.e_score_correction_bias = None + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = FusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -133,7 +158,9 @@ class DeepseekV2MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -503,6 +530,7 @@ class DeepseekV2DecoderLayer(nn.Module): model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -543,6 +571,7 @@ class DeepseekV2DecoderLayer(nn.Module): config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, ) else: self.mlp = DeepseekV2MLP( @@ -615,6 +644,7 @@ class DeepseekV2Model(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size @@ -636,6 +666,7 @@ class DeepseekV2Model(nn.Module): model_config=model_config, cache_config=cache_config, quant_config=quant_config, + enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers") @@ -681,7 +712,7 @@ class DeepseekV2Model(nn.Module): return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -700,6 +731,44 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -752,7 +821,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -789,24 +859,44 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -824,6 +914,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index f759f8f1f2731..3ea424e44b62e 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, MutableSequence from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) @@ -426,6 +427,73 @@ def is_hybrid( return isinstance(model, IsHybrid) +@runtime_checkable +class MixtureOfExperts(Protocol): + """ + Check if the model is a mixture of experts (MoE) model. + """ + + expert_weights: MutableSequence[Iterable[Tensor]] + """ + Expert weights saved in this rank. + + The first dimension is the layer, and the second dimension is different + parameters in the layer, e.g. up/down projection weights. + """ + + num_moe_layers: int + """Number of MoE layers in this model.""" + + num_expert_groups: int + """Number of expert groups in this model.""" + + num_logical_experts: int + """Number of logical experts in this model.""" + + num_physical_experts: int + """Number of physical experts in this model.""" + + num_local_physical_experts: int + """Number of local physical experts in this model.""" + + num_routed_experts: int + """Number of routed experts in this model.""" + + num_shared_experts: int + """Number of shared experts in this model.""" + + num_redundant_experts: int + """Number of redundant experts in this model.""" + + def set_eplb_state( + self, + expert_load_view: Tensor, + logical_to_physical_map: Tensor, + logical_replica_count: Tensor, + ) -> None: + """ + Register the EPLB state in the MoE model. + + Since these are views of the actual EPLB state, any changes made by + the EPLB algorithm are automatically reflected in the model's behavior + without requiring additional method calls to set new states. + + You should also collect model's `expert_weights` here instead of in + the weight loader, since after initial weight loading, further + processing like quantization may be applied to the weights. + + Args: + expert_load_view: A view of the expert load metrics tensor. + logical_to_physical_map: Mapping from logical to physical experts. + logical_replica_count: Count of replicas for each logical expert. + """ + ... + + +def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: + return isinstance(model, MixtureOfExperts) + + @runtime_checkable class HasNoOps(Protocol): has_noops: ClassVar[Literal[True]] = True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 40639fdf24338..3c9de57204051 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -21,6 +21,7 @@ from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) +from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -33,7 +34,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.interfaces import has_step_pooler +from vllm.model_executor.models.interfaces import (has_step_pooler, + is_mixture_of_experts) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -150,6 +152,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Sampler self.sampler = Sampler() + self.eplb_state: Optional[EplbState] = None + """ + State of the expert parallelism load balancer. + + Will be lazily initialized when the model is loaded. + """ + # Lazy initializations # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache @@ -1178,6 +1187,24 @@ class GPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) + def eplb_step(self, + is_dummy: bool = False, + is_profile: bool = False) -> None: + """ + Step for the EPLB (Expert Parallelism Load Balancing) state. + """ + if not self.parallel_config.enable_eplb: + return + + assert self.eplb_state is not None + assert is_mixture_of_experts(self.model) + self.eplb_state.step( + self.model, + is_dummy, + is_profile, + log_stats=self.parallel_config.eplb_log_balancedness, + ) + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size @@ -1595,6 +1622,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() + self.eplb_step() + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, @@ -1729,6 +1758,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): time_after_load - time_before_load) prepare_communication_buffer_for_model(self.model) + if is_mixture_of_experts( + self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", + self.model_config.model) + self.eplb_state = EplbState.build( + self.model, + self.device, + self.parallel_config, + ) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", @@ -1887,6 +1926,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, num_tokens: int, capture_attn_cudagraph: bool = False, + skip_eplb: bool = False, + is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # Padding for DP @@ -1983,6 +2024,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states, hidden_states[logit_indices] @@ -2175,8 +2226,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens) + = self._dummy_run(self.max_num_tokens, is_profile=True) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -2210,10 +2262,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), desc="Capturing CUDA graphs", total=len(self.cudagraph_batch_sizes)): + # We skip EPLB here since we don't want to record dummy metrics for _ in range( self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, + skip_eplb=True) + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, + skip_eplb=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b0f80c701325f..9e7e44d068612 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -259,9 +259,10 @@ class Worker(WorkerBase): x for x in warmup_sizes if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size) + self.model_runner._dummy_run(size, skip_eplb=True) if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -274,8 +275,12 @@ class Worker(WorkerBase): max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) + # We skip EPLB here since we don't want to record dummy metrics hidden_states, last_hidden_states = \ - self.model_runner._dummy_run(num_tokens=max_num_reqs) + self.model_runner._dummy_run( + num_tokens=max_num_reqs, + skip_eplb=True, + ) if self.model_runner.is_pooling_model: self.model_runner._dummy_pooler_run(hidden_states) else: From 71799fd005ca08c9c362e548945a3dde93790fec Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 27 Jun 2025 12:21:04 +0900 Subject: [PATCH 028/175] [CI Failure] Fix OOM with test_oot_registration_embedding (#20144) Signed-off-by: mgoin --- tests/models/test_oot_registration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index ef0ad613d5252..59de35644c12d 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -53,7 +53,9 @@ def test_oot_registration_embedding( with monkeypatch.context() as m: m.setenv("VLLM_PLUGINS", "register_dummy_model") prompts = ["Hello, my name is", "The text does not matter"] - llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy") + llm = LLM(model=dummy_gemma2_embedding_path, + load_format="dummy", + max_model_len=2048) outputs = llm.embed(prompts) for output in outputs: From a57d57fa72f092b9b8ed8415553ec02609daa644 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 26 Jun 2025 23:50:06 -0400 Subject: [PATCH 029/175] [Quantization] Bump to use latest `compressed-tensors` (#20033) Signed-off-by: Dipika Co-authored-by: Kyle Sayers --- requirements/common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/common.txt b/requirements/common.txt index 9a9ae1d93896b..6cc304e5b1f6d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -37,7 +37,7 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.10.1 # required for compressed-tensors +compressed-tensors == 0.10.2 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files From 2d7779f888f6443c067e0c36bab808ef6b368221 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Fri, 27 Jun 2025 05:50:09 +0200 Subject: [PATCH 030/175] [Perf] SM100 FP8 GEMM Optimizations after cutlass_profiler (#20071) Signed-off-by: ilmarkov Co-authored-by: ilmarkov --- .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 1549ed96aa2be..24564efbd21be 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -29,26 +29,12 @@ struct sm100_fp8_config_default { template typename Epilogue> struct sm100_fp8_config_M256 { - // M in (128, 256] + // M in (64, 256] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _2, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm_sm100; -}; - -template typename Epilogue> -struct sm100_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; - using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_128, _128, _256>; - using ClusterShape = Shape<_2, _4, _1>; + using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; @@ -57,12 +43,26 @@ struct sm100_fp8_config_M128 { template typename Epilogue> struct sm100_fp8_config_M64 { - // M in [1, 64] + // M in (16, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M16 { + // M in [1, 16] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _4, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; @@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM16 = + typename sm100_fp8_config_M16::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm100_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm100_fp8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM256 = typename sm100_fp8_config_M256::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 - if (mp2 <= 64) { - // m in [1, 64] + if (mp2 <= 16) { + // m in [1, 16] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 64) { + // m in (16, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); } else if (mp2 <= 256) { - // m in (128, 256] + // m in (64, 256] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { From 44d2e6af636b7a62dbec1bd985543cbe2918049b Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 27 Jun 2025 12:50:12 +0900 Subject: [PATCH 031/175] [Bugfix] Build moe_data for both sm100 and sm90 (#20086) Signed-off-by: mgoin --- CMakeLists.txt | 14 ++++++++++++-- csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu | 9 +++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 402131b7a1e7a..8966a663d3ccf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -513,6 +513,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${FP4_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") else() message(STATUS "Not building NVFP4 as no compatible archs were found.") @@ -547,8 +548,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -566,6 +566,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # moe_data.cu is used by all CUTLASS MoE kernels. + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + endif() + # # Machete kernels diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 348525810810c..a2080c3001190 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -241,7 +241,7 @@ void get_cutlass_moe_mm_data( // mm to run it for. int32_t version_num = get_sm_version_num(); #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k, @@ -252,7 +252,7 @@ void get_cutlass_moe_mm_data( false, "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " "CUDA device capability: ", - version_num, ". Required capability: 90"); + version_num, ". Required capability: 90 or 100"); } void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, @@ -265,7 +265,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, num_local_experts, padded_m, n, k); @@ -275,7 +276,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, false, "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " "for CUDA device capability: ", - version_num, ". Required capability: 90"); + version_num, ". Required capability: 90 or 100"); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, From 0740e29b66ca5589f7f35a7c25b6c3de1a749da1 Mon Sep 17 00:00:00 2001 From: li haoyang Date: Fri, 27 Jun 2025 11:54:24 +0800 Subject: [PATCH 032/175] [Feature] add quick all reduce (#19744) Signed-off-by: ilmarkov Signed-off-by: Haoyang Li Co-authored-by: ilmarkov --- CMakeLists.txt | 8 + csrc/custom_quickreduce.cu | 114 +++ csrc/ops.h | 11 + csrc/quickreduce/base.h | 338 +++++++++ csrc/quickreduce/quick_reduce.h | 196 +++++ csrc/quickreduce/quick_reduce_impl.cuh | 698 ++++++++++++++++++ csrc/torch_bindings.cpp | 18 + tests/distributed/test_quick_all_reduce.py | 138 ++++ vllm/_custom_ops.py | 32 + .../device_communicators/cuda_communicator.py | 22 +- .../device_communicators/quick_all_reduce.py | 278 +++++++ vllm/envs.py | 28 + 12 files changed, 1879 insertions(+), 2 deletions(-) create mode 100644 csrc/custom_quickreduce.cu create mode 100644 csrc/quickreduce/base.h create mode 100644 csrc/quickreduce/quick_reduce.h create mode 100644 csrc/quickreduce/quick_reduce_impl.cuh create mode 100644 tests/distributed/test_quick_all_reduce.py create mode 100644 vllm/distributed/device_communicators/quick_all_reduce.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 8966a663d3ccf..b1adeac586f2e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -648,6 +648,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if CUDA endif endif() +if (VLLM_GPU_LANG STREQUAL "HIP") + # Add QuickReduce kernels + list(APPEND VLLM_EXT_SRC + "csrc/custom_quickreduce.cu" + ) +# if ROCM endif +endif() + message(STATUS "Enabling C extension.") define_gpu_extension_target( _C diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu new file mode 100644 index 0000000000000..33d0d4a7226e6 --- /dev/null +++ b/csrc/custom_quickreduce.cu @@ -0,0 +1,114 @@ +#include +#include +#include +#include + +#ifdef USE_ROCM + + #include "quickreduce/quick_reduce.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, + std::optional qr_max_size) { + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) + throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank, qr_max_size); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, + const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, + torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + if (cast_bf2half) { + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } else { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } + } else { + throw std::runtime_error( + "quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + // The default is 2GB (2,147,483,648 bytes) + return static_cast(std::numeric_limits::max()) + 1; +} + + #define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \ + template struct quickreduce::AllReduceTwoshot, \ + cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, \ + cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; + +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) + +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) + +#endif // USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index f02f5083ac197..52c264d64ccad 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -360,3 +360,14 @@ std::tuple allocate_shared_buffer_and_handle( int64_t size); int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); + +#ifdef USE_ROCM +fptr_t init_custom_qr(int64_t rank, int64_t world_size, + std::optional qr_max_size = std::nullopt); +void qr_destroy(fptr_t _fa); +torch::Tensor qr_get_handle(fptr_t _fa); +void qr_open_handles(fptr_t _fa, const std::vector& handles); +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + int64_t quant_level, bool cast_bf2half = false); +int64_t qr_max_size(); +#endif \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h new file mode 100644 index 0000000000000..a2170e483207d --- /dev/null +++ b/csrc/quickreduce/base.h @@ -0,0 +1,338 @@ +#pragma once + +#include +#include +#include +#include + +#define __quickreduce_device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) +#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4) + +namespace quickreduce { + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + +// Setup acquire-release semantics for vector memory reads (mubuf instruction) +// as per architecture. +#if defined(__gfx942__) +// CDNA3: Scope bits sc0, sc1 + #define MUBUF_ACQUIRE 16 + #define MUBUF_RELEASE 16 +#elif (defined(__gfx908__) || defined(__gfx90a__)) +// CDNA1 and CDNA2 - glc bit + #define MUBUF_ACQUIRE 1 + #define MUBUF_RELEASE 0 +#endif + +static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + +// Number of atoms (4xf16x2_t) processed by a single thread +static constexpr int kAtoms = 8; + +// We use a workgroup of 256 threads +static constexpr int kBlockSize = 256; +static constexpr int kAtomStride = kBlockSize; + +// Size and atom stride of source/destination data that the block will +// process. +// Workgroup scope = Tile = (256 threads x 8 atoms x 16B) +static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); + +// Max number of blocks. 304 CUs on MI300 +static constexpr int kMaxNumBlocks = 304 * 4; + +// Standard CDNA wavefront size. +static constexpr int kWavefront = 64; + +// 256 thread, 4 wavefronts. +static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; + +// Number of threads in a group for quantization +// It corresponds to 32 F16 elements in quantization block +static constexpr int kThreadGroupSize = 8; + +// Methods +__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, + unsigned long y) { + return ((x + y - 1) / y); +} + +union BufferResource { + __quickreduce_device_inline__ constexpr BufferResource() + : config(0x00020000U) {} + + __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, + uint32_t buffer_size) + : address(buffer_address), range(buffer_size), config(0x00020000U) {} + + int32x4_t descriptor; + struct { + void* address; // 8B, out of which first 48b is address, and 16b is stride + // (unused) + uint32_t range; // Byte range for the buffer resource + uint32_t config; // Constant, DFMT=32b + }; +}; + +__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4( + int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__quickreduce_device_inline__ static void buffer_store_dwordx4( + int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; + +template +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B); + +template <> +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B) { + int32x4_t& tR_fragment = A[0]; + int32x4_t& tA_fragment = B[0]; + + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[0]) + : "v"(tR_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[1]) + : "v"(tR_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[2]) + : "v"(tR_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[3]) + : "v"(tR_fragment[3]), "v"(tA_fragment[3])); +} + +template <> +__quickreduce_device_inline__ void packed_assign_add( + int32x4_t* A, int32x4_t* B) { + nv_bfloat162* tA = reinterpret_cast(A); + nv_bfloat162* tB = reinterpret_cast(B); +#pragma unroll + for (int i = 0; i < 4; i++) { + tA[i] = __hadd2(tA[i], tB[i]); + } +} + +template +__quickreduce_device_inline__ int packed_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + int result; + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_min(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + int result; + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_abs_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + half2 wmaxh2 = __builtin_bit_cast(half2, a); + half2 wminh2 = __builtin_bit_cast(half2, b); + half2 wblockmaxh2; + + wblockmaxh2.x = + __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + wblockmaxh2.y = + __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + return __builtin_bit_cast(int, wblockmaxh2); +} + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; + R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; + return R.i; +} + +template +__quickreduce_device_inline__ int packed_add(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template +__quickreduce_device_inline__ int packed_sub(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + int result; + + // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" + : "=v"(result) + : "v"(kNegOne), "v"(b), "v"(a)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hsub2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_mul(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + int result; + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmul2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__quickreduce_device_inline__ int packed_rcp(int a); + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); +} + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; +} + +// changes dtype +__quickreduce_device_inline__ float T2float_cast(half a) { + return __half2float(a); +} + +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { + return __bfloat162float(a); +} + +template +__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { + const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + + int wmax, wmin, wblockmax; + int a, b; + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + + wmin = packed_min(a, b); + + // Reduce the max among a group of threads + // Note: This is basically 2 blocks of values setup as the + // upper/lower halves of the f16x2_t + for (int i = 1; i < kThreadGroupSize; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = packed_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = packed_min(wmin, y); + } + wblockmax = packed_abs_max(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + return wblockmax; +} + +__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { + __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); +} + +__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { + } +} + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h new file mode 100644 index 0000000000000..4fe4c44be7eb9 --- /dev/null +++ b/csrc/quickreduce/quick_reduce.h @@ -0,0 +1,196 @@ +#pragma once + +#include +#include +#include "quick_reduce_impl.cuh" + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \ + hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +namespace quickreduce { +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void +allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, + int rank, uint8_t** dbuffer_list, + uint32_t data_offset, uint32_t flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, + flag_color); + block += grid; + flag_color++; + } +} + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } + +enum QuickReduceQuantLevel { + F16 = 0, + INT8 = 1, + INT6 = 2, + INT4 = 3, +}; + +struct DeviceComms { + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + int64_t kMaxProblemSize = + static_cast(std::numeric_limits::max()) + 1; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + uint32_t flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + uint32_t data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { destroy(); } + + void init(int world_size, int rank, + std::optional max_problem_size = std::nullopt) { + destroy(); + this->world_size = world_size; + this->rank = rank; + if (max_problem_size.has_value() && max_problem_size.value() > 0) { + this->kMaxProblemSize = max_problem_size.value(); + } + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = + 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); + static int64_t data_buffer_size = 2 * this->kMaxProblemSize; + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, + hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } + int get_world_size() { return world_size; } + int get_rank() { return rank; } + bool status() { return initialized; } + hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; } + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i], + all_buffer_ipc_handles[i], + hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), + world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(T const* A, T* B, uint32_t N, int quant_level, + hipStream_t stream) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + + std::to_string(world_size)); + } + + // Configuration. + uint32_t msg_size = N * sizeof(T); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; + } + HIP_CHECK(cudaGetLastError()); + // Rotate the flag color. + flag_color += divceil(N, grid); + } +}; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh new file mode 100644 index 0000000000000..17816c552d25b --- /dev/null +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -0,0 +1,698 @@ +#pragma once + +#include +#include "base.h" + +namespace quickreduce { + +struct CodecBase { + const int thread; + const int rank; + const int group_leader; + __quickreduce_device_inline__ CodecBase(int thread, int rank) + : thread(thread), + rank(rank), + group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) { + set_fp16_ovfl(true); + } +}; + +// Default full precision codec. +template +struct CodecFP : public CodecBase { + static constexpr int kWorldSize = world_size; + static constexpr int kRankAtoms = kAtoms / kWorldSize; + + // Codec tile size process by this workgroup. + // Each thread processes atoms of f16x8_t (16B). + static constexpr int kRankTransmittedTileSize = + kBlockSize * kRankAtoms * sizeof(int32x4_t); + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + __quickreduce_device_inline__ CodecFP(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + __builtin_nontemporal_store(data[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + data[i] = __builtin_nontemporal_load(*recv_buffer + thread); + *recv_buffer += kAtomStride; + } + } +}; + +// Int4 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ4 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/8.0h, -1/8.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xB000B000 : 0xBE00BE00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-8, -8}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xC800C800 : 0xC100C100; + + // {+7, +7}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x47004700 : 0x40E040E0; + + // {+8, +8}, int16x2_t + static constexpr int kRangeBias = 0x00080008; + + __quickreduce_device_inline__ CodecQ4(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q4 into f16x8_t + int32x4_t w; + { + static constexpr uint kMask000F = 0x000F000F; + static constexpr uint kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + w[i] = packed_add(q4, kHalf2_1032); + } else { + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Int6 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int6 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ6 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1664; + static constexpr int kRankTileQ2Offset = 1024; + static constexpr int kRankTileScaleOffset = 1536; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/32.0h, -1/32.0h}, fp16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xA800A800 : 0xBD00BD00; + + // {1e-7, 1e-7}, fp16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-32, -32}, fp16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xD000D000 : 0xC200C200; + + // {+31, +31}, fp16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; + + // {+32, +32}, int16x2_t + static constexpr int kRangeBias = 0x00200020; + + __quickreduce_device_inline__ CodecQ6(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | + ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q6 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1056 = + 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; + if constexpr (std::is_same::value) { + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q6), "v"(kHalf2_1056)); + } else { + int32_t int16_2 = q4 | (q2 << 4); + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// Int8 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int8 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/128.0h, -1/128.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xA000A000 : 0xBC00BC00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-128, -128}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + + // {+128, +128}, int16x2_t + static constexpr int kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = + ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Twoshot All Reduce +template +struct AllReduceTwoshot { + static_assert(sizeof(T) == 2); + + static constexpr int kWorldSize = Codec::kWorldSize; + + __device__ static void run( + T const* __restrict__ input, T* __restrict__ output, + uint32_t const N, // number of elements + int const block, // block index + int const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + uint32_t const data_offset, // offset to start of the data buffer + uint32_t flag_color) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + uint8_t* rank_buffer = buffer_list[rank]; + Codec codec(thread, rank); + int block_id = blockIdx.x; + int grid_size = gridDim.x; + // -------------------------------------------------------- + // Read input into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(input), N * sizeof(T)); + uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + if constexpr (cast_bf2half) { + const nv_bfloat162* bf_buf = + reinterpret_cast(&tA[i]); + half2 half_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __bfloat1622float2(bf_buf[j]); + half_buf[j] = __float22half2_rn(f); + } + tA[i] = *reinterpret_cast(half_buf); + } + } + + // -------------------------------------------------------- + // Phase-1A: Write segment data into the communication buffer of the target + // rank responsible for this segment. + uint32_t comm_data0_offset = + data_offset + block_id * Codec::kTransmittedTileSize; + uint32_t comm_data1_offset = + grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + + uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); + uint32_t comm_flags1_offset = + grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data0_offset + + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + // -------------------------------------------------------- + // Phase-1B: Reduce the segment data from the communication buffers. + int32x4_t tR[Codec::kRankAtoms] = {}; + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data0_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags0_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // note: we reuse tA as temp buffer here + codec.recv(&recv_buffer, tA); + + for (int i = 0; i < Codec::kRankAtoms; i++) { + packed_assign_add(&tR[i], &tA[i]); + } + } + } + + // Phase-2: Write the reduced segment to every other rank + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data1_offset + + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, tR); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + + // Phase-2: Read the gather segments from the rank's communication buffer. + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data1_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags1_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // Gather all reduced and final rank segments into tA. + codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]); + } + } + + // -------------------------------------------------------- + // Write the result to output. + BufferResource dst_buffer(output, N * sizeof(T)); + uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + if constexpr (cast_bf2half) { + const half2* half_buf = reinterpret_cast(&tA[i]); + nv_bfloat162 bf16_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __half22float2(half_buf[j]); + bf16_buf[j] = __float22bfloat162_rn(f); + } + buffer_store_dwordx4(*reinterpret_cast(bf16_buf), + dst_buffer.descriptor, dst_offset, 0, 0); + } else { + buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); + } + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1a1896b4c1ee9..8bb71cad29dae 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -725,6 +725,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle); custom_ar.def("free_shared_buffer", &free_shared_buffer); +#ifdef USE_ROCM + // Quick Reduce all-reduce kernels + custom_ar.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " + "cast_bf2half) -> ()"); + custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + custom_ar.def("init_custom_qr", &init_custom_qr); + custom_ar.def("qr_destroy", &qr_destroy); + + custom_ar.def("qr_get_handle", &qr_get_handle); + + custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + // Max input size in bytes + custom_ar.def("qr_max_size", &qr_max_size); +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py new file mode 100644 index 0000000000000..a4added29144e --- /dev/null +++ b/tests/distributed/test_quick_all_reduce.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random + +import pytest +import ray +import torch +import torch.distributed as dist + +from vllm.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, graph_capture) +from vllm.platforms import current_platform + +from ..utils import (ensure_model_parallel_initialized, + init_test_distributed_environment, multi_process_parallel) + +torch.manual_seed(42) +random.seed(44) +# Size over 8MB is sufficient for custom quick allreduce. +test_sizes = [ + random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8) +] +for i, v in enumerate(test_sizes): + test_sizes[i] -= v % 8 + + +@ray.remote(num_gpus=1, max_calls=1) +def graph_quickreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + ensure_model_parallel_initialized(tp_size, pp_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + + for sz in test_sizes: + for dtype in [torch.float16, torch.bfloat16]: + with graph_capture(device=device) as graph_capture_context: + inp1 = torch.randint(1, + 23, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + inp2 = torch.randint(-23, + 1, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): + for _ in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1, atol=2.5, rtol=0.1) + torch.testing.assert_close(out2, inp2, atol=2.5, rtol=0.1) + + +@ray.remote(num_gpus=1, max_calls=1) +def eager_quickreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + # Size over 8MB is sufficient for custom quick allreduce. + sz = 16 * 1024 * 1024 + fa = get_tp_group().device_communicator.qr_comm + inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], + dtype=torch.float16, + device=device) + out = fa.quick_all_reduce(inp) + torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) + + inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], + dtype=torch.bfloat16, + device=device) + out = fa.quick_all_reduce(inp) + torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) + + +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test quick allreduce for rocm") +@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) +@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce]) +def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size, test_target, + quant_mode): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) + + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, + test_target) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d5a41284385e6..215f35bad34d9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1748,6 +1748,38 @@ def free_shared_buffer(ptr: int) -> None: torch.ops._C_custom_ar.free_shared_buffer(ptr) +# quick all reduce +def init_custom_qr(rank: int, + world_size: int, + qr_max_size: Optional[int] = None) -> int: + return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) + + +def qr_destroy(fa: int) -> None: + torch.ops._C_custom_ar.qr_destroy(fa) + + +def qr_all_reduce(fa: int, + inp: torch.Tensor, + out: torch.Tensor, + quant_level: int, + cast_bf2half: bool = False) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, + cast_bf2half) + + +def qr_get_handle(fa: int) -> torch.Tensor: + return torch.ops._C_custom_ar.qr_get_handle(fa) + + +def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + return torch.ops._C_custom_ar.qr_open_handles(fa, handles) + + +def qr_max_size() -> int: + return torch.ops._C_custom_ar.qr_max_size() + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 055d91690e676..3958d566b1745 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -8,6 +8,7 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase @@ -41,6 +42,8 @@ class CudaCommunicator(DeviceCommunicatorBase): CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -50,6 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ) self.ca_comm: Optional[CustomAllreduce] = None + self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -57,6 +61,14 @@ class CudaCommunicator(DeviceCommunicatorBase): device=self.device, ) + if current_platform.is_rocm(): + # Initialize a custom quick all-reduce implementation for AMD. + # Quick reduce is designed as a complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + # If it's a rocm, 'use_custom_allreduce==True' means it must + # currently be an MI300 series. + self.qr_comm = QuickAllReduce(group=self.cpu_group, + device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -79,8 +91,14 @@ class CudaCommunicator(DeviceCommunicatorBase): raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try custom allreduce first, - # and then pynccl. + # always try quick reduce first, then custom allreduce, + # and then pynccl. (quick reduce just for ROCM MI3*) + qr_comm = self.qr_comm + if qr_comm is not None and not qr_comm.disabled and \ + qr_comm.should_quick_allreduce(input_): + out = qr_comm.quick_all_reduce(input_) + assert out is not None + return out ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py new file mode 100644 index 0000000000000..c61231e2d33f4 --- /dev/null +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import get_current_vllm_config +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +logger = init_logger(__name__) + +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs and CUDA + quick_ar = False + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + +MB = 1024 * 1024 + + +class QuickAllReduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 8] + _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # The following data is based on kernel tests. + # In this order [FP, INT8, INT6, INT4]. + _QR_MIN_SIZE = { + (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB], + (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], + (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB], + (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], + } + + def __init__(self, group: ProcessGroup, + device: Union[int, str, torch.device]) -> None: + """ + Custom allreduce provides non-destructive acceleration and is + available for CUDA and ROCm MI300 series. + + Custom quick allreduce leverages quantization for further + acceleration on ROCm. It currently supports Q8, Q6, and Q4 + quantization formats and FP(float16, bfloat16). + + Quick allreduce is designed as a complement to custom allreduce. + Its initialization requires even stricter conditions. + + Only the ROCm MI300 series is supported for quick allreduce at + this time. + + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self.disabled = True + if not self._rocm_arch_available(): + logger.debug( + "Custom quick allreduce is only supported on ROCm MI300 series." + ) + return + + if not quick_ar: + # disable because of missing quick reduce library + # e.g. in a cuda environment + logger.info("Custom quick allreduce is disabled because " + "of missing custom quick allreduce library") + return + + self.group = group + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "Custom quick allreduce should be attached to a non-NCCL group.") + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom quick allreduce for + # multi-node case. + logger.warning("Custom quick allreduce is disabled because this " + "process group spans across nodes.") + return + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size + if world_size == 1: + # No need to initialize QuickReduce for single GPU case. + return + + if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom quick allreduce is disabled due to an " + "unsupported world size: %d. Supported world sizes: %s.", + world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(self.world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom quick allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda_alike() + self.fully_connected = current_platform.is_fully_connected( + physical_device_ids) + if self.world_size > 2 and not self.fully_connected: + logger.debug( + "Custom quick allreduce is disabled because it's not supported " + "on more than two PCIe-only GPUs. ") + return + + self.init_quick_all_reduce() + + def init_quick_all_reduce(self): + # On RocM, bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment variable is set to 1, we convert input to fp16 + self.use_fp16_kernels = envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16 + regime_str = envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}") + return + + if regime_str == "NONE": + logger.debug("Custom quick allreduce is disabled based " + "on env variable " + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'") + return + self.qr_quant_level = QuickReduceRegime[regime_str] + vllm_config = get_current_vllm_config() + if vllm_config is not None and \ + hasattr(vllm_config, "model_config") and \ + hasattr(vllm_config.model_config, "dtype"): + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + logger.debug( + "Custom quick allreduce disabled: only supports " + "float16 and float16, but get %s.", dtype) + return + + if dtype == torch.bfloat16 and self.use_fp16_kernels: + logger.info( + "Custom quick allreduce: BF16 inputs will be converted " + "to FP16 to improve performance. set " + "envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 " + "to turn off.") + + # VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB + qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB + if qr_max_size is not None: + if qr_max_size < 1: + logger.info( + "You should not set a max_size smaller than 1MB, which can " + "lead to error or degradation to custom allreduce or rccl." + ) + qr_max_size = qr_max_size * MB + self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) + self.qr_max_size = qr_max_size if qr_max_size is not None \ + else ops.qr_max_size() + self.create_shared_buffer() + self.disabled = False + + def _rocm_arch_available(self): + if not current_platform.is_rocm(): + return False + try: + props = torch.cuda.get_device_properties(0) + gcn_arch = getattr(props, "gcnArchName", "") + supported_archs = ['gfx94', 'gfx95'] + return any(gfx in gcn_arch for gfx in supported_archs) + except Exception as e: + logger.warning("Failed to determine ROCm for quick allreduce: %s", + e) + return False + + def create_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after init_custom_qr + """ + handle = ops.qr_get_handle(self._ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._ptr, handles) + + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.disabled: + return False + if inp.dtype not in self._SUPPORTED_DTYPES: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + dtype = inp.dtype + if self.use_fp16_kernels: + dtype = torch.float16 + return inp_size <= self.qr_max_size and \ + inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\ + [self.qr_quant_level.value] + + def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + # quick allreduce doesn't require a separate graph mode, + # as QR uses static IPC buffer. + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value, + self.use_fp16_kernels) + return out + + def close(self): + if not self.disabled and getattr(self, "_ptr", None): + if ops is not None: + ops.qr_destroy(self._ptr) + self._ptr = 0 + self.disabled = True + + def __del__(self): + self.close() diff --git a/vllm/envs.py b/vllm/envs.py index c9c81603a75a8..a3f19c7ee5c70 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -135,6 +135,9 @@ if TYPE_CHECKING: VLLM_KV_CACHE_LAYOUT: Optional[str] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False + VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" + VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True + VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None def get_default_cache_root(): @@ -690,6 +693,31 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # Custom quick allreduce kernel for MI3* cards + # Choice of quantization level: FP, INT8, INT6, INT4 or NONE + # Recommended for large models to get allreduce + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": + lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(), + + # Custom quick allreduce kernel for MI3* cards + # Due to the lack of the bfloat16 asm instruction, bfloat16 + # kernels are slower than fp16, + # If environment variable is set to 1, the input is converted to fp16 + "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16": + lambda: + (os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() in + ("true", "1")), + + # Custom quick allreduce kernel for MI3* cards. + # Controls the maximum allowed number of data bytes(MB) for custom quick + # allreduce communication. + # Default: 2048 MB. + # Data exceeding this size will use either custom allreduce or RCCL + # communication. + "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB": + lambda: maybe_convert_int( + os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)), + # If set, when running in Quark emulation mode, do not dequantize the # weights at load time. Instead, dequantize weights on-the-fly during # kernel execution. From 8b64c895c0c05d83458f1af67f81060d699d2526 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Thu, 26 Jun 2025 20:55:25 -0700 Subject: [PATCH 033/175] [CI] Sync test dependency with test.in for torch nightly (#19632) Signed-off-by: Yang Wang Signed-off-by: Yida Wu Signed-off-by: Nick Hill Co-authored-by: Concurrensee Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung Co-authored-by: Nick Hill --- .buildkite/test-pipeline.yaml | 12 ++- .pre-commit-config.yaml | 5 ++ requirements/nightly_torch_test.txt | 77 ++++++++++--------- requirements/test.in | 3 +- .../pytorch_nightly_dependency.sh | 42 ++++++++++ tools/generate_nightly_torch_test.py | 34 ++++++++ 6 files changed, 134 insertions(+), 39 deletions(-) create mode 100644 tests/standalone_tests/pytorch_nightly_dependency.sh create mode 100644 tools/generate_nightly_torch_test.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 26f70ad457b67..7f1841b1c97cd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -41,6 +41,16 @@ steps: # TODO: add `--strict` once warnings in docstrings are fixed - mkdocs build +- label: Pytorch Nightly Dependency Override Check # 2min + # if this test fails, it means the nightly torch version is not compatible with some + # of the dependencies. Please check the error message and add the package to whitelist + # in /vllm/tools/generate_nightly_torch_test.py + soft_fail: true + source_file_dependencies: + - requirements/nightly_torch_test.txt + commands: + - bash standalone_tests/pytorch_nightly_dependency.sh + - label: Async Engine, Inputs, Utils, Worker Test # 24min mirror_hardwares: [amdexperimental] source_file_dependencies: @@ -767,7 +777,7 @@ steps: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt - label: Weight Loading Multiple GPU Test - Large Models # optional - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 gpu: a100 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e62b623b4e114..15ef5defff69e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,6 +53,11 @@ repos: files: ^requirements/test\.(in|txt)$ - repo: local hooks: + - id: format-torch-nightly-test + name: reformat nightly_torch_test.txt to be in sync with test.in + language: python + entry: python tools/generate_nightly_torch_test.py + files: ^requirements/test\.(in|txt)$ - id: mypy-local name: Run mypy for local Python installation entry: tools/mypy.sh 0 "local" diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 00acda3662608..fd0b0fac12a92 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -1,47 +1,50 @@ -# Dependency that able to run entrypoints test -# pytest and its extensions +# testing pytest -pytest-asyncio +tensorizer>=2.9.0 pytest-forked -pytest-mock +pytest-asyncio pytest-rerunfailures pytest-shard pytest-timeout -librosa # required by audio tests in entrypoints/openai -sentence-transformers # required for embedding tests -transformers==4.52.4 -transformers_stream_generator # required for qwen-vl test -numba == 0.61.2; python_version > '3.9' # testing utils -boto3 -botocore -datasets -ray >= 2.10.0 +backoff # required for phi4mm test +blobfile # required for kimi-vl test +einops # required for MPT, qwen-vl and Mamba +httpx +librosa # required for audio tests +vocos # required for minicpmo_26 test peft +pqdm +ray[cgraph,default]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests +sentence-transformers # required for embedding tests +soundfile # required for audio tests +jiwer # required for audio tests +timm # required for internvl test +transformers_stream_generator # required for qwen-vl test +matplotlib # required for qwen-vl test +mistral_common[opencv] >= 1.6.2 # required for pixtral test +num2words # required for smolvlm test +opencv-python-headless >= 4.11.0 # required for video test +datamodel_code_generator # required for minicpm3 test +lm-eval[api]==0.4.8 # required for model evaluation test +mteb>=1.38.11, <2 # required for mteb test +transformers==4.52.4 +tokenizers==0.21.1 +huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. +schemathesis>=3.39.15 # Required for openai schema test. +# quantization +bitsandbytes>=0.45.3 +buildkite-test-collector==0.1.9 + + +genai_perf==0.0.8 +tritonclient==2.51.0 + +numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding +numba == 0.61.2; python_version > '3.9' +numpy runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 -tensorizer>=2.9.0 -lm-eval==0.4.8 -buildkite-test-collector==0.1.9 -lm-eval[api]==0.4.8 # required for model evaluation test - -# required for quantization test -bitsandbytes>=0.45.3 - -# required for minicpmo_26 test -vector_quantize_pytorch -vocos - -# required for Basic Models Test -blobfile # required for kimi-vl test -matplotlib # required for qwen-vl test - -# required for Multi-Modal Models Test (Standard) -num2words # required for smolvlm test -pqdm -timm # required for internvl test -mistral-common==1.6.2 - -schemathesis==3.39.15 # Required for openai schema test. -mteb>=1.38.11, <2 # required for mteb test +fastsafetensors>=0.1.10 +pydantic>=2.10 # 2.9 leads to error on python 3.10 diff --git a/requirements/test.in b/requirements/test.in index e8f44059fcf87..85c96df8e8f4c 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -42,6 +42,7 @@ schemathesis>=3.39.15 # Required for openai schema test. bitsandbytes>=0.45.3 buildkite-test-collector==0.1.9 + genai_perf==0.0.8 tritonclient==2.51.0 @@ -51,4 +52,4 @@ numpy runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 -pydantic>=2.10 # 2.9 leads to error on python 3.10 \ No newline at end of file +pydantic>=2.10 # 2.9 leads to error on python 3.10 diff --git a/tests/standalone_tests/pytorch_nightly_dependency.sh b/tests/standalone_tests/pytorch_nightly_dependency.sh new file mode 100644 index 0000000000000..cb531e13ecb81 --- /dev/null +++ b/tests/standalone_tests/pytorch_nightly_dependency.sh @@ -0,0 +1,42 @@ +#!/bin/sh +# This script tests if the nightly torch packages are not overridden by the dependencies + +set -e +set -x + +cd /vllm-workspace/ + +rm -rf .venv + +uv venv .venv + +source .venv/bin/activate + +# check the environment +uv pip freeze + +echo ">>> Installing nightly torch packages" +uv pip install --quiet torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu128 + +echo ">>> Capturing torch-related versions before requirements install" +uv pip freeze | grep -E '^torch|^torchvision|^torchaudio' | sort > before.txt +echo "Before:" +cat before.txt + +echo ">>> Installing requirements/nightly_torch_test.txt" +uv pip install --quiet -r requirements/nightly_torch_test.txt + +echo ">>> Capturing torch-related versions after requirements install" +uv pip freeze | grep -E '^torch|^torchvision|^torchaudio' | sort > after.txt +echo "After:" +cat after.txt + +echo ">>> Comparing versions" +if diff before.txt after.txt; then + echo "torch version not overridden." +else + echo "torch version overridden by nightly_torch_test.txt, \ + if the dependency is not triggered by the pytroch nightly test,\ + please add the dependency to the list 'white_list' in tools/generate_nightly_torch_test.py" + exit 1 +fi diff --git a/tools/generate_nightly_torch_test.py b/tools/generate_nightly_torch_test.py new file mode 100644 index 0000000000000..a3d7f7a609ba6 --- /dev/null +++ b/tools/generate_nightly_torch_test.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Generates specialized requirements files for nightly PyTorch testing. + +This script reads the main test requirements input file (`requirements/test.in`) +and splits its content into two files: +1. `requirements/nightly_torch_test.txt`: Contains dependencies +except PyTorch-related. +2. `torch_nightly_test.txt`: Contains only PyTorch-related packages. +""" + +input_file = "requirements/test.in" +output_file = "requirements/nightly_torch_test.txt" + +# white list of packages that are not compatible with PyTorch nightly directly +# with pip install. Please add your package to this list if it is not compatible +# or make the dependency test fails. +white_list = ["torch", "torchaudio", "torchvision", "mamba_ssm"] + +with open(input_file) as f: + lines = f.readlines() + +skip_next = False + +for line in lines: + if skip_next: + if line.startswith((" ", "\t")) or line.strip() == "": + continue + skip_next = False + + if any(k in line.lower() for k in white_list): + skip_next = True + continue From e11093068043a780ba8e778cdfcff8291d3f5b8c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 27 Jun 2025 06:06:59 +0200 Subject: [PATCH 034/175] [Fix] Fix gemma CI test failing on main (#20124) Signed-off-by: Thomas Parnell --- .../models/language/generation/test_gemma.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index ed0f0c19a0411..5be4ae874e615 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -7,14 +7,21 @@ MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"] @pytest.mark.parametrize("model", MODELS) -def test_dummy_loader(vllm_runner, model: str) -> None: - with vllm_runner( - model, - load_format="dummy", - ) as llm: - normalizers = llm.collective_rpc(lambda self: self.worker.model_runner. - model.model.normalizer.cpu().item()) - assert np.allclose( - normalizers, - llm.llm_engine.model_config.hf_config.hidden_size**0.5, - rtol=1e-3) +def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner( + model, + load_format="dummy", + ) as llm: + if model == "google/gemma-3-4b-it": + normalizers = llm.model.collective_rpc( + lambda self: self.model_runner.model.language_model.model. + normalizer.cpu().item()) + config = llm.model.llm_engine.model_config.hf_config.text_config + else: + normalizers = llm.model.collective_rpc( + lambda self: self.model_runner.model.model.normalizer.cpu( + ).item()) + config = llm.model.llm_engine.model_config.hf_config + assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) From cd4cfee68902dcad9498b3d9d4530b817499d592 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 27 Jun 2025 12:10:04 +0800 Subject: [PATCH 035/175] [Model][1/N] Automatic conversion of CrossEncoding model (#20012) Signed-off-by: wang.yuqi --- tests/models/language/pooling/mteb_utils.py | 11 +- vllm/config.py | 29 ++- vllm/model_executor/models/bert_with_rope.py | 149 +------------- vllm/model_executor/models/config.py | 200 +++++++++++++++++++ vllm/model_executor/models/qwen3.py | 17 +- 5 files changed, 239 insertions(+), 167 deletions(-) create mode 100644 vllm/model_executor/models/config.py diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 21d55c418c363..0284e69f3f0e2 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -43,7 +43,7 @@ class VllmMtebEncoder(mteb.Encoder): # issues by randomizing the order. r = self.rng.permutation(len(sentences)) sentences = [sentences[i] for i in r] - outputs = self.model.encode(sentences, use_tqdm=False) + outputs = self.model.embed(sentences, use_tqdm=False) embeds = np.array(outputs) embeds = embeds[np.argsort(r)] return embeds @@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner, with vllm_runner(model_info.name, task="score", max_model_len=None, + max_num_seqs=8, **vllm_extra_kwargs) as vllm_model: + model_config = vllm_model.model.llm_engine.model_config + if model_info.architecture: - assert (model_info.architecture - in vllm_model.model.llm_engine.model_config.architectures) + assert (model_info.architecture in model_config.architectures) + assert model_config.hf_config.num_labels == 1 vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model), tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) - vllm_dtype = vllm_model.model.llm_engine.model_config.dtype + vllm_dtype = model_config.dtype with hf_runner(model_info.name, is_cross_encoder=True, dtype="float32") as hf_model: diff --git a/vllm/config.py b/vllm/config.py index 856b361531168..7a3329aea5f78 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -569,6 +569,10 @@ class ModelConfig: else: self.truncation_side = "right" + model_info, arch = self.registry.inspect_model_cls(self.architectures) + self._model_info = model_info + self._architecture = arch + self.pooler_config = self._init_pooler_config() self.dtype = _get_and_verify_dtype( @@ -660,8 +664,18 @@ class ModelConfig: @property def architectures(self) -> list[str]: + # architectures in the model config. return getattr(self.hf_config, "architectures", []) + @property + def architecture(self) -> str: + # The architecture vllm actually used. + return self._architecture + + @property + def model_info(self) -> dict[str, Any]: + return self._model_info + def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: """Pull model/tokenizer from S3 to temporary directory when needed. @@ -4450,6 +4464,9 @@ class VllmConfig: def __post_init__(self): """Verify configs are valid & consistent with each other. """ + + self.try_verify_and_update_config() + if self.model_config is not None: self.model_config.verify_async_output_proc(self.parallel_config, self.speculative_config, @@ -4694,11 +4711,21 @@ class VllmConfig: batch_size_capture_list) def recalculate_max_model_len(self, max_model_len: int): + # Can only be called in try_verify_and_update_config model_config = self.model_config max_model_len = model_config.get_and_verify_max_len(max_model_len) self.model_config.max_model_len = max_model_len self.scheduler_config.max_model_len = max_model_len - self.compute_hash() + + def try_verify_and_update_config(self): + architecture = getattr(self.model_config, "architecture", None) + if architecture is None: + return + + from vllm.model_executor.models.config import MODELS_CONFIG_MAP + cls = MODELS_CONFIG_MAP.get(architecture, None) + if cls is not None: + cls.verify_and_update_config(self) def __str__(self): return ( diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 0f22393c79d98..0b7350f07d3f6 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from copy import deepcopy from typing import Optional import torch @@ -12,7 +11,6 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.logger import init_logger from vllm.model_executor.layers.activation import (get_act_and_mul_fn, get_act_fn) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -30,8 +28,6 @@ from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors -logger = init_logger(__name__) - class BertWithRopeEmbedding(nn.Module): @@ -408,7 +404,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.vllm_config = vllm_config - self.config = self.config_verify(vllm_config) + self.config = vllm_config.model_config.hf_config self.embeddings = BertWithRopeEmbedding(self.config) self.encoder = BertWithRopeEncoder( vllm_config=vllm_config, @@ -416,9 +412,6 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): rotary_kwargs=self.config.rotary_kwargs, prefix=f"{prefix}.encoder") - def config_verify(self, vllm_config): - raise NotImplementedError - def forward( self, input_ids: Optional[torch.Tensor], @@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope): "norm2": "mlp_ln", }) - def config_verify(self, vllm_config): - config = vllm_config.model_config.hf_config - - assert config.__class__.__name__ == "NomicBertConfig" - assert config.activation_function in ["swiglu", "gelu"] - config.position_embedding_type = getattr(config, - "position_embedding_type", - "rope") - - if config.activation_function == "swiglu": - config.hidden_act = "silu" - else: - config.hidden_act = config.activation_function - - assert (config.mlp_fc1_bias == config.mlp_fc2_bias == - config.qkv_proj_bias) - config.bias = config.qkv_proj_bias - - assert config.rotary_emb_scale_base is None - assert not config.rotary_emb_interleaved - - config.layer_norm_eps = config.layer_norm_epsilon - config.intermediate_size = config.n_inner - config.hidden_size = config.n_embd - config.num_hidden_layers = config.n_layer - - head_dim = config.hidden_size // config.num_attention_heads - rotary_emb_dim = head_dim * config.rotary_emb_fraction - max_trained_positions = getattr(config, "max_trained_positions", 2048) - config.rotary_kwargs = { - "head_size": head_dim, - "rotary_dim": rotary_emb_dim, - "max_position": max_trained_positions, - "base": getattr(config, "rope_theta", config.rotary_emb_base), - "rope_scaling": getattr(config, "rope_scaling", None) - } - - # we ignore config.rotary_scaling_factor so that for datasets shorter - # than max_trained_positions 2048, the results are consistent - # with SentenceTransformer. - # The context extension uses vllm style rope_theta and rope_scaling. - # See #17785 #18755 - if (not vllm_config.model_config.hf_overrides - and vllm_config.model_config.original_max_model_len is None): - # Default - # Reset max_model_len to max_trained_positions. - # nomic-embed-text-v2-moe the length is set to 512 - # by sentence_bert_config.json. - max_model_len_before = vllm_config.model_config.max_model_len - max_model_len = min(vllm_config.model_config.max_model_len, - max_trained_positions) - - vllm_config.recalculate_max_model_len(max_model_len) - logger.warning( - "Nomic context extension is disabled. " - "Changing max_model_len from %s to %s. " - "To enable context extension, see: " - "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html", - max_model_len_before, vllm_config.model_config.max_model_len) - else: - # We need to re-verify max_model_len to avoid lengths - # greater than position_embedding. - model_config = vllm_config.model_config - hf_text_config = model_config.hf_text_config - - if isinstance(model_config.hf_overrides, dict): - # hf_overrides_kw - max_model_len = model_config.hf_overrides.get( - "max_model_len", vllm_config.model_config.max_model_len) - else: - # hf_overrides_fn - # This might be overridden by sentence_bert_config.json. - max_model_len = vllm_config.model_config.max_model_len - - # reset hf_text_config for recalculate_max_model_len. - if hasattr(hf_text_config, "max_model_len"): - delattr(hf_text_config, "max_model_len") - hf_text_config.max_position_embeddings = max_trained_positions - hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"] - - # The priority of sentence_bert_config.json is higher - # than max_position_embeddings - encoder_config = deepcopy(model_config.encoder_config) - encoder_config.pop("max_seq_length", None) - model_config.encoder_config = encoder_config - - vllm_config.recalculate_max_model_len(max_model_len) - return config - class GteNewModel(BertWithRope): # for https://huggingface.co/Alibaba-NLP/new-impl @@ -600,24 +504,6 @@ class GteNewModel(BertWithRope): layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.skip_bias_add = True - def config_verify(self, vllm_config): - config = vllm_config.model_config.hf_config - - assert config.__class__.__name__ == "NewConfig" - assert config.hidden_act == "gelu" - - config.hidden_act = "geglu" - - head_dim = config.hidden_size // config.num_attention_heads - config.rotary_kwargs = { - "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_position_embeddings, - "base": config.rope_theta, - "rope_scaling": getattr(config, "rope_scaling", None) - } - return config - def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]): n = "mlp.up_gate_proj" for name, weight in weights: @@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel): "attention.o_proj": "attn.out_proj", }) - def config_verify(self, vllm_config): - config = vllm_config.model_config.hf_config - - assert config.__class__.__name__ == "GteConfig" - assert config.hidden_act == "gelu" - - config.hidden_act = "geglu" - - head_dim = config.hidden_size // config.num_attention_heads - config.rotary_kwargs = { - "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_position_embeddings, - "base": config.rope_theta, - "rope_scaling": getattr(config, "rope_scaling", None) - } - return config - class JinaRobertaModel(BertWithRope): # for https://huggingface.co/jinaai/jina-embeddings-v3 @@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope): "norm2": "mlp_ln", }) - def config_verify(self, vllm_config): - config = vllm_config.model_config.hf_config - - assert config.__class__.__name__ == "XLMRobertaFlashConfig" - - head_dim = config.hidden_size // config.num_attention_heads - config.rotary_kwargs = { - "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_position_embeddings, - "base": getattr(config, "rope_theta", config.rotary_emb_base), - "rope_scaling": getattr(config, "rope_scaling", None) - } - return config - def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py new file mode 100644 index 0000000000000..7b5345704ad00 --- /dev/null +++ b/vllm/model_executor/models/config.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class VerifyAndUpdateConfig: + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + raise NotImplementedError + + +class GteNewModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NewConfig" + assert config.hidden_act == "gelu" + + config.hidden_act = "geglu" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + "rope_scaling": getattr(config, "rope_scaling", None) + } + + +class JinaRobertaModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + if config.position_embedding_type == "rotary": + assert config.__class__.__name__ == "XLMRobertaFlashConfig" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": getattr(config, "rope_theta", config.rotary_emb_base), + "rope_scaling": getattr(config, "rope_scaling", None) + } + + +class NomicBertModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NomicBertConfig" + assert config.activation_function in ["swiglu", "gelu"] + config.position_embedding_type = getattr(config, + "position_embedding_type", + "rope") + + if config.activation_function == "swiglu": + config.hidden_act = "silu" + else: + config.hidden_act = config.activation_function + + assert (config.mlp_fc1_bias == config.mlp_fc2_bias == + config.qkv_proj_bias) + config.bias = config.qkv_proj_bias + + assert config.rotary_emb_scale_base is None + assert not config.rotary_emb_interleaved + + config.layer_norm_eps = config.layer_norm_epsilon + config.intermediate_size = config.n_inner + config.hidden_size = config.n_embd + config.num_hidden_layers = config.n_layer + + head_dim = config.hidden_size // config.num_attention_heads + rotary_emb_dim = head_dim * config.rotary_emb_fraction + max_trained_positions = getattr(config, "max_trained_positions", 2048) + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": rotary_emb_dim, + "max_position": max_trained_positions, + "base": getattr(config, "rope_theta", config.rotary_emb_base), + "rope_scaling": getattr(config, "rope_scaling", None) + } + + # we ignore config.rotary_scaling_factor so that for datasets shorter + # than max_trained_positions 2048, the results are consistent + # with SentenceTransformer. + # The context extension uses vllm style rope_theta and rope_scaling. + # See #17785 #18755 + if (not vllm_config.model_config.hf_overrides + and vllm_config.model_config.original_max_model_len is None): + # Default + # Reset max_model_len to max_trained_positions. + # nomic-embed-text-v2-moe the length is set to 512 + # by sentence_bert_config.json. + max_model_len_before = vllm_config.model_config.max_model_len + max_model_len = min(vllm_config.model_config.max_model_len, + max_trained_positions) + + vllm_config.recalculate_max_model_len(max_model_len) + logger.warning( + "Nomic context extension is disabled. " + "Changing max_model_len from %s to %s. " + "To enable context extension, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html", + max_model_len_before, vllm_config.model_config.max_model_len) + else: + # We need to re-verify max_model_len to avoid lengths + # greater than position_embedding. + model_config = vllm_config.model_config + hf_text_config = model_config.hf_text_config + + if isinstance(model_config.hf_overrides, dict): + # hf_overrides_kw + max_model_len = model_config.hf_overrides.get( + "max_model_len", vllm_config.model_config.max_model_len) + else: + # hf_overrides_fn + # This might be overridden by sentence_bert_config.json. + max_model_len = vllm_config.model_config.max_model_len + + # reset hf_text_config for recalculate_max_model_len. + if hasattr(hf_text_config, "max_model_len"): + delattr(hf_text_config, "max_model_len") + hf_text_config.max_position_embeddings = max_trained_positions + hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"] + + # The priority of sentence_bert_config.json is higher + # than max_position_embeddings + encoder_config = deepcopy(model_config.encoder_config) + encoder_config.pop("max_seq_length", None) + model_config.encoder_config = encoder_config + + vllm_config.recalculate_max_model_len(max_model_len) + + +class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + is_original_qwen3_reranker = getattr(config, + "is_original_qwen3_reranker", + False) + + if not is_original_qwen3_reranker: + return + + tokens = getattr(config, "classifier_from_token", None) + assert tokens is not None and len(tokens) == 2, \ + ("Try loading the original Qwen3 Reranker?, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") + config.num_labels = 1 + + +class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "GteConfig" + assert config.hidden_act == "gelu" + + config.hidden_act = "geglu" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + "rope_scaling": getattr(config, "rope_scaling", None) + } + + +MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { + "GteModel": SnowflakeGteNewModelConfig, + "GteNewModel": GteNewModelConfig, + "NomicBertModel": NomicBertModelConfig, + "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, + "XLMRobertaModel": JinaRobertaModelConfig, +} diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 216c1f1c7ff74..1224ba7abc756 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -400,22 +400,10 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, def load_weights_from_original_qwen3_reranker( self, weights: Iterable[tuple[str, torch.Tensor]]): - tokens = getattr(self.config, "classifier_from_token", None) - assert tokens is not None and len(tokens) == 2, \ - ("Try loading the original Qwen3 Reranker?, see: " - "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") - self.config.num_labels = 1 model_config = self.vllm_config.model_config - + tokens = getattr(self.config, "classifier_from_token", None) device = self.score.weight.device - self.score = RowParallelLinear(self.config.hidden_size, - self.config.num_labels, - quant_config=self.quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix( - self.prefix, "score")).to(device) if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens @@ -443,5 +431,6 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, self.score.weight.data.copy_(weight) del self.lm_head - loaded_weights.add("classifier.weight") + loaded_weights.add("score.weight") loaded_weights.discard("lm_head.weight") + return loaded_weights From 6e244ae09121b2c1cdcd4db51076decc4a724c5c Mon Sep 17 00:00:00 2001 From: Yazan Sharaya <97323283+Yazan-Sharaya@users.noreply.github.com> Date: Fri, 27 Jun 2025 07:44:14 +0300 Subject: [PATCH 036/175] [Perf][Frontend] eliminate api_key and x_request_id headers middleware overhead (#19946) Signed-off-by: Yazan-Sharaya --- docs/serving/openai_compatible_server.md | 5 - .../openai/test_optional_middleware.py | 116 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 100 +++++++++++---- vllm/entrypoints/openai/cli_args.py | 2 +- 4 files changed, 190 insertions(+), 33 deletions(-) create mode 100644 tests/entrypoints/openai/test_optional_middleware.py diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 00756e719992d..a3f1ef9fd8b6b 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -146,11 +146,6 @@ completion = client.chat.completions.create( Only `X-Request-Id` HTTP request header is supported for now. It can be enabled with `--enable-request-id-headers`. -> Note that enablement of the headers can impact performance significantly at high QPS -> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio), -> rather than within the vLLM layer for this reason. -> See [this PR](https://github.com/vllm-project/vllm/pull/11529) for more details. - ??? Code ```python diff --git a/tests/entrypoints/openai/test_optional_middleware.py b/tests/entrypoints/openai/test_optional_middleware.py new file mode 100644 index 0000000000000..882fa0886ce30 --- /dev/null +++ b/tests/entrypoints/openai/test_optional_middleware.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for middleware that's off by default and can be toggled through +server arguments, mainly --api-key and --enable-request-id-headers. +""" + +from http import HTTPStatus + +import pytest +import requests + +from ...utils import RemoteOpenAIServer + +# Use a small embeddings model for faster startup and smaller memory footprint. +# Since we are not testing any chat functionality, +# using a chat capable model is overkill. +MODEL_NAME = "intfloat/multilingual-e5-small" + + +@pytest.fixture(scope="module") +def server(request: pytest.FixtureRequest): + passed_params = [] + if hasattr(request, "param"): + passed_params = request.param + if isinstance(passed_params, str): + passed_params = [passed_params] + + args = [ + "--task", + "embed", + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "512", + "--enforce-eager", + "--max-num-seqs", + "2", + *passed_params + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_no_api_token(server: RemoteOpenAIServer): + response = requests.get(server.url_for("v1/models")) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_no_request_id_header(server: RemoteOpenAIServer): + response = requests.get(server.url_for("health")) + assert "X-Request-Id" not in response.headers + + +@pytest.mark.parametrize( + "server", + [["--api-key", "test"]], + indirect=True, +) +@pytest.mark.asyncio +async def test_missing_api_token(server: RemoteOpenAIServer): + response = requests.get(server.url_for("v1/models")) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.parametrize( + "server", + [["--api-key", "test"]], + indirect=True, +) +@pytest.mark.asyncio +async def test_passed_api_token(server: RemoteOpenAIServer): + response = requests.get(server.url_for("v1/models"), + headers={"Authorization": "Bearer test"}) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.parametrize( + "server", + [["--api-key", "test"]], + indirect=True, +) +@pytest.mark.asyncio +async def test_not_v1_api_token(server: RemoteOpenAIServer): + # Authorization check is skipped for any paths that + # don't start with /v1 (e.g. /v1/chat/completions). + response = requests.get(server.url_for("health")) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.parametrize( + "server", + ["--enable-request-id-headers"], + indirect=True, +) +@pytest.mark.asyncio +async def test_enable_request_id_header(server: RemoteOpenAIServer): + response = requests.get(server.url_for("health")) + assert "X-Request-Id" in response.headers + assert len(response.headers.get("X-Request-Id", "")) == 32 + + +@pytest.mark.parametrize( + "server", + ["--enable-request-id-headers"], + indirect=True, +) +@pytest.mark.asyncio +async def test_custom_request_id_header(server: RemoteOpenAIServer): + response = requests.get(server.url_for("health"), + headers={"X-Request-Id": "Custom"}) + assert "X-Request-Id" in response.headers + assert response.headers.get("X-Request-Id") == "Custom" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 681633a2aff77..f3fd154862711 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,7 +14,7 @@ import socket import tempfile import uuid from argparse import Namespace -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus @@ -30,8 +30,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool -from starlette.datastructures import State +from starlette.datastructures import URL, Headers, MutableHeaders, State from starlette.routing import Mount +from starlette.types import ASGIApp, Message, Receive, Scope, Send from typing_extensions import assert_never import vllm.envs as envs @@ -1061,6 +1062,74 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: return None +class AuthenticationMiddleware: + """ + Pure ASGI middleware that authenticates each request by checking + if the Authorization header exists and equals "Bearer {api_key}". + + Notes + ----- + There are two cases in which authentication is skipped: + 1. The HTTP method is OPTIONS. + 2. The request path doesn't start with /v1 (e.g. /health). + """ + + def __init__(self, app: ASGIApp, api_token: str) -> None: + self.app = app + self.api_token = api_token + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", + "websocket") or scope["method"] == "OPTIONS": + # scope["type"] can be "lifespan" or "startup" for example, + # in which case we don't need to do anything + return self.app(scope, receive, send) + root_path = scope.get("root_path", "") + url_path = URL(scope=scope).path.removeprefix(root_path) + headers = Headers(scope=scope) + # Type narrow to satisfy mypy. + if url_path.startswith("/v1") and headers.get( + "Authorization") != f"Bearer {self.api_token}": + response = JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return response(scope, receive, send) + return self.app(scope, receive, send) + + +class XRequestIdMiddleware: + """ + Middleware the set's the X-Request-Id header for each response + to a random uuid4 (hex) value if the header isn't already + present in the request, otherwise use the provided request id. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", "websocket"): + return self.app(scope, receive, send) + + # Extract the request headers. + request_headers = Headers(scope=scope) + + async def send_with_request_id(message: Message) -> None: + """ + Custom send function to mutate the response headers + and append X-Request-Id to it. + """ + if message["type"] == "http.response.start": + response_headers = MutableHeaders(raw=message["headers"]) + request_id = request_headers.get("X-Request-Id", + uuid.uuid4().hex) + response_headers.append("X-Request-Id", request_id) + await send(message) + + return self.app(scope, receive, send_with_request_id) + + def build_app(args: Namespace) -> FastAPI: if args.disable_fastapi_docs: app = FastAPI(openapi_url=None, @@ -1108,33 +1177,10 @@ def build_app(args: Namespace) -> FastAPI: # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY if token := args.api_key or envs.VLLM_API_KEY: - - @app.middleware("http") - async def authentication(request: Request, call_next): - if request.method == "OPTIONS": - return await call_next(request) - url_path = request.url.path - if app.root_path and url_path.startswith(app.root_path): - url_path = url_path[len(app.root_path):] - if not url_path.startswith("/v1"): - return await call_next(request) - if request.headers.get("Authorization") != "Bearer " + token: - return JSONResponse(content={"error": "Unauthorized"}, - status_code=401) - return await call_next(request) + app.add_middleware(AuthenticationMiddleware, api_token=token) if args.enable_request_id_headers: - logger.warning( - "CAUTION: Enabling X-Request-Id headers in the API Server. " - "This can harm performance at high QPS.") - - @app.middleware("http") - async def add_request_id(request: Request, call_next): - request_id = request.headers.get( - "X-Request-Id") or uuid.uuid4().hex - response = await call_next(request) - response.headers["X-Request-Id"] = request_id - return response + app.add_middleware(XRequestIdMiddleware) if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: logger.warning("CAUTION: Enabling log response in the API Server. " diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index dd4bd53046a35..f9bec84518688 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -216,7 +216,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--enable-request-id-headers", action="store_true", help="If specified, API server will add X-Request-Id header to " - "responses. Caution: this hurts performance at high QPS.") + "responses.") parser.add_argument( "--enable-auto-tool-choice", action="store_true", From dec197e3e5d14e1d4fbad61b565e151f52976c0f Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Fri, 27 Jun 2025 00:48:13 -0500 Subject: [PATCH 037/175] Quick Fix by adding conditional import for flash_attn_varlen_func in flash_attn (#20143) Signed-off-by: Chendi.Xue --- vllm/attention/utils/fa_utils.py | 4 ++++ vllm/v1/attention/backends/flash_attn.py | 10 +++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 36fd2d231bc5f..f8b00565f0517 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -66,3 +66,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: def flash_attn_supports_fp8() -> bool: return get_flash_attn_version() == 3 and \ current_platform.get_device_capability().major == 9 + + +def is_flash_attn_varlen_func_available() -> bool: + return current_platform.is_cuda() or current_platform.is_xpu() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 42b5997f085b1..527b31153410b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -14,10 +14,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - flash_attn_varlen_func, get_flash_attn_version, - get_scheduler_metadata, - reshape_and_cache_flash) + is_flash_attn_varlen_func_available) + +if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, + get_scheduler_metadata, + reshape_and_cache_flash) + from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv From d1c956dc0f8b64af7a43e23f9fc2850756dea1cb Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Fri, 27 Jun 2025 03:16:26 -0400 Subject: [PATCH 038/175] Gemma3n (Text-only) (#20134) Signed-off-by: rshaw@neuralmagic.com Signed-off-by: Roger Wang Co-authored-by: Roger Wang --- docs/models/supported_models.md | 4 + tests/models/registry.py | 2 + vllm/model_executor/layers/activation.py | 51 ++ vllm/model_executor/models/gemma3n.py | 811 +++++++++++++++++++++++ vllm/model_executor/models/registry.py | 2 + 5 files changed, 870 insertions(+) create mode 100644 vllm/model_executor/models/gemma3n.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 04d9923f92105..9782fd1781512 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -336,6 +336,7 @@ Specified using `--task generate`. | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | | `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | @@ -392,6 +393,9 @@ Specified using `--task generate`. !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. +!!! note + Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0. + ### Pooling Models See [this page](./pooling_models.md) for more information on how to use pooling models. diff --git a/tests/models/registry.py b/tests/models/registry.py index 4a587e39ad4cd..1bcb4f88a30ff 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -164,6 +164,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), + "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 + min_transformers_version="4.53"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"), "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index cc9c8d445ab6c..1fd96fe405b9a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -135,6 +135,57 @@ class MulAndSilu(CustomOp): # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: +@CustomOp.register("gelu_and_mul_sparse") +class GeluAndMulSparse(CustomOp): + """An activation function for GeluAndMulSparse. + This activation function is used in Gemma3n. It computes: + up_proj = self.up_proj(x) + gate_proj = self.gate_proj(x) + gate_proj = self._gaussian_topk(gate_proj) # sparsity + activations = self.act_fn(gate_proj) # gelu + down_proj = self.down_proj(activations * up_proj) + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, activation_sparsity: float, approximate: str = "none"): + super().__init__() + # Gelu. + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + + # Sparsity. + if activation_sparsity == 0.0: + raise ValueError( + "activation_sparsity is 0.0. Please use GeluAndMul.") + target_sparsity_tensor = torch.tensor(activation_sparsity, + dtype=torch.float32) + normal_dist = torch.distributions.normal.Normal(0, 1) + self.std_multiplier = normal_dist.icdf(target_sparsity_tensor) + + def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor: + """Get % sparse percentile of the Gaussian distribution.""" + # NOTE(rob): for TP>1, we could all-gather to get the means/std. + # But we do not do this because in expectation they are the same + # and in practice the eval scores are good without gathering. + mean = torch.mean(x, dim=-1, keepdim=True) + std = torch.std(x, dim=-1, keepdim=True, unbiased=False) + cutoff_x = mean + std * self.std_multiplier + return nn.functional.relu(x - cutoff_x) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + out = self._gaussian_topk(x[..., :d]) + out = F.gelu(out, approximate=self.approximate) + return out * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + @CustomOp.register("gelu_and_mul") class GeluAndMul(CustomOp): """An activation function for GeGLU. diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py new file mode 100644 index 0000000000000..7d163320e0d6a --- /dev/null +++ b/vllm/model_executor/models/gemma3n.py @@ -0,0 +1,811 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, + GeluAndMul, + GeluAndMulSparse) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, make_layers, maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma3nAltUp(nn.Module): + """Alternating updates (Altup) + The AltUp module wraps transformer layers. The `predict` step modifies the + input to the transformer layer, and the `correct` step propagates the output + of the transformer layer to the sparsely updated dimensions. + See more in the research paper: + https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf + """ + + def __init__( + self, + hidden_size: int, + rms_norm_eps: float, + altup_num_inputs: int, + altup_coef_clip: float, + altup_active_idx: int, + prefix: str, + ): + super().__init__() + + self.altup_num_inputs = altup_num_inputs + self.altup_active_idx = altup_active_idx + self.altup_coef_clip = altup_coef_clip + + self.correction_coefs = ReplicatedLinear( + altup_num_inputs, + altup_num_inputs, + bias=False, + prefix=f"{prefix}.correction_coefs", + return_bias=False, + ) + self.prediction_coefs = ReplicatedLinear( + altup_num_inputs, + altup_num_inputs**2, + bias=False, + prefix=f"{prefix}.prediction_coefs", + return_bias=False, + ) + self.modality_router = ReplicatedLinear( + hidden_size, + altup_num_inputs, + bias=False, + prefix=f"{prefix}.modality_router", + return_bias=False, + ) + self.router_norm = RMSNorm( + hidden_size=hidden_size, + eps=rms_norm_eps, + ) + self.router_input_scale = torch.tensor( + hidden_size**-1.0, dtype=self.modality_router.weight.dtype) + self.correct_output_scale = nn.Parameter( + torch.zeros(hidden_size, dtype=torch.float32)) + + def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: + router_inputs = self.router_norm(x) * self.router_input_scale + routed = self.modality_router(router_inputs) + return torch.tanh(routed.float()).type_as(x) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + return (corrected.type_as(self.correct_output_scale) * + self.correct_output_scale).type_as(corrected) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden: [altup_num_inputs, num_tokens, hidden_size] + # modalities: [num_tokens, num_altup_inputs] + # all_coefs: [num_tokens, num_altup_inputs ** 2] + modalities = self._compute_router_modalities( + hidden_states[self.altup_active_idx]) + all_coefs = self.prediction_coefs(modalities) + + # Reshape and transpose the 2D matrix for the matmul. + # all_coefs_T: [num_tokens, num_altup_inputs, num_altup_inputs] + all_coefs_T = all_coefs.reshape( + -1, + self.altup_num_inputs, + self.altup_num_inputs, + ).permute(0, 2, 1) + + # hidden_states to [num_tokens, hidden_size, altup_num_inputs] + predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T) + # [altup_num_inputs, num_tokens, hidden_size] + predictions = predictions.permute(2, 0, 1) + predictions += hidden_states + return predictions.contiguous() + + def correct(self, predictions: torch.Tensor, + activated: torch.Tensor) -> torch.Tensor: + # predictions: [altup_num_inputs, num_tokens, hidden_size] + # activated: [num_tokens, hidden_size] + # modalities: [num_tokens, altup_num_inputs] + modalities = self._compute_router_modalities(activated) + # innovation: [num_tokens, altup_num_inputs] + innovation = activated - predictions[self.altup_active_idx] + # innovation: [altup_num_inputs, num_tokens, hidden_size] + innovation = innovation.repeat(self.altup_num_inputs, 1, 1) + + # Permute to [altup_num_inputs, num_tokens] as the last dim + # is a scalar applied to each altup input and expand on + # num_tokens dim for broadcastability over hidden_size. + # all_coefs: [num_tokens, altup_num_inputs] + all_coefs = self.correction_coefs(modalities) + 1.0 + # all_coefs: [altup_num_inputs, num_tokens, 1] + all_coefs = all_coefs.T.unsqueeze(-1) + + # Elementwise (broadcast over hidden_size). + corrected = torch.mul(innovation, all_coefs) + corrected += predictions + + return corrected.contiguous() + + +class Gemma3nLaurelBlock(nn.Module): + """Learned Augmented Residual Layer""" + + def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float, + prefix: str): + super().__init__() + + self.linear_left = ColumnParallelLinear( + hidden_size, + laurel_rank, + bias=False, + prefix=f"{prefix}.linear_left", + return_bias=False, + ) + self.linear_right = RowParallelLinear(laurel_rank, + hidden_size, + bias=False, + prefix=f"{prefix}.linear_right", + return_bias=False) + self.post_laurel_norm = RMSNorm( + hidden_size=hidden_size, + eps=rms_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + laurel_x = self.linear_left(x) + laurel_x = self.linear_right(laurel_x) + normed_laurel_x = self.post_laurel_norm(laurel_x) + return x + normed_laurel_x + + +class Gemma3nMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + activation_sparsity: float = 0.0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + + self.act_fn = GeluAndMulSparse( + activation_sparsity=activation_sparsity, + approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul( + approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3nAttention(nn.Module): + + def __init__(self, + config: Gemma3nTextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.q_norm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps) + self.k_norm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps) + self.v_norm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps, + has_weight=False) + + layer_idx = extract_layer_index(prefix) + if config.layer_types[layer_idx] == "sliding_attention": + self.sliding_window = config.sliding_window + rope_theta = config.rope_local_base_freq + rope_scaling = {"rope_type": "default"} + else: + self.sliding_window = None + rope_theta = config.rope_theta + rope_scaling = config.rope_scaling + + first_kv_shared_layer_idx = (config.num_hidden_layers - + config.num_kv_shared_layers) + self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx + + if self.is_kv_shared: + # Last full attention layer is 1 before sharing + # Last sliding attention layer is 2 before sharing + offset = 2 if self.sliding_window is not None else 1 + kv_shared_layer_index = first_kv_shared_layer_idx - offset + kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 + else: + kv_sharing_target_layer_name = None + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=1.0, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + v = v.flatten(-2, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + output, _ = self.o_proj(attn_output) + return output + + +class Gemma3nDecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma3nTextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.altup_active_idx = config.altup_active_idx + assert config.altup_correct_scale + + self.altup = Gemma3nAltUp( + hidden_size=config.hidden_size, + rms_norm_eps=config.rms_norm_eps, + altup_num_inputs=config.altup_num_inputs, + altup_coef_clip=config.altup_coef_clip, + altup_active_idx=config.altup_active_idx, + prefix=f"{prefix}.altup", + ) + self.self_attn = Gemma3nAttention( + config=config, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Gemma3nMLP( + hidden_size=config.hidden_size, + # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501 + intermediate_size=config.intermediate_size[extract_layer_index( + prefix)], + hidden_activation=config.hidden_activation, + quant_config=quant_config, + activation_sparsity=config.activation_sparsity_pattern[ + extract_layer_index(prefix)], + prefix=f"{prefix}.mlp", + ) + self.laurel = Gemma3nLaurelBlock( + hidden_size=config.hidden_size, + laurel_rank=config.laurel_rank, + rms_norm_eps=config.rms_norm_eps, + prefix=f"{prefix}.laurel", + ) + + # NOTE(rob): should be ColumnParallelLinear and RowParallelLinear + # But, we need to add per_layer_input_gate(x) to per_layer_input. + # per_layer_input cannot be sharded, so we replicate for now. + self.per_layer_input_gate = ReplicatedLinear( + config.hidden_size, + config.hidden_size_per_layer_input, + bias=False, + prefix=f"{prefix}.per_layer_input_gate", + return_bias=False, + ) + self.per_layer_projection = ReplicatedLinear( + config.hidden_size_per_layer_input, + config.hidden_size, + bias=False, + prefix=f"{prefix}.per_layer_projection", + return_bias=False, + ) + + # LayerNorms. + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_per_layer_input_norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation] + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + + # ActUp (predict). + predictions = self.altup.predict(hidden_states) + active_prediction = predictions[self.altup_active_idx] + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + + # Attention. + attn = self.self_attn( + positions=positions, + hidden_states=active_prediction_normed, + **kwargs, + ) + attn = self.post_attention_layernorm(attn) + attn_gated = attn + active_prediction + attn_laurel = (attn_gated + laurel_output) / torch.sqrt( + torch.tensor(2.0)) + + # MLP. + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm + + # ActUp (connect). + corrected_predictions = self.altup.correct(predictions, + attn_ffw_laurel_gated) + first_prediction = corrected_predictions[self.altup_active_idx] + first_prediction = self.altup.scale_corrected_output(first_prediction) + + # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) + first_prediction = self.per_layer_input_gate(first_prediction) + first_prediction = self.act_fn(first_prediction) + first_prediction = torch.mul(first_prediction, per_layer_input) + + # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...) + first_prediction = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + corrected_predictions[1:] += first_prediction + + return corrected_predictions + + +@support_torch_compile +class Gemma3nTextModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config.text_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + self.embed_scale = torch.tensor( + config.hidden_size**0.5, + dtype=self.embed_tokens.weight.dtype, + ) + self.embed_tokens_per_layer = VocabParallelEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + prefix=f"{prefix}.per_layer_embed_tokens", + ) + self.embed_scale_per_layer = torch.tensor( + config.hidden_size_per_layer_input**0.5, + dtype=self.embed_tokens.weight.dtype, + ) + self.per_layer_model_projection = ColumnParallelLinear( + config.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + gather_output=True, + return_bias=False, + prefix=f"{prefix}.per_layer_model_projection", + ) + self.per_layer_projection_norm = RMSNorm( + hidden_size=config.hidden_size_per_layer_input, + eps=config.rms_norm_eps, + ) + self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to( + self.embed_tokens.weight.dtype) + self.per_layer_projection_scale = torch.tensor( + config.hidden_size**0.5, + dtype=self.embed_tokens.weight.dtype, + ) + self.altup_projections = nn.ModuleList([ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + prefix=f"{prefix}.{idx-1}.altup_projections", + ) for idx in range(1, self.config.altup_num_inputs) + ]) + self.altup_unembed_projections = nn.ModuleList([ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + prefix=f"{prefix}.{idx-1}.altup_unembed_projections", + ) for idx in range(1, self.config.altup_num_inputs) + ]) + + # Transformer blocks. + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3nDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.eps = torch.tensor(torch.finfo().min) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.embed_scale + + def get_per_layer_input_embeddings( + self, input_ids: torch.Tensor) -> torch.Tensor: + # Deal with the fact that vocab_size_per_layer_input < vocab_size + # which causes us to have some out of vocab tokens by setting + # those token ids to 0. This matches the HF implementation. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, + torch.zeros_like(input_ids)) + return self.embed_tokens_per_layer( + per_layer_inputs_tokens) * self.embed_scale_per_layer + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds + else: + hidden_states_0 = self.get_input_embeddings(input_ids) + + # Per layer inputs. + if input_ids is None: + raise ValueError("Passing None for input ids is not supported.") + per_layer_inputs = self.get_per_layer_input_embeddings(input_ids) + per_layer_inputs = per_layer_inputs.reshape( + -1, self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input) + per_layer_projection = self.per_layer_model_projection(hidden_states_0) + per_layer_projection = per_layer_projection.reshape( + *hidden_states_0.shape[:-1], + self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection) + per_layer_inputs = per_layer_projection + per_layer_inputs + per_layer_inputs *= self.per_layer_input_scale + + # Altup embed. + hidden_states = [hidden_states_0] * self.config.altup_num_inputs + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + hidden_states = torch.stack(hidden_states, dim=0) + + # Transformer blocks. + for layer_idx, layer in enumerate(self.layers): + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + # Altup unembed. + target_magnitude = torch.mean(hidden_states[0]**2, + dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_unembed_projections[i - 1]( + hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=0) + + return self.norm(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + # Avoid spurious match with ".up_proj". + if "altup_projections" in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma3nModel(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.language_model = Gemma3nTextModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "language_model")) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.language_model(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + **kwargs) + + +class Gemma3nForConditionalGeneration(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + self.model = Gemma3nModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.text_config.vocab_size, + soft_cap=config.text_config.final_logit_softcapping) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.language_model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: Optional[SamplingMetadata], + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.language_model.embed_tokens, + hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_substrs=([ + "embed_audio.", "embed_vision.", + "audio_tower.", "vision_tower." + ])) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index faeaf6ef68ccc..ff605cae02ea4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -58,6 +58,8 @@ _TEXT_GENERATION_MODELS = { "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), + #TODO(ywang96): Support multimodal gemma3n + "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501 "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), From 4ab3ac285e824542831c4326d01ce84bd8b65aad Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 27 Jun 2025 16:30:53 +0900 Subject: [PATCH 039/175] [Bugfix] Fix flaky failure when getting DP ports (#20151) Signed-off-by: mgoin --- vllm/config.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7a3329aea5f78..623ba3aaf1093 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1878,18 +1878,41 @@ class ParallelConfig: return answer def stateless_init_dp_group(self) -> "ProcessGroup": + # NOTE: In high-concurrency scenarios multiple processes + # can pick the same (currently free) port through a race + # condition when calling `get_open_port()`. When the first + # process binds the port the others will subsequently fail + # with `torch.distributed.DistNetworkError: EADDRINUSE`. + # To make the initialization more robust we retry a few times + # with a fresh port whenever this specific error is observed. + from torch.distributed import DistNetworkError + from vllm.distributed.utils import ( stateless_init_torch_distributed_process_group) - # use gloo since the engine process might not have cuda device - dp_group = stateless_init_torch_distributed_process_group( - self.data_parallel_master_ip, - self.get_next_dp_init_port(), - self.data_parallel_rank, - self.data_parallel_size, - backend="gloo") + max_retries = 5 + last_exc: Optional[Exception] = None + for _ in range(max_retries): + try: + # use gloo since the engine process might not have cuda device + return stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="gloo") + except DistNetworkError as e: + # We only want to retry when the root cause is EADDRINUSE. + if "EADDRINUSE" in str(e): + logger.warning( + "Address already in use. Retrying with a new port.") + last_exc = e + continue # try again with a new port + raise e - return dp_group + # If we get here all retries have failed. + assert last_exc is not None + raise last_exc @staticmethod def has_unfinished_dp(dp_group: "ProcessGroup", From aa0dc77ef53b365ddf54be51748c166895a0bcd9 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Fri, 27 Jun 2025 13:16:41 +0400 Subject: [PATCH 040/175] [Perf] Improved perf for resolve_chat_template_content_format (#20065) Signed-off-by: Ilya Lavrenov --- vllm/entrypoints/chat_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 7951c49f5da05..35ee52ab4601d 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -293,6 +293,7 @@ def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: return None +@lru_cache(maxsize=32) def _detect_content_format( chat_template: str, *, From 94a55c76813f97557a0d480c78f1488a360beee7 Mon Sep 17 00:00:00 2001 From: Hosang <156028780+hyoon1@users.noreply.github.com> Date: Fri, 27 Jun 2025 10:14:44 -0400 Subject: [PATCH 041/175] [Fix][ROCm] Remove unused variables to fix build error on GFX11/12 (#19891) Signed-off-by: Hosang Yoon --- csrc/rocm/attention.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 39997030751b8..3bddd12cad077 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1598,7 +1598,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; const int lane2id = laneid % 2; - const int lane4id = laneid % 4; const int lane16id = laneid % 16; const int rowid = laneid / 16; @@ -1745,7 +1744,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; - const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; @@ -2368,7 +2366,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; const int lane2id = laneid % 2; - const int lane4id = laneid % 4; const int lane16id = laneid % 16; const int rowid = laneid / 16; @@ -2514,7 +2511,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; - const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; From aafabaa0d5c87c283b366f81fdce55cf91ae980c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 27 Jun 2025 11:00:42 -0400 Subject: [PATCH 042/175] [Fix][torch.compile] Enable custom ops by default when Inductor off (#20102) Signed-off-by: luka --- .../model_executor/test_enabled_custom_ops.py | 45 +++++++++++-------- vllm/config.py | 27 ++++------- vllm/model_executor/custom_op.py | 12 ++--- 3 files changed, 41 insertions(+), 43 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index a94215ee397bf..140f00294765d 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation): @pytest.mark.parametrize( - "env, torch_level, ops_enabled, default_on", + "env, torch_level, use_inductor, ops_enabled, default_on", [ # Default values based on compile level - ("", 0, [True] * 4, True), - ("", 1, [True] * 4, True), - ("", 2, [True] * 4, True), # All by default - ("", 3, [False] * 4, False), - ("", 4, [False] * 4, False), # None by default + # - All by default (no Inductor compilation) + ("", 0, False, [True] * 4, True), + ("", 1, True, [True] * 4, True), + ("", 2, False, [True] * 4, True), + # - None by default (with Inductor) + ("", 3, True, [False] * 4, False), + ("", 4, True, [False] * 4, False), + # - All by default (without Inductor) + ("", 3, False, [True] * 4, True), + ("", 4, False, [True] * 4, True), # Explicitly enabling/disabling # # Default: all # # All but SiluAndMul - ("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True), + ("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True), # Only ReLU3 - ("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False), + ("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False), # All but SiluAndMul - ("all,-silu_and_mul", 1, [1, 0, 1, 1], True), + ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,relu2", 1, [1, 1, 1, 0], True), - # GeluAndMul and SiluAndMul - ("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False), + ("-relu3,relu2", 3, False, [1, 1, 1, 0], True), + # RMSNorm and SiluAndMul + ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), # All but RMSNorm - ("-rms_norm", 2, [0, 1, 1, 1], True), + ("-rms_norm", 3, False, [0, 1, 1, 1], True), # # Default: none # # Only ReLU3 - ("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False), + ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), # All but RMSNorm - ("all,-rms_norm", 4, [0, 1, 1, 1], True), + ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), ]) -def test_enabled_ops(env: str, torch_level: int, ops_enabled: list[int], - default_on: bool): - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=torch_level, custom_ops=env.split(","))) +def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, + ops_enabled: list[int], default_on: bool): + vllm_config = VllmConfig( + compilation_config=CompilationConfig(use_inductor=bool(use_inductor), + level=torch_level, + custom_ops=env.split(","))) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on diff --git a/vllm/config.py b/vllm/config.py index 623ba3aaf1093..84aa14b7c8605 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3994,7 +3994,8 @@ class CompilationConfig: - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and - disabled when running with Inductor (compile_level >= Inductor).""" + disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. + Inductor generates (fused) Triton kernels for disabled custom ops.""" splitting_ops: list[str] = field(default_factory=list) """A list of ops to split the full graph into subgraphs, used in piecewise compilation.""" @@ -4003,10 +4004,13 @@ class CompilationConfig: use_inductor: bool = True """Whether to use inductor compilation: - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for compile_sizes, - using configurations in inductor_compile_config.""" + - False: inductor compilation is not used. graph runs in eager + (custom_ops enabled by default). + - True: inductor compilation is used (custom_ops disabled by default). + One graph for symbolic shape and one graph per size in compile_sizes + are compiled using configurations in inductor_compile_config. + + This setting is ignored if level 0 and \ diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 1680b723d6a29..9c88721fb2782 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -141,16 +141,16 @@ class CustomOp(nn.Module): @staticmethod def default_on() -> bool: """ - On by default if level < CompilationLevel.PIECEWISE + On by default if PyTorch Inductor is not used. Specifying 'all' or 'none' in custom_op takes precedence. """ from vllm.config import CompilationLevel compilation_config = get_current_vllm_config().compilation_config - custom_ops = compilation_config.custom_ops - count_none = custom_ops.count("none") - count_all = custom_ops.count("all") - return compilation_config.level < CompilationLevel.PIECEWISE and \ - not count_none > 0 or count_all > 0 + default_on = (compilation_config.level < CompilationLevel.PIECEWISE + or not compilation_config.use_inductor) + count_none = compilation_config.custom_ops.count("none") + count_all = compilation_config.custom_ops.count("all") + return default_on and not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). # To check if an op with a name is enabled, call .enabled() on the class. From c6c983053d1c430e9347797094c95bbed37bac2a Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Fri, 27 Jun 2025 11:42:22 -0400 Subject: [PATCH 043/175] [Bugfix] Mark 'hidden_states' as mutable in moe_forward registration. (#20152) Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6fe95d32a10e7..672244385e52c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1743,7 +1743,8 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, direct_register_custom_op( op_name="moe_forward", op_func=moe_forward, - mutates_args=[], + mutates_args=["hidden_states"], fake_impl=moe_forward_fake, dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), ) From e8c3bd2cd164786e764f7e3436bbaa6cd00ae64a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 27 Jun 2025 12:01:28 -0400 Subject: [PATCH 044/175] [Bugfix] Fix some narrowing conversion warnings (#20141) Signed-off-by: Tyler Michael Smith --- csrc/attention/mla/cutlass_mla_kernels.cu | 2 +- csrc/mamba/causal_conv1d/causal_conv1d.cu | 8 ++------ csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 4 +--- csrc/quantization/fp4/nvfp4_experts_quant.cu | 4 ++-- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 2 +- csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu | 2 +- 6 files changed, 8 insertions(+), 14 deletions(-) diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index f4b6b19f4b232..9d05d910dd81f 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, "page_table must be a 32-bit integer tensor"); auto in_dtype = q_nope.dtype(); - at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); if (in_dtype == at::ScalarType::Half) { diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index f62d08c17c6d8..c83d72751a55c 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -185,9 +185,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, params.conv_states_ptr = nullptr; } - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { causal_conv1d_fwd_cuda(params, stream); @@ -278,9 +276,7 @@ void causal_conv1d_update(const at::Tensor &x, params.conv_state_indices_ptr = nullptr; } - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 0c9df925bdbf6..785d316025eca 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -647,9 +647,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ); - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index b51033c9b72c9..190d66f318a83 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -561,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a( TORCH_CHECK(output_scale.size(1) * 4 == padded_k); auto in_dtype = input.dtype(); - at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); if (in_dtype == at::ScalarType::Half) { @@ -579,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a( } else { TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); } -} \ No newline at end of file +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index fef74111624f0..d32911357a953 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); - at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); // We don't support e8m0 scales at this moment. diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 97c0e0da7b1fb..7572a7eb3122d 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, B_sf.sizes()[1], ")"); auto out_dtype = D.dtype(); - at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); if (out_dtype == at::ScalarType::Half) { From 3c545c0c3b98ee642373a308197d750d0e449403 Mon Sep 17 00:00:00 2001 From: Fabien Dupont Date: Fri, 27 Jun 2025 18:04:39 +0200 Subject: [PATCH 045/175] [CI/Build] Allow hermetic builds (#18064) Signed-off-by: Fabien Dupont Signed-off-by: Tyler Michael Smith Signed-off-by: Fabien Dupont Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Tyler Michael Smith Co-authored-by: Elias Levy Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docker/Dockerfile | 188 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 158 insertions(+), 30 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 8d4375470adf9..a71b052bfca25 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,30 +6,106 @@ # docs/assets/contributing/dockerfile-stages-dependency.png ARG CUDA_VERSION=12.8.1 +ARG PYTHON_VERSION=3.12 + +# By parameterizing the base images, we allow third-party to use their own +# base images. One use case is hermetic builds with base images stored in +# private registries that use a different repository naming conventions. +# +# Example: +# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 + +# By parameterizing the Deadsnakes repository URL, we allow third-party to use +# their own mirror. When doing so, we don't benefit from the transparent +# installation of the GPG key of the PPA, as done by add-apt-repository, so we +# also need a URL for the GPG key. +ARG DEADSNAKES_MIRROR_URL +ARG DEADSNAKES_GPGKEY_URL + +# The PyPA get-pip.py script is a self contained script+zip file, that provides +# both the installer script and the pip base85-encoded zip archive. This allows +# bootstrapping pip in environment where a dsitribution package does not exist. +# +# By parameterizing the URL for get-pip.py installation script, we allow +# third-party to use their own copy of the script stored in a private mirror. +# We set the default value to the PyPA owned get-pip.py script. +# +# Reference: https://pip.pypa.io/en/stable/installation/#get-pip-py +ARG GET_PIP_URL="https://bootstrap.pypa.io/get-pip.py" + +# PIP supports fetching the packages from custom indexes, allowing third-party +# to host the packages in private mirrors. The PIP_INDEX_URL and +# PIP_EXTRA_INDEX_URL are standard PIP environment variables to override the +# default indexes. By letting them empty by default, PIP will use its default +# indexes if the build process doesn't override the indexes. +# +# Uv uses different variables. We set them by default to the same values as +# PIP, but they can be overridden. +ARG PIP_INDEX_URL +ARG PIP_EXTRA_INDEX_URL +ARG UV_INDEX_URL=${PIP_INDEX_URL} +ARG UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} + +# PyTorch provides its own indexes for standard and nightly builds +ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl +ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly + +# PIP supports multiple authentication schemes, including keyring +# By parameterizing the PIP_KEYRING_PROVIDER variable and setting it to +# disabled by default, we allow third-party to use keyring authentication for +# their private Python indexes, while not changing the default behavior which +# is no authentication. +# +# Reference: https://pip.pypa.io/en/stable/topics/authentication/#keyring-support +ARG PIP_KEYRING_PROVIDER=disabled +ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER} + #################### BASE BUILD IMAGE #################### # prepare basic build environment -FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base -ARG CUDA_VERSION=12.8.1 -ARG PYTHON_VERSION=3.12 +FROM ${BUILD_BASE_IMAGE} AS base +ARG CUDA_VERSION +ARG PYTHON_VERSION ARG TARGETPLATFORM ENV DEBIAN_FRONTEND=noninteractive +ARG DEADSNAKES_MIRROR_URL +ARG DEADSNAKES_GPGKEY_URL +ARG GET_PIP_URL + # Install Python and other dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl sudo \ - && for i in 1 2 3; do \ - add-apt-repository -y ppa:deadsnakes/ppa && break || \ - { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ - done \ + && if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \ + if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \ + mkdir -p -m 0755 /etc/apt/keyrings ; \ + curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \ + sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \ + echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \ + fi ; \ + else \ + for i in 1 2 3; do \ + add-apt-repository -y ppa:deadsnakes/ppa && break || \ + { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ + done ; \ + fi \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ - && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version + +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL +ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL +ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER + # Install uv for faster pip installs RUN --mount=type=cache,target=/root/.cache/uv \ python3 -m pip install uv @@ -63,15 +139,19 @@ WORKDIR /workspace # after this step RUN --mount=type=cache,target=/root/.cache/uv \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \ - uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \ + uv pip install --system \ + --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \ + uv pip install --system \ + --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + --pre pytorch_triton==3.3.0+gitab727c40; \ fi COPY requirements/common.txt requirements/common.txt COPY requirements/cuda.txt requirements/cuda.txt RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/cuda.txt \ - --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -88,6 +168,10 @@ ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} FROM base AS build ARG TARGETPLATFORM +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL + # install build dependencies COPY requirements/build.txt requirements/build.txt @@ -98,7 +182,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/build.txt \ - --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') COPY . . ARG GIT_REPO_CHECK=0 @@ -113,6 +197,8 @@ ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads ARG USE_SCCACHE +ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz +ARG SCCACHE_ENDPOINT ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_S3_NO_CREDENTIALS=0 @@ -121,10 +207,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" = "1" ]; then \ echo "Installing sccache..." \ - && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ + && curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \ && tar -xzf sccache.tar.gz \ && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ + && if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \ && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ @@ -162,6 +249,10 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ #################### DEV IMAGE #################### FROM base as dev +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL + # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 @@ -176,21 +267,25 @@ COPY requirements/test.txt requirements/test.txt COPY requirements/dev.txt requirements/dev.txt RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/dev.txt \ - --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') #################### DEV IMAGE #################### #################### vLLM installation IMAGE #################### # image with vLLM installed # TODO: Restore to base image after FlashInfer AOT wheel fixed -FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base -ARG CUDA_VERSION=12.8.1 -ARG PYTHON_VERSION=3.12 +FROM ${FINAL_BASE_IMAGE} AS vllm-base +ARG CUDA_VERSION +ARG PYTHON_VERSION WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM SHELL ["/bin/bash", "-c"] +ARG DEADSNAKES_MIRROR_URL +ARG DEADSNAKES_GPGKEY_URL +ARG GET_PIP_URL + RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment @@ -200,17 +295,33 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ - && for i in 1 2 3; do \ - add-apt-repository -y ppa:deadsnakes/ppa && break || \ - { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ - done \ + && if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \ + if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \ + mkdir -p -m 0755 /etc/apt/keyrings ; \ + curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \ + sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \ + echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \ + fi ; \ + else \ + for i in 1 2 3; do \ + add-apt-repository -y ppa:deadsnakes/ppa && break || \ + { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ + done ; \ + fi \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ - && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version + +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL +ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL +ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER + # Install uv for faster pip installs RUN --mount=type=cache,target=/root/.cache/uv \ python3 -m pip install uv @@ -232,15 +343,19 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ # after this step RUN --mount=type=cache,target=/root/.cache/uv \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \ - uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \ + uv pip install --system \ + --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319" ; \ + uv pip install --system \ + --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + --pre pytorch_triton==3.3.0+gitab727c40 ; \ fi # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/uv \ uv pip install --system dist/*.whl --verbose \ - --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # If we need to build FlashInfer wheel before its release: # $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+ @@ -254,15 +369,20 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl # $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl +# Allow specifying a version, Git revision or local .whl file +ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer" +ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl" +ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" +ARG FLASHINFER_GIT_REF="v0.2.6.post1" RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ # FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use if [[ "$CUDA_VERSION" == 12.8* ]]; then \ - uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl; \ + uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \ else \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \ - git clone https://github.com/flashinfer-ai/flashinfer.git --single-branch --branch v0.2.6.post1 --recursive && \ + git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \ # Needed to build AOT kernels (cd flashinfer && \ python3 -m flashinfer.aot && \ @@ -286,7 +406,7 @@ uv pip list COPY requirements/build.txt requirements/build.txt RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/build.txt \ - --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') #################### vLLM installation IMAGE #################### @@ -297,6 +417,11 @@ FROM vllm-base AS test ADD . /vllm-workspace/ +ARG PYTHON_VERSION + +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL + # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 @@ -307,7 +432,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" # install development dependencies (for testing) -RUN --mount=type=cache,target=/root/.cache/uv \ +RUN --mount=type=cache,target=/root/.cache/uv \ CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ if [ "$CUDA_MAJOR" -ge 12 ]; then \ uv pip install --system -r requirements/dev.txt; \ @@ -323,7 +448,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ENV HF_HUB_ENABLE_HF_TRANSFER 1 # Copy in the v1 package for testing (it isn't distributed yet) -COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 +COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 # doc requires source code # we hide them inside `test_docs/` , so that this source code @@ -340,6 +465,9 @@ RUN mv mkdocs.yaml test_docs/ FROM vllm-base AS vllm-openai-base ARG TARGETPLATFORM +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL + # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 From c329ceca6dd7263a65c7913a14de943266a38088 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 28 Jun 2025 14:43:06 +0900 Subject: [PATCH 046/175] [CI Fix] Pin tests/models/registry.py MiniMaxText01ForCausalLM to revision due to model changes (#20199) Signed-off-by: mgoin --- tests/models/registry.py | 9 ++++++++- tests/models/test_initialization.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 1bcb4f88a30ff..72e361e2637fd 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -70,6 +70,12 @@ class _HfExamplesInfo: length that is too large to fit into memory in CI. """ + revision: Optional[str] = None + """ + The specific revision (commit hash, tag, or branch) to use for the model. + If not specified, the default revision will be used. + """ + def check_transformers_version( self, *, @@ -207,7 +213,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", trust_remote_code=True), "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", - trust_remote_code=True), + trust_remote_code=True, + revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501 "MiniMaxM1ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-M1-40k", trust_remote_code=True), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index e56bc925c9c40..df72607767fdd 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -88,6 +88,7 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): model_info.default, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, speculative_config={ "model": model_info.speculative_model, "num_speculative_tokens": 1, From e53be6f00ab37c58d767e44d8ceef26e9defc743 Mon Sep 17 00:00:00 2001 From: Chales Xu <111160781+SHA-4096@users.noreply.github.com> Date: Sat, 28 Jun 2025 13:47:36 +0800 Subject: [PATCH 047/175] [Misc] Add type assertion of request_id for LLMEngine.add_request (#19700) Signed-off-by: n2ptr --- tests/mq_llm_engine/test_error_handling.py | 10 +++++----- vllm/engine/llm_engine.py | 4 ++++ vllm/v1/engine/llm_engine.py | 5 +++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 49b02279d61bb..3feee01dadf73 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -66,7 +66,7 @@ async def test_evil_forward(tmp_socket): with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass assert client.errored @@ -115,7 +115,7 @@ async def test_failed_health_check(tmp_socket): with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass client.close() @@ -157,7 +157,7 @@ async def test_failed_abort(tmp_socket): async for _ in client.generate( prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=10), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass assert "KeyError" in repr(execinfo.value) assert client.errored @@ -189,7 +189,7 @@ async def test_batch_error(tmp_socket): params = SamplingParams(min_tokens=2048, max_tokens=2048) async for _ in client.generate(prompt="Hello my name is", sampling_params=params, - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] @@ -289,7 +289,7 @@ async def test_engine_process_death(tmp_socket): with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass # And the health check should show the engine is dead diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8fccf9bd2aa00..25fa1c3058bef 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -687,6 +687,10 @@ class LLMEngine: >>> # continue the request processing >>> ... """ + if not isinstance(request_id, str): + raise TypeError( + f"request_id must be a string, got {type(request_id)}") + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 25fab27131142..a2328c37ba0c5 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -192,6 +192,11 @@ class LLMEngine: prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + # Validate the request_id type. + if not isinstance(request_id, str): + raise TypeError( + f"request_id must be a string, got {type(request_id)}") + # Process raw inputs into the request. prompt_str, request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, From a29e62ea3452bc6b1d4f3eeac2dc9a6b30357c4d Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 28 Jun 2025 14:48:13 +0900 Subject: [PATCH 048/175] Fix num_token_padding support for static per-tensor scaled_fp8_quant (#20188) Signed-off-by: mgoin --- vllm/_custom_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 215f35bad34d9..51900de1cc099 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1274,8 +1274,7 @@ def scaled_fp8_quant( scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 and num_token_padding is None) + assert scale.numel() == 1 torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale From d45417b804aaf7f90c9ae70a32f8f07d6b371a8c Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sat, 28 Jun 2025 01:50:00 -0400 Subject: [PATCH 049/175] fix ci issue distributed 4 gpu test (#20204) Signed-off-by: yewentao256 --- examples/offline_inference/data_parallel.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 3eccb4e11ab6f..dbf8ed58cc477 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -64,6 +64,18 @@ def parse_args(): parser.add_argument( "--trust-remote-code", action="store_true", help="Trust remote code." ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=64, + help=("Maximum number of sequences to be processed in a single iteration."), + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.8, + help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), + ) return parser.parse_args() @@ -77,6 +89,8 @@ def main( GPUs_per_dp_rank, enforce_eager, trust_remote_code, + max_num_seqs, + gpu_memory_utilization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) @@ -127,6 +141,8 @@ def main( enforce_eager=enforce_eager, enable_expert_parallel=True, trust_remote_code=trust_remote_code, + max_num_seqs=max_num_seqs, + gpu_memory_utilization=gpu_memory_utilization, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -181,6 +197,8 @@ if __name__ == "__main__": tp_size, args.enforce_eager, args.trust_remote_code, + args.max_num_seqs, + args.gpu_memory_utilization, ), ) proc.start() From f719772281c76fee6a8641647aa40fbab8a0f3a4 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 28 Jun 2025 14:50:52 +0900 Subject: [PATCH 050/175] [Bugfix] Properly reject requests with empty list guided_choice (#20195) Signed-off-by: mgoin --- vllm/v1/engine/processor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a0b170ba55ad7..7e7703df2cf10 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -173,6 +173,12 @@ class Processor: params.guided_decoding.backend = engine_level_backend # Request content validation + if (isinstance(params.guided_decoding.choice, list) + and not params.guided_decoding.choice): + # It is invalid for choice to be an empty list + raise ValueError(f"Choice '{params.guided_decoding.choice}' " + "cannot be an empty list") + if engine_level_backend.startswith("xgrammar"): # xgrammar with no fallback validate_xgrammar_grammar(params) From 7b460c25f987dfb6ca7cebc7d1e2c26989801674 Mon Sep 17 00:00:00 2001 From: Jiayi Yan <66017932+1195343015@users.noreply.github.com> Date: Sat, 28 Jun 2025 13:51:16 +0800 Subject: [PATCH 051/175] [BugFix] Fix the incorrect func name in the comments. (config.py) (#20185) --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 84aa14b7c8605..57b9df2364775 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1484,7 +1484,7 @@ class CacheConfig: sizes up to 32 are supported. On HPU devices, block size defaults to 128. This config has no static default. If left unspecified by the user, it will - be set in `Platform.check_and_update_configs()` based on the current + be set in `Platform.check_and_update_config()` based on the current platform.""" gpu_memory_utilization: float = 0.9 """The fraction of GPU memory to be used for the model executor, which can From 8615d9776fb066e3b11284bb9e6871e5d8820463 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 28 Jun 2025 08:00:25 +0200 Subject: [PATCH 052/175] [CI/Build] Add new CI job to validate Hybrid Models for every PR (#20147) Signed-off-by: Thomas Parnell --- .buildkite/test-pipeline.yaml | 13 ++++++++++++- pyproject.toml | 1 + tests/models/language/generation/test_hybrid.py | 3 +++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7f1841b1c97cd..a13e2cb782182 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -536,6 +536,17 @@ steps: - pip freeze | grep -E 'torch' - pytest -v -s models/language -m core_model +- label: Language Models Test (Hybrid) # 35 min + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pytest -v -s models/language/generation -m hybrid_model + - label: Language Models Test (Extended Generation) # 1hr20min mirror_hardwares: [amdexperimental] optional: true @@ -545,7 +556,7 @@ steps: commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - - pytest -v -s models/language/generation -m 'not core_model' + - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' - label: Language Models Test (Extended Pooling) # 36min mirror_hardwares: [amdexperimental] diff --git a/pyproject.toml b/pyproject.toml index e8c2403af064f..fb45572d265b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ skip_gitignore = true markers = [ "skip_global_cleanup", "core_model: enable this model test in each PR instead of only nightly", + "hybrid_model: models that contain mamba layers (including pure SSM and hybrid architectures)", "cpu_model: enable this model test in CPU tests", "split: run this test as part of a split", "distributed: run this test only in distributed GPU tests", diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 90c4cd968e7a2..b2348e6449339 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -9,6 +9,9 @@ from vllm.sampling_params import SamplingParams from ...utils import check_logprobs_close, check_outputs_equal +# Mark all tests as hybrid +pytestmark = pytest.mark.hybrid_model + # NOTE: The first model in each list is taken as the primary model, # meaning that it will be used in all tests in this file # The rest of the models will only be tested by test_models From daceac57c7d79c4736d64621610076b3a98b0209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Sat, 28 Jun 2025 17:15:26 +0200 Subject: [PATCH 053/175] [Frontend] Generalize `v1/audio/transcriptions` endpoint (#20179) Signed-off-by: NickLucche --- vllm/entrypoints/openai/speech_to_text.py | 142 +++------------------- vllm/model_executor/models/interfaces.py | 11 ++ vllm/model_executor/models/whisper.py | 129 ++++++++++++++++++++ 3 files changed, 154 insertions(+), 128 deletions(-) diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index b23cf6cab0979..6c16e53245314 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing, from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger +from vllm.model_executor.model_loader.utils import get_model_architecture from vllm.outputs import RequestOutput from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import PlaceholderModule @@ -38,118 +39,10 @@ T = TypeVar("T", bound=SpeechToTextResponse) logger = init_logger(__name__) -# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages -# TODO these configs should live somewhere with the model so we can support -# additional ones - -ISO639_1_SUPPORTED_LANGS = { - "af": "Afrikaans", - "ar": "Arabic", - "hy": "Armenian", - "az": "Azerbaijani", - "be": "Belarusian", - "bs": "Bosnian", - "bg": "Bulgarian", - "ca": "Catalan", - "zh": "Chinese", - "hr": "Croatian", - "cs": "Czech", - "da": "Danish", - "nl": "Dutch", - "en": "English", - "et": "Estonian", - "fi": "Finnish", - "fr": "French", - "gl": "Galician", - "de": "German", - "el": "Greek", - "he": "Hebrew", - "hi": "Hindi", - "hu": "Hungarian", - "is": "Icelandic", - "id": "Indonesian", - "it": "Italian", - "ja": "Japanese", - "kn": "Kannada", - "kk": "Kazakh", - "ko": "Korean", - "lv": "Latvian", - "lt": "Lithuanian", - "mk": "Macedonian", - "ms": "Malay", - "mr": "Marathi", - "mi": "Maori", - "ne": "Nepali", - "no": "Norwegian", - "fa": "Persian", - "pl": "Polish", - "pt": "Portuguese", - "ro": "Romanian", - "ru": "Russian", - "sr": "Serbian", - "sk": "Slovak", - "sl": "Slovenian", - "es": "Spanish", - "sw": "Swahili", - "sv": "Swedish", - "tl": "Tagalog", - "ta": "Tamil", - "th": "Thai", - "tr": "Turkish", - "uk": "Ukrainian", - "ur": "Urdu", - "vi": "Vietnamese", - "cy": "Welsh" -} -ISO639_1_OTHER_LANGS = { - "lo": "Lao", - "jw": "Javanese", - "tk": "Turkmen", - "yi": "Yiddish", - "so": "Somali", - "bn": "Bengali", - "nn": "Norwegian Nynorsk", - "si": "Sinhala", - "yo": "Yoruba", - "sa": "Sanskrit", - "mi": "Māori", - "fo": "Faroese", # codespell:ignore - "mt": "Maltese", - "tg": "Tajik", - "mg": "Malagasy", - "haw": "Hawaiian", - "km": "Khmer", - "br": "Breton", - "ps": "Pashto", - "ln": "Lingala", - "la": "Latin", - "ml": "Malayalam", - "sq": "Albanian", - "su": "Sundanese", - "eu": "Basque", - "ka": "Georgian", - "uz": "Uzbek", - "sn": "Shona", - "ht": "Haitian", - "as": "Assamese", - "mn": "Mongolian", - "te": "Telugu", - "pa": "Panjabi", - "tt": "Tatar", - "gu": "Gujarati", - "oc": "Occitan", - "ha": "Hausa", - "ba": "Bashkir", - "my": "Burmese", - "sd": "Sindhi", - "am": "Amharic", - "lb": "Luxembourgish", - "bo": "Tibetan" -} - # As per https://platform.openai.com/docs/guides/speech-to-text#overview. # TODO configurable MAX_AUDIO_CLIP_FILESIZE_MB = 25 +MAX_AUDIO_CLIP_SECONDS = 30 OVERLAP_CHUNK_SECOND = 1 MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio @@ -177,10 +70,13 @@ class OpenAISpeechToText(OpenAIServing): self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) processor = cached_get_processor(model_config.model) - self.max_audio_clip_s = processor.feature_extractor.chunk_length + self.max_audio_clip_s = processor.feature_extractor.chunk_length \ + if hasattr(processor.feature_extractor, 'chunk_length') \ + else MAX_AUDIO_CLIP_SECONDS self.model_sr = processor.feature_extractor.sampling_rate self.hop_length = processor.feature_extractor.hop_length self.task_type = task_type + self.model_cls, _ = get_model_architecture(model_config) if self.default_sampling_params: logger.info( @@ -196,21 +92,8 @@ class OpenAISpeechToText(OpenAIServing): # TODO language should be optional and can be guessed. # For now we default to en. See # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 - lang_token = f"<|{request.language}|>" if request.language else "<|en|>" - if request.language: - if request.language in ISO639_1_SUPPORTED_LANGS: - pass - elif request.language in ISO639_1_OTHER_LANGS: - logger.warning( - "The selected language %s has limited accuracy with" - " reported WER>=0.5. Results may be less accurate " - "for this choice.", request.language) - else: - raise ValueError( - f"Unsupported language: {request.language}." - "Language should be one of:" + - f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + - f"or {list(ISO639_1_OTHER_LANGS.values())}") + lang = request.language or "en" + self.model_cls.validate_language(lang) # type: ignore[attr-defined] if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: raise ValueError("Maximum file size exceeded.") @@ -221,7 +104,9 @@ class OpenAISpeechToText(OpenAIServing): y, sr = librosa.load(bytes_, sr=self.model_sr) duration = librosa.get_duration(y=y, sr=sr) - chunks = [y] if duration < 30 else self._split_audio(y, int(sr)) + chunks = [y + ] if duration < self.max_audio_clip_s else self._split_audio( + y, int(sr)) prompts = [] for chunk in chunks: prompt = { @@ -232,8 +117,9 @@ class OpenAISpeechToText(OpenAIServing): }, }, "decoder_prompt": - (f"<|startoftranscript|>{lang_token}" - f"<|{self.task_type}|><|notimestamps|>{request.prompt}") + self.model_cls. + get_decoder_prompt( # type: ignore[attr-defined] + lang, self.task_type, request.prompt) } prompts.append(cast(PromptType, prompt)) return prompts, duration diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 3ea424e44b62e..ad59fe79edcb1 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -599,6 +599,17 @@ class SupportsTranscription(Protocol): supports_transcription: ClassVar[Literal[True]] = True + @classmethod + def get_decoder_prompt(cls, language: str, task_type: str, + prompt: str) -> str: + """Get the decoder prompt for the ASR model.""" + ... + + @classmethod + def validate_language(cls, language: str) -> bool: + """Check if the model supports a specific ISO639_1 language.""" + ... + @overload def supports_transcription( diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 8cf2a009d6670..5a0094fa749fd 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -41,6 +41,113 @@ from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, logger = init_logger(__name__) +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages + +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", # codespell:ignore + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + class WhisperAudioInputs(TypedDict): input_features: NestedTensors @@ -731,6 +838,28 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, weights = _create_fake_bias_for_k_proj(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + @classmethod + def validate_language(cls, language: str) -> bool: + if language in ISO639_1_SUPPORTED_LANGS: + return True + elif language in ISO639_1_OTHER_LANGS: + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", language) + return True + else: + raise ValueError(f"Unsupported language: {language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + @classmethod + def get_decoder_prompt(cls, language: str, task_type: str, + prompt: str) -> str: + return (f"<|startoftranscript|><|{language}|><|{task_type}|>" + f"<|notimestamps|>{prompt}") + def _create_fake_bias_for_k_proj( weights: Iterable[tuple[str, torch.Tensor]] From daec9dea6e9bb95cb12faa38e347196de4f672a0 Mon Sep 17 00:00:00 2001 From: Stan Wozniak <77159600+s3woz@users.noreply.github.com> Date: Sat, 28 Jun 2025 17:16:41 +0200 Subject: [PATCH 054/175] [Bugfix] Correct behavior of GraniteMoeHybrid for TensorParallel execution (#20137) Signed-off-by: Stanislaw Wozniak --- .../generation/test_granitemoehybrid.py | 42 ------- .../models/language/generation/test_hybrid.py | 5 +- .../model_executor/models/granitemoehybrid.py | 104 ++++++++++++------ 3 files changed, 73 insertions(+), 78 deletions(-) delete mode 100644 tests/models/language/generation/test_granitemoehybrid.py diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py deleted file mode 100644 index 952449f284159..0000000000000 --- a/tests/models/language/generation/test_granitemoehybrid.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from ...utils import check_logprobs_close - -# Path of the checkpoints -MODELS = [ - "ibm-granite/granite-4.0-tiny-preview", -] - - -@pytest.mark.skip( - reason="Granite 4.0 is not yet available in huggingface transformers") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_model_equivalence_to_hf_greedy( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -): - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index b2348e6449339..e6dd6c35e64d6 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -28,8 +28,9 @@ SSM_MODELS = [ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as - # it is not yet available in huggingface transformers + # NOTE: Currently the test failes due to HF transformers issue fixed in: + # https://github.com/huggingface/transformers/pull/39033 + # We will enable vLLM test for Granite after next HF transformers release. # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 26b5b3ac15345..33e8626209d50 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -15,7 +15,8 @@ from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) @@ -36,8 +37,9 @@ from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant, SupportsV0Only) -from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory, - make_layers, maybe_prefix) +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class GraniteMoeHybridMambaDecoderLayer(nn.Module): @@ -220,35 +222,37 @@ class GraniteMoeHybridAttention(nn.Module): self.hidden_size = config.hidden_size self.attention_bias = config.attention_bias self.attention_multiplier = config.attention_multiplier - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads + self.total_num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads - self.q_proj = ReplicatedLinear(self.hidden_size, - self.num_heads * self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + # TensorParallel logic + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size) - self.k_proj = ReplicatedLinear(self.hidden_size, - self.num_key_value_heads * - self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.k_proj") + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") - self.v_proj = ReplicatedLinear(self.hidden_size, - self.num_key_value_heads * - self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.v_proj") - - self.o_proj = ReplicatedLinear(self.hidden_size, - self.hidden_size, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") if config.position_embedding_type == "rope": self.rotary_emb = get_rope( @@ -278,9 +282,12 @@ class GraniteMoeHybridAttention(nn.Module): hidden_states: torch.Tensor, ) -> torch.Tensor: - query = self.q_proj(hidden_states)[0] - key = self.k_proj(hidden_states)[0] - value = self.v_proj(hidden_states)[0] + qkv, _ = self.qkv_proj(hidden_states) + query, key, value = qkv.split([ + self.num_heads * self.head_dim, self.num_key_value_heads * + self.head_dim, self.num_key_value_heads * self.head_dim + ], + dim=-1) if self.rotary_emb is not None: query, key = self.rotary_emb(positions, query, key) @@ -401,6 +408,12 @@ class GraniteMoeHybridModel(nn.Module): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -411,6 +424,15 @@ class GraniteMoeHybridModel(nn.Module): weight_loader(param, p) loaded_params.add(n) + def _load_shard(n, p, shard_id): + # Skip layers on other devices. + if not is_pp_missing_parameter(n, self): + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, p, shard_id) + loaded_params.add(n) + def _load_expert(n, p, name, shard_id, expert_id): param = params_dict[n] weight_loader = getattr(param, "weight_loader", @@ -465,7 +487,15 @@ class GraniteMoeHybridModel(nn.Module): ".block_sparse_moe.gate.weight") _load(gate_name, p) else: - _load(n, p) + loaded = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in n: + _load_shard(n.replace(weight_name, param_name), + p, + shard_id=shard_id) + loaded = True + if not loaded: + _load(n, p) return loaded_params @@ -473,7 +503,13 @@ class GraniteMoeHybridModel(nn.Module): class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsV0Only, SupportsQuant): - packed_modules_mapping = {} + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", From 4d366936875330526908185ac93ed1e0e0eb7f40 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sat, 28 Jun 2025 18:06:38 -0400 Subject: [PATCH 055/175] [Refactor] Create a function util and cache the results for `has_deepgemm`, `has_deepep`, `has_pplx` (#20187) Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepep_deepgemm_moe.py | 20 +++++-------- tests/kernels/moe/test_deepep_moe.py | 8 ++---- .../device_communicators/all2all.py | 10 +++---- .../layers/fused_moe/batched_deep_gemm_moe.py | 3 -- .../layers/fused_moe/deep_gemm_moe.py | 12 ++++---- vllm/model_executor/layers/fused_moe/layer.py | 10 ++----- vllm/model_executor/layers/fused_moe/utils.py | 2 +- .../compressed_tensors_moe.py | 8 ++---- .../layers/quantization/deepgemm.py | 6 ++-- .../model_executor/layers/quantization/fp8.py | 6 ++-- .../layers/quantization/utils/fp8_utils.py | 6 ++-- vllm/utils.py | 28 +++++++++++++++++++ 12 files changed, 61 insertions(+), 58 deletions(-) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index f580dee4c9285..475427f439289 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -6,7 +6,6 @@ fp8 block-quantized case. """ import dataclasses -import importlib from typing import Optional import pytest @@ -21,18 +20,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform +from vllm.utils import has_deep_ep, has_deep_gemm from .utils import ProcessGroupInfo, parallel_launch -has_deep_ep = importlib.util.find_spec("deep_ep") is not None - -try: - import deep_gemm - has_deep_gemm = True -except ImportError: - has_deep_gemm = False - -if has_deep_ep: +if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 @@ -40,19 +32,21 @@ if has_deep_ep: from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a -if has_deep_gemm: +if has_deep_gemm(): + import deep_gemm + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts) requires_deep_ep = pytest.mark.skipif( - not has_deep_ep, + not has_deep_ep(), reason="Requires deep_ep kernels", ) requires_deep_gemm = pytest.mark.skipif( - not has_deep_gemm, + not has_deep_gemm(), reason="Requires deep_gemm kernels", ) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 380eb43c42a40..80a36dc39712a 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -4,7 +4,6 @@ Test deepep dispatch-combine logic """ import dataclasses -import importlib from typing import Optional, Union import pytest @@ -22,12 +21,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform +from vllm.utils import has_deep_ep from .utils import ProcessGroupInfo, parallel_launch -has_deep_ep = importlib.util.find_spec("deep_ep") is not None - -if has_deep_ep: +if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 @@ -36,7 +34,7 @@ if has_deep_ep: from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a requires_deep_ep = pytest.mark.skipif( - not has_deep_ep, + not has_deep_ep(), reason="Requires deep_ep kernels", ) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 35f2fd0ba9e22..85f87cb21edcd 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib.util from typing import TYPE_CHECKING, Any import torch @@ -8,6 +7,7 @@ import torch.distributed as dist from vllm.forward_context import get_forward_context from vllm.logger import init_logger +from vllm.utils import has_deep_ep, has_pplx from .base_device_communicator import All2AllManagerBase, Cache @@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase): """ def __init__(self, cpu_group): - has_pplx = importlib.util.find_spec("pplx_kernels") is not None - assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + assert has_pplx( + ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa super().__init__(cpu_group) if self.internode: @@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): """ def __init__(self, cpu_group): - has_deepep = importlib.util.find_spec("deep_ep") is not None - assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa + assert has_deep_ep( + ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa super().__init__(cpu_group) self.handle_cache = Cache() diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 70836879d17c0..b54ac80535a42 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import importlib.util from typing import Optional import torch @@ -11,8 +10,6 @@ from vllm.triton_utils import tl, triton logger = init_logger(__name__) -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - @triton.jit def _silu_mul_fp8_quant_deep_gemm( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 050d9520ca013..321fb0351ad93 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -import importlib.util from typing import Optional import torch @@ -12,14 +11,13 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, per_token_group_quant_fp8) -from vllm.utils import round_up +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import has_deep_gemm, round_up logger = init_logger(__name__) -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - @functools.cache def deep_gemm_block_shape() -> list[int]: @@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, gemm kernel. All of M, N, K and the quantization block_shape must be aligned by `dg.get_m_alignment_for_contiguous_layout()`. """ - if not has_deep_gemm: + if not has_deep_gemm(): logger.debug("DeepGemm disabled: deep_gemm not available.") return False diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 672244385e52c..5408ef1f75e89 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib from abc import abstractmethod from collections.abc import Iterable from dataclasses import dataclass @@ -32,10 +31,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op - -has_pplx = importlib.util.find_spec("pplx_kernels") is not None -has_deepep = importlib.util.find_spec("deep_ep") is not None +from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -43,9 +39,9 @@ if current_platform.is_cuda_alike(): from .modular_kernel import (FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) - if has_pplx: + if has_pplx(): from .pplx_prepare_finalize import PplxPrepareAndFinalize - if has_deepep: + if has_deep_ep(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE, DeepEPLLPrepareAndFinalize) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 8f3191db680fd..4c91e697f8e97 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -104,4 +104,4 @@ def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] \ No newline at end of file + return s.getsockname()[1] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7703b9e687c4a..4a19473005708 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum -import importlib from enum import Enum from typing import Callable, Optional @@ -29,13 +28,12 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types - -has_pplx = importlib.util.find_spec("pplx_kernels") is not None +from vllm.utils import has_pplx if current_platform.is_cuda_alike(): from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize) - if has_pplx: + if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) @@ -577,7 +575,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): use_batched_format=True, ) - if has_pplx and isinstance( + if has_pplx() and isinstance( prepare_finalize, (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): # no expert_map support in this case diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index 1d40f4915a1be..e4cf647407584 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -import importlib.util import logging import torch from vllm.platforms import current_platform from vllm.triton_utils import triton -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, has_deep_gemm -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None -if has_deep_gemm: +if has_deep_gemm(): import deep_gemm logger = logging.getLogger(__name__) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d2eda541f7a40..93472207fbb86 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -import importlib.util from typing import Any, Callable, Optional, Union import torch @@ -38,13 +37,12 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter, from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils import has_deep_gemm ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - def _is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 @@ -451,7 +449,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # Check for DeepGemm support. self.allow_deep_gemm = False if envs.VLLM_USE_DEEP_GEMM: - if not has_deep_gemm: + if not has_deep_gemm(): logger.warning_once("Failed to import DeepGemm kernels.") elif not self.block_quant: logger.warning_once("Model is not block quantized. Not using " diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 3a0fb83d627af..c38a445c571b8 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -3,7 +3,6 @@ # Adapted from https://github.com/sgl-project/sglang/pull/2575 import functools -import importlib.util import json import os from typing import Any, Callable, Optional, Union @@ -19,10 +18,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm logger = init_logger(__name__) -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: @@ -109,7 +107,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): """ return (current_platform.is_cuda() - and current_platform.is_device_capability(90) and has_deep_gemm + and current_platform.is_device_capability(90) and has_deep_gemm() and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) diff --git a/vllm/utils.py b/vllm/utils.py index fdefda901c4d2..7eb3c1e347cde 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2929,3 +2929,31 @@ def is_torch_equal_or_newer(target: str) -> bool: def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: torch_version = version.parse(torch_version) return torch_version >= version.parse(target) + + +@cache +def _has_module(module_name: str) -> bool: + """Return True if *module_name* can be found in the current environment. + + The result is cached so that subsequent queries for the same module incur + no additional overhead. + """ + return importlib.util.find_spec(module_name) is not None + + +def has_pplx() -> bool: + """Whether the optional `pplx_kernels` package is available.""" + + return _has_module("pplx_kernels") + + +def has_deep_ep() -> bool: + """Whether the optional `deep_ep` package is available.""" + + return _has_module("deep_ep") + + +def has_deep_gemm() -> bool: + """Whether the optional `deep_gemm` package is available.""" + + return _has_module("deep_gemm") \ No newline at end of file From 7b1895e6ce4942091e16da790af8c12772a1d384 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sun, 29 Jun 2025 11:31:37 +0900 Subject: [PATCH 056/175] [CI Fix] Try fixing eagle e2e test OOM by reducing block allocation (#20213) Signed-off-by: mgoin --- tests/spec_decode/e2e/test_eagle_correctness.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index fd838285aba7c..7c369feec4152 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -370,6 +370,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize( "common_llm_kwargs", [{ + # 2 for small prompt, 256//16 for generated. + "num_gpu_blocks_override": 2 + 256 // 16, + "max_model_len": (2 + 256 // 16) * 16, + # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -420,6 +424,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize( "common_llm_kwargs", [{ + # 2 for small prompt, 256//16 for generated. + "num_gpu_blocks_override": 2 + 256 // 16, + "max_model_len": (2 + 256 // 16) * 16, + # Skip cuda graph recording for fast test. "enforce_eager": True, From 6f2f53a82dc9703abb389761bf9931e3c9a5a75b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 30 Jun 2025 00:05:40 +0200 Subject: [PATCH 057/175] [Quantization] Add compressed-tensors NVFP4 MoE Support (#19990) Signed-off-by: Dipika Sikka Signed-off-by: Dipika --- tests/quantization/test_compressed_tensors.py | 6 +- vllm/model_executor/layers/fused_moe/layer.py | 4 +- .../compressed_tensors/compressed_tensors.py | 4 +- .../compressed_tensors_moe.py | 275 +++++++++++++++++- .../schemes/compressed_tensors_w4a4_nvfp4.py | 13 +- .../utils/nvfp4_emulation_utils.py | 15 +- 6 files changed, 295 insertions(+), 22 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 516bf4513816a..3646ad6c481b0 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsWNA16, cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform @@ -668,8 +668,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) if isinstance(qkv_proj.scheme, scheme) or isinstance( - qkv_proj.scheme, CompressedTensorsW4A16Fp4 - ) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported(): + qkv_proj.scheme, + CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported(): assert True else: raise AssertionError("FP4 Scheme Mismatch") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5408ef1f75e89..e6f555d315d8e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1246,6 +1246,7 @@ class FusedMoE(torch.nn.Module): param.materialize(final_shape, dtype=loaded_weight.dtype) expert_data = param.data if full_load else param.data[expert_id] + # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # this is needed for compressed-tensors only @@ -1273,6 +1274,7 @@ class FusedMoE(torch.nn.Module): tp_rank=self.tp_rank) return True if return_success else None + # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern if "ModelOpt" in quant_method_name: if ('weight_scale_2' in weight_name or 'input_scale' in weight_name): @@ -1289,7 +1291,7 @@ class FusedMoE(torch.nn.Module): tp_rank=self.tp_rank) return True if return_success else None - # Case weight scales, zero_points and offset + # Case weight scales, zero_points and offset, weight/input global scales if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): # load the weight scales and zp based on the quantization scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d21abb2741a29..4f87b2a44f0ac 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -33,6 +33,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -375,7 +377,7 @@ class CompressedTensorsConfig(QuantizationConfig): if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if CompressedTensorsW4A4Fp4.cutlass_fp4_supported( + if cutlass_fp4_supported( ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 4a19473005708..fa4ce5668091b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -21,8 +21,12 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -46,12 +50,11 @@ class GPTQMarlinState(Enum): __all__ = [ - "CompressedTensorsMoEMethod", - "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoECutlassMethod", "CompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW4A4MoeMethod" ] @@ -84,6 +87,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod(quant_config) + elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A4MoeMethod() elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): @@ -95,6 +100,268 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") +class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): + + def __init__(self): + self.use_marlin = not cutlass_fp4_supported() + self.group_size = 16 + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # From packed to weight + layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data, + requires_grad=False) + + layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data, + requires_grad=False) + + if not torch.allclose(layer.w13_weight_global_scale[:, 0], + layer.w13_weight_global_scale[:, 1]): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected.") + + # Take inverse of global scale saved to disk + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False) + + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False) + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + return + + # swizzle weight scales + layer.w13_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w13_weight_scale), + requires_grad=False) + + layer.w2_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w2_weight_scale), + requires_grad=False) + + # w13 + w13_input_global_scale = layer.w13_input_global_scale.max( + dim=1).values.to(torch.float32) + + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False) + + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False) + + # w2 + layer.g2_alphas = torch.nn.Parameter( + ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to( + torch.float32), + requires_grad=False) + + layer.w2_input_scale_quant = torch.nn.Parameter( + (layer.w2_input_global_scale), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError("EPLB not supported for " + "`CompressedTensorsW4A4MoeMethod` yet.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + if self.use_marlin: + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for CompressedTensorsW4A4MoeMethod.") + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "CompressedTensorsW4A4MoeMethod.") + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype) + + class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index ec1d4a6c0efae..65cbc49d2640a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -5,8 +5,7 @@ import torch from torch.nn.parameter import Parameter import vllm.envs as envs -from vllm._custom_ops import (cutlass_scaled_fp4_mm, - cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) @@ -15,7 +14,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -33,15 +31,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): return 80 return 100 - @classmethod - def cutlass_fp4_supported(cls) -> bool: - if not current_platform.is_cuda(): - return False - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501 - ) - return cutlass_scaled_mm_supports_fp4(capability) - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], input_size_per_partition: int, diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index d5ce6d7ad757a..fb3287d3b89e6 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -2,9 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"] +__all__ = [ + "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant", + "cutlass_fp4_supported" +] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() @@ -12,6 +17,14 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], dtype=torch.float32) +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + def break_fp4_bytes(a, dtype): assert a.dtype == torch.uint8 m, n = a.shape From 6c9837a761409aecf47e94ae4272879ce0c81590 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sun, 29 Jun 2025 16:52:34 -0700 Subject: [PATCH 058/175] Fix cuda_archs_loose_intersection when handling sm_*a (#20207) Signed-off-by: Huy Do --- CMakeLists.txt | 14 ++++++++++++-- cmake/utils.cmake | 33 ++++++++++++++------------------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b1adeac586f2e..f6f8d59d28aef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -562,7 +562,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "if you intend on running FP8 quantized MoE models on Hopper.") else() message(STATUS "Not building grouped_mm_c3x as no compatible archs found " - "in CUDA target architectures") + "in CUDA target architectures.") endif() endif() @@ -574,7 +574,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${SRCS}" CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") - endif() + message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + message(STATUS "Not building moe_data as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") + else() + message(STATUS "Not building moe_data as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() # # Machete kernels diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 59c78950a1093..621179a701692 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -265,8 +265,8 @@ macro(set_gencode_flags_for_srcs) endmacro() # -# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form -# `.[letter]` compute the "loose intersection" with the +# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form +# `.[letter]` compute the "loose intersection" with the # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the @@ -278,7 +278,7 @@ endmacro() # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is # in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add -# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). +# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). # The result is stored in `OUT_CUDA_ARCHS`. # # Example: @@ -313,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) - if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) - list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") - set(_CUDA_ARCHS "9.0a") + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\a$") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + string(REPLACE "a" "" _base "${_arch}") + if ("${_base}" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") + list(APPEND _CUDA_ARCHS "${_arch}") + endif() endif() - endif() - - if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) - list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") - if ("10.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") - set(_CUDA_ARCHS "10.0a") - endif() - endif() + endforeach() list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) @@ -359,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) - + # reapply +PTX suffix to architectures that requested PTX set(_FINAL_ARCHS) foreach(_arch ${_CUDA_ARCHS}) @@ -370,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endif() endforeach() set(_CUDA_ARCHS ${_FINAL_ARCHS}) - + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) endfunction() From 65b1cbb1381bf2301a2441fd988bbe88b4b7865e Mon Sep 17 00:00:00 2001 From: redmoe-moutain Date: Mon, 30 Jun 2025 10:34:36 +0800 Subject: [PATCH 059/175] [Model] support dots1 (#18254) Signed-off-by: redmoe-moutain --- docs/models/supported_models.md | 1 + tests/models/registry.py | 2 + vllm/model_executor/models/dots1.py | 535 +++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 4 files changed, 539 insertions(+) create mode 100644 vllm/model_executor/models/dots1.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9782fd1781512..0248700292ae2 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -329,6 +329,7 @@ Specified using `--task generate`. | `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | | ✅︎ | ✅︎ | | `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | | ✅︎ | ✅︎ | | `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | | ✅︎ | ✅︎ | +| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst` etc. | | ✅︎ | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | | `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 72e361e2637fd..e56dd19bec670 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -268,6 +268,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), + "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst", + min_transformers_version="4.53"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py new file mode 100644 index 0000000000000..01a27d02a3044 --- /dev/null +++ b/vllm/model_executor/models/dots1.py @@ -0,0 +1,535 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The rednote-hilab team. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only dots1 model.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Dots1MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Dots1MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = (nn.Parameter( + torch.empty(config.n_routed_experts))) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = Dots1MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Dots1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + config: PretrainedConfig, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr(config, "head_dim", + hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + attention_bias = config.attention_bias + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward(self, positions: torch.Tensor, + hidden_states: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.q_norm(q.reshape(-1, self.num_heads, + self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, + self.head_dim)).reshape(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Dots1DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + + self.self_attn = Dots1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + config=config, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = Dots1MoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Dots1MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Dots1Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Dots1DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@support_torch_compile +class Dots1ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Dots1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ff605cae02ea4..d566146662b8d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = { "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), + "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), From 5a52f389dde444b9e122a7fa903393fd64857b86 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Sun, 29 Jun 2025 21:46:19 -0500 Subject: [PATCH 060/175] [BUGFIX][DEEPSEEK][MODEL_LOAD] fix w13, w2 weight not initialized assert (#20202) Signed-off-by: Chendi Xue --- vllm/model_executor/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f712b626c74c3..2fa1294b79b95 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -889,6 +889,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): expert_id=expert_id, return_success=True) if success: + name = name_mapped break else: if is_expert_weight: From 19108ef31191e217766ffe52e8e382ddbec20fdb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 29 Jun 2025 20:34:54 -0700 Subject: [PATCH 061/175] [Misc] Fix import (#20233) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3c9de57204051..290b9a44a80e2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -45,7 +45,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, - is_pin_memory_available) + is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) @@ -1308,7 +1308,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): tp_size = self.vllm_config.parallel_config.tensor_parallel_size if self.compilation_config.pass_config. \ enable_sequence_parallelism and tp_size > 1: - from vllm.utils import round_up num_input_tokens = round_up(num_scheduled_tokens, tp_size) else: num_input_tokens = num_scheduled_tokens From 022c58b80f8ae2e27dc526860b769455ef4c5498 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Mon, 30 Jun 2025 15:53:45 +0800 Subject: [PATCH 062/175] [doc] Add Slack and Forum to the top navigation (#20208) Signed-off-by: reidliu41 --- docs/mkdocs/javascript/slack_and_forum.js | 56 +++++++++++++++++++++++ docs/mkdocs/stylesheets/extra.css | 26 +++++++++++ mkdocs.yaml | 1 + 3 files changed, 83 insertions(+) create mode 100644 docs/mkdocs/javascript/slack_and_forum.js diff --git a/docs/mkdocs/javascript/slack_and_forum.js b/docs/mkdocs/javascript/slack_and_forum.js new file mode 100644 index 0000000000000..9a92332238363 --- /dev/null +++ b/docs/mkdocs/javascript/slack_and_forum.js @@ -0,0 +1,56 @@ +/** + * slack_and_forum.js + * + * Adds a custom Slack and Forum button to the MkDocs Material header. + * + */ + +window.addEventListener('DOMContentLoaded', () => { + const headerInner = document.querySelector('.md-header__inner'); + + if (headerInner) { + const slackButton = document.createElement('button'); + slackButton.className = 'slack-button'; + slackButton.title = 'Join us on Slack'; + slackButton.style.border = 'none'; + slackButton.style.background = 'transparent'; + slackButton.style.cursor = 'pointer'; + + slackButton.innerHTML = ` + Slack + `; + + slackButton.addEventListener('click', () => { + window.open('https://slack.vllm.ai', '_blank', 'noopener'); + }); + + const forumButton = document.createElement('button'); + forumButton.className = 'forum-button'; + forumButton.title = 'Join the Forum'; + forumButton.style.border = 'none'; + forumButton.style.background = 'transparent'; + forumButton.style.cursor = 'pointer'; + + forumButton.innerHTML = ` + + + + `; + + forumButton.addEventListener('click', () => { + window.open('https://discuss.vllm.ai/', '_blank', 'noopener'); + }); + + const githubSource = document.querySelector('.md-header__source'); + if (githubSource) { + githubSource.parentNode.insertBefore(slackButton, githubSource.nextSibling); + githubSource.parentNode.insertBefore(forumButton, slackButton.nextSibling); + } + } +}); diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css index 220657f83d5fc..248711f491b9d 100644 --- a/docs/mkdocs/stylesheets/extra.css +++ b/docs/mkdocs/stylesheets/extra.css @@ -108,3 +108,29 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . .md-content__button-wrapper a:hover { color: var(--md-accent-fg-color); } + +/* Slack and Forum css */ +.slack-button, +.forum-button { + display: inline-flex; + align-items: center; + justify-content: center; + margin-left: 0.4rem; + height: 24px; +} + +.slack-button img { + height: 18px; + filter: none !important; +} + +.slack-button:hover, +.forum-button:hover { + opacity: 0.7; +} + +.forum-button svg { + height: 28px; + opacity: 0.9; + transform: translateY(2px); +} diff --git a/mkdocs.yaml b/mkdocs.yaml index 9fb3fed8b8ac6..45b6ffadbeb71 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -127,6 +127,7 @@ extra_javascript: - mkdocs/javascript/run_llm_widget.js - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML - mkdocs/javascript/edit_and_feedback.js + - mkdocs/javascript/slack_and_forum.js # Makes the url format end in .html rather than act as a dir # So index.md generates as index.html and is available under URL /index.html From f5dfa0753163530117b4766c4e79e8cb2dc7066e Mon Sep 17 00:00:00 2001 From: noiji <52301388+noiji@users.noreply.github.com> Date: Mon, 30 Jun 2025 18:21:56 +0900 Subject: [PATCH 063/175] [Bugfix] Skip loading extra parameters for modelopt Qwen3 MoE model (#19598) Signed-off-by: noiji <> --- vllm/model_executor/models/qwen3_moe.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 417d7b22088bf..90a28192eccbc 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -386,6 +386,11 @@ class Qwen3MoeModel(nn.Module): ("gate_up_proj", "up_proj", 1), ] + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", + ".v_scale", "_v_scale", ".weight_scale", + "_weight_scale", ".input_scale", "_input_scale") + # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( @@ -410,10 +415,11 @@ class Qwen3MoeModel(nn.Module): if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: continue + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue @@ -433,9 +439,9 @@ class Qwen3MoeModel(nn.Module): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith( + ignore_suffixes) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -446,9 +452,9 @@ class Qwen3MoeModel(nn.Module): expert_id=expert_id) break else: - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith( + ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): From e936e401debe7fba64d6462666d7dc632bc76357 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 30 Jun 2025 18:16:16 +0800 Subject: [PATCH 064/175] [Bugfix] Fix processor initialization in transformers 4.53.0 (#20244) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/inputs/registry.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 66e78833f52af..fc6e190e54806 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -5,7 +5,9 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch +from packaging.version import Version from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from transformers import __version__ as TRANSFORMERS_VERSION from typing_extensions import TypeVar from vllm.jsontree import JSONTree, json_map_leaves @@ -128,9 +130,13 @@ class InputProcessingContext(InputContext): /, **kwargs: object, ) -> _P: + # Transformers 4.53.0 has issue with passing tokenizer to + # initialize processor. We disable it for this version. + # See: https://github.com/vllm-project/vllm/issues/20224 + if Version(TRANSFORMERS_VERSION) != Version("4.53.0"): + kwargs["tokenizer"] = self.tokenizer return super().get_hf_processor( typ, - tokenizer=self.tokenizer, **kwargs, ) From 8fe7fc863481a7d48c6f5bcc7bb40b2c7ebb5476 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 30 Jun 2025 18:22:09 +0800 Subject: [PATCH 065/175] [Quantization] Improve BitsAndBytesModelLoader (#20242) Signed-off-by: Jee Jee Li --- .../model_loader/bitsandbytes_loader.py | 123 ++++++++++-------- 1 file changed, 72 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 09857ef297f0a..0c46d170e88d5 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -20,8 +20,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) # yapf: enable from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable from vllm.model_executor.layers.linear import (LinearBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -39,6 +37,8 @@ from vllm.model_executor.utils import (get_packed_modules_mapping, set_weight_attrs) from vllm.platforms import current_platform +# yapf conflicts with isort for this block + logger = init_logger(__name__) @@ -54,11 +54,17 @@ class BitsAndBytesModelLoader(BaseModelLoader): self.unsharded_weights_modules: list[str] = [] # Save the module names that are sharded by column. self.column_sharded_weights_modules: list[str] = [] + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: dict[str, list[int]] = {} # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] # mapping weight names from transformers to vllm. self.weight_mapper: Callable = lambda name: name + self.pre_quant: bool = False + self.load_8bit: bool = False + self.is_pool_model: bool = False def _get_weight_files( self, @@ -134,13 +140,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): - def _maybe_pool_model(module_name:str): + + def _maybe_pool_model(module_name: str): # For pool model, we need to add the prefix `model.` # for the weight name if possible. if self.is_pool_model and self.target_modules[0]. \ startswith("model.") and not module_name.startswith( "model."): - return "model."+module_name + return "model." + module_name return module_name @@ -159,8 +166,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): # mapping weight names from transformers to vllm while preserving # original names. mapped_name = self.weight_mapper(org_name) - mapped_name=_maybe_pool_model(mapped_name) - + mapped_name = _maybe_pool_model(mapped_name) yield org_name, mapped_name, param @@ -168,8 +174,6 @@ class BitsAndBytesModelLoader(BaseModelLoader): self, model_name_or_path: str, revision: Optional[str], - pre_quant: bool, - load_8bit: bool, ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, @@ -192,8 +196,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): quant_state_dict: dict[str, Any] = {} - if pre_quant: - if load_8bit: + if self.pre_quant: + if self.load_8bit: return self._quantized_8bit_generator( hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict @@ -390,10 +394,13 @@ class BitsAndBytesModelLoader(BaseModelLoader): yield org_weight_name, processed_weight def _get_bnb_target_modules(self, model: nn.Module) -> None: - + """ + Identify and collect all modules that support BitsAndBytes + quantization. + """ for name, module in model.named_modules(): - if (isinstance(module, LinearBase) and - hasattr(module.quant_method, "quant_config")): + if (isinstance(module, LinearBase) + and hasattr(module.quant_method, "quant_config")): if modules_info := self.modules_mapping.get_sub_modules(name): # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info @@ -409,29 +416,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): ), "vllm currently does not support BNB quantization for" f" {type(model).__name__}" - def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: - if not hasattr(model, "load_weights"): - raise AttributeError( - "The required method 'load_weights' is not defined in class" - f" {type(model).__name__}.") - - if not hasattr(model, "packed_modules_mapping"): - raise AttributeError( - f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet. No 'packed_modules_mapping' found.") - self.is_pool_model=is_pooling_model(model) - - self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) - - # For some models like Molmo, we need to use hf_to_vllm_mapper - # to ensure correct loading of weights. - if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): - self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) - - # Modules whose weights might have fused on disk - # we need their output_sizes to make shard in flight correctly with TP - self.maybe_fused_weights_modules: dict[str, list[int]] = {} - self._get_bnb_target_modules(model) + def _classify_module_sharding(self, model: nn.Module): + """ + Categorize modules based on their weight sharding requirements + for tensor parallelism. + """ for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new @@ -449,19 +438,27 @@ class BitsAndBytesModelLoader(BaseModelLoader): elif isinstance(module, (RowParallelLinear, )): self.column_sharded_weights_modules.append(name) - self.model_type = type(model).__name__ + def _verify_model_compatibility(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Verify that the model is compatible with BitsAndBytes quantization. + """ + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") + if not hasattr(model, "packed_modules_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet. No 'packed_modules_mapping' found.") quant_config = getattr(model_config.hf_config, "quantization_config", None) - - pre_quant = False if quant_config is not None: quant_method = quant_config.get("quant_method") if quant_method == "bitsandbytes": - pre_quant = True + self.pre_quant = True else: raise ValueError( f"BitsAndBytes loader does not support {quant_method} " @@ -469,20 +466,43 @@ class BitsAndBytesModelLoader(BaseModelLoader): # The quant_states in pre_quantized models cannot work with a split # weight tensor. So TP does not work with pre_quantized bnb models. - if pre_quant and get_tensor_model_parallel_world_size() > 1: + if self.pre_quant and get_tensor_model_parallel_world_size() > 1: raise ValueError( "Prequant BitsAndBytes models with tensor parallelism is not " "supported. Please try with pipeline parallelism.") + if self.pre_quant: + self.load_8bit = quant_config.get("load_in_8bit", False) - load_8bit = False - if pre_quant: - load_8bit = quant_config.get("load_in_8bit", False) + def _initialize_loader_state(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Initialize the loader's internal state based on the model and + configuration. + """ + self.is_pool_model = is_pooling_model(model) + self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + + self._get_bnb_target_modules(model) + self._classify_module_sharding(model) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + + self._verify_model_compatibility(model, model_config) + self._initialize_loader_state(model, model_config) + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator(model_config.model, - model_config.revision, - pre_quant, load_8bit)) - + self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + )) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(qweight_iterator) # Some models may have weights loading tracker unimplemented. @@ -562,10 +582,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): offsets = torch.tensor(offsets).cpu() set_weight_attrs(param, {"bnb_shard_offsets": offsets}) - if load_8bit: + if self.load_8bit: set_weight_attrs( param, {"matmul_state": [None] * len(quant_states)}) torch.cuda.empty_cache() + def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) From 3ee56e26be4cfddc17f7d2e5f38f15ab74ede1c2 Mon Sep 17 00:00:00 2001 From: Michael Yao Date: Mon, 30 Jun 2025 19:20:51 +0800 Subject: [PATCH 066/175] [Docs] Fix 1-2-3 list in v1/prefix_caching.md (#20243) Signed-off-by: windsonsea --- docs/design/v1/prefix_caching.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/design/v1/prefix_caching.md b/docs/design/v1/prefix_caching.md index e87e4c6a48b73..2d3c8412894a6 100644 --- a/docs/design/v1/prefix_caching.md +++ b/docs/design/v1/prefix_caching.md @@ -117,8 +117,8 @@ There are two design points to highlight: 1. We allocate all KVCacheBlock when initializing the KV cache manager to be a block pool. This avoids Python object creation overheads and can easily track all blocks all the time. 2. We introduce doubly linked list pointers directly in the KVCacheBlock, so that we could construct a free queue directly. This gives us two benefits: - 1. We could have O(1) complexity moving elements in the middle to the tail. - 2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements. + 1. We could have O(1) complexity moving elements in the middle to the tail. + 2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements. As a result, we will have the following components when the KV cache manager is initialized: @@ -135,19 +135,19 @@ As a result, we will have the following components when the KV cache manager is **New request:** Workflow for the scheduler to schedule a new request with KV cache block allocation: -1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up Cache Blocks. +1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up cache blocks. 2. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps: - 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. - 2. “Touch” the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasn’t used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration. - 3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. - 4. If an allocated block is already full of tokens, we immediately add it to the Cache Block, so that the block can be reused by other requests in the same batch. + 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. + 2. “Touch” the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasn’t used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration. + 3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. + 4. If an allocated block is already full of tokens, we immediately add it to the cache block, so that the block can be reused by other requests in the same batch. **Running request:** Workflow for the scheduler to schedule a running request with KV cache block allocation: 1. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps: - 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. - 2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. - 3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the Cache Block to cache it. + 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate. + 2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on. + 3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the cache block to cache it. **Duplicated blocks** Assuming block size is 4 and you send a request (Request 1\) with prompt ABCDEF and decoding length 3: @@ -199,7 +199,7 @@ When a request is finished, we free all its blocks if no other requests are usin When the head block (least recently used block) of the free queue is cached, we have to evict the block to prevent it from being used by other requests. Specifically, eviction involves the following steps: 1. Pop the block from the head of the free queue. This is the LRU block to be evicted. -2. Remove the block ID from the Cache Block. +2. Remove the block ID from the cache block. 3. Remove the block hash. ## Example From 1c50e100a9c5dc439aceb9c4437b262d564baa53 Mon Sep 17 00:00:00 2001 From: li haoyang Date: Mon, 30 Jun 2025 21:24:50 +0800 Subject: [PATCH 067/175] [Bugfix] fix quark ptpc (#20251) Signed-off-by: Haoyang Li Co-authored-by: Haoyang Li <307790822@qq.com> --- .../layers/quantization/quark/quark.py | 6 +--- .../quark/schemes/quark_w8a8_fp8.py | 33 ++++++++++++------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 6ae5f5c9ad46b..05dff4bae3957 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig): is_fp8_w8a8_supported = self._check_scheme_supported( QuarkW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: - weight_qscheme = cast(str, weight_config.get("qscheme")) - input_static = (input_config is not None and - not cast(bool, input_config.get("is_dynamic"))) - return QuarkW8A8Fp8(qscheme=weight_qscheme, - is_static_input_scheme=input_static) + return QuarkW8A8Fp8(weight_config, input_config) elif self._is_static_tensor_w8a8(weight_config, input_config): weight_qscheme = cast(str, weight_config.get("qscheme")) return QuarkW8A8Int8(qscheme=weight_qscheme, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 47e0a492b23b9..c7bc98184d0eb 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from typing import Any, Callable, Optional, cast import torch from torch.nn import Parameter @@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"] class QuarkW8A8Fp8(QuarkScheme): - def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): - self.qscheme = qscheme - self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) + def __init__(self, weight_config: dict[str, Any], + input_config: Optional[dict[str, Any]]): + self.weight_qscheme = cast(str, weight_config.get("qscheme")) + self.is_static_input_scheme: bool = False + self.input_qscheme: Optional[str] = None + if input_config is not None: + self.is_static_input_scheme = not cast( + bool, input_config.get("is_dynamic")) + self.input_qscheme = cast(str, input_config.get("qscheme")) + self.use_per_token_if_dynamic = (not self.is_static_input_scheme \ + and self.input_qscheme == "per_channel") + self.fp8_linear = Fp8LinearOp( + use_per_token_if_dynamic=self.use_per_token_if_dynamic) self.out_dtype = torch.get_default_dtype() @classmethod @@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme): # If per tensor, when we have a fused module (e.g. QKV) with per # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor - if self.qscheme == "per_tensor": + if self.weight_qscheme == "per_tensor": if current_platform.is_rocm(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( @@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme): layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # If channelwise, scales are already lined up, so just transpose. - elif self.qscheme == "per_channel": + elif self.weight_qscheme == "per_channel": weight = layer.weight if current_platform.is_fp8_fnuz(): @@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme): requires_grad=False) else: weight_scale = layer.weight_scale.data - + if self.use_per_token_if_dynamic: + weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: - raise ValueError(f"Unknown quantization scheme {self.qscheme}") + raise ValueError( + f"Unknown quantization scheme {self.weight_qscheme}") # INPUT SCALE if self.is_static_input_scheme: @@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme): # WEIGHT SCALE # TODO: update create_xxx_parameter functions to return # the newly added parameters - if self.qscheme == "per_channel": + if self.weight_qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, weight_loader=weight_loader) else: - assert self.qscheme == "per_tensor" + assert self.weight_qscheme == "per_tensor" weight_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) From 2062c0723d38a8f4a8a7565b61a99e8c81b5cacd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 08:13:50 -0700 Subject: [PATCH 068/175] [Spec Decode] Refactor spec decoding into a separate function (#20238) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 93 +++++++++++++++++++----------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 290b9a44a80e2..e063e44dabfa1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1388,6 +1388,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states, aux_hidden_states = model_output else: hidden_states = model_output + aux_hidden_states = None + # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks # TODO: Support overlapping mirco-batches @@ -1510,25 +1512,67 @@ class GPUModelRunner(LoRAModelRunnerMixin): if not self.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None - elif self.speculative_config.method == "ngram": + else: + spec_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + attn_metadata, + ) + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + self.eplb_step() + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + finished_sending=finished_sending, + finished_recving=finished_recving, + num_nans_in_logits=num_nans_in_logits, + ) + + def propose_draft_token_ids( + self, + scheduler_output: "SchedulerOutput", + sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], + attn_metadata: dict[str, Any], + ) -> list[list[int]]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self.generate_draft_token_ids( - valid_sampled_token_ids, sampling_metadata) + spec_token_ids = self.propose_ngram_draft_token_ids( + sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) - if max_gen_len == 1: + if sample_hidden_states.shape[0] == len(sampled_token_ids): + # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: indices = [] offset = 0 for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, - valid_sampled_token_ids): + sampled_token_ids): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 - - indices = torch.tensor(indices, - device=sample_hidden_states.device) + indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] spec_token_ids = self.drafter.propose( @@ -1539,7 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): + for i, token_ids in enumerate(sampled_token_ids): if token_ids: # Common case. next_token_id = token_ids[-1] @@ -1569,7 +1613,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1582,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] num_rejected_tokens_tensor = async_tensor_h2d( @@ -1597,7 +1642,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens, ) target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1616,25 +1662,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - - # Clear KVConnector state after all KVs are generated. - if has_kv_transfer_group(): - get_kv_transfer_group().clear_connector_metadata() - - self.eplb_step() - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, - num_nans_in_logits=num_nans_in_logits, - ) + return spec_token_ids def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: @@ -1682,10 +1710,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None - def generate_draft_token_ids( + def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = [] From 2965c99c86b460ee819e4805764d769c7b7d3d8e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 08:28:13 -0700 Subject: [PATCH 069/175] [Spec Decode] Clean up spec decode example (#20240) Signed-off-by: Woosuk Kwon --- examples/offline_inference/eagle.py | 144 ---------------------- examples/offline_inference/spec_decode.py | 40 +++--- 2 files changed, 21 insertions(+), 163 deletions(-) delete mode 100644 examples/offline_inference/eagle.py diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py deleted file mode 100644 index f4193fdb8bd38..0000000000000 --- a/examples/offline_inference/eagle.py +++ /dev/null @@ -1,144 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import json -import os - -from transformers import AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.v1.metrics.reader import Counter, Vector - - -def load_prompts(dataset_path, num_prompts): - if os.path.exists(dataset_path): - prompts = [] - try: - with open(dataset_path) as f: - for line in f: - data = json.loads(line) - prompts.append(data["turns"][0]) - except Exception as e: - print(f"Error reading dataset: {e}") - return [] - else: - prompts = ["The future of AI is", "The president of the United States is"] - - return prompts[:num_prompts] - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--dataset", - type=str, - default="./examples/data/gsm8k.jsonl", - help="downloaded from the eagle repo " - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", - ) - parser.add_argument( - "--method", type=str, default="eagle", choices=["eagle", "eagle3"] - ) - parser.add_argument("--max_num_seqs", type=int, default=8) - parser.add_argument("--num_prompts", type=int, default=80) - parser.add_argument("--num_spec_tokens", type=int, default=2) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--draft_tp", type=int, default=1) - parser.add_argument("--enforce_eager", action="store_true") - parser.add_argument("--enable_chunked_prefill", action="store_true") - parser.add_argument("--max_num_batched_tokens", type=int, default=2048) - parser.add_argument("--temp", type=float, default=0) - return parser.parse_args() - - -def main(): - args = parse_args() - - model_dir = "meta-llama/Llama-3.1-8B-Instruct" - - if args.method == "eagle": - eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - elif args.method == "eagle3": - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - else: - raise ValueError(f"unknown method: {args.method}") - - max_model_len = 2048 - - tokenizer = AutoTokenizer.from_pretrained(model_dir) - - prompts = load_prompts(args.dataset, args.num_prompts) - - prompt_ids = [ - tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], add_generation_prompt=True - ) - for prompt in prompts - ] - - llm = LLM( - model=model_dir, - trust_remote_code=True, - tensor_parallel_size=args.tp, - enable_chunked_prefill=args.enable_chunked_prefill, - max_num_batched_tokens=args.max_num_batched_tokens, - enforce_eager=args.enforce_eager, - max_model_len=max_model_len, - max_num_seqs=args.max_num_seqs, - gpu_memory_utilization=0.8, - speculative_config={ - "method": args.method, - "model": eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": max_model_len, - }, - disable_log_stats=False, - ) - - sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) - - outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) - - # print the generated text - for output in outputs: - print("-" * 50) - print(f"prompt: {output.prompt}") - print(f"generated text: {output.outputs[0].text}") - print("-" * 50) - - try: - metrics = llm.get_metrics() - except AssertionError: - print("Metrics are not supported in the V0 engine.") - return - - num_drafts = num_accepted = 0 - acceptance_counts = [0] * args.num_spec_tokens - for metric in metrics: - if metric.name == "vllm:spec_decode_num_drafts": - assert isinstance(metric, Counter) - num_drafts += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens": - assert isinstance(metric, Counter) - num_accepted += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": - assert isinstance(metric, Vector) - for pos in range(len(metric.values)): - acceptance_counts[pos] += metric.values[pos] - - print("-" * 50) - print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") - print("-" * 50) - - # print acceptance at each token position - for i in range(len(acceptance_counts)): - print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") - - -if __name__ == "__main__": - print( - "[WARNING] Use examples/offline_inference/spec_decode.py" - " instead of this script." - ) - main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 6fa68d2ecee1d..90d103e5cb05d 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -16,24 +16,17 @@ def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) parser.add_argument( - "--dataset", + "--method", type=str, - default="./examples/data/gsm8k.jsonl", - help="downloaded from the eagle repo " - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], ) - parser.add_argument( - "--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"] - ) - parser.add_argument("--max-num-seqs", type=int, default=8) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-min", type=int, default=2) parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--draft-tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") - parser.add_argument("--max-num-batched-tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) @@ -41,7 +34,6 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) - parser.add_argument("--max-model-len", type=int, default=2048) return parser.parse_args() @@ -71,8 +63,6 @@ def main(): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, - "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": args.max_model_len, } elif args.method == "ngram": speculative_config = { @@ -80,7 +70,6 @@ def main(): "num_speculative_tokens": args.num_spec_tokens, "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, - "max_model_len": args.max_model_len, } else: raise ValueError(f"unknown method: {args.method}") @@ -92,7 +81,6 @@ def main(): enable_chunked_prefill=args.enable_chunked_prefill, max_num_batched_tokens=args.max_num_batched_tokens, enforce_eager=args.enforce_eager, - max_model_len=args.max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config=speculative_config, @@ -116,27 +104,41 @@ def main(): print("Metrics are not supported in the V0 engine.") return - num_drafts = num_accepted = 0 + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 acceptance_counts = [0] * args.num_spec_tokens for metric in metrics: if metric.name == "vllm:spec_decode_num_drafts": assert isinstance(metric, Counter) num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value elif metric.name == "vllm:spec_decode_num_accepted_tokens": assert isinstance(metric, Counter) - num_accepted += metric.value + num_accepted_tokens += metric.value elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": assert isinstance(metric, Vector) for pos in range(len(metric.values)): acceptance_counts[pos] += metric.values[pos] print("-" * 50) - print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") + print(f"total_num_output_tokens: {total_num_output_tokens}") + print(f"num_drafts: {num_drafts}") + print(f"num_draft_tokens: {num_draft_tokens}") + print(f"num_accepted_tokens: {num_accepted_tokens}") + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") print("-" * 50) # print acceptance at each token position for i in range(len(acceptance_counts)): - print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") if __name__ == "__main__": From 2863befce359ee1a82afe02d1953252866aa3e96 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 09:07:50 -0700 Subject: [PATCH 070/175] [Optimization] Use Shared `CachedRequestData` Instance Across All Requests (#20232) Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 130 +++++++++--------- .../unit/test_remote_decode_lifecycle.py | 4 +- .../unit/test_remote_prefill_lifecycle.py | 12 +- tests/v1/kv_connector/unit/utils.py | 1 - tests/v1/tpu/worker/test_tpu_model_runner.py | 22 +-- tests/v1/worker/test_gpu_model_runner.py | 22 +-- .../kv_connector/v1/p2p/p2p_nccl_connector.py | 43 +++--- .../v1/shared_storage_connector.py | 19 ++- vllm/v1/core/sched/output.py | 34 +++-- vllm/v1/core/sched/scheduler.py | 106 ++++++-------- vllm/v1/worker/gpu_model_runner.py | 34 ++--- vllm/v1/worker/tpu_model_runner.py | 24 ++-- 12 files changed, 220 insertions(+), 231 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8994816a3017c..652a556659fe3 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): @@ -225,7 +225,7 @@ def test_schedule_multimodal_requests(): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 for req_id, num_tokens in output.num_scheduled_tokens.items(): assert num_tokens == len(requests[int(req_id)].prompt_token_ids) @@ -259,7 +259,7 @@ def test_schedule_partial_requests(): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 assert scheduler.max_num_encoder_input_tokens == 1024 @@ -295,7 +295,7 @@ def test_schedule_partial_requests(): output = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output.scheduled_new_reqs) == 0 - assert len(output.scheduled_cached_reqs) == 2 + assert output.scheduled_cached_reqs.num_reqs == 2 assert len(output.finished_req_ids) == 0 assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[1].request_id] == 700 @@ -319,7 +319,7 @@ def test_no_mm_input_chunking(): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # We want to only see the 400 text tokens at the start scheduled assert output.num_scheduled_tokens[requests[0].request_id] == 400 @@ -342,7 +342,7 @@ def test_no_mm_input_chunking(): output = scheduler.schedule() assert len(scheduler.running) == 1 assert len(output.scheduled_new_reqs) == 0 - assert len(output.scheduled_cached_reqs) == 1 + assert output.scheduled_cached_reqs.num_reqs == 1 assert len(output.finished_req_ids) == 0 assert output.num_scheduled_tokens[requests[0].request_id] == 800 @@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 - assert len(output.scheduled_cached_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # The first request is scheduled partially - 400. @@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output1 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output1.scheduled_new_reqs) == 0 - assert len(output1.scheduled_cached_reqs) == 3 + assert output1.scheduled_cached_reqs.num_reqs == 3 assert len(output1.finished_req_ids) == 0 assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400 @@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output2 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output2.scheduled_new_reqs) == 0 - assert len(output2.scheduled_cached_reqs) == 3 + assert output2.scheduled_cached_reqs.num_reqs == 3 assert len(output2.finished_req_ids) == 0 assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1 @@ -449,23 +449,24 @@ def test_stop_via_update_from_output(): scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -501,23 +502,25 @@ def test_stop_via_update_from_output(): scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=5, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 42], - requests[1].request_id: [13] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -551,23 +554,25 @@ def test_stop_via_update_from_output(): scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, - total_num_scheduled_tokens=4, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 11], - requests[1].request_id: [] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -603,7 +608,7 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={requests[0].request_id: 3}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, @@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 0 assert len(scheduler.finished_req_ids) == 0 - assert len(scheduler._cached_reqs_data) == 0 # EncoderCacheManager. assert len(scheduler.encoder_cache_manager.freed) == 0 diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index ff36a281c413d..12a71d97e8d29 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -66,7 +66,7 @@ def test_basic_lifecycle(): assert len(scheduler_output.finished_req_ids) == 1 assert request_id in scheduler_output.finished_req_ids assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 0 assert len(scheduler.finished_req_ids) == 0 # (2b): execute_model() @@ -81,7 +81,7 @@ def test_basic_lifecycle(): assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 0 assert len(scheduler.finished_req_ids) == 0 # (3b): execute_model() diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index a1156306dc4bf..f89970bf2c807 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -36,7 +36,7 @@ def test_basic_lifecycle(): # Nothing running and empty scheduler output. assert len(scheduler.running) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 0 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 0 assert len(scheduler_output.num_scheduled_tokens) == 0 assert scheduler_output.total_num_scheduled_tokens == 0 @@ -158,7 +158,7 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 - assert len(scheduler_output.scheduled_cached_reqs) == 1 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 1 model_runner_output = create_model_runner_output( [request_local_a, request_local_b]) @@ -169,7 +169,7 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( reqs=[request_local_a, request_local_b]) @@ -177,14 +177,14 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 # STEP 4: KVs arrive. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( [request_local_a, request_local_b], @@ -196,7 +196,7 @@ def test_interleaved_lifecycle(): assert len(scheduler.running) == 3 assert len(scheduler.waiting) == 0 assert len(scheduler_output.scheduled_new_reqs) == 1 - assert len(scheduler_output.scheduled_cached_reqs) == 2 + assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( [request_local_a, request_local_b, request_remote]) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 61f59f35f75b9..983d900606fc9 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.running) == 0 assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_recving_kv_req_ids) == 0 - assert len(scheduler._cached_reqs_data) == 0 # EncoderCacheManager. assert len(scheduler.encoder_cache_manager.freed) == 0 diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 25839d0897a4c..40db0b2afe0d9 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: return SchedulerOutput( scheduled_new_reqs=new_reqs, - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, @@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner): # finish req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner): # unschedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner): # resume req cached_req_data = CachedRequestData( - req_id=req_id, - resumed_from_preemption=False, - new_token_ids=[], - new_block_ids=([], ), - num_computed_tokens=0, + req_ids=[req_id], + resumed_from_preemption=[False], + new_token_ids=[[]], + new_block_ids=[([], )], + num_computed_tokens=[0], ) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[cached_req_data], + scheduled_cached_reqs=cached_req_data, num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner): # schedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner): # unschedule req_1 scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_ids[0]: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 583a88d8e6ec6..c739b23b90dc8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: return SchedulerOutput( scheduled_new_reqs=new_reqs, - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, @@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner): # finish req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner): # unschedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, @@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner): # resume req cached_req_data = CachedRequestData( - req_id=req_id, - resumed_from_preemption=False, - new_token_ids=[], - new_block_ids=([], ), - num_computed_tokens=0, + req_ids=[req_id], + resumed_from_preemption=[False], + new_token_ids=[[]], + new_block_ids=([[0]], ), + num_computed_tokens=[0], ) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[cached_req_data], + scheduled_cached_reqs=cached_req_data, num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner): # schedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, @@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner): # unschedule req_1 scheduler_output = SchedulerOutput( scheduled_new_reqs=[], - scheduled_cached_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), num_scheduled_tokens={req_ids[0]: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index a47deaf91272e..2f870971ded70 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -371,45 +371,48 @@ class P2pNcclConnector(KVConnectorBase_V1): block_size=self._block_size) self._requests_need_load.pop(new_req.req_id) - for cached_req in scheduler_output.scheduled_cached_reqs: + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_computed_tokens = cached_reqs.num_computed_tokens[i] + new_block_ids = cached_reqs.new_block_ids[i] + resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + if self.is_producer: num_scheduled_tokens = ( - scheduler_output.num_scheduled_tokens)[cached_req.req_id] - num_tokens = (num_scheduled_tokens + - cached_req.num_computed_tokens) - assert cached_req.req_id in self.chunked_prefill - block_ids = cached_req.new_block_ids[0] - if not cached_req.resumed_from_preemption: - block_ids = (self.chunked_prefill[cached_req.req_id][0] + - block_ids) - prompt_token_ids = self.chunked_prefill[cached_req.req_id][1] + scheduler_output.num_scheduled_tokens)[req_id] + num_tokens = (num_scheduled_tokens + num_computed_tokens) + assert req_id in self.chunked_prefill + block_ids = new_block_ids[0] + if not resumed_from_preemption: + block_ids = (self.chunked_prefill[req_id][0] + block_ids) + prompt_token_ids = self.chunked_prefill[req_id][1] # the request's prompt is chunked prefill again if num_tokens < len(prompt_token_ids): - self.chunked_prefill[cached_req.req_id] = ( - block_ids, prompt_token_ids) + self.chunked_prefill[req_id] = (block_ids, + prompt_token_ids) continue # the request's prompt is all prefilled finally - meta.add_request(request_id=cached_req.req_id, + meta.add_request(request_id=req_id, token_ids=prompt_token_ids, block_ids=block_ids, block_size=self._block_size) - self.chunked_prefill.pop(cached_req.req_id, None) + self.chunked_prefill.pop(req_id, None) continue # NOTE(rob): here we rely on the resumed requests being # the first N requests in the list scheduled_cache_reqs. - if not cached_req.resumed_from_preemption: + if not resumed_from_preemption: break - if cached_req.req_id in self._requests_need_load: - request, _ = self._requests_need_load.pop(cached_req.req_id) - total_tokens = cached_req.num_computed_tokens + 1 + if req_id in self._requests_need_load: + request, _ = self._requests_need_load.pop(req_id) + total_tokens = num_computed_tokens + 1 token_ids = request.all_token_ids[:total_tokens] # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids[0] + block_ids = new_block_ids[0] - meta.add_request(request_id=cached_req.req_id, + meta.add_request(request_id=req_id, token_ids=token_ids, block_ids=block_ids, block_size=self._block_size) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f86b92692a0e5..0bceee19f873d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -304,23 +304,28 @@ class SharedStorageConnector(KVConnectorBase_V1): block_size=self._block_size, is_store=True) - for cached_req in scheduler_output.scheduled_cached_reqs: + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_computed_tokens = cached_reqs.num_computed_tokens[i] + new_token_ids = cached_reqs.new_token_ids[i] + new_block_ids = cached_reqs.new_block_ids[i] + resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + # NOTE(rob): here we rely on the resumed requests being # the first N requests in the list scheduled_cache_reqs. - if not cached_req.resumed_from_preemption: + if not resumed_from_preemption: break - if cached_req.req_id in self._requests_need_load: + if req_id in self._requests_need_load: # NOTE(rob): cached_req_data does not have the full # list of token ids (only new tokens). So we look it # up in the actual request object. - request = self._requests_need_load[cached_req.req_id] - total_tokens = (len(cached_req.new_token_ids) + - cached_req.num_computed_tokens) + request = self._requests_need_load[req_id] + total_tokens = (len(new_token_ids) + num_computed_tokens) token_ids = request.all_token_ids[:total_tokens] # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids[0] + block_ids = new_block_ids[0] meta.add_request(token_ids=token_ids, block_ids=block_ids, diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6f31031a1086e..efc5b3012ec2f 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -83,29 +83,27 @@ class NewRequestData: @dataclass class CachedRequestData: - req_id: str + req_ids: list[str] # If resumed_from_preemption is False, new_block_ids will be appended to # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. - resumed_from_preemption: bool - new_token_ids: list[int] - new_block_ids: tuple[list[int], ...] - num_computed_tokens: int + resumed_from_preemption: list[bool] + new_token_ids: list[list[int]] + new_block_ids: list[tuple[list[int], ...]] + num_computed_tokens: list[int] + + @property + def num_reqs(self) -> int: + return len(self.req_ids) @classmethod - def from_request( - cls, - request: Request, - resumed_from_preemption: bool, - new_token_ids: list[int], - new_block_ids: tuple[list[int], ...], - ) -> CachedRequestData: + def make_empty(cls) -> CachedRequestData: return cls( - req_id=request.request_id, - resumed_from_preemption=resumed_from_preemption, - new_token_ids=new_token_ids, - new_block_ids=new_block_ids, - num_computed_tokens=request.num_computed_tokens, + req_ids=[], + resumed_from_preemption=[], + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], ) @@ -119,7 +117,7 @@ class SchedulerOutput: # list of the requests that have been scheduled before. # Since the request's data is already cached in the worker processes, # we only send the diff to minimize the communication cost. - scheduled_cached_reqs: list[CachedRequestData] + scheduled_cached_reqs: CachedRequestData # req_id -> num_scheduled_tokens # Number of tokens scheduled for each request. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 00b0844a5660b..20a40d74f3118 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -3,8 +3,9 @@ from __future__ import annotations +import itertools import time -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Iterable from typing import Any, Optional, Union @@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface): # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating - # them at each scheduling step. - # Request id -> deque of CachedRequestData - self._cached_reqs_data: dict[ - str, deque[CachedRequestData]] = defaultdict(deque) - # Encoder-related. # Calculate encoder cache size if applicable # NOTE: For now we use the same budget for both compute and space. @@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface): req_to_new_block_ids[req.request_id]) for req in scheduled_new_reqs ] - resumed_reqs_data = [ - self._make_cached_request_data( - req, - num_scheduled_tokens[req.request_id], - len(scheduled_spec_decode_tokens.get(req.request_id, ())), - req_to_new_block_ids[req.request_id], - resumed_from_preemption=True, - ) for req in scheduled_resumed_reqs - ] - running_reqs_data = [ - self._make_cached_request_data( - req, - num_scheduled_tokens[req.request_id], - len(scheduled_spec_decode_tokens.get(req.request_id, ())), - req_to_new_block_ids[req.request_id], - resumed_from_preemption=False, - ) for req in scheduled_running_reqs - ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_block_ids, + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, - scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + scheduled_cached_reqs=cached_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, @@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface): def _make_cached_request_data( self, - request: Request, - num_scheduled_tokens: int, - num_scheduled_spec_tokens: int, - new_block_ids: tuple[list[int], ...], - resumed_from_preemption: bool, + running_reqs: list[Request], + resumed_reqs: list[Request], + num_scheduled_tokens: dict[str, int], + spec_decode_tokens: dict[str, list[int]], + req_to_new_block_ids: dict[str, tuple[list[int], ...]], ) -> CachedRequestData: - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating - # them at each scheduling step. - num_computed_tokens = request.num_computed_tokens - num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens - new_token_ids = request.all_token_ids[ - num_computed_tokens:num_computed_tokens + num_regular_tokens] + req_ids: list[str] = [] + new_token_ids: list[list[int]] = [] + new_block_ids: list[tuple[list[int], ...]] = [] + num_computed_tokens: list[int] = [] - req_data_queue = self._cached_reqs_data.get(request.request_id) - if req_data_queue: - req_data = req_data_queue.popleft() - req_data.resumed_from_preemption = resumed_from_preemption - req_data.new_token_ids = new_token_ids - req_data.new_block_ids = new_block_ids - req_data.num_computed_tokens = num_computed_tokens - else: - # No cached request data, or all cached request data has been - # used by the scheduled requests. - req_data = CachedRequestData.from_request(request, - resumed_from_preemption, - new_token_ids, - new_block_ids) - return req_data + for req in itertools.chain(running_reqs, resumed_reqs): + req_id = req.request_id + req_ids.append(req_id) + num_tokens = (num_scheduled_tokens[req_id] - + len(spec_decode_tokens.get(req_id, ()))) + token_ids = req.all_token_ids[req.num_computed_tokens:req. + num_computed_tokens + num_tokens] + new_token_ids.append(token_ids) + new_block_ids.append(req_to_new_block_ids[req_id]) + num_computed_tokens.append(req.num_computed_tokens) + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + + return CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) def _try_schedule_encoder_inputs( self, @@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface): if not stopped: new_running.append(request) + self.running = new_running # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output) - # Return the cached request data to the queue so they can be reused. - for req_data in scheduler_output.scheduled_cached_reqs: - # NOTE(rob): since we free stopped reqs above, adding stopped reqs - # to _cached_reqs_data will cause a memory leak. - if req_data.req_id not in self.finished_req_ids: - self._cached_reqs_data[req_data.req_id].append(req_data) - - self.running = new_running - # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. engine_core_outputs = { @@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface): self._free_request(request) def _free_request(self, request: Request) -> Optional[dict[str, Any]]: - assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) request_id = request.request_id - self._cached_reqs_data.pop(request_id, None) self.finished_req_ids.add(request_id) if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) @@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface): def _free_blocks(self, request: Request): assert request.is_finished() - assert request.request_id not in self._cached_reqs_data self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e063e44dabfa1..29d39de212f88 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -470,34 +470,36 @@ class GPUModelRunner(LoRAModelRunnerMixin): req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_token_ids = req_data.new_token_ids[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] # Update the cached states. - num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(req_data.new_token_ids) - + num_new_tokens = (num_computed_tokens + len(new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. - req_state.output_token_ids.append(req_data.new_token_ids[-1]) + req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( - req_data.new_token_ids[-num_new_tokens:]) + new_token_ids[-num_new_tokens:]) # Update the block IDs. - if not req_data.resumed_from_preemption: + if not resumed_from_preemption: # Append the new blocks to the existing block IDs. - for block_ids, new_block_ids in zip(req_state.block_ids, - req_data.new_block_ids): - block_ids.extend(new_block_ids) + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. - req_state.block_ids = req_data.new_block_ids + req_state.block_ids = new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: @@ -510,14 +512,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(req_data.new_block_ids, - req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(req_data.new_token_ids) + end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = req_data.new_token_ids + req_index, start_token_index:end_token_index] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index bc334419c4cec..0cc218bdb646f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -418,21 +418,24 @@ class TPUModelRunner(LoRAModelRunnerMixin): req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. - for req_data in scheduler_output.scheduled_cached_reqs: - req_id = req_data.req_id + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] # Update the cached states. - req_state.num_computed_tokens = req_data.num_computed_tokens - if not req_data.resumed_from_preemption: + req_state.num_computed_tokens = num_computed_tokens + if not resumed_from_preemption: # Append the new blocks to the existing block IDs. - for block_ids, new_block_ids in zip(req_state.block_ids, - req_data.new_block_ids): - block_ids.extend(new_block_ids) + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. - req_state.block_ids = req_data.new_block_ids + req_state.block_ids = new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: @@ -444,9 +447,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - self.input_batch.block_table.append_row(req_data.new_block_ids, - req_index) + num_computed_tokens) + self.input_batch.block_table.append_row(new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. From 551ef1631a98d60fe9e82f0282e49c4a59a7887b Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 30 Jun 2025 12:26:42 -0400 Subject: [PATCH 071/175] [Unit Test] Add unit test for deep gemm (#20090) Signed-off-by: yewentao256 Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/kernels/moe/test_deepgemm.py | 225 +++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 tests/kernels/moe/test_deepgemm.py diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py new file mode 100644 index 0000000000000..5d2690904cea2 --- /dev/null +++ b/tests/kernels/moe/test_deepgemm.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Unit-test DeepGEMM FP8 kernels (no DeepEP). +Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts. +""" + +import importlib +import math + +import pytest +import torch + +# vLLM fused-expert reference (Triton fallback + DeepGEMM option) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import cdiv + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + +if has_deep_gemm: + import deep_gemm + BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout() + BLOCK_SIZE = [BLOCK_M, BLOCK_M] + +requires_deep_gemm = pytest.mark.skipif( + not has_deep_gemm, + reason="Requires deep_gemm kernels", +) + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def make_block_quant_fp8_weights( + e: int, + n: int, + k: int, + block_size: list[int], +): + """ + Generate (w1, w2) expert weights and their per-block scale tensors + in FP8 block-quantized format. + + w1 shape: (E, 2N, K) + w2 shape: (E, K, N) + """ + dtype = torch.bfloat16 + fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo( + torch.float8_e4m3fn).min + + # bf16 reference weights + w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10 + w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10 + w1_bf16.clamp_(fp8_min, fp8_max) + w2_bf16.clamp_(fp8_min, fp8_max) + + block_n, block_k = block_size + n_tiles_w1 = math.ceil((2 * n) / block_n) + k_tiles_w1 = math.ceil(k / block_k) + n_tiles_w2 = math.ceil(k / block_n) + k_tiles_w2 = math.ceil(n / block_k) + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + w1_s = torch.empty(e, + n_tiles_w1, + k_tiles_w1, + device="cuda", + dtype=torch.float32) + w2_s = torch.empty(e, + n_tiles_w2, + k_tiles_w2, + device="cuda", + dtype=torch.float32) + + for i in range(e): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + return w1, w2, w1_s, w2_s + + +def run_single_case(m, n, k, topk, num_experts, block_size): + """ + Run one (M,N,K) configuration on a single GPU and assert DeepGEMM == + Triton baseline within tolerance. + """ + tokens_bf16 = torch.randn( + m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) + _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) + + # expert weight tensors + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, + block_size) + + router_logits = torch.randn(m, + num_experts, + device="cuda", + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + + # triton referrence + out_triton = fused_experts( + hidden_states=tokens_bf16, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + allow_deep_gemm=False, + ) + + # DeepGemm + out_deepgemm = fused_experts( + hidden_states=tokens_bf16, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + allow_deep_gemm=True, + ) + + base = out_triton.abs().mean() + atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3 + rtol = 0.05 + # ----- Compare ----- + torch.testing.assert_close( + out_deepgemm.to(torch.float32), + out_triton.to(torch.float32), + rtol=rtol, + atol=float(atol), + ) + + +# Note: W1 has shape (E, 2N, K), so N = 512 +# can trigger the deepgemm path. +MNKs = [ + (1024, 512, 128), + (1024, 512, 512), + (2048, 512, 512), + (512, 1024, 1024), + (512, 2048, 2048), + (4096, 4096, 1024), +] + +TOPKS = [2, 6] +NUM_EXPERTS = [32] + + +@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("topk", TOPKS) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@requires_deep_gemm +def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_DEEP_GEMM", "1") + + _fused_moe_mod = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe") + + call_counter = {"cnt": 0} + + orig_fn = _fused_moe_mod.deep_gemm_moe_fp8 + + def _spy_deep_gemm_moe_fp8(*args, **kwargs): + call_counter["cnt"] += 1 + return orig_fn(*args, **kwargs) + + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", + _spy_deep_gemm_moe_fp8) + + m, n, k = mnk + + if topk > num_experts: + pytest.skip(f"topk={topk} > num_experts={num_experts}") + + run_single_case( + m=m, + n=n, + k=k, + topk=topk, + num_experts=num_experts, + block_size=BLOCK_SIZE, + ) + + # ensure that the DeepGEMM path was indeed taken. + assert call_counter["cnt"] == 1, \ + f"DeepGEMM path was not executed during the test. " \ + f"Call counter: {call_counter['cnt']}" From d8cf819a9a337a578b7dfc7642617921cc468c17 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 30 Jun 2025 13:26:49 -0400 Subject: [PATCH 072/175] [Core] [Bugfix] [Multimodal] Fix multimodal profiling and generation for SFT/PTQed models (#20058) Signed-off-by: Kyle Sayers --- docs/contributing/model/multimodal.md | 7 +++ tests/multimodal/test_processing.py | 1 + vllm/entrypoints/llm.py | 8 +++ vllm/entrypoints/utils.py | 4 ++ vllm/inputs/preprocess.py | 29 +++++++--- vllm/model_executor/models/aya_vision.py | 2 + vllm/model_executor/models/blip2.py | 2 + vllm/model_executor/models/chameleon.py | 2 + vllm/model_executor/models/deepseek_vl2.py | 6 ++- vllm/model_executor/models/florence2.py | 4 +- vllm/model_executor/models/fuyu.py | 2 + vllm/model_executor/models/gemma3_mm.py | 2 + vllm/model_executor/models/glm4v.py | 1 + vllm/model_executor/models/granite_speech.py | 2 + vllm/model_executor/models/h2ovl.py | 3 ++ vllm/model_executor/models/idefics3.py | 2 + vllm/model_executor/models/internvl.py | 5 +- vllm/model_executor/models/llava.py | 5 +- vllm/model_executor/models/llava_onevision.py | 7 +++ vllm/model_executor/models/minicpmo.py | 10 ++-- vllm/model_executor/models/minicpmv.py | 18 +++++-- vllm/model_executor/models/minimax_vl_01.py | 2 + vllm/model_executor/models/mistral3.py | 2 + vllm/model_executor/models/mllama.py | 6 ++- vllm/model_executor/models/mllama4.py | 2 + vllm/model_executor/models/ovis.py | 2 + vllm/model_executor/models/paligemma.py | 5 +- vllm/model_executor/models/phi3v.py | 2 + vllm/model_executor/models/phi4mm.py | 3 +- vllm/model_executor/models/pixtral.py | 7 ++- .../models/prithvi_geospatial_mae.py | 1 + .../models/qwen2_5_omni_thinker.py | 7 +++ vllm/model_executor/models/qwen2_audio.py | 2 + vllm/model_executor/models/qwen2_vl.py | 4 +- vllm/model_executor/models/qwen_vl.py | 3 ++ vllm/model_executor/models/skyworkr1v.py | 2 + vllm/model_executor/models/ultravox.py | 6 +++ vllm/model_executor/models/whisper.py | 4 +- vllm/multimodal/processing.py | 54 +++++++++++++++---- vllm/multimodal/profiling.py | 7 ++- vllm/utils.py | 2 + 41 files changed, 207 insertions(+), 38 deletions(-) diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 6ff2abbae6329..670d747b9ee7d 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -538,11 +538,13 @@ return a schema of the tensors outputted by the HF processor that are related to prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) image_patches = processed_outputs.get("image_patches") @@ -566,6 +568,11 @@ return a schema of the tensors outputted by the HF processor that are related to Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling for text-only inputs to prevent unnecessary warnings from HF processor. + !!! note + The `_call_hf_processor` method specifies both `mm_kwargs` and `tok_kwargs` for + processing. `mm_kwargs` is used to both initialize and call the huggingface + processor, whereas `tok_kwargs` is only used to call the huggingface processor. + This lets us override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] as follows: ```python diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 8b52911c6ccf3..2f97475f121a0 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1086,6 +1086,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs): prompt="", mm_data={}, mm_kwargs=call_kwargs, + tok_kwargs={}, ) assert out_kwargs == expected_kwargs diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63967e4d2d4bc..f0404e0bc6eac 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -481,6 +481,13 @@ class LLM: # Use default sampling params. sampling_params = self.get_default_sampling_params() + tokenization_kwargs: dict[str, Any] = {} + truncate_prompt_tokens = None + if isinstance(sampling_params, SamplingParams): + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens + _validate_truncation_size(self.llm_engine.model_config.max_model_len, + truncate_prompt_tokens, tokenization_kwargs) + self._validate_and_add_requests( prompts=parsed_prompts, params=sampling_params, @@ -488,6 +495,7 @@ class LLM: lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, guided_options=guided_options_request, + tokenization_kwargs=tokenization_kwargs, priority=priority, ) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 16ba2b4531acf..50f810afb8ccd 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -171,6 +171,10 @@ def _validate_truncation_size( tokenization_kwargs["truncation"] = True tokenization_kwargs["max_length"] = truncate_prompt_tokens + else: + if tokenization_kwargs is not None: + tokenization_kwargs["truncation"] = False + return truncate_prompt_tokens diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a13e563f34a14..deda9bc23dafe 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -265,7 +265,8 @@ class InputPreprocessor: prompt: Union[str, list[int]], mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Mapping[str, object]], - lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: """ @@ -280,15 +281,19 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply(prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes) async def _process_multimodal_async( self, prompt: Union[str, list[int]], mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Mapping[str, object]], - lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: """ @@ -302,8 +307,11 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply(prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes) def _process_embeds( self, @@ -338,6 +346,7 @@ class InputPreprocessor: def _process_tokens( self, parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: @@ -350,6 +359,7 @@ class InputPreprocessor: prompt_token_ids, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -367,6 +377,7 @@ class InputPreprocessor: async def _process_tokens_async( self, parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: @@ -379,6 +390,7 @@ class InputPreprocessor: prompt_token_ids, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -408,6 +420,7 @@ class InputPreprocessor: prompt_text, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -442,6 +455,7 @@ class InputPreprocessor: prompt_text, multi_modal_data, parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -860,7 +874,8 @@ class InputPreprocessor: "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder - return self._process_encoder_decoder_prompt(prompt) + return self._process_encoder_decoder_prompt( + prompt, tokenization_kwargs) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index a48631ad709f7..38daf995b8ca3 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -185,11 +185,13 @@ class AyaVisionMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt, mm_data, mm_kwargs, + tok_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) image_processor = hf_processor.image_processor diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 3c3955161daaa..ecc12fa8d3727 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -454,6 +454,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: # HF processor always adds placeholders even when there's no image @@ -465,6 +466,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _get_mm_fields_config( diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index d538ba09c65cf..06e33ad7737e3 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -107,6 +107,7 @@ class ChameleonMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: prompt_ids = self.info.get_tokenizer().encode(prompt) @@ -117,6 +118,7 @@ class ChameleonMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _apply_hf_processor_tokens_only( diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index da5452409d2f9..cdda9fb5a7490 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -204,12 +204,13 @@ class DeepseekVL2MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: processed_outputs = self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(prompt=prompt, **mm_data), - mm_kwargs, + dict(**mm_kwargs, **tok_kwargs), ) pixel_values = processed_outputs["pixel_values"] # split pixel values into patches corresponding to each image @@ -278,6 +279,7 @@ class DeepseekVL2MultiModalProcessor( prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -290,6 +292,7 @@ class DeepseekVL2MultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) @@ -297,6 +300,7 @@ class DeepseekVL2MultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 425407c19ab5d..bda552721eb23 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -794,6 +794,7 @@ class Florence2MultiModalProcessor( prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False @@ -828,10 +829,11 @@ class Florence2MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs) + prompt, mm_data, mm_kwargs, tok_kwargs) else: hf_processor = self.info.get_hf_processor() tokenizer = hf_processor.tokenizer diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7e03982e78e69..b3e055b966b08 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -153,6 +153,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: # Avoid warning from HF logger for text-only input @@ -164,6 +165,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) image_patches = processed_outputs.get("image_patches") diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 3a1c14978b45b..e9c27674b8457 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -259,11 +259,13 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt, mm_data, mm_kwargs, + tok_kwargs, ) # HF processor pops the `num_crops` kwarg, which is needed by vLLM diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 70916c45c0e09..95e3fcfc02fab 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -481,6 +481,7 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index f2dc5708028ba..77fbc4808b4a3 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -141,6 +141,7 @@ class GraniteSpeechMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) audios = mm_data.pop("audios", []) @@ -153,6 +154,7 @@ class GraniteSpeechMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) if "audio" in mm_data: diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 8f7f359b75521..467b074f37753 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -490,6 +490,7 @@ class H2OVLMultiModalProcessor( prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -502,6 +503,7 @@ class H2OVLMultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) @@ -509,6 +511,7 @@ class H2OVLMultiModalProcessor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index b1d0626217a0a..36cfb5807d7d6 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -326,6 +326,7 @@ class Idefics3MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # Text-only input not supported in composite processor if not (images := mm_data.get("images", [])): @@ -337,6 +338,7 @@ class Idefics3MultiModalProcessor( prompt, mm_data, mm_kwargs, + tok_kwargs, ) parsed_images = (self._get_data_parser().parse_mm_data({ diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index bb71177ecad8e..6abe6cd6965c8 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -758,11 +758,13 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) @@ -941,9 +943,10 @@ class InternVLMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs) + mm_kwargs, tok_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs) if self.info.supports_video and ( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 1c35bf5206db7..7a7aefb267181 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -296,11 +296,13 @@ class PixtralHFMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values = processed_outputs.get("pixel_values") @@ -797,6 +799,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() @@ -809,7 +812,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) + tokenization_kwargs, return_mm_hashes) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c5403762f5390..7ff1026bfc94d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -286,6 +286,7 @@ class LlavaOnevisionMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) videos = mm_data.pop("videos", []) @@ -296,6 +297,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) # LLaVA-OneVision processor doesn't support multiple videos @@ -310,6 +312,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=prompt, mm_data={}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) images = mm_data.pop("images", []) @@ -319,6 +322,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=image_token * len(images), mm_data={"images": images}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) image_outputs = { k: v @@ -334,6 +338,7 @@ class LlavaOnevisionMultiModalProcessor( prompt=video_token, mm_data={"videos": video}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values_videos.append(item_outputs["pixel_values_videos"][0]) @@ -352,11 +357,13 @@ class LlavaOnevisionMultiModalProcessor( prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: base_result = super()._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return base_result and mm_items.get_count("video", strict=False) == 0 diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index ff5959ed196ea..112e0b91d3f17 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -260,6 +260,7 @@ class MiniCPMOMultiModalProcessor( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: if (audios := mm_data.get("audios")) is None: return {} @@ -276,9 +277,9 @@ class MiniCPMOMultiModalProcessor( prompts=[self.info.audio_pattern] * len(parsed_audios), mm_data={"audios": [[audio] for audio in parsed_audios]}, mm_kwargs={ - **mm_kwargs, - "chunk_input": True, + **mm_kwargs, "chunk_input": True }, + tok_kwargs=tok_kwargs, out_keys={"audio_features", "audio_feature_lens"}, ) @@ -302,10 +303,11 @@ class MiniCPMOMultiModalProcessor( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: return { - **super().process_mm_inputs(mm_data, mm_kwargs), - **self.process_audios(mm_data, mm_kwargs), + **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs), + **self.process_audios(mm_data, mm_kwargs, tok_kwargs), } def _get_prompt_updates( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 9dc03c8001824..1dba88be83500 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -534,6 +534,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: if (images := mm_data.get("images")) is None: return {} @@ -550,6 +551,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompts=[self.info.image_pattern] * len(parsed_images), mm_data={"images": [[image] for image in parsed_images]}, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) @@ -563,6 +565,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: if (videos := mm_data.get("videos")) is None: return {} @@ -586,6 +589,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): "max_slice_nums": self.info.get_video_max_slice_num(), }, + tok_kwargs=tok_kwargs, out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) @@ -601,10 +605,11 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: return { - **self.process_images(mm_data, mm_kwargs), - **self.process_videos(mm_data, mm_kwargs), + **self.process_images(mm_data, mm_kwargs, tok_kwargs), + **self.process_videos(mm_data, mm_kwargs, tok_kwargs), } def _base_call_hf_processor( @@ -612,6 +617,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompts: list[str], mm_data: Mapping[str, Sequence[object]], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], *, out_keys: set[str], ) -> dict[str, NestedTensors]: @@ -621,6 +627,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt=prompts, # type: ignore mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) else: inputs = defaultdict[str, list[torch.Tensor]](list) @@ -633,6 +640,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): for k, v in mm_data.items() }, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) for k, v in inputs_one.items(): @@ -646,11 +654,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() - input_ids = torch.tensor([tokenizer.encode(prompt)]) - mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) + input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)]) + mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs, tok_kwargs) return BatchFeature({ "input_ids": input_ids, @@ -662,6 +671,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 8ce94540e87fe..a125454c0c060 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -113,11 +113,13 @@ class MiniMaxVL01MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values = processed_outputs.get("pixel_values") diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 04d6d347cb84f..6840c672a3299 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -228,11 +228,13 @@ class Mistral3MultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) pixel_values = processed_outputs.get("pixel_values") diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 1b7e93fafad93..ead5a8e950f0f 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -166,10 +166,11 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) + tokenization_kwargs, return_mm_hashes) image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches @@ -239,6 +240,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() if mm_data: @@ -247,7 +249,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] for img in mm_data["images"] ] processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs) + prompt, mm_data, mm_kwargs, tok_kwargs) processed_outputs["num_tiles"] = torch.tensor(num_tiles) for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): processed_outputs[k] = processed_outputs[k].squeeze(0) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index a420e757e2194..ea781e18db272 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -574,6 +574,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() @@ -583,6 +584,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) processor = self.info.get_hf_processor(**mm_kwargs) diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 6eecd4499fb96..5059b4e69f076 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -335,6 +335,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: # Avoid warning from HF logger for text-only input @@ -346,6 +347,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) hf_processor = self.info.get_hf_processor() diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index e1de8cf458780..29ffb62eeafd0 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -121,6 +121,7 @@ class PaliGemmaMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() if not mm_data: @@ -131,6 +132,7 @@ class PaliGemmaMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _get_mm_fields_config( @@ -191,10 +193,11 @@ class PaliGemmaMultiModalProcessor( prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) + tokenization_kwargs, return_mm_hashes) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0a7adf91e488f..a084e71f734c2 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -376,11 +376,13 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) input_ids = processed_outputs["input_ids"] diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 5d1f0775b07fb..3c4162507f03d 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -762,6 +762,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: prompt_ids = self.info.get_tokenizer().encode(prompt) @@ -773,7 +774,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): mm_data['audios'] = [(data, sr) for data in audio_data] processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs) + mm_kwargs, tok_kwargs) num_img_tokens = [ self.info.get_num_image_tokens(image_width=img_size[0], diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 709ac1d9df945..a31c757f7d592 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -237,6 +237,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_images = dummy_mm_data.get("image", []) + tokenization_kwargs = {"truncation": False} request = ChatCompletionRequest(messages=[ UserMessage(content=[ @@ -247,7 +248,9 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens - return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) + return ProcessorInputs(prompt=dummy_tokens, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs) class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] @@ -297,6 +300,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -309,6 +313,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 4fdcae5de644a..f89cf1b5274cf 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -92,6 +92,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: mm_kwargs = {} diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 9497f15984b75..8980f386502fc 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -244,6 +244,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) audios = mm_data.pop("audios", []) @@ -258,6 +259,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) input_features = hf_inputs.pop('input_features', None) @@ -453,6 +455,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt: Union[str, list[int]], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, ) -> tuple[list[int], MultiModalKwargs, bool]: @@ -465,6 +468,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt_text=prompt, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) tokenizer = self.info.get_tokenizer() prompt_ids = encode_tokens(tokenizer, prompt) @@ -474,6 +478,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( mm_kwargs = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids, mm_kwargs, False @@ -482,6 +487,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> MultiModalKwargs: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. @@ -498,6 +504,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return mm_kwargs diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index aefa1db24628d..31b25ef0bc731 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -150,6 +150,7 @@ class Qwen2AudioMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # NOTE - we rename audios -> audio in mm data because transformers has # deprecated audios for the qwen2audio processor and will remove @@ -174,6 +175,7 @@ class Qwen2AudioMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _get_mm_fields_config( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 899fc57c7a0e5..dc7b08c65bb13 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1027,11 +1027,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: + mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs) return self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), - self.info._get_image_processor_kwargs(**mm_kwargs), + dict(**mm_kwargs, **tok_kwargs), ) def _get_prompt_updates( diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index fc29785af95a0..563650a4f162c 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -580,6 +580,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # Drops anything between / tags; encoding with the tokenizer # will automatically add the image pads for the context. @@ -600,6 +601,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) def _hf_processor_applies_updates( @@ -607,6 +609,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: return False diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 28f181dde2154..d362838dbb398 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -534,11 +534,13 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 94f5e03fd446e..5cccd6b8841b4 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -144,6 +144,7 @@ class UltravoxMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: # Text-only input not supported in composite processor if not mm_data.get("audios", []): @@ -165,10 +166,15 @@ class UltravoxMultiModalProcessor( item_processor_data = dict(**mm_data, audios=audios) + # some tokenizer kwargs are incompatible with UltravoxProcessor + tok_kwargs.pop("padding", None) + tok_kwargs.pop("truncation", None) + output = super()._call_hf_processor( prompt=prompt, mm_data=item_processor_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) output['audio_features'] = output.pop('audio_values') diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 5a0094fa749fd..568b81c4bbfa8 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -700,9 +700,10 @@ class WhisperMultiModalProcessor( prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: - feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + feature_extractor = self.info.get_feature_extractor() mm_data = dict(audio=mm_data.pop("audios")) mm_kwargs = dict( **mm_kwargs, @@ -712,6 +713,7 @@ class WhisperMultiModalProcessor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) if "labels" in processed_outputs: processed_outputs["input_ids"] = processed_outputs.pop("labels") diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 38f3a7cb932f4..aa7889fc3cc59 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1267,6 +1267,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # This refers to the data to be passed to HF processor. mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> "BatchFeature": """ Call the HF processor on the prompt text and @@ -1275,7 +1276,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), - mm_kwargs, + dict(**mm_kwargs, **tok_kwargs), ) def _hf_processor_applies_updates( @@ -1283,6 +1284,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> bool: """ Return whether the HF processor applies prompt updates. @@ -1300,6 +1302,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> tuple[list[int], MultiModalKwargs, bool]: """ Apply the HF processor on the prompt text and multi-modal data @@ -1313,6 +1316,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt_text, mm_data=processor_data, mm_kwargs=hf_processor_mm_kwargs, + tok_kwargs=tokenization_kwargs, ) processed_data.update(passthrough_data) @@ -1327,11 +1331,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids, mm_kwargs, is_update_applied - def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: + def _apply_hf_processor_text_only( + self, prompt_text: str, + tokenization_kwargs: Mapping[str, object]) -> list[int]: """ Apply the HF processor on the prompt text only. @@ -1343,6 +1350,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=prompt_text, mm_items=MultiModalDataItems({}), hf_processor_mm_kwargs={}, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids @@ -1368,6 +1376,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> MultiModalKwargs: """ Apply the HF processor on the multi-modal data only. @@ -1383,6 +1392,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return mm_kwargs @@ -1392,6 +1402,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, ) -> tuple[list[int], MultiModalKwargs, bool]: @@ -1412,15 +1423,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_text=prompt, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) - prompt_ids = self._apply_hf_processor_text_only(prompt) + prompt_ids = self._apply_hf_processor_text_only( + prompt, tokenization_kwargs) else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) mm_kwargs = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) return prompt_ids, mm_kwargs, False @@ -1430,14 +1444,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): cache: ProcessingCache, mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ str, list[object]]]: model_id = self.info.model_id mm_cache_items = { modality: [ - cache.get_item(model_id, modality, item, - hf_processor_mm_kwargs) for item in items + cache.get_item( + model_id, modality, item, + dict(**hf_processor_mm_kwargs, **tokenization_kwargs)) + for item in items ] for modality, items in mm_data_items.items() } @@ -1457,10 +1474,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return mm_cache_items, mm_missing_data def _hash_mm_items( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalHashes: + self, mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes: """Create MM hashes to be returned (only used in V1).""" model_id = self.info.model_id @@ -1468,7 +1484,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): modality: [ MultiModalHasher.hash_kwargs(model_id=model_id, **{modality: item}, - **hf_processor_mm_kwargs) + **hf_processor_mm_kwargs, + **tokenization_kwargs) for item in items ] for modality, items in mm_items.items() @@ -1513,6 +1530,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -1524,10 +1542,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt, mm_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, enable_hf_prompt_update=True, ) - mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs) + mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, + tokenization_kwargs) if return_mm_hashes else None) return prompt_ids, mm_kwargs, mm_hashes, is_update_applied @@ -1537,6 +1557,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: @@ -1552,6 +1573,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) @@ -1562,6 +1584,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): cache=cache, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) # NOTE: `prompt` does not correspond to `mm_missing_data_items`, @@ -1575,6 +1598,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt, mm_items=self._to_mm_items(mm_missing_data), hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, enable_hf_prompt_update=False, ) @@ -1783,6 +1807,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: """ @@ -1800,6 +1825,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ mm_items = self._to_mm_items(mm_data) + if tokenization_kwargs is None: + tokenization_kwargs = {} + ( prompt_ids, mm_kwargs, @@ -1809,9 +1837,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt, mm_items, hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, return_mm_hashes=return_mm_hashes, ) + # NOTE: tokenization_kwargs are not required to init processor prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -1892,6 +1922,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: """ @@ -1906,6 +1937,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): encoder_prompt, mm_data, hf_processor_mm_kwargs, + tokenization_kwargs, return_mm_hashes, ) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 67bcb31f23f70..fb5a7b64c4199 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -30,6 +30,7 @@ class ProcessorInputs: prompt: Union[str, list[int]] mm_data: MultiModalDataDict hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) + tokenization_kwargs: Mapping[str, object] = field(default_factory=dict) class DummyEncoderData(NamedTuple): @@ -90,8 +91,11 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): """ dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + tokenization_kwargs = {"truncation": False} - return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data) + return ProcessorInputs(prompt=dummy_text, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs) def _get_dummy_audios( self, @@ -170,6 +174,7 @@ class MultiModalProfiler(Generic[_I]): prompt=processor_inputs.prompt, mm_data=processor_inputs.mm_data, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + tokenization_kwargs=processor_inputs.tokenization_kwargs, ) def _get_mm_num_tokens( diff --git a/vllm/utils.py b/vllm/utils.py index 7eb3c1e347cde..689102281c54f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1729,6 +1729,7 @@ def supports_kw( last_param = params[next(reversed(params))] # type: ignore return (last_param.kind == inspect.Parameter.VAR_KEYWORD and last_param.name != kw_name) + return False @@ -1771,6 +1772,7 @@ def resolve_mm_processor_kwargs( # Merge the final processor kwargs, prioritizing inference # time values over the initialization time values. mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs From 97d9524fe90ad5799cc11db4b4216fe3a30a07d6 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 30 Jun 2025 14:15:24 -0400 Subject: [PATCH 073/175] [Refactor] Remove useless pdb comment (#20266) Signed-off-by: yewentao256 --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 321fb0351ad93..818f6d345ba6d 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -141,7 +141,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - # import pdb; pdb.set_trace() dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) From ded1fb635b7c1504a83fc7c195a5bf47d31c1bef Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Tue, 1 Jul 2025 07:45:14 +0800 Subject: [PATCH 074/175] [Bugfix][V1][P/D]Fix the issue of occasional garbled output for P2pNcclConnector (#20263) Signed-off-by: Abatom --- .../kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 81f7a2525896e..35c26897fe3f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -310,10 +310,11 @@ class P2pNcclEngine: elif data["cmd"] == "PUT": tensor_id = data["tensor_id"] try: - tensor = torch.empty(data["shape"], - dtype=getattr( - torch, data["dtype"]), - device=self.device) + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) self.router_socket.send_multipart( [remote_address, b"0"]) comm, rank = self.comms[remote_address.decode()] From 6d42ce83155d42f04643c1fa54eaed8abf8170c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 30 Jun 2025 21:03:13 -0400 Subject: [PATCH 075/175] [CLI] Improve CLI arg parsing for `-O`/`--compilation-config` (#20156) Signed-off-by: luka --- tests/engine/test_arg_utils.py | 28 +++++++++------ tests/test_utils.py | 47 ++++++++++++++++++++++++ vllm/config.py | 19 +++++----- vllm/engine/arg_utils.py | 5 ++- vllm/utils.py | 65 ++++++++++++++++++++++++---------- 5 files changed, 124 insertions(+), 40 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index cfbc7c245ffd4..847f150bd6443 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -239,32 +239,40 @@ def test_compilation_config(): assert args.compilation_config == CompilationConfig() # set to O3 - args = parser.parse_args(["-O3"]) - assert args.compilation_config.level == 3 + args = parser.parse_args(["-O0"]) + assert args.compilation_config.level == 0 # set to O 3 (space) - args = parser.parse_args(["-O", "3"]) - assert args.compilation_config.level == 3 + args = parser.parse_args(["-O", "1"]) + assert args.compilation_config.level == 1 # set to O 3 (equals) - args = parser.parse_args(["-O=3"]) + args = parser.parse_args(["-O=2"]) + assert args.compilation_config.level == 2 + + # set to O.level 3 + args = parser.parse_args(["-O.level", "3"]) assert args.compilation_config.level == 3 # set to string form of a dict args = parser.parse_args([ - "--compilation-config", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', + "-O", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": false}', ]) assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and not args.compilation_config.use_inductor) # set to string form of a dict args = parser.parse_args([ "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": true}', ]) assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and args.compilation_config.use_inductor) def test_prefix_cache_default(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 913188455d8e6..36db8202ba622 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import asyncio import hashlib import json +import logging import pickle import socket from collections.abc import AsyncIterator @@ -142,6 +143,7 @@ def parser(): parser.add_argument('--batch-size', type=int) parser.add_argument('--enable-feature', action='store_true') parser.add_argument('--hf-overrides', type=json.loads) + parser.add_argument('-O', '--compilation-config', type=json.loads) return parser @@ -265,6 +267,11 @@ def test_dict_args(parser): "val2", "--hf-overrides.key2.key4", "val3", + # Test compile config and compilation level + "-O.use_inductor=true", + "-O.backend", + "custom", + "-O1", # Test = sign "--hf-overrides.key5=val4", # Test underscore to dash conversion @@ -281,6 +288,13 @@ def test_dict_args(parser): "true", "--hf_overrides.key12.key13", "null", + # Test '-' and '.' in value + "--hf_overrides.key14.key15", + "-minus.and.dot", + # Test array values + "-O.custom_ops+", + "-quant_fp8", + "-O.custom_ops+=+silu_mul,-rms_norm", ] parsed_args = parser.parse_args(args) assert parsed_args.model_name == "something.something" @@ -301,7 +315,40 @@ def test_dict_args(parser): "key12": { "key13": None, }, + "key14": { + "key15": "-minus.and.dot", + } } + assert parsed_args.compilation_config == { + "level": 1, + "use_inductor": True, + "backend": "custom", + "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"], + } + + +def test_duplicate_dict_args(caplog_vllm, parser): + args = [ + "--model-name=something.something", + "--hf-overrides.key1", + "val1", + "--hf-overrides.key1", + "val2", + "-O1", + "-O.level", + "2", + "-O3", + ] + + parsed_args = parser.parse_args(args) + # Should be the last value + assert parsed_args.hf_overrides == {"key1": "val2"} + assert parsed_args.compilation_config == {"level": 3} + + assert len(caplog_vllm.records) == 1 + assert "duplicate" in caplog_vllm.text + assert "--hf-overrides.key1" in caplog_vllm.text + assert "-O.level" in caplog_vllm.text # yapf: enable diff --git a/vllm/config.py b/vllm/config.py index 57b9df2364775..46a5bf34f66e4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4140,9 +4140,9 @@ class CompilationConfig: @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": - """Parse the CLI value for the compilation config.""" - if cli_value in ["0", "1", "2", "3"]: - return cls(level=int(cli_value)) + """Parse the CLI value for the compilation config. + -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. + """ return TypeAdapter(CompilationConfig).validate_json(cli_value) def __post_init__(self) -> None: @@ -4303,17 +4303,16 @@ class VllmConfig: """Quantization configuration.""" compilation_config: CompilationConfig = field( default_factory=CompilationConfig) - """`torch.compile` configuration for the model. + """`torch.compile` and cudagraph capture configuration for the model. - When it is a number (0, 1, 2, 3), it will be interpreted as the - optimization level. + As a shorthand, `-O` can be used to directly specify the compilation + level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). + Currently, -O and -O= are supported as well but this will likely be + removed in favor of clearer -O syntax in the future. NOTE: level 0 is the default level without any optimization. level 1 and 2 are for internal testing only. level 3 is the recommended level for - production. - - Following the convention of traditional compilers, using `-O` without space - is also supported. `-O3` is equivalent to `-O 3`. + production, also default in V1. You can specify the full compilation config like so: `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6c908f88b9a92..2d3783363c00b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: passed individually. For example, the following sets of arguments are equivalent:\n\n - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n + Additionally, list elements can be passed individually using '+': + - `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n + - `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n""" if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: diff --git a/vllm/utils.py b/vllm/utils.py index 689102281c54f..60e560c70ad3a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ "Prefix caching for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Chunked prefill for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( "Models with logits_soft_cap " @@ -752,7 +752,7 @@ def _generate_random_fp8( # to generate random data for fp8 data. # For example, s.11111.00 in fp8e5m2 format represents Inf. # | E4M3 | E5M2 - #-----|-------------|------------------- + # -----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops @@ -840,7 +840,6 @@ def create_kv_caches_with_random( seed: Optional[int] = None, device: Optional[str] = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" @@ -1205,7 +1204,6 @@ def deprecate_args( is_deprecated: Union[bool, Callable[[], bool]] = True, additional_message: Optional[str] = None, ) -> Callable[[F], F]: - if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) @@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: return weak_bound -#From: https://stackoverflow.com/a/4104188/2749989 +# From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: @@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser): # Convert underscores to dashes and vice versa in argument names processed_args = list[str]() - for arg in args: + for i, arg in enumerate(args): if arg.startswith('--'): if '=' in arg: key, value = arg.split('=', 1) @@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser): else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) - elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: - # allow -O flag to be used without space, e.g. -O3 - processed_args.append('-O') - processed_args.append(arg[2:]) + elif arg.startswith('-O') and arg != '-O' and arg[2] != '.': + # allow -O flag to be used without space, e.g. -O3 or -Odecode + # -O.<...> handled later + # also handle -O= here + level = arg[3:] if arg[2] == '=' else arg[2:] + processed_args.append(f'-O.level={level}') + elif arg == '-O' and i + 1 < len(args) and args[i + 1] in { + "0", "1", "2", "3" + }: + # Convert -O to -O.level + processed_args.append('-O.level') else: processed_args.append(arg) @@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser): def recursive_dict_update( original: dict[str, Any], update: dict[str, Any], - ): - """Recursively updates a dictionary with another dictionary.""" + ) -> set[str]: + """Recursively updates a dictionary with another dictionary. + Returns a set of duplicate keys that were overwritten. + """ + duplicates = set[str]() for k, v in update.items(): if isinstance(v, dict) and isinstance(original.get(k), dict): - recursive_dict_update(original[k], v) + nested_duplicates = recursive_dict_update(original[k], v) + duplicates |= {f"{k}.{d}" for d in nested_duplicates} + elif isinstance(v, list) and isinstance(original.get(k), list): + original[k] += v else: + if k in original: + duplicates.add(k) original[k] = v + return duplicates delete = set[int]() dict_args = defaultdict[str, dict[str, Any]](dict) + duplicates = set[str]() for i, processed_arg in enumerate(processed_args): - if processed_arg.startswith("--") and "." in processed_arg: + if i in delete: # skip if value from previous arg + continue + + if processed_arg.startswith("-") and "." in processed_arg: if "=" in processed_arg: processed_arg, value_str = processed_arg.split("=", 1) if "." not in processed_arg: - # False positive, . was only in the value + # False positive, '.' was only in the value continue else: value_str = processed_args[i + 1] delete.add(i + 1) + if processed_arg.endswith("+"): + processed_arg = processed_arg[:-1] + value_str = json.dumps(list(value_str.split(","))) + key, *keys = processed_arg.split(".") try: value = json.loads(value_str) @@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser): # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) - recursive_dict_update(dict_args[key], arg_dict) + arg_duplicates = recursive_dict_update(dict_args[key], + arg_dict) + duplicates |= {f'{key}.{d}' for d in arg_duplicates} delete.add(i) # Filter out the dict args we set to None processed_args = [ a for i, a in enumerate(processed_args) if i not in delete ] + if duplicates: + logger.warning("Found duplicate keys %s", ", ".join(duplicates)) + # Add the dict args back as if they were originally passed as JSON for dict_arg, dict_value in dict_args.items(): processed_args.append(dict_arg) @@ -2405,7 +2432,7 @@ def memory_profiling( The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). - """ # noqa + """ # noqa gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() From e28533a16f73a4eae01c2b7b1b4ddf3fc1beedab Mon Sep 17 00:00:00 2001 From: fyuan1316 Date: Tue, 1 Jul 2025 09:30:14 +0800 Subject: [PATCH 076/175] [Bugfix] Fix include prompt in stream response when echo=true (#15233) Signed-off-by: Yuan Fang --- tests/entrypoints/openai/test_completion.py | 54 +++++++++++++++++++ vllm/entrypoints/openai/serving_completion.py | 21 ++++++-- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 7e54143f6e1c3..7933ca5cd6c6f 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -779,3 +779,57 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, prompt="Give an example string that fits this regex", extra_body=dict(guided_regex=sample_regex, guided_json=sample_json_schema)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,stream,echo", + [ + (MODEL_NAME, False, False), + (MODEL_NAME, False, True), + (MODEL_NAME, True, False), + (MODEL_NAME, True, True) # should not raise BadRequestError error + ], +) +async def test_echo_stream_completion(client: openai.AsyncOpenAI, + model_name: str, stream: bool, + echo: bool): + saying: str = "Hello, my name is" + result = await client.completions.create(model=model_name, + prompt=saying, + max_tokens=10, + temperature=0.0, + echo=echo, + stream=stream) + + stop_reason = "length" + + if not stream: + completion = result + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == stop_reason + + if echo: + assert choice.text is not None and saying in choice.text + else: + assert choice.text is not None and saying not in choice.text + + else: + chunks: list[str] = [] + final_finish_reason = None + async for chunk in result: + if chunk.choices and chunk.choices[0].text: + chunks.append(chunk.choices[0].text) + if chunk.choices and chunk.choices[0].finish_reason: + final_finish_reason = chunk.choices[0].finish_reason + + assert final_finish_reason == stop_reason + content = "".join(chunks) + if echo: + assert content is not None and saying in content + else: + assert content is not None and saying not in content diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a19fde8d70a83..8171b491aafcc 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -25,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ErrorResponse, RequestResponseMetadata, UsageInfo) -# yapf: enable +from vllm.entrypoints.openai.serving_engine import ( + EmbedsPrompt as ServingEngineEmbedsPrompt) from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + TextTokensPrompt, clamp_prompt_logprobs, is_text_tokens_prompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) @@ -223,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing): if stream: return self.completion_stream_generator( request, + request_prompts, result_generator, request_id, created_time, @@ -285,6 +289,8 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, + request_prompts: list[Union[TextTokensPrompt, + ServingEngineEmbedsPrompt]], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -313,7 +319,15 @@ class OpenAIServingCompletion(OpenAIServing): async for prompt_idx, res in result_generator: prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs - prompt_text = res.prompt + + if res.prompt is not None: + prompt_text = res.prompt + else: + request_prompt = request_prompts[prompt_idx] + if is_text_tokens_prompt(request_prompt): + prompt_text = request_prompt["prompt"] + else: + prompt_text = None # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: @@ -336,14 +350,13 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = prompt_token_ids out_logprobs = prompt_logprobs else: - assert prompt_logprobs is not None # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ *prompt_token_ids, *output.token_ids ] out_logprobs = [ - *prompt_logprobs, + *(prompt_logprobs or []), *(output.logprobs or []), ] has_echoed[i] = True From 7151f92241db1bb6ef4eb0fcfed87256646d554e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 21:01:48 -0700 Subject: [PATCH 077/175] [Misc] Fix spec decode example (#20296) Signed-off-by: Woosuk Kwon --- examples/offline_inference/spec_decode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 90d103e5cb05d..3f38aa9fcaa60 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -79,9 +79,7 @@ def main(): trust_remote_code=True, tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, - max_num_batched_tokens=args.max_num_batched_tokens, enforce_eager=args.enforce_eager, - max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, From 92ee7baaf9a5bf6c8132dde56e4056933c61f50f Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 30 Jun 2025 21:03:55 -0700 Subject: [PATCH 078/175] [Example] add one-click runnable example for P2P NCCL XpYd (#20246) Signed-off-by: KuntaiDu --- .../disagg_example_p2p_nccl_xpyd.sh | 245 ++++++++++++++++++ .../disagg_proxy_p2p_nccl_xpyd.py} | 0 2 files changed, 245 insertions(+) create mode 100644 examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh rename examples/online_serving/{disagg_xpyd/disagg_prefill_proxy_xpyd.py => disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py} (100%) diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh new file mode 100644 index 0000000000000..2966f386c93a3 --- /dev/null +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh @@ -0,0 +1,245 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Disaggregated Serving Script - P2P NCCL XpYd Architecture +# ============================================================================= +# This script demonstrates disaggregated prefill and decode serving using +# P2P NCCL communication. The architecture supports various XpYd configurations: +# +# - 1P3D: 1 Prefill server + 3 Decode servers (current default) +# - 3P1D: 3 Prefill servers + 1 Decode server +# - etc. +# +# Configuration can be customized via environment variables: +# MODEL: Model to serve +# PREFILL_GPUS: Comma-separated GPU IDs for prefill servers +# DECODE_GPUS: Comma-separated GPU IDs for decode servers +# PREFILL_PORTS: Comma-separated ports for prefill servers +# DECODE_PORTS: Comma-separated ports for decode servers +# PROXY_PORT: Proxy server port used to setup XpYd connection. +# TIMEOUT_SECONDS: Server startup timeout +# ============================================================================= + +# Configuration - can be overridden via environment variables +MODEL=${MODEL:-meta-llama/Llama-3.1-8B-Instruct} +TIMEOUT_SECONDS=${TIMEOUT_SECONDS:-1200} +PROXY_PORT=${PROXY_PORT:-30001} + +# Default 1P3D configuration (1 Prefill + 3 Decode) +PREFILL_GPUS=${PREFILL_GPUS:-0} +DECODE_GPUS=${DECODE_GPUS:-1,2,3} +PREFILL_PORTS=${PREFILL_PORTS:-20003} +DECODE_PORTS=${DECODE_PORTS:-20005,20007,20009} + +echo "Warning: P2P NCCL disaggregated prefill XpYd support for vLLM v1 is experimental and subject to change." +echo "" +echo "Architecture Configuration:" +echo " Model: $MODEL" +echo " Prefill GPUs: $PREFILL_GPUS, Ports: $PREFILL_PORTS" +echo " Decode GPUs: $DECODE_GPUS, Ports: $DECODE_PORTS" +echo " Proxy Port: $PROXY_PORT" +echo " Timeout: ${TIMEOUT_SECONDS}s" +echo "" + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_required_files() { + local files=("disagg_proxy_p2p_nccl_xpyd.py") + for file in "${files[@]}"; do + if [[ ! -f "$file" ]]; then + echo "Required file $file not found in $(pwd)" + exit 1 + fi + done +} + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + echo "Example: export HF_TOKEN=your_token_here" + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # Check if the number of GPUs are >=2 via nvidia-smi + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + if ! python3 -c "import $1" > /dev/null 2>&1; then + echo "$1 is not installed. Please install it via pip install $1." + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + kill -- -$$ # negative PID == "this whole process-group" + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=$TIMEOUT_SECONDS + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + echo "Server on port $port is ready." + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server on port $port" + return 1 + fi + + sleep 1 + done +} + +main() { + check_required_files + check_hf_token + check_num_gpus + ensure_python_library_installed pandas + ensure_python_library_installed datasets + ensure_python_library_installed vllm + ensure_python_library_installed quart + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching disaggregated serving components..." + echo "Please check the log files for detailed output:" + echo " - prefill*.log: Prefill server logs" + echo " - decode*.log: Decode server logs" + echo " - proxy.log: Proxy server log" + + # ============================================================================= + # Launch Proxy Server + # ============================================================================= + echo "" + echo "Starting proxy server on port $PROXY_PORT..." + python3 disagg_proxy_p2p_nccl_xpyd.py & + PIDS+=($!) + + # Parse GPU and port arrays + IFS=',' read -ra PREFILL_GPU_ARRAY <<< "$PREFILL_GPUS" + IFS=',' read -ra DECODE_GPU_ARRAY <<< "$DECODE_GPUS" + IFS=',' read -ra PREFILL_PORT_ARRAY <<< "$PREFILL_PORTS" + IFS=',' read -ra DECODE_PORT_ARRAY <<< "$DECODE_PORTS" + + # ============================================================================= + # Launch Prefill Servers (X Producers) + # ============================================================================= + echo "" + echo "Starting ${#PREFILL_GPU_ARRAY[@]} prefill server(s)..." + for i in "${!PREFILL_GPU_ARRAY[@]}"; do + local gpu_id=${PREFILL_GPU_ARRAY[$i]} + local port=${PREFILL_PORT_ARRAY[$i]} + local kv_port=$((21001 + i)) + + echo " Prefill server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port" + CUDA_VISIBLE_DEVICES=$gpu_id VLLM_USE_V1=1 vllm serve $MODEL \ + --enforce-eager \ + --host 0.0.0.0 \ + --port $port \ + --tensor-parallel-size 1 \ + --seed 1024 \ + --dtype float16 \ + --max-model-len 10000 \ + --max-num-batched-tokens 10000 \ + --max-num-seqs 256 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --disable-log-request \ + --kv-transfer-config \ + "{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_size\":\"1e1\",\"kv_port\":\"$kv_port\",\"kv_connector_extra_config\":{\"proxy_ip\":\"0.0.0.0\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$port\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" > prefill$((i+1)).log 2>&1 & + PIDS+=($!) + done + + # ============================================================================= + # Launch Decode Servers (Y Decoders) + # ============================================================================= + echo "" + echo "Starting ${#DECODE_GPU_ARRAY[@]} decode server(s)..." + for i in "${!DECODE_GPU_ARRAY[@]}"; do + local gpu_id=${DECODE_GPU_ARRAY[$i]} + local port=${DECODE_PORT_ARRAY[$i]} + local kv_port=$((22001 + i)) + + echo " Decode server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port" + VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ + --enforce-eager \ + --host 0.0.0.0 \ + --port $port \ + --tensor-parallel-size 1 \ + --seed 1024 \ + --dtype float16 \ + --max-model-len 10000 \ + --max-num-batched-tokens 10000 \ + --max-num-seqs 256 \ + --trust-remote-code \ + --gpu-memory-utilization 0.7 \ + --disable-log-request \ + --kv-transfer-config \ + "{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_size\":\"8e9\",\"kv_port\":\"$kv_port\",\"kv_connector_extra_config\":{\"proxy_ip\":\"0.0.0.0\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$port\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" > decode$((i+1)).log 2>&1 & + PIDS+=($!) + done + + # ============================================================================= + # Wait for All Servers to Start + # ============================================================================= + echo "" + echo "Waiting for all servers to start..." + for port in "${PREFILL_PORT_ARRAY[@]}" "${DECODE_PORT_ARRAY[@]}"; do + if ! wait_for_server $port; then + echo "Failed to start server on port $port" + cleanup + exit 1 + fi + done + + echo "" + echo "All servers are up. Starting benchmark..." + + # ============================================================================= + # Run Benchmark + # ============================================================================= + cd ../../../benchmarks/ + python3 benchmark_serving.py --port 10001 --seed $(date +%s) \ + --model $MODEL \ + --dataset-name random --random-input-len 7500 --random-output-len 200 \ + --num-prompts 200 --burstiness 100 --request-rate 2 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup +} + +main \ No newline at end of file diff --git a/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py similarity index 100% rename from examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py rename to examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py From a2f14dc8f9bb04bd782d1aa4d2e6364841d63d6c Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Mon, 30 Jun 2025 23:17:07 -0500 Subject: [PATCH 079/175] [CI][Intel Gaudi][vllm-Plugin]Add CI for hpu-plugin-v1-test (#20196) Signed-off-by: Chendi Xue --- .../scripts/hardware_ci/run-hpu-test.sh | 48 +++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-hpu-test.sh b/.buildkite/scripts/hardware_ci/run-hpu-test.sh index 5efac3ddf469f..ae5b35a9ac6bd 100644 --- a/.buildkite/scripts/hardware_ci/run-hpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-hpu-test.sh @@ -2,10 +2,34 @@ # This script build the CPU docker image and run the offline inference inside the container. # It serves a sanity check for compilation and basic model usage. -set -ex +set -exuo pipefail # Try building the docker image -docker build -t hpu-test-env -f docker/Dockerfile.hpu . +cat <&2 +fi + +# The trap will handle the container removal and final exit. \ No newline at end of file From bd5038af076a2e299d4781c3885415639a1ed3a5 Mon Sep 17 00:00:00 2001 From: Ernest Wong Date: Mon, 30 Jun 2025 21:44:39 -0700 Subject: [PATCH 080/175] [Doc] add config and troubleshooting guide for NCCL & GPUDirect RDMA (#15897) Signed-off-by: Ernest Wong --- docs/serving/distributed_serving.md | 45 ++++++++++++++++++++++++++++- docs/usage/troubleshooting.md | 21 ++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/docs/serving/distributed_serving.md b/docs/serving/distributed_serving.md index 38dcb8c81caf7..6665955411ad5 100644 --- a/docs/serving/distributed_serving.md +++ b/docs/serving/distributed_serving.md @@ -100,7 +100,50 @@ vllm serve /path/to/the/model/in/the/container \ --tensor-parallel-size 16 ``` -To make tensor parallel performant, you should make sure the communication between nodes is efficient, e.g. using high-speed network cards like Infiniband. To correctly set up the cluster to use Infiniband, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Please contact your system administrator for more information on how to set up the flags. One way to confirm if the Infiniband is working is to run vLLM with `NCCL_DEBUG=TRACE` environment variable set, e.g. `NCCL_DEBUG=TRACE vllm serve ...` and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, it means NCCL uses raw TCP Socket, which is not efficient for cross-node tensor parallel. If you find `[send] via NET/IB/GDRDMA` in the logs, it means NCCL uses Infiniband with GPU-Direct RDMA, which is efficient. +To make tensor parallel performant, you should make sure the communication between nodes is efficient, e.g. using high-speed network cards like InfiniBand. To correctly set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Please contact your system administrator for more information on how to set up the flags. One way to confirm if the InfiniBand is working is to run vLLM with `NCCL_DEBUG=TRACE` environment variable set, e.g. `NCCL_DEBUG=TRACE vllm serve ...` and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, it means NCCL uses raw TCP Socket, which is not efficient for cross-node tensor parallel. If you find `[send] via NET/IB/GDRDMA` in the logs, it means NCCL uses InfiniBand with GPUDirect RDMA, which is efficient. + +### GPUDirect RDMA + +To enable GPUDirect RDMA with vLLM, specific configuration tweaks are needed. This setup ensures: + +- `IPC_LOCK` Security Context: Add the `IPC_LOCK` capability to the container’s security context to lock memory pages and prevent swapping to disk. +- Shared Memory with `/dev/shm`: Mount `/dev/shm` in the pod spec to provide shared memory for IPC. + +When using Docker, you can set up the container as follows: + +```bash +docker run --gpus all \ + --ipc=host \ + --shm-size=16G \ + -v /dev/shm:/dev/shm \ + vllm/vllm-openai +``` + +When using Kubernetes, you can set up the pod spec as follows: + +```yaml +... +spec: + containers: + - name: vllm + image: vllm/vllm-openai + securityContext: + capabilities: + add: ["IPC_LOCK"] + volumeMounts: + - mountPath: /dev/shm + name: dshm + resources: + limits: + nvidia.com/gpu: 8 + requests: + nvidia.com/gpu: 8 + volumes: + - name: dshm + emptyDir: + medium: Memory +... +``` !!! warning After you start the Ray cluster, you'd better also check the GPU-GPU communication between nodes. It can be non-trivial to set up. Please refer to the [sanity check script][troubleshooting-incorrect-hardware-driver] for more information. If you need to set some environment variables for the communication configuration, you can append them to the `run_cluster.sh` script, e.g. `-e NCCL_SOCKET_IFNAME=eth0`. Note that setting environment variables in the shell (e.g. `NCCL_SOCKET_IFNAME=eth0 vllm serve ...`) only works for the processes in the same node, not for the processes in the other nodes. Setting environment variables when you create the cluster is the recommended way. See for more information. diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 82957d33b19e0..7f1f76ce3d2e3 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -273,6 +273,27 @@ But you are sure that the model is in the [list of supported models][supported-m If you see an error like `RuntimeError: Failed to infer device type`, it means that vLLM failed to infer the device type of the runtime environment. You can check [the code](gh-file:vllm/platforms/__init__.py) to see how vLLM infers the device type and why it is not working as expected. After [this PR](gh-pr:14195), you can also set the environment variable `VLLM_LOGGING_LEVEL=DEBUG` to see more detailed logs to help debug the issue. +## NCCL error: unhandled system error during `ncclCommInitRank` + +If your serving workload uses GPUDirect RDMA for distributed serving across multiple nodes and encounters an error during `ncclCommInitRank`, with no clear error message even with `NCCL_DEBUG=INFO` set, it might look like this: + +```text +Error executing method 'init_device'. This might cause deadlock in distributed execution. +Traceback (most recent call last): +... + File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 99, in __init__ + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 277, in ncclCommInitRank + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 256, in NCCL_CHECK + raise RuntimeError(f"NCCL error: {error_str}") + RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details) +... +``` + +This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Distributed Inference and Serving](../serving/distributed_serving.md#running-vllm-on-multiple-nodes) for guidance on properly configuring the environment for distributed serving. + ## Known Issues - In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). From 27949354faa06035645aa908cc73922500a80b17 Mon Sep 17 00:00:00 2001 From: Alex Kogan <82225080+sakogan@users.noreply.github.com> Date: Tue, 1 Jul 2025 01:44:38 -0400 Subject: [PATCH 081/175] [Feature] A calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference (#18768) Signed-off-by: Alex Kogan Co-authored-by: Michael Goin --- tests/quantization/test_rtn.py | 28 ++ .../layers/quantization/__init__.py | 3 + .../model_executor/layers/quantization/rtn.py | 288 ++++++++++++++++++ 3 files changed, 319 insertions(+) create mode 100644 tests/quantization/test_rtn.py create mode 100644 vllm/model_executor/layers/quantization/rtn.py diff --git a/tests/quantization/test_rtn.py b/tests/quantization/test_rtn.py new file mode 100644 index 0000000000000..04c1f98a709e2 --- /dev/null +++ b/tests/quantization/test_rtn.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright © 2025, Oracle and/or its affiliates. +"""Tests RTN quantization startup and generation, +doesn't test correctness +""" +import pytest + +from tests.quantization.utils import is_quant_method_supported + +MODELS = ["microsoft/Phi-3-mini-4k-instruct"] + + +@pytest.mark.skipif(not is_quant_method_supported("rtn"), + reason="RTN is not supported on this GPU type.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_model_rtn_startup( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1cb23e7a18875..60217ee86ad1d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -35,6 +35,7 @@ QuantizationMethods = Literal[ "moe_wna16", "torchao", "auto-round", + "rtn", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .neuron_quant import NeuronQuantConfig from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig + from .rtn import RTNConfig from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig @@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, "auto-round": AutoRoundConfig, + "rtn": RTNConfig } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py new file mode 100644 index 0000000000000..7e7fd6d51fd32 --- /dev/null +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright © 2025, Oracle and/or its affiliates. + +import os +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +logger = init_logger(__name__) +"""By default, use 8 bit as target precision, but it can be +overridden by setting the RTN_NUM_BITS envvar +""" +NUM_BITS = os.getenv('RTN_NUM_BITS', "8") +"""By default, use group size of 128 parameters, but it can be +overridden by setting the RTN_GROUP_SIZE envvar +""" +GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") + + +class RTNConfig(QuantizationConfig): + """Config class for RTN. + """ + + def __init__( + self, + weight_bits: int = int(NUM_BITS), + group_size: int = int(GROUP_SIZE), + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + + if self.weight_bits != 4 and self.weight_bits != 8: + raise ValueError( + "Currently, only 4-bit or 8-bit weight quantization is " + f"supported for RTN, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"RTNConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "rtn" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "RTNConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["RTNLinearMethod"]: + if isinstance(layer, LinearBase): + return RTNLinearMethod(self) + return None + + +class RTNTensor: + """A wrapper over Tensor that enables quantization on-the-fly by + overloading the copy_ method. + """ + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, + quant_config: RTNConfig) -> None: + self.data = data + self.scale = scale + self.quant_config = quant_config + + def narrow(self, dim, start, length): + factor = 1 if self.quant_config.weight_bits == 8 else 2 + return RTNTensor( + self.data.narrow(dim, start // factor, length // factor), + self.scale.narrow(dim, start, length), self.quant_config) + + @property + def shape(self): + shape = self.data.shape + factor = 1 if self.quant_config.weight_bits == 8 else 2 + return torch.Size((shape[0] * factor, shape[1])) + + def copy_(self, loaded_weight: torch.Tensor) -> None: + qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), + self.quant_config.weight_bits, + self.quant_config.group_size) + + self.data.copy_(qweight) + self.scale.data.copy_(weight_scale) + + +class RTNParameter(Parameter): + """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor) + when its data is accessed. We need this wrapper for the data loading phase + only, so we can intercept a weight copying function (torch.Tensor.copy_) + and apply quantization on-the-fly. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, + quant_config: RTNConfig) -> None: + self.scale = scale + self.quant_config = quant_config + + @property + def data(self): + return RTNTensor(super().data, self.scale, self.quant_config) + + +class RTNLinearMethod(LinearMethodBase): + """Linear method for RTN. + + Args: + quant_config: The RTN quantization config. + """ + + def __init__(self, quant_config: RTNConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + num_groups_per_col = (input_size_per_partition // + self.quant_config.group_size + if self.quant_config.group_size != -1 else 1) + + scale = Parameter( + torch.empty(output_size_per_partition, + num_groups_per_col, + dtype=params_dtype), + requires_grad=False, + ) + factor = 1 if self.quant_config.weight_bits == 8 else 2 + + weight = RTNParameter(data=torch.empty(output_size_per_partition // + factor, + input_size_per_partition, + dtype=torch.int8), + scale=scale, + quant_config=self.quant_config) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) + + layer.register_parameter("scale", scale) + layer.output_size_per_partition = output_size_per_partition + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """torch.compile does not know how to deal with a Parameter subclass + (aka RTNParameter). As we don't really need RTNParameters for the + forward pass, we replace them with equivalent instances of Parameters. + """ + old_weight = layer.weight + assert isinstance(old_weight, RTNParameter) + data = old_weight.data.data + + delattr(layer, "weight") + + new_weight = Parameter(data=data, requires_grad=False) + layer.register_parameter("weight", new_weight) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.weight + scale = layer.scale + + weight = rtn_dequantize(qweight, scale) + out = F.linear(x, weight) + del weight + if bias is not None: + out.add_(bias) + + return out + + +def rtn_quantize(tensor: torch.Tensor, num_bits: int, + group_size: int) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor using per-group static scaling factor. + + Args: + tensor: The input tensor. + num_bits: Target precision for the result (supported values are + 8 or 4). + group_size: Quantization granularity. + If equal to -1, each row in the input tensor is treated + as one group. + """ + + q_range = 2**num_bits + num_groups = (tensor.shape[0] * tensor.shape[1] // + group_size if group_size != -1 else tensor.shape[0]) + """Calculate a scaling factor per input group. + """ + input_flat = tensor.reshape(num_groups, -1) + input_min = torch.min(input_flat, dim=1, keepdim=True)[0] + input_max = torch.max(input_flat, dim=1, keepdim=True)[0] + input_max_abs = torch.max(input_min.abs(), input_max.abs()) + scale = (input_max_abs * 2.0 / (q_range - 1)) + """Scale each input group, truncate and round to the nearest integer. + """ + scaled_input = input_flat / scale + scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1) + scaled_input = scaled_input.round() + + scale = scale.reshape(tensor.shape[0], -1).contiguous() + inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8) + inputs_q = inputs_q.contiguous() + + if num_bits == 4: + """Pack two 4-bit values into each byte. + """ + inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf) + inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1]) + inputs_q = inputs_q.contiguous() + + return inputs_q, scale + + +def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Dequantize a tensor using per-group static scaling factors. + + Args: + tensor: The input tensor. + scale: The tensor with per-group scale factors. + """ + + num_groups = scale.size(0) * scale.size(1) + input_dim, output_dim = tensor.shape + + num_bits = 8 if input_dim == scale.size(0) else 4 + if num_bits == 4: + input_dim *= 2 + + data = torch.empty((input_dim, output_dim), + dtype=scale.dtype, + device=tensor.device) + + if num_bits == 8: + data.copy_(tensor) + else: + """Unpack two 4-bit values from each byte. + """ + tensor = tensor.reshape(input_dim, output_dim // 2) + for i in range(2): + data[:, i::2] = (tensor << 4 * (1 - i)) >> 4 + """Scale each input group with its scaling factor. + """ + scale = scale.reshape(num_groups, -1) + data = data.reshape(num_groups, -1) + data = torch.mul(data, scale) + + input_deq = data.reshape((input_dim, output_dim)).contiguous() + return input_deq From be250bbc67973766e546e0e3d8abb21e5caa2b1f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 1 Jul 2025 15:02:09 +0900 Subject: [PATCH 082/175] [V1] Only print cudagraph tqdm on rank 0 with `is_global_first_rank` (#19516) Signed-off-by: mgoin --- vllm/distributed/parallel_state.py | 31 ++++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 11 +++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 50dbbf50e9fcf..c53601a22f215 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1315,6 +1315,37 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], return [x == 1 for x in aggregated_data.tolist()] +def is_global_first_rank() -> bool: + """ + Check if the current process is the first rank globally across all + parallelism strategies (PP, TP, DP, EP, etc.). + + Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0` + or `get_pp_group().is_first_rank`, this function checks the global rank + across all parallelism dimensions. + + Returns: + bool: True if this is the global first rank (rank 0), False otherwise. + Returns True if distributed is not initialized (single process). + """ + try: + # If world group is available, use it for the most accurate check + global _WORLD + if _WORLD is not None: + return _WORLD.is_first_rank + + # If torch distributed is not initialized, assume single process + if not torch.distributed.is_initialized(): + return True + + # Fallback to torch's global rank + return torch.distributed.get_rank() == 0 + + except Exception: + # If anything goes wrong, assume this is the first rank + return True + + def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int: """ Returns the total number of nodes in the process group. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29d39de212f88..5bdaf4b969e70 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -26,7 +26,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, + get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) @@ -2285,9 +2285,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): full_cg = self.full_cuda_graph - for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), - desc="Capturing CUDA graphs", - total=len(self.cudagraph_batch_sizes)): + # Only rank 0 should print progress bar during capture + compilation_cases = reversed(self.cudagraph_batch_sizes) + if is_global_first_rank(): + compilation_cases = tqdm(list(compilation_cases), + desc="Capturing CUDA graph shapes") + for num_tokens in compilation_cases: # We skip EPLB here since we don't want to record dummy metrics for _ in range( self.compilation_config.cudagraph_num_of_warmups): From 86debab54c046232014b108d530a8c25d857e9a3 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 1 Jul 2025 00:48:10 -0600 Subject: [PATCH 083/175] Fix `numel()` downcast in vllm/csrc/moe/moe_align_sum_kernels.cu +2 (#17082) Co-authored-by: mgoin --- csrc/moe/moe_align_sum_kernels.cu | 2 +- csrc/moe/topk_softmax_kernels.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 9335e2333b0d9..462dbd1f8b380 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -239,7 +239,7 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); - const int num_tokens = output.numel() / hidden_size; + const auto num_tokens = output.numel() / hidden_size; const int topk = input.size(1); dim3 grid(num_tokens); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index dea5b1f21ec27..064b76c9cd427 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -492,7 +492,7 @@ void topk_softmax( torch::Tensor& gating_output) // [num_tokens, num_experts] { const int num_experts = gating_output.size(-1); - const int num_tokens = gating_output.numel() / num_experts; + const auto num_tokens = gating_output.numel() / num_experts; const int topk = topk_weights.size(-1); const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); From 22e9d42040f3ecf83da181cfd84ab4cea000c4af Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 1 Jul 2025 00:02:20 -0700 Subject: [PATCH 084/175] [Misc] add xgrammar for arm64 (#18359) Signed-off-by: Prashant Gupta --- requirements/common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/common.txt b/requirements/common.txt index 6cc304e5b1f6d..97a35e05d38ab 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -23,7 +23,7 @@ lm-format-enforcer >= 0.10.11, < 0.11 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs From 9909726d2a30d834d97efd7bf1c4fc0e52fa48b5 Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Tue, 1 Jul 2025 00:12:20 -0700 Subject: [PATCH 085/175] Enable ZP Support for Machete (#20268) Signed-off-by: czhu-cohere --- benchmarks/kernels/benchmark_machete.py | 2 ++ tests/kernels/quantization/test_machete_mm.py | 2 +- .../kernels/mixed_precision/machete.py | 20 +++++++++++++++---- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 0f896f187ecb9..f73d0511e01fc 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: fn = lambda: ops.gptq_marlin_gemm( a=bt.a, + c=None, b_q_weight=w_q, b_scales=w_s, + global_scale=None, b_zeros=w_zp, g_idx=g_idx, perm=sort_indices, diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index 998171baaf2de..a4fb9874c4906 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: - return group_size is None or group_size == -1 or group_size % shape[2] == 0 + return group_size is None or group_size == -1 or shape[2] % group_size == 0 def machete_quantize_and_pack(atype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index c7c45861875af..a75f3ac8d5033 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel): return False, "Act reordering currently not supported by Machete, "\ "when the input features are partitioned across "\ "devices" - if c.zero_points: - return False, "Zero points currently not supported by Machete" if c.weight_type not in query_machete_supported_quant_types( c.zero_points): @@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel): # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} + # `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config @@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel): x.data = x.data.contiguous() return x + def transform_w_zp(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1) + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=1) + w_s = getattr(layer, self.w_s_name).data + # pre-apply scales to zero-points + x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous() + return x + # Repack weights and scales for Machete self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) + if c.zero_points: + self._transform_param(layer, self.w_zp_name, transform_w_zp) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config - w_q, w_s, _, _ = self._get_weight_params(layer) + w_q, w_s, w_zp, _ = self._get_weight_params(layer) x_2d = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) @@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel): output = ops.machete_mm(a=x_2d, b_q=w_q, b_type=c.weight_type, - b_group_zeros=None, + b_group_zeros=w_zp, b_group_scales=w_s, b_group_size=c.group_size) From 6cc1e7d96dab6b9c344ec87dec6dc9ab07ad5d21 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Tue, 1 Jul 2025 15:25:03 +0800 Subject: [PATCH 086/175] [CPU] Update custom ops for the CPU backend (#20255) Signed-off-by: jiang1.li --- .../scripts/hardware_ci/run-cpu-test.sh | 3 +- cmake/cpu_extension.cmake | 20 + csrc/cpu/sgl-kernels/common.h | 238 +++ csrc/cpu/sgl-kernels/gemm.cpp | 464 ++++++ csrc/cpu/sgl-kernels/gemm.h | 266 ++++ csrc/cpu/sgl-kernels/gemm_fp8.cpp | 530 +++++++ csrc/cpu/sgl-kernels/gemm_int8.cpp | 440 ++++++ csrc/cpu/sgl-kernels/moe.cpp | 1330 +++++++++++++++++ csrc/cpu/sgl-kernels/moe_fp8.cpp | 502 +++++++ csrc/cpu/sgl-kernels/moe_int8.cpp | 769 ++++++++++ csrc/cpu/sgl-kernels/vec.h | 308 ++++ csrc/cpu/shm.cpp | 178 +-- csrc/cpu/torch_bindings.cpp | 43 + docs/getting_started/installation/cpu.md | 1 + .../models/language/generation/test_common.py | 3 +- vllm/_custom_ops.py | 49 + vllm/envs.py | 5 + .../layers/fused_moe/cpu_fused_moe.py | 214 +++ vllm/model_executor/layers/fused_moe/layer.py | 41 +- vllm/model_executor/layers/linear.py | 25 +- vllm/model_executor/layers/utils.py | 25 +- .../layers/vocab_parallel_embedding.py | 2 +- vllm/platforms/cpu.py | 2 + 23 files changed, 5357 insertions(+), 101 deletions(-) create mode 100644 csrc/cpu/sgl-kernels/common.h create mode 100644 csrc/cpu/sgl-kernels/gemm.cpp create mode 100644 csrc/cpu/sgl-kernels/gemm.h create mode 100644 csrc/cpu/sgl-kernels/gemm_fp8.cpp create mode 100644 csrc/cpu/sgl-kernels/gemm_int8.cpp create mode 100644 csrc/cpu/sgl-kernels/moe.cpp create mode 100644 csrc/cpu/sgl-kernels/moe_fp8.cpp create mode 100644 csrc/cpu/sgl-kernels/moe_int8.cpp create mode 100644 csrc/cpu/sgl-kernels/vec.h create mode 100644 vllm/model_executor/layers/fused_moe/cpu_fused_moe.py diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 8db8c3a05fb30..42506730e868c 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -51,6 +51,7 @@ function cpu_tests() { pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model pytest -v -s tests/models/language/generation -m cpu_model + VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model pytest -v -s tests/models/language/pooling -m cpu_model pytest -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ @@ -98,4 +99,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 5cd2c98f23438..264c970ef784a 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -96,12 +96,21 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + set(ENABLE_AVX512BF16 ON) else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") endif() else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() + + find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND) + if (AVX512VNNI_FOUND) + list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni") + set(ENABLE_AVX512VNNI ON) + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -231,6 +240,17 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) + if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) + set(VLLM_EXT_SRC + "csrc/cpu/sgl-kernels/gemm.cpp" + "csrc/cpu/sgl-kernels/gemm_int8.cpp" + "csrc/cpu/sgl-kernels/gemm_fp8.cpp" + "csrc/cpu/sgl-kernels/moe.cpp" + "csrc/cpu/sgl-kernels/moe_int8.cpp" + "csrc/cpu/sgl-kernels/moe_fp8.cpp" + ${VLLM_EXT_SRC}) + add_compile_definitions(-DCPU_CAPABILITY_AVX512) + endif() elseif(POWER10_FOUND) set(VLLM_EXT_SRC "csrc/cpu/quant.cpp" diff --git a/csrc/cpu/sgl-kernels/common.h b/csrc/cpu/sgl-kernels/common.h new file mode 100644 index 0000000000000..20261c1ef3e87 --- /dev/null +++ b/csrc/cpu/sgl-kernels/common.h @@ -0,0 +1,238 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +#include +#include +#include + +// clang-format off + +#if defined(_OPENMP) +#include +#endif + +namespace { + +// dispatch bool +#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// dispatch: bfloat16, float16, int8_t, fp8_e4m3 +#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case at::ScalarType::BFloat16 : { \ + using packed_t = at::BFloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using packed_t = at::Half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Char : { \ + using packed_t = int8_t; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e4m3fn : { \ + using packed_t = at::Float8_e4m3fn; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported floating data type.\n"); \ + } \ + }() + +#define UNUSED(x) (void)(x) + +#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") + +#define CHECK_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +// parallel routines +constexpr int GRAIN_SIZE = 1024; + +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int n, const func_t& f) { +#if defined(_OPENMP) +#pragma omp parallel +{ + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); +#endif +} + +// for 1d parallel, use `actual_nth` +// for 2d parallel, use even nths, e.g. 43->42 +int inline adjust_num_threads(int m) { + int actual_nth = at::get_num_threads(); + if (m == 1) { + return actual_nth; + } + return std::max(1, (actual_nth >> 1) * 2); +} + +template +inline void parallel_2d(int m, int n, const func_t& f) { + + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) +{ + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); +} +#else + f(0, m, 0, n); +#endif +} + +template +int get_cache_blocks(int BLOCK_SIZE, int K) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T)))); +} + +// data indexing for dimension collapse +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// forced unroll for perf critical path + +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +} // anonymous namespace diff --git a/csrc/cpu/sgl-kernels/gemm.cpp b/csrc/cpu/sgl-kernels/gemm.cpp new file mode 100644 index 0000000000000..c122d07185ddb --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.cpp @@ -0,0 +1,464 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// packed layout: +// quants {N, K} int8_t +// comp {N} int32_t +template +inline void s8s8_compensation(int8_t* __restrict__ packed, int K) { +#if defined(CPU_CAPABILITY_AVX512) + constexpr int COLS = BLOCK_N / 16; + __m512i vcomp[COLS]; + + for (int col = 0; col < COLS; ++col) { + vcomp[col] = _mm512_setzero_si512(); + } + + const int64_t offset = BLOCK_N * K; + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < K / 4; ++k) { + for (int col = 0; col < COLS; ++col) { + __m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64)); + vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb); + } + } + + for (int col = 0; col < COLS; ++col) { + _mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]); + } +#else + TORCH_CHECK(false, "s8s8_compensation not implemented!"); +#endif +} + +// convert to vnni format +// from [N, K] to [K/2, N, 2] for bfloat16 and float16 +template +inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) { + const int VNNI_BLK = 2; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } +} + +template <> +inline void pack_vnni(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) { + constexpr int BLOCK_N = block_size_n(); + TORCH_CHECK(N == BLOCK_N); + + const int VNNI_BLK = 4; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } + s8s8_compensation(packed, K); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_set1_ps(0.f); + } + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + // for COLS = 1, 3 use 256bit store + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + } else { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + row * ldc + col * 16), + (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, const float* __restrict__ bias, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int BLOCK_N = block_size_n(); + at::native::cpublas::brgemm( + M, N, K, lda, ldb, BLOCK_N, /* add_C */false, + A, B, Ctmp); + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + if (brg) { + brgemm::apply( + A, B, C, Ctmp, bias, + M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void weight_packed_linear_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const scalar_t* __restrict__ mat2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx + const bool use_brgemm = (M > 4) || (!std::is_same_v); + + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N, K); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */, + /* C */ out + mb_start * out_strideM + nb_start, + /* Ctmp*/ Ctmp, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm); + }}} + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \ + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \ + int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor convert_weight_packed(at::Tensor& weight) { + // for 3d moe weights + // weight : [E, OC, IC] + // w1 : [E, 2N, K] + // w2 : [E, K, N] + CHECK_INPUT(weight); + + const int64_t ndim = weight.ndimension(); + TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor."); + const auto st = weight.scalar_type(); + const int64_t E = ndim == 3 ? weight.size(0) : 1; + const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); + const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); + + // we handle 2 TILE_N at a time. + TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); + TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); + + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t NB = div_up(OC, BLOCK_N); + + // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2] + auto packed_weight = at::empty({}, weight.options()); + const int64_t stride = OC * IC; + + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, + "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); + + CPU_DISPATCH_PACKED_TYPES(st, [&] { + // adjust most inner dimension size + const int packed_row_size = get_row_size(IC); + auto sizes = weight.sizes().vec(); + sizes[ndim - 1] = packed_row_size; + packed_weight.resize_(sizes); + + const packed_t* w_data = weight.data_ptr(); + packed_t* packed_data = packed_weight.data_ptr(); + + // parallel on {E, NB} + at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) { + int64_t e{0}, nb{0}; + data_index_init(begin, e, E, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + + int64_t n = nb * BLOCK_N; + int64_t n_size = std::min(BLOCK_N, OC - n); + pack_vnni( + packed_data + e * OC * packed_row_size + n * packed_row_size, + w_data + e * stride + n * IC, + n_size, + IC); + + // move to the next index + data_index_step(e, E, nb, NB); + } + }); + }); + return packed_weight; +} + +// mat1 : [M, K] +// mat2 : [N, K] +// bias : [N] +// out : [M, N] +// +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, bool is_vnni) { + RECORD_FUNCTION( + "sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + auto out = at::empty({M, N}, mat1.options()); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] { + weight_packed_linear_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + bias_data, + M, + N, + K, + mat1_strideM, + out_strideM); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm.h b/csrc/cpu/sgl-kernels/gemm.h new file mode 100644 index 0000000000000..afae19721ae96 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.h @@ -0,0 +1,266 @@ +#pragma once + +#include + +// clang-format off + +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 + +// block size for AMX gemm +constexpr int block_size_m() { return 2 * TILE_M; } +constexpr int block_size_n() { return 2 * TILE_N; } + +// define threshold using brgemm (intel AMX) +template inline bool can_use_brgemm(int M); +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return true; } +// TODO: add u8s8 brgemm, this requires PyTorch 2.7 +template <> inline bool can_use_brgemm(int M) { return false; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } + +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + +// adjust leading dimension size for K +template +inline int64_t get_row_size(int64_t K) { + return K; +} + +template <> +inline int64_t get_row_size(int64_t K) { + return K + sizeof(int32_t); +} + +inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { + return use_int8_w8a8 ? K + sizeof(int32_t) : K; +} + +// pack weight to vnni format +at::Tensor convert_weight_packed(at::Tensor& weight); + +// moe implementations for int8 w8a8 +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for fp8 w8a16 +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for int4 w4a16 +template +void fused_experts_int4_w4a16_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::quint4x2* __restrict__ packed_w1, + const at::quint4x2* __restrict__ packed_w2, + const uint8_t* __restrict__ w1z, + const uint8_t* __restrict__ w2z, + const scalar_t* __restrict__ w1s, + const scalar_t* __restrict__ w2s, + int group_size, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// shared expert implememntation for int8 w8a8 +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::quint4x2* __restrict__ B, + scalar_t* __restrict__ C, + const uint8_t* __restrict__ Bz, + const scalar_t* __restrict__ Bs, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int group_size, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t strideBz, + int64_t strideBs, + bool brg); + +// TODO: debug print, remove me later +inline void print_16x32i(const __m512i x) { + int32_t a[16]; + _mm512_storeu_si512((__m512i *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + +inline void print_16x32(const __m512 x) { + float a[16]; + _mm512_storeu_ps((__m512 *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + + +inline void print_32x8u(const __m256i x) { + uint8_t a[32]; + _mm256_storeu_si256((__m256i *)a, x); + + for (int i = 0; i < 32; ++i) { + std::cout << int32_t(a[i]) << " "; + } + std::cout << std::endl; +} diff --git a/csrc/cpu/sgl-kernels/gemm_fp8.cpp b/csrc/cpu/sgl-kernels/gemm_fp8.cpp new file mode 100644 index 0000000000000..b5f2f07bad623 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_fp8.cpp @@ -0,0 +1,530 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +// we use 4x32 for BLOCK_M +#define BLOCK_SIZE_M_SCALE 4 + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +inline void unpack_B( + at::BFloat16* __restrict__ Btmp, + const at::Float8_e4m3fn* __restrict__ packed_B, + int N, + int K, + int ldb, + int ldb_tmp, + float scale) { +#if defined(CPU_CAPABILITY_AVX512) + // [K/2, N, 2] + const int K2 = K >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const uint16_t* b_ptr = reinterpret_cast(packed_B); + const __m512 vd = _mm512_set1_ps(scale); + + constexpr int BLOCK_N = block_size_n(); + static_assert(BLOCK_N == 32); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + +#pragma GCC unroll 4 + for (int k = 0; k < K2; ++k) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); + } + + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); + + bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); + } +#else + TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); +#endif +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + const int KB = div_up(K, BLOCK_K); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + constexpr int PREFETCH_SIZE_KB = 1; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vsum[ROWS * COLS]; + + // block quant scale + __m512 vscale; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + Unroll{}(loadc); + + const int lda2 = lda >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const uint16_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + } + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); + vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); + } + } + vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); + }; + + constexpr int BLOCK_K2 = BLOCK_K >> 1; + for (int kb = 0; kb < KB; ++kb) { + int kb_start = kb * BLOCK_K2; + int kb_end = std::min(K, kb_start + BLOCK_K2); + // 1. load scale vector + vscale = _mm512_set1_ps(scale[kb]); + if constexpr (PREFETCH_SIZE_KB > 0) { + _mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0); + } + // 2. zero vsum for each block + Unroll{}([&](auto i) { + vsum[i] = _mm512_setzero_ps(); + }); + // 3. accumulate across each block + for (int k = kb_start; k < kb_end; ++k) { + Unroll{}(compute, k); + } + // 4. apply scale + Unroll{}([&](auto i) { + vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); + }); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, + const packed_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); + } +}; + +template +struct brgemm { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* __restrict__ C, + at::BFloat16* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + + constexpr int BLOCK_N = block_size_n(); + + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; + + for (int k = 0; k < K; k += BLOCK_K) { + int kb_size = std::min(BLOCK_K, K - k); + + int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 + unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); + } + + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + + if (brg) { + brgemm::apply( + A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fp8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const at::Float8_e4m3fn* __restrict__ mat2, + const float* __restrict__ scales2, + const float* __restrict__ bias, + scalar_t* __restrict__ buffer, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM, + int64_t block_size_N, + int64_t block_size_K, + int64_t buffer_size_per_thread) { + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + int tid = at::get_thread_num(); + scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; + float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K + /* C */ out + mb_start * out_strideM + nb_start, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ scale_ptr, + /* bias */ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + tinygemm_kernel(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, \ + const at::Float8_e4m3fn* __restrict__ B, \ + TYPE* __restrict__ C, \ + TYPE* __restrict__ Btmp, \ + float* __restrict__ Ctmp, \ + const float* __restrict__ scale, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg, \ + int64_t block_size_K) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + std::vector block_size, std::optional& bias, + at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales2 to be float32."); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + TORCH_CHECK(block_size.size() == 2, + "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); + + int64_t block_size_N = block_size[0]; + int64_t block_size_K = block_size[1]; + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); + TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); + CHECK_EQ(scales2.size(0), div_up(N, block_size_N)); + CHECK_EQ(scales2.size(1), div_up(K, block_size_K)); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "fp8_scaled_mm_cpu: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, + "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales to be float32."); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + // Btmp : [T, BLOCK_N * K] + // Ctmp : [T, BLOCK_M * BLOCK_N] + int num_threads = at::get_num_threads(); + int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { + fp8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales2.data_ptr(), + bias_data, + buffer.data_ptr(), + M, + N, + K, + mat1_strideM, + out_strideM, + block_size_N, + block_size_K, + size_per_thread); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm_int8.cpp b/csrc/cpu/sgl-kernels/gemm_int8.cpp new file mode 100644 index 0000000000000..5a0f65a9200d4 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_int8.cpp @@ -0,0 +1,440 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 vd0; + __m512 vd1[COLS]; + + // oops! 4x4 spills but luckly we use 4x2 + __m512 vbias[COLS]; + + // [NOTE]: s8s8 igemm compensation in avx512-vnni + // + // avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate: + // + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // 1) 128 * b is pre-computed when packing B to vnni formats + // 2) a + 128 is fused when dynamically quantize A + // + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + vd0 = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + if constexpr (has_bias) { + vbias[col + 0] = _mm512_loadu_ps(bias + col * 16); + vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16); + } + } + } + + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + __m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0])); + __m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1])); + if constexpr (has_bias) { + vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]); + vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]); + } else { + vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]); + vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]); + } + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void int8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const uint8_t* __restrict__ mat1, + const int8_t* __restrict__ mat2, + const float* __restrict__ scales1, + const float* __restrict__ scales2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // TODO: brgemm u8s8 depends on PyTorch 2.7 release. + const bool use_brgemm = false; + + // K + 4 after compensation + const int64_t packed_row_size = get_row_size(K); + + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use int32_t for accumulate + alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; + + for (int i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * K, + /* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, + /* C */ out + mb_start * N + nb_start, + /* Ctmp*/ Ctmp, + /* As */ scales1 + mb_start, + /* Bs */ scales2 + nb_start, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ N, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \ + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \ + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +std::tuple per_token_quant_int8_cpu(at::Tensor& A) { + RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector({A})); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(A); + CHECK_DIM(2, A); + + int64_t M = A.size(0); + int64_t K = A.size(1); + int64_t lda = A.stride(0); + + const auto st = A.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "per_token_quant_int8: expect A to be bfloat16 or half."); + + auto Aq = at::empty({M, K}, A.options().dtype(at::kByte)); + auto As = at::empty({M}, A.options().dtype(at::kFloat)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] { + uint8_t* __restrict__ Aq_data = Aq.data_ptr(); + float* __restrict__ As_data = As.data_ptr(); + const scalar_t* __restrict__ A_data = A.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + }); + return std::make_tuple(Aq, As); +} + +// weight : static, per-channel, symmetric +// activation : dynamic, per-token, symmetric +// +// mat1 : [M, K] +// mat2 : [N, K] +// scales1 : [M] +// scales2 : [N] +// bias : [N] +// out : [M, N] +// +at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales1, at::Tensor& scales2, + std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales1, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales1); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales1.numel(), M); + CHECK_EQ(scales2.numel(), N); + + TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8."); + TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat, + "int8_scaled_mm: expect scales to be float32."); + + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] { + int8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales1.data_ptr(), + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} + +// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + int64_t lda = mat1.stride(0); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales2.numel(), N); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "int8_scaled_mm_with_quant: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, + "int8_scaled_mm_with_quant: expect mat2 to be int8."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "int8_scaled_mm_with_quant: expect scales to be float32."); + + const int64_t buffer_size = M * K + M * sizeof(float); + auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte)); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] { + uint8_t* __restrict__ Aq_data = buffer.data_ptr(); + float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K)); + const scalar_t* __restrict__ A_data = mat1.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + + int8_scaled_mm_kernel_impl( + out.data_ptr(), + Aq_data, + packed_w.data_ptr(), + As_data, + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} diff --git a/csrc/cpu/sgl-kernels/moe.cpp b/csrc/cpu/sgl-kernels/moe.cpp new file mode 100644 index 0000000000000..beeccff783ea0 --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe.cpp @@ -0,0 +1,1330 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// [NOTE]: Fused MoE kernel with AMX +// +// This file contains implementations for +// * `moe_align_block_size` +// * `fused_moe` +// +// The functionality is identical to triton kernel, excepts: +// * fuse silu_and_mul with gemm1, therefore this kernel +// allocates 2 intermediate_caches instead of 3 +// * add `offsets` in `moe_align_block_size` which keeps track +// of starting offset for each M block. this is for keeping +// output of silu_and_mul in sorted order, thus load_A for +// the 2nd gemm would be contiguous, therefore we can directly +// load A from intermediate_cache1. +// +// TODO: +// 1. tune BLOCK_M and BLOCK_N (BLOCK_N * K fit L2) +// 2. add prefetch for load A which is indexed access +// 3. abstract at::native::cpublas::brgemm with WoQ gemm (M = 1 & M != 1) +// + +template +inline void fill_stub(scalar_t* __restrict__ out, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + const Vec data_vec(val); + at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +int moe_align_block_size( + int32_t* __restrict__ sorted_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ topk_ids, + int32_t* __restrict__ total_cnts, + int32_t* __restrict__ cumsums, + int32_t* __restrict__ offsets, + int num_experts, + int numel, + int num_threads) { + + #define T_INDEX(tt) total_cnts + (tt) * num_experts + + // accumulate count of expert ids locally + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + int32_t* __restrict__ local_cnts = T_INDEX(tid + 1); + + for (int i = begin; i < end; ++i) { + local_cnts[topk_ids[i]]++; + } + }); + + using iVec = at::vec::Vectorized; + for (int t = 0; t < num_threads; ++t) { + at::vec::map2( + [](iVec x, iVec y) { return x + y; }, + T_INDEX(t + 1), T_INDEX(t + 1), T_INDEX(t), num_experts); + } + + // the last row holds sums of each experts + int32_t* total_cnts_t_1 = T_INDEX(num_threads); + + cumsums[0] = 0; + for (int e = 0; e < num_experts; ++e) { + // accumulate `num_tokens_post_pad`, also as the expert offset + cumsums[e + 1] = cumsums[e] + div_up(total_cnts_t_1[e], BLOCK_M) * BLOCK_M; + + for (int k = cumsums[e]; k < cumsums[e + 1]; k += BLOCK_M) { + expert_ids[k / BLOCK_M] = e; + } + } + int num_tokens_post_pad = cumsums[num_experts]; + + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + // thread tid offsets in `total_cnts` + int32_t* __restrict__ offsets = T_INDEX(tid); + + for (int i = begin; i < end; ++i) { + int32_t expert_id = topk_ids[i]; + int32_t b_offset = cumsums[expert_id]; + int32_t t_offset = offsets[expert_id]; + sorted_ids[b_offset + t_offset] = i; + offsets[expert_id]++; + } + }); + + // debug: the offset for thread t_1 should be identical to t_2 + int32_t* total_cnts_t_2 = T_INDEX(num_threads - 1); + for (int e = 0; e < num_experts; ++e) { + TORCH_CHECK(total_cnts_t_1[e] == total_cnts_t_2[e]); + } + + // padding value for sorted_ids: numel + auto sorted_id_size = [=](const int32_t* sorted_ids_ptr) { + for (int d = 0; d < BLOCK_M; ++d) { + if (sorted_ids_ptr[d] == numel) { return d; } + } + return BLOCK_M; + }; + + // offsets holds starting offset for each valida M blocks + // shape : [num_token_blocks + 1] + offsets[0] = 0; + const int num_token_blocks = num_tokens_post_pad / BLOCK_M; + at::parallel_for(0, num_token_blocks, GRAIN_SIZE / BLOCK_M, [&](int begin, int end) { + for (int mb = begin; mb < end; ++mb) { + offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); + } + }); + // TODO: do we need to vecterize this ? + for (int mb = 0; mb < num_token_blocks; ++mb) { + offsets[mb + 1] += offsets[mb]; + } + // debug: the last value of offsets should be `numel` + TORCH_CHECK(offsets[num_token_blocks] == numel); + + return num_tokens_post_pad; +} + +// silu : shape leading dimension +// input0 [m_size, BLOCK_N] BLOCK_N +// input1 [m_size, BLOCK_N] BLOCK_N +// output [M * topk, N] N +template +inline void silu_and_mul( + scalar_t* __restrict__ output, + const float* __restrict__ input0, // x: x0, x1 + const float* __restrict__ input1, // y: y0, y1 + int64_t m_size, + int64_t N) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + // no remainder + for (int64_t m = 0; m < m_size; ++m) { + scalar_t* __restrict__ out = output + m * N; + const float* __restrict__ x = input0 + m * BLOCK_N; + const float* __restrict__ y = input1 + m * BLOCK_N; + + for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) { + fVec x0 = fVec::loadu(x + d); + fVec x1 = fVec::loadu(x + d + fVec::size()); + fVec y0 = fVec::loadu(y + d); + fVec y1 = fVec::loadu(y + d + fVec::size()); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + // convert + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + } +} + +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B0, const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B0, const at::BFloat16* __restrict__ B1, + at::BFloat16* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb0[COLS]; + __m512bh vb1[COLS]; + __m512 vc0[ROWS * COLS]; + __m512 vc1[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_ps(0.f); + vc1[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b0_ptr = reinterpret_cast(B0); + const float* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb0[col] = (__m512bh)(_mm512_loadu_si512(b0_ptr + k * ldb2 + col * 16)); + vb1[col] = (__m512bh)(_mm512_loadu_si512(b1_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b0_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + _mm_prefetch(b1_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc0[i] = _mm512_dpbf16_ps(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbf16_ps(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = vc0[row * COLS + col + 0]; + Vec x1 = vc0[row * COLS + col + 1]; + Vec y0 = vc1[row * COLS + col + 0]; + Vec y1 = vc1[row * COLS + col + 1]; + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn2::apply( \ + A + mb_start * lda, B0 + nb_start * 2, B1 + nb_start * 2, \ + C + mb_start * ldc + nb_start, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B0, + const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), vc[i]); + + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-2-8 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN2(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN2(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN2(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fused_experts_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + const int64_t offset = offsets[mb]; + silu_and_mul( + ic1 + offset * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +template +void shared_expert_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + //int64_t mb_start = mb * BLOCK_M; + //int64_t mb_size = std::min(M - mb_start, BLOCK_M); + + // A shape [m_size, K] + const scalar_t* A = input + mb * BLOCK_M * K; + + // B shape [K, n_size] in vnni format + const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + silu_and_mul( + ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: output = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A shape [m_size, IC] + const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N; + + // B shape [IC, n_size] in vnni format + const scalar_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); +} + +} // anonymous namespace + +// common checks +static inline void check_moe_scales( + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale) { + if (use_int8_w8a8) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); + TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); + TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); + } + if (use_fp8_w8a16) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16."); + TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); + TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2."); + } +} + +#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ + auto w1s = w1_scale.value(); \ + auto w2s = w2_scale.value(); \ + auto block_size_val = block_size.value(); \ + int64_t block_size_N = block_size_val[0]; \ + int64_t block_size_K = block_size_val[1]; \ + TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \ + TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \ + TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \ + TORCH_CHECK(w2s.size(DIM1) == N / block_size_K) + +// hidden_states: [M, K] +// w1: [E, 2N, K] +// w2: [E, K, N] +// topk_weights: [M, topk] +// topk_ids: [M, topk] (int32_t) +// +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& topk_weights, + at::Tensor& topk_ids, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_EQ(topk_weights.sizes(), topk_ids.sizes()); + CHECK_DIM(2, hidden_states); + CHECK_DIM(3, w1); + CHECK_DIM(3, w2); + CHECK_DIM(2, topk_weights); + CHECK_DIM(2, topk_ids); + + CHECK_EQ(topk_ids.scalar_type(), at::kInt); + CHECK_EQ(topk_weights.scalar_type(), at::kFloat); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(1) / 2; + int64_t E = w1.size(0); + int64_t topk = topk_weights.size(1); + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), E); + CHECK_EQ(w2.size(1), K); + CHECK_EQ(packed_w1.size(2), packed_K); + CHECK_EQ(packed_w2.size(2), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // NB: worst case is each expert holds a block with remainder of 1 + // 1. sorted_ids : [M * topk + E * (BLOCK_M - 1)] + // 2. expert_ids : [max_num_blocks] + // 3. total_cnts : [T + 1, E] + // 4. cumsums : [E + 1] + // 5. offsets : [max_num_blocks + 1] + // + int num_threads = at::get_num_threads(); + int64_t max_num_tokens_padded = M * topk + E * (BLOCK_M - 1); + int64_t max_num_blocks = div_up(max_num_tokens_padded, BLOCK_M); + auto buffer = at::empty( + {max_num_tokens_padded + max_num_blocks + (num_threads + 1) * E + (E + 1) + (max_num_blocks + 1)}, + topk_ids.options()); + + int32_t* __restrict__ sorted_ids = buffer.data_ptr(); + int32_t* __restrict__ expert_ids = sorted_ids + max_num_tokens_padded; + int32_t* __restrict__ total_cnts = expert_ids + max_num_blocks; + int32_t* __restrict__ cumsums = total_cnts + (num_threads + 1) * E; + int32_t* __restrict__ offsets = cumsums + (E + 1); + + // init sorted_ids with `numel` as the padding number + // init expert_ids with `num_experts` + int64_t numel = M * topk; + at::parallel_for(0, max_num_blocks, GRAIN_SIZE / BLOCK_M, [&](int64_t begin, int64_t end) { + int64_t m_start = begin * BLOCK_M; + int64_t m_size = std::min((end - begin) * BLOCK_M, max_num_tokens_padded - m_start); + fill_stub(sorted_ids + m_start, (int32_t)numel, m_size); + fill_stub(expert_ids + begin, (int32_t)E, end - begin); + }); + // zero total_cnts and cumsums + at::parallel_for(0, (num_threads + 1) * E + (E + 1), GRAIN_SIZE, [&](int64_t begin, int64_t end) { + fill_stub(total_cnts + begin, 0, end - begin); + }); + + // align experts index + int64_t num_tokens_post_pad = moe_align_block_size( + sorted_ids, expert_ids, topk_ids.data_ptr(), total_cnts, cumsums, offsets, E, numel, num_threads); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M * topk, N] + // 2. intermediate_cache2 : [M * topk, K] + // 3. A_tmp : [T, BLOCK_M * K] + // 4. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 5. Aq_tmp : [M, K] or [M * topk, N] + // 6. As_tmp : [M * topk] + // + // for fp8 w8a16: + // 7. intermediate_cache0 : [M * topk, 2N] + // 8. B_tmp : [T, BLOCK_N, std::max(K, N)] + // + int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; + } + + auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "fused_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer2.data_ptr())); + scalar_t* __restrict__ intermediate_cache2 = intermediate_cache1 + M * topk * N; + + if (use_int8_w8a8) { + uint8_t* __restrict__ A_tmp = (uint8_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * topk * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == E * 2 * N); + TORCH_CHECK(w2s.numel() == E * K); + + fused_experts_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else if (use_fp8_w8a16) { + // here we just ignore C_tmp as it is not used + scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N)); + + CHECK_MOE_SCALES_FP8(1, 2); + fused_experts_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + intermediate_cache2, + A_tmp, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else { + scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K; + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + + fused_experts_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } + }); + return out_hidden_states; +} + +// shared expert kernel +// +// hidden_states: [M, K] +// w1: [2N, K] +// w2: [K, N] +// fused_experts_out +at::Tensor shared_expert_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& fused_experts_out, + double routed_scaling_factor, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional> block_size, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(fused_experts_out); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_DIM(2, hidden_states); + CHECK_DIM(2, w1); + CHECK_DIM(2, w2); + CHECK_EQ(hidden_states.sizes(), fused_experts_out.sizes()); + CHECK_EQ(hidden_states.scalar_type(), st); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(0) / 2; + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), K); + CHECK_EQ(packed_w1.size(1), packed_K); + CHECK_EQ(packed_w2.size(1), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M, N] + // 2. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 3. Aq_tmp : [M, K] or [M, N] + // 4. As_tmp : [M] + // + // for fp8 w8a16: + // 5. intermediate_cache0 : [M, 2N] + // 6. B_tmp: [T, BLOCK_M, max(K, N)] + // + int num_threads = at::get_num_threads(); + int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; + } + + auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr())); + float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N)); + + if (use_int8_w8a8) { + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == 2 * N); + TORCH_CHECK(w2s.numel() == K); + + shared_expert_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else if (use_fp8_w8a16) { + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N)); + + CHECK_MOE_SCALES_FP8(0, 1); + shared_expert_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else { + shared_expert_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } + }); + return out_hidden_states; +} diff --git a/csrc/cpu/sgl-kernels/moe_fp8.cpp b/csrc/cpu/sgl-kernels/moe_fp8.cpp new file mode 100644 index 0000000000000..84a6af267740a --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_fp8.cpp @@ -0,0 +1,502 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "gemm.h" +#include "vec.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + x0 = x0 * weight_vec; + x1 = x1 * weight_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + float scale, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x_bvec); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +inline void silu_and_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec one = fVec(1.f); + + // no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += bVec::size()) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + bVec y = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y); + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + x0 = x0 * y0; + x1 = x1 * y1; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } +} + +} // anonymous namespace + +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_N = div_up(2 * N, block_size_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n; + const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + scale_size_N = div_up(K, block_size_N); + scale_size_K = div_up(N, block_size_K); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \ + template void fused_experts_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, \ + TYPE* __restrict__ A_tmp, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const float* __restrict__ topk_weights, \ + const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, \ + const int32_t* __restrict__ offsets, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t E, \ + int64_t topk, \ + int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_FP8_TEMPLATE(at::Half); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + tinygemm_kernel( + /* A */ input + mb * BLOCK_M * K, + /* B */ packed_w1 + nb * BLOCK_N * K, + /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(K, BLOCK_N); + scale_size_K = div_up(N, block_size_K); + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ ic1 + mb * BLOCK_M * N, + /* B */ packed_w2 + nb * BLOCK_N * N, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } +} + +#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \ + template void shared_expert_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const TYPE* __restrict__ fused_experts_out, \ + float routed_scaling_factor, \ + int64_t M, \ + int64_t N, \ + int64_t K) + +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/moe_int8.cpp b/csrc/cpu/sgl-kernels/moe_int8.cpp new file mode 100644 index 0000000000000..89d0fb5d9f3b7 --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_int8.cpp @@ -0,0 +1,769 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template <> +inline void copy_stub(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { + // size might be 64x + 32 + std::memcpy(out, input, size * sizeof(uint8_t)); +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +/// gemm for w13 +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb0[COLS]; + __m512i vb1[COLS]; + __m512i vc0[ROWS * COLS]; + __m512i vc1[ROWS * COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 was; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_epi32(0); + vc1[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b0_ptr = reinterpret_cast(B0); + const int32_t* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); + vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); + } + vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto scalec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp + if constexpr (row == 0) { + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + } + __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); + __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); + vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col])); + vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col])); + }; + Unroll{}(scalec); + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); + Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); + Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); + Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \ + C + mb_start * ldc + nb_start, As + mb_start, \ + Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + scalar_t* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + const int32_t* Bcomp0 = reinterpret_cast(B0 + block_size_n() * K); + const int32_t* Bcomp1 = reinterpret_cast(B1 + block_size_n() * K); + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +/// gemm for w2 +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 was; + __m512 vbs[COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + } + } + __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); + x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni2::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +} // anonymous namespace + +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + + const int64_t stride_e = 2 * N * packed_K; + const int64_t stride_n = packed_K; + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + alignas(64) float As[BLOCK_M]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, Aq_tmp + index * K, K); + As[m] = As_tmp[index]; + } + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * packed_N; + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; + const float* __restrict__ As = As_tmp + offsets[mb]; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \ + template void fused_experts_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \ + int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); + +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + const int64_t stride_n = packed_K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + // A shape [m_size, K] + const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; + const float* As = As_tmp + mb * BLOCK_M; + + // B shape [K, n_size] in vnni format + const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A shape [m_size, IC] + const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; + const float* __restrict__ As = As_tmp + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); +} + +#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ + template void shared_expert_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \ + int64_t M, int64_t N, int64_t K) + +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/vec.h b/csrc/cpu/sgl-kernels/vec.h new file mode 100644 index 0000000000000..87955cfb2922c --- /dev/null +++ b/csrc/cpu/sgl-kernels/vec.h @@ -0,0 +1,308 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +// clang-format off + +#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) +#define CPU_CAPABILITY_AVX512 +#endif + +#include +#include + +namespace { + +using namespace at::vec; + +template , int> = 0> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return at::vec::convert_from_float(a, b); +} + +#if defined(CPU_CAPABILITY_AVX512) + +// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics +// use native instruction for bfloat16->float32 conversion +template <> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); +} + +#define CVT_BF16_TO_FP32(a) \ + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + +#define CVT_FP16_TO_FP32(a) \ + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + +// this doesn't hanel NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { + // The following conversion is without denorm behavior, that is to say, + // Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6) + // Min subnorm : S.0000.001 = 2**(−9) + // 0.0019 ~ 0.0137 cannot be converted correctly. + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + auto mask = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_setzero_si512()); // mask = x & 0x7f + auto mask_nan = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_set1_epi16(127)); // mask_nan = x & 0x7f + auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4 + auto exponent = _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), + _mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120) + auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7))); + nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan + return (__m512bh)(_mm512_or_si512( + nonsign, + _mm512_slli_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(128)), + 8))); // add sign (x & 128) << 8 +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + __m512i lg2mant = _mm512_mask_mov_epi16( + _mm512_mask_mov_epi16( + _mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)), + _mm512_test_epi16_mask(x, _mm512_set1_epi16(4)), + _mm512_set1_epi16(2)); + return (__m512bh)(_mm512_or_si512( + _mm512_maskz_mov_epi16( + _mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()), + _mm512_mask_blend_epi16( + _mm512_test_epi16_mask(x, _mm512_set1_epi16(120)), + _mm512_or_si512( + _mm512_and_si512( + _mm512_sllv_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)), + _mm512_set1_epi16(0x007f)), + _mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)), + _mm512_or_si512( + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4), + _mm512_slli_epi16( + _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)), + 7)))), + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8))); +} + +inline __m512bh CVT_FP8_TO_BF16(__m256i a) { +#ifdef SGLANG_CPU_FP8_CVT_FTZ + return cvt_e4m3_bf16_intrinsic_no_nan(a); +#else + return cvt_e4m3_bf16_intrinsic_with_denorm(a); +#endif +} + +#endif + +// vector to scalar reduction +#if defined(CPU_CAPABILITY_AVX512) && 0 +inline float vec_reduce_sum(const Vectorized& a) { + return _mm512_reduce_add_ps(__m512(a)); +} + +inline float vec_reduce_max(const Vectorized& a) { + return _mm512_reduce_max_ps(__m512(a)); +} +#else +inline float vec_reduce_sum(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, a); +} + +inline float vec_reduce_max(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return maximum(x, y); }, a); +} +#endif + +// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 +template +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { + + float amax = 0.f; // absolute max + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]); + amax = std::max(amax, std::abs(val)); + } + + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]) * inv_scale; + Aq[k] = (uint8_t)(std::round(val)) + 128; + } + As = scale; +} + +#if defined(CPU_CAPABILITY_AVX512) +template <> +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const at::BFloat16* __restrict__ A, int64_t K, float eps) { + + const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m512i off = _mm512_set1_epi32(128); + + // K is 32x, no remainder + float amax = 0.f; + __m512 vamax0 = _mm512_set1_ps(0.f); + __m512 vamax1 = _mm512_set1_ps(0.f); + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); + vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); + } + amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + const __m512 vd = _mm512_set1_ps(inv_scale); + + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + va0 = _mm512_mul_ps(va0, vd); + va1 = _mm512_mul_ps(va1, vd); + va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); + __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); + } + As = scale; +} +#endif + +// transpose utils +// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998 +#if defined(CPU_CAPABILITY_AVX512) +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" + +// transpose from [2, 32] to [32, 2] +inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) { + // r0: {a0, a1, ..., a31} + // r1: {b0, b1, ..., b31} + // + // d0: {a0, b0, ..., a15, b15} + // d1: {a16, b16, ..., a31, b31} + // + __m512i d0 = _mm512_unpacklo_epi16(r0, r1); + __m512i d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + return std::make_tuple(d0, d1); +} +#pragma GCC diagnostic pop + +#endif + +// TODO: debug print, remove me later +template +void print_array(scalar_t* ptr, int size) { + for (int d = 0; d < size; ++d) { + if (d % 16 == 0) { std::cout << std::endl; } + std::cout << ptr[d] << " "; + } + std::cout << std::endl; +} + +} // anonymous namespace diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp index f55e96de251d0..9adb6f27ec411 100644 --- a/csrc/cpu/shm.cpp +++ b/csrc/cpu/shm.cpp @@ -7,9 +7,10 @@ namespace { #define MAX_SHM_RANK_NUM 8 -#define MAX_THREAD_NUM 12 -#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) -#define MIN_THREAD_PROCESS_SIZE (8 * 1024) +#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024) +static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0); +#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1) +#define MIN_THREAD_PROCESS_SIZE (256) #define MAX_P2P_SEND_TENSOR_NUM 8 template @@ -32,10 +33,10 @@ struct KernelVecType { using scalar_vec_t = vec_op::FP16Vec16; }; -enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE }; - struct ThreadSHMContext { - volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM]; + volatile char _curr_thread_stamp; + volatile char _ready_thread_stamp; + char _padding1[6]; int thread_id; int thread_num; int rank; @@ -44,14 +45,19 @@ struct ThreadSHMContext { int swizzled_ranks[MAX_SHM_RANK_NUM]; void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; + size_t _thread_buffer_mask; + char _padding2[56]; ThreadSHMContext(const int thread_id, const int thread_num, const int rank, const int group_size, void* thread_shm_ptr) - : thread_id(thread_id), + : _curr_thread_stamp(1), + _ready_thread_stamp(0), + thread_id(thread_id), thread_num(thread_num), rank(rank), group_size(group_size), - _spinning_count(0) { + _spinning_count(0), + _thread_buffer_mask(0) { static_assert(sizeof(ThreadSHMContext) % 64 == 0); TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); TORCH_CHECK((size_t)this % 64 == 0); @@ -60,7 +66,6 @@ struct ThreadSHMContext { shm_contexts[i] = nullptr; thread_shm_ptrs[i] = nullptr; swizzled_ranks[i] = (i + rank) % group_size; - thread_stats[i] = ThreadSHMStat::DONE; } set_context(rank, this, thread_shm_ptr); } @@ -77,59 +82,66 @@ struct ThreadSHMContext { template T* get_thread_shm_ptr(int rank) { - return reinterpret_cast(thread_shm_ptrs[rank]); + return reinterpret_cast( + reinterpret_cast(thread_shm_ptrs[rank]) + + (PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask)); + } + + void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; } + + char get_curr_stamp() const { return _curr_thread_stamp; } + + char get_ready_stamp() const { return _ready_thread_stamp; } + + void next_stamp() { + _mm_mfence(); + _curr_thread_stamp += 1; + } + + void commit_ready_stamp() { + _mm_mfence(); + _ready_thread_stamp = _curr_thread_stamp; } int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } - void wait_for_all(ThreadSHMStat prev_stat) { - for (int idx = 0; idx < group_size; ++idx) { + template + void wait_for_all(Cond&& cond) { + for (int idx = 1; idx < group_size; ++idx) { int rank = get_swizzled_rank(idx); - while (thread_stats[rank] == prev_stat) { - ++_spinning_count; - _mm_pause(); - } + wait_for_one(rank, std::forward(cond)); } - vec_op::mem_barrier(); } - void wait_for_one(int rank, ThreadSHMStat prev_stat) { - while (thread_stats[rank] == prev_stat) { + template + void wait_for_one(int rank, Cond&& cond) { + ThreadSHMContext* rank_ctx = shm_contexts[rank]; + for (;;) { + char local_curr_stamp = get_curr_stamp(); + char local_ready_stamp = get_ready_stamp(); + char rank_curr_stamp = rank_ctx->get_curr_stamp(); + char rank_ready_stamp = rank_ctx->get_ready_stamp(); + if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp, + rank_ready_stamp)) { + break; + } ++_spinning_count; _mm_pause(); } - vec_op::mem_barrier(); } - void set_thread_stat(ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[this->rank] = stat; - } + static bool check_no_buffer_conflict(char local_curr_stamp, + char local_ready_stamp, + char rank_curr_stamp, + char rank_ready_stamp) { + char temp = rank_curr_stamp + 2; + return local_curr_stamp != temp; } - void set_thread_stat(int target_rank, ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[target_rank] = stat; - } - } - - // barrier for all ranks in the group, used for all2all ops - // DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ... - void barrier(ThreadSHMStat next_stat) { - if (next_stat == ThreadSHMStat::THREAD_READY) { - set_thread_stat(ThreadSHMStat::THREAD_READY); - wait_for_all(ThreadSHMStat::DONE); - } else if (next_stat == ThreadSHMStat::SHM_DATA_READY) { - set_thread_stat(ThreadSHMStat::SHM_DATA_READY); - wait_for_all(ThreadSHMStat::THREAD_READY); - } else if (next_stat == ThreadSHMStat::DONE) { - set_thread_stat(ThreadSHMStat::DONE); - wait_for_all(ThreadSHMStat::SHM_DATA_READY); - } else { - TORCH_CHECK(false, "Invalid next_stat to barrier."); - } + static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp, + char rank_curr_stamp, char rank_ready_stamp) { + char temp = local_curr_stamp + 1; + return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp); } std::string to_string() const { @@ -164,7 +176,7 @@ class SHMManager { const int group_size) : _rank(rank), _group_size(group_size), - _thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)), + _thread_num(torch::get_num_threads()), _shm_names({""}), _shared_mem_ptrs({nullptr}), _shm_ctx(nullptr) { @@ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { (total_units_num + thread_num - 1) / thread_num; int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); int64_t max_per_thread_iteration_elem_num = - PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t); + (PER_THREAD_SHM_BUFFER_BYTES >> 1) / + sizeof(scalar_t); // Note: double buffer int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; #pragma omp parallel for schedule(static, 1) @@ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { int64_t curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); ThreadSHMContext* thread_ctx = ctx + i; + bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num); while (curr_elem_num > 0) { - inner_func(thread_ctx, offset, curr_elem_num); + inner_func(thread_ctx, offset, curr_elem_num, fast_mode); + thread_ctx->next_stamp(); + thread_ctx->next_buffer(); offset += max_per_thread_iteration_elem_num; curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); } @@ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); @@ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, thread_ctx->get_swizzled_rank(idx + 1)); }); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, thread_data_elem_num); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); - + thread_ctx->commit_ready_stamp(); int64_t aligned_data_elem_num = (data_elem_num / vec_elem_num) * vec_elem_num; int64_t i = 0; + thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready); #pragma GCC unroll 4 for (; i < aligned_data_elem_num; i += vec_elem_num) { vec_t local_data(thread_data_ptr + i); // load from cache @@ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, reduced_data.save(thread_data_ptr + i, data_elem_num - aligned_data_elem_num); } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); - - shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset, - data_elem_num * sizeof(scalar_t)); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } + shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset, + data_elem_num * sizeof(scalar_t)); + thread_ctx->commit_ready_stamp(); if (rank == dst) { shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, data_elem_num * sizeof(scalar_t)); @@ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, scalar_t* src_ptr = thread_ctx->get_thread_shm_ptr(src_rank); // shm scalar_t* dst_ptr = outputs[src_rank] + data_offset; - shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr, - data_elem_num * sizeof(scalar_t)); + thread_ctx->wait_for_one(src_rank, + ThreadSHMContext::check_stamp_ready); + shm_cc_ops::memcpy(dst_ptr, src_ptr, + data_elem_num * sizeof(scalar_t)); } } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -599,7 +614,7 @@ struct TensorListMeta { int8_t _padding[40]; }; -void shm_send_tensor_list_impl(ThreadSHMContext* ctx, +void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst, const std::vector& tensor_list) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) std::vector tensor_list_with_metadata; @@ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata->total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; - // Wait until the receiver set the stat to DONE - thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY); - int64_t curr_shm_offset = 0; + thread_ctx->wait_for_one(dst, + ThreadSHMContext::check_no_buffer_conflict); while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); @@ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, frag.ptr, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY); + thread_ctx->commit_ready_stamp(); }); } @@ -646,8 +659,7 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, torch::Tensor metadata_tensor = torch::empty({sizeof(TensorListMeta)}, options); - // Wait until the sender set the stat of the thread 0 to SHM_DATA_READY - ctx->wait_for_one(src, ThreadSHMStat::DONE); + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); shm_cc_ops::memcpy(metadata_tensor.data_ptr(), ctx->get_thread_shm_ptr(src), sizeof(TensorListMeta)); @@ -664,9 +676,8 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata.total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { - // Wait until the sender set the stat to SHM_DATA_READY - thread_ctx->wait_for_one(src, ThreadSHMStat::DONE); + int64_t data_elem_num, bool fast_mode) { + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); int64_t curr_shm_offset = 0; while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); @@ -677,8 +688,6 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE); }); std::vector tensor_list; @@ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle, int64_t dst) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list) shm_send_tensor_list_impl( - SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list); + SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst, + tensor_list); CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) } @@ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) { TORCH_CHECK(shm_manager); shm_manager->join(name); return shm_manager->get_shm_ctx()->to_string(); -} \ No newline at end of file +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 60304d229a8f5..ebfc81f858367 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle, std::vector shm_recv_tensor_list(int64_t handle, int64_t src); +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, + bool is_vnni); + +at::Tensor convert_weight_packed(at::Tensor& weight); + +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2, + at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace, + bool use_int8_w8a8, bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); + +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, bool is_vnni); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -214,6 +235,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", &shm_recv_tensor_list); #endif + + // sgl-kernels +#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__) + ops.def( + "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? " + "bias, bool is_vnni) -> Tensor"); + ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); + ops.def("convert_weight_packed(Tensor! weight) -> Tensor"); + ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); + ops.def( + "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor " + "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool " + "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? " + "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> " + "Tensor"); + ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); + ops.def( + "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, " + "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); + ops.impl("int8_scaled_mm_with_quant", torch::kCPU, + &int8_scaled_mm_with_quant); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 370b854def0f3..5f2d0dbe27d34 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -118,6 +118,7 @@ vLLM CPU backend supports the following vLLM features: - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`. - `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). +- `VLLM_CPU_SGL_KERNEL` (Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). ## Performance tips diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index f656f90c4bd37..7d7a62eec118a 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -78,7 +78,7 @@ AITER_MODEL_LIST = [ ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), pytest.param( "Qwen/Qwen3-8B", # qwen (text-only) @@ -87,6 +87,7 @@ AITER_MODEL_LIST = [ pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( "TitanML/tiny-mixtral", # mixtral + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ) ]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 51900de1cc099..36a0395ccdc93 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1850,3 +1850,52 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale) return out + + +if hasattr(torch.ops._C, "weight_packed_linear"): + + @register_fake("_C::weight_packed_linear") + def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor, + bias: Optional[torch.Tensor], + is_vnni: bool) -> torch.Tensor: + return torch.empty((mat1.size(0), mat2.size(0)), + dtype=mat1.dtype, + device=mat2.device) + + +if hasattr(torch.ops._C, "fused_experts_cpu"): + + @register_fake("_C::fused_experts_cpu") + def fused_experts_cpu_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool, + use_int8_w8a8: bool, + use_fp8_w8a16: bool, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + block_size: Optional[list[int]], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + is_vnni: bool, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): + + @register_fake("_C::int8_scaled_mm_with_quant") + def int8_scaled_mm_with_quant_fake( + mat1: torch.Tensor, + mat2: torch.Tensor, + scales2: torch.Tensor, + bias: Optional[torch.Tensor], + out_dtype: torch.dtype, + is_vnni: bool, + ) -> torch.Tensor: + M = mat1.size(0) + N = mat2.size(0) + return torch.empty((M, N), dtype=out_dtype) diff --git a/vllm/envs.py b/vllm/envs.py index a3f19c7ee5c70..c73dbb0a8446f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0 VLLM_CPU_MOE_PREPACK: bool = True + VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 @@ -447,6 +448,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), + # (CPU backend only) whether to use SGL kernels, optimized for small batch. + "VLLM_CPU_SGL_KERNEL": + lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), + # If the env var is set, then all workers will execute as separate # processes from the engine, and we use the same mechanism to trigger # execution on all workers. diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py new file mode 100644 index 0000000000000..68ce6bcccb5d4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + +import torch + +from vllm import envs + + +class IPEXFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=envs.VLLM_CPU_MOE_PREPACK, + ) + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function, + scoring_func, + e_score_correction_bias, + ) + + +class SGLFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + pass + + @staticmethod + def _grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use + # biased scores for expert selection but original scores for + # routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, + k=topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, + -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, + keepdim=True) + + return topk_weights, topk_ids.to(torch.int32) + + @staticmethod + def _select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = SGLFusedMOE._grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + topk_ids = topk_ids.to(torch.int32) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = SGLFusedMOE._select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + torch.ops._C.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + True, + False, + False, + None, + None, + None, + None, + None, + True, + ) + return x diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e6f555d315d8e..d6ead084af99c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -550,12 +550,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - import intel_extension_for_pytorch as ipex - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=envs.VLLM_CPU_MOE_PREPACK, - ) + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + dtype = layer.w13_weight.dtype + if (envs.VLLM_CPU_SGL_KERNEL + and torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16): + packed_w13_weight = torch.ops._C.convert_weight_packed( + layer.w13_weight) + assert packed_w13_weight.size() == layer.w13_weight.size() + layer.w13_weight.copy_(packed_w13_weight) + del packed_w13_weight + packed_w2_weight = torch.ops._C.convert_weight_packed( + layer.w2_weight) + assert packed_w2_weight.size() == layer.w2_weight.size() + layer.w2_weight.copy_(packed_w2_weight) + layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: raise NotImplementedError("CPU MOE only supports x86 arch.") @@ -673,13 +684,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", apply_router_weight_on_input: bool = False, + activation: str = "silu", **kwargs, ): - assert activation == "silu", f"{activation} is not supported." - assert apply_router_weight_on_input is False - return layer.ipex_fusion( + return layer.cpu_fused_moe( + layer, x, use_grouped_topk, top_k, @@ -687,9 +697,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize, topk_group, num_expert_group, + global_num_experts, + expert_map, custom_routing_function, scoring_func, e_score_correction_bias, + apply_router_weight_on_input, + activation, ) def forward_hpu( @@ -764,7 +778,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, renormalize=renormalize) - forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda + if current_platform.is_tpu(): + forward_native = forward_tpu + elif current_platform.is_cpu(): + forward_native = forward_cpu + else: + forward_native = forward_cuda def determine_expert_map( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 588aa8deb1832..a05ae0edbd775 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter +from vllm import envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -27,6 +28,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, RowvLLMParameter) # yapf: enable from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -195,12 +197,33 @@ class UnquantizedLinearMethod(LinearMethodBase): layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if (torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16 and N % 32 == 0 + and K % 32 == 0): + packed_weight = torch.ops._C.convert_weight_packed( + layer.weight) + assert packed_weight.size() == layer.weight.size() + layer.weight.copy_(packed_weight) + if layer.bias is not None: + layer.bias = Parameter(layer.bias.to(torch.float32), + requires_grad=False) + layer.use_cpu_sgl = True + else: + logger.warning( + "CPU SGL kernels require Intel AMX support," + " bfloat16 weight, IC and OC are divisible by 32.") + layer.use_cpu_sgl = False + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return dispatch_unquantized_gemm()(x, layer.weight, bias) + return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 41b5253dca048..ad4ba9c0b827a 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -63,7 +63,15 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def rocm_unquantized_gemm(x: torch.Tensor, +def default_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + return torch.nn.functional.linear(x, weight, bias) + + +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): from vllm.platforms.rocm import on_gfx9 @@ -89,7 +97,20 @@ def rocm_unquantized_gemm(x: torch.Tensor, return torch.nn.functional.linear(x, weight, bias) +def cpu_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + if getattr(layer, "use_cpu_sgl", False): + return torch.ops._C.weight_packed_linear(x, weight, bias, True) + else: + return torch.nn.functional.linear(x, weight, bias) + + def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: if current_platform.is_rocm(): return rocm_unquantized_gemm - return torch.nn.functional.linear + elif current_platform.is_cpu(): + return cpu_unquantized_gemm + else: + return default_unquantized_gemm diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9ff3a7a7327d9..f35f969781bd1 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -43,7 +43,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return dispatch_unquantized_gemm()(x, layer.weight, bias) + return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 106bce162003f..dccd60f4463aa 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -194,6 +194,8 @@ class CpuPlatform(Platform): "epilogue_fusion": True, }) + if compilation_config.use_inductor: + compilation_config.custom_ops = ["none"] if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION From 08d81f1014d174d4dd96518914c7ed9767c67a3f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 1 Jul 2025 03:29:08 -0400 Subject: [PATCH 087/175] [Bugfix] Fix deepep tests (#20288) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_deepep_deepgemm_moe.py | 2 +- tests/kernels/moe/test_deepep_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 475427f439289..008406c3f1593 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -30,7 +30,7 @@ if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) - from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a + from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a if has_deep_gemm(): import deep_gemm diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 80a36dc39712a..94947c809e3a3 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -31,7 +31,7 @@ if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) - from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a + from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a requires_deep_ep = pytest.mark.skipif( not has_deep_ep(), From b1c1fe35a599cfd3c0404702c65c2381b025bc6a Mon Sep 17 00:00:00 2001 From: Kebe Date: Tue, 1 Jul 2025 15:33:22 +0800 Subject: [PATCH 088/175] [Misc] remove redundant char (#20287) Signed-off-by: Kebe --- benchmarks/benchmark_serving.py | 2 +- vllm/benchmarks/serve.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 886a51e1cbd9a..9b235266dff1a 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -551,7 +551,7 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": metrics.request_goodput if goodput_config_dict else None, + "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 419284cca042e..8b16fea9e3d3c 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -498,7 +498,7 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": + "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, From 96453cfa831340788ef72c42bc2a1a2b4496a27f Mon Sep 17 00:00:00 2001 From: TY-AMD Date: Tue, 1 Jul 2025 16:12:19 +0800 Subject: [PATCH 089/175] [BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067) Signed-off-by: Tianyuan Wu --- .../attention/test_attention_selector.py | 6 +- .../attention/test_rocm_attention_selector.py | 6 +- vllm/platforms/rocm.py | 10 +++- vllm/v1/attention/backends/mla/common.py | 9 ++- vllm/v1/attention/backends/mla/triton_mla.py | 57 +++++++++++++++++++ 5 files changed, 78 insertions(+), 10 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index f3e64155703c2..a8ed749ba13b5 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -106,10 +106,8 @@ def test_env( block_size, False, use_mla=use_mla) - if use_v1 and name != "TRITON_MLA": - assert backend.get_name() == f"{name}_VLLM_V1" - else: - assert backend.get_name() == name + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index ed58880cc9e6c..34311b9ccd767 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) - assert backend.get_name() == "TRITON_MLA" + assert (backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, None) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) - assert backend.get_name() == "TRITON_MLA" + assert (backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1") # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 08d471d5a983c..ee53a76ceb6db 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -186,8 +186,14 @@ class RocmPlatform(Platform): if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + if use_v1: + logger.info_once( + "Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1878ae74dbc6f..d45ec04472a69 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -640,7 +640,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj - self.vllm_flash_attn_version = get_flash_attn_version() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -672,11 +671,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) + if is_vllm_fa: + kwargs["return_softmax_lse"] = return_softmax_lse + else: + # ROCm leverages the upstream flash_attn, which takes a parameter + # called "return_attn_probs" instead of return_softmax_lse + kwargs["return_attn_probs"] = return_softmax_lse + attn_out = self.flash_attn_varlen_func( q=q, k=k, v=maybe_padded_v, - return_softmax_lse=return_softmax_lse, softmax_scale=softmax_scale, **kwargs, ) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index e26d7909184b5..99938f22f108c 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -5,10 +5,14 @@ from typing import Any, Optional import torch +from vllm import envs from vllm.attention.backends.abstract import (AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) @@ -68,6 +72,59 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): raise NotImplementedError( "TritonMLA V1 with FP8 KV cache not yet supported") + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + self.triton_fa_func = triton_attention if HAS_TRITON else None + + def _flash_attn_varlen_diff_headdims_rocm(self, + q, + k, + v, + softmax_scale=None, + **kwargs): + assert self.triton_fa_func is not None + + # Triton Attention requires a padded V + padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + # The output of triton_attention is a tuple of + # [output_tensor, encoded_softmax] where encoded_softmax is always None + output_tensor, _ = self.triton_fa_func( + q, + k, + padded_v, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias + ) + + return output_tensor + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + if current_platform.is_rocm() \ + and self.use_triton_flash_attn \ + and not return_softmax_lse: + return self._flash_attn_varlen_diff_headdims_rocm( + q, k, v, softmax_scale=softmax_scale, **kwargs) + else: + return super()._flash_attn_varlen_diff_headdims( + q, + k, + v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs) + def _forward_decode( self, q_nope: torch.Tensor, From 787b13389e2c0b114074f0a0d715eeb6c0a2b0c5 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Tue, 1 Jul 2025 16:18:09 +0800 Subject: [PATCH 090/175] [doc] fix the incorrect logo in dark mode (#20289) Signed-off-by: reidliu41 --- docs/README.md | 3 ++- docs/mkdocs/stylesheets/extra.css | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/README.md b/docs/README.md index 9fb3137b31928..e1d1046951a59 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,7 +1,8 @@ # Welcome to vLLM
- ![](./assets/logos/vllm-logo-text-light.png){ align="center" alt="vLLM" class="no-scaled-link" width="60%" } + ![](./assets/logos/vllm-logo-text-light.png){ align="center" alt="vLLM Light" class="logo-light" width="60%" } + ![](./assets/logos/vllm-logo-text-dark.png){ align="center" alt="vLLM Dark" class="logo-dark" width="60%" }

diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css index 248711f491b9d..892013c1cddfa 100644 --- a/docs/mkdocs/stylesheets/extra.css +++ b/docs/mkdocs/stylesheets/extra.css @@ -134,3 +134,12 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . opacity: 0.9; transform: translateY(2px); } + +/* For logo css */ +[data-md-color-scheme="default"] .logo-dark { + display: none; +} + +[data-md-color-scheme="slate"] .logo-light { + display: none; +} From c05596f1a350f3d993c467959ed02492141c2527 Mon Sep 17 00:00:00 2001 From: Lionel Villard Date: Tue, 1 Jul 2025 05:10:28 -0400 Subject: [PATCH 091/175] [Perf] Validate @config in pre-commit instead of dynamically (#20200) Signed-off-by: Lionel Villard --- .pre-commit-config.yaml | 7 ++ tests/test_config.py | 35 +----- tests/tools/__init__.py | 0 tests/tools/test_config_validator.py | 49 +++++++++ tools/validate_config.py | 158 +++++++++++++++++++++++++++ vllm/config.py | 28 +---- 6 files changed, 220 insertions(+), 57 deletions(-) create mode 100644 tests/tools/__init__.py create mode 100644 tests/tools/test_config_validator.py create mode 100644 tools/validate_config.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15ef5defff69e..d962252eb3dd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -160,6 +160,13 @@ repos: types: [python] pass_filenames: false additional_dependencies: [pathspec, regex] + - id: validate-config + name: Validate configuration has default values and that each field has a docstring + entry: python tools/validate_config.py + language: python + types: [python] + pass_filenames: true + files: vllm/config.py|tests/test_config.py # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/tests/test_config.py b/tests/test_config.py index 5d5c4453d30d2..cb7654c26afc8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,49 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import MISSING, Field, asdict, dataclass, field -from typing import Literal, Union import pytest from vllm.compilation.backends import VllmBackend from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - config, get_field) + get_field) from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform -class _TestConfig1: - pass - - -@dataclass -class _TestConfig2: - a: int - """docstring""" - - -@dataclass -class _TestConfig3: - a: int = 1 - - -@dataclass -class _TestConfig4: - a: Union[Literal[1], Literal[2]] = 1 - """docstring""" - - -@pytest.mark.parametrize(("test_config", "expected_error"), [ - (_TestConfig1, "must be a dataclass"), - (_TestConfig2, "must have a default"), - (_TestConfig3, "must have a docstring"), - (_TestConfig4, "must use a single Literal"), -]) -def test_config(test_config, expected_error): - with pytest.raises(Exception, match=expected_error): - config(test_config) - - def test_compile_config_repr_succeeds(): # setup: VllmBackend mutates the config object config = VllmConfig() diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tools/test_config_validator.py b/tests/tools/test_config_validator.py new file mode 100644 index 0000000000000..b0475894a114e --- /dev/null +++ b/tests/tools/test_config_validator.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast + +import pytest + +from tools.validate_config import validate_ast + +_TestConfig1 = ''' +@config +class _TestConfig1: + pass +''' + +_TestConfig2 = ''' +@config +@dataclass +class _TestConfig2: + a: int + """docstring""" +''' + +_TestConfig3 = ''' +@config +@dataclass +class _TestConfig3: + a: int = 1 +''' + +_TestConfig4 = ''' +@config +@dataclass +class _TestConfig4: + a: Union[Literal[1], Literal[2]] = 1 + """docstring""" +''' + + +@pytest.mark.parametrize(("test_config", "expected_error"), [ + (_TestConfig1, "must be a dataclass"), + (_TestConfig2, "must have a default"), + (_TestConfig3, "must have a docstring"), + (_TestConfig4, "must use a single Literal"), +]) +def test_config(test_config, expected_error): + tree = ast.parse(test_config) + with pytest.raises(Exception, match=expected_error): + validate_ast(tree) diff --git a/tools/validate_config.py b/tools/validate_config.py new file mode 100644 index 0000000000000..8b1e955c653d7 --- /dev/null +++ b/tools/validate_config.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Ensures all fields in a config dataclass have default values +and that each field has a docstring. +""" + +import ast +import inspect +import sys + + +def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: + """ + Get any docstrings placed after attribute assignments in a class body. + + Adapted from https://davidism.com/attribute-docstrings/ + https://davidism.com/mit-license/ + """ + + def pairwise(iterable): + """ + Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise + + Can be removed when Python 3.9 support is dropped. + """ + iterator = iter(iterable) + a = next(iterator, None) + + for b in iterator: + yield a, b + a = b + + out = {} + + # Consider each pair of nodes. + for a, b in pairwise(cls_node.body): + # Must be an assignment then a constant string. + if (not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str)): + continue + + doc = inspect.cleandoc(b.value.value) + + # An assignment can have multiple targets (a = b = v), but an + # annotated assignment only has one target. + targets = a.targets if isinstance(a, ast.Assign) else [a.target] + + for target in targets: + # Must be assigning to a plain name. + if not isinstance(target, ast.Name): + continue + + out[target.id] = doc + + return out + + +class ConfigValidator(ast.NodeVisitor): + + def __init__(self): + ... + + def visit_ClassDef(self, node): + # Validate class with both @config and @dataclass decorators + decorators = [ + id for d in node.decorator_list if (isinstance(d, ast.Name) and ( + (id := d.id) == 'config' or id == 'dataclass')) or + (isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and + (id := d.func.id) == 'dataclass')) + ] + + if set(decorators) == {'config', 'dataclass'}: + validate_class(node) + elif set(decorators) == {'config'}: + fail( + f"Class {node.name} with config decorator must be a dataclass.", + node) + + self.generic_visit(node) + + +def validate_class(class_node: ast.ClassDef): + attr_docs = get_attr_docs(class_node) + + for stmt in class_node.body: + # A field is defined as a class variable that has a type annotation. + if isinstance(stmt, ast.AnnAssign): + # Skip ClassVar + # see https://docs.python.org/3/library/dataclasses.html#class-variables + if isinstance(stmt.annotation, ast.Subscript) and isinstance( + stmt.annotation.value, + ast.Name) and stmt.annotation.value.id == "ClassVar": + continue + + if isinstance(stmt.target, ast.Name): + field_name = stmt.target.id + if stmt.value is None: + fail( + f"Field '{field_name}' in {class_node.name} must have " + "a default value.", stmt) + + if field_name not in attr_docs: + fail( + f"Field '{field_name}' in {class_node.name} must have " + "a docstring.", stmt) + + if isinstance(stmt.annotation, ast.Subscript) and \ + isinstance(stmt.annotation.value, ast.Name) \ + and stmt.annotation.value.id == "Union" and \ + isinstance(stmt.annotation.slice, ast.Tuple): + args = stmt.annotation.slice.elts + literal_args = [ + arg for arg in args + if isinstance(arg, ast.Subscript) and isinstance( + arg.value, ast.Name) and arg.value.id == "Literal" + ] + if len(literal_args) > 1: + fail( + f"Field '{field_name}' in {class_node.name} must " + "use a single " + "Literal type. Please use 'Literal[Literal1, " + "Literal2]' instead of 'Union[Literal1, Literal2]'" + ".", stmt) + + +def validate_ast(tree: ast.stmt): + ConfigValidator().visit(tree) + + +def validate_file(file_path: str): + try: + print(f"validating {file_path} config dataclasses ", end="") + with open(file_path, encoding="utf-8") as f: + source = f.read() + + tree = ast.parse(source, filename=file_path) + validate_ast(tree) + except ValueError as e: + print(e) + SystemExit(2) + else: + print("✅") + + +def fail(message: str, node: ast.stmt): + raise ValueError(f"❌ line({node.lineno}): {message}") + + +def main(): + for filename in sys.argv[1:]: + validate_file(filename) + + +if __name__ == "__main__": + main() diff --git a/vllm/config.py b/vllm/config.py index 46a5bf34f66e4..6412e6e293b45 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -18,7 +18,7 @@ from functools import cached_property from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Protocol, TypeVar, Union, cast, get_args, get_origin) + Protocol, TypeVar, Union, cast, get_args) import regex as re import torch @@ -193,28 +193,10 @@ def config(cls: ConfigT) -> ConfigT: (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` requires custom construction from CLI (i.e. `CompilationConfig`), it can have a `from_cli` method, which will be called instead. + + Config validation is performed by the tools/validate_config.py + script, which is invoked during the pre-commit checks. """ - if not is_dataclass(cls): - raise TypeError("The decorated class must be a dataclass.") - attr_docs = get_attr_docs(cls) - for f in fields(cls): - if f.init and f.default is MISSING and f.default_factory is MISSING: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must have a default value." - ) - - if f.name not in attr_docs: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must have a docstring.") - - if get_origin(f.type) is Union: - args = get_args(f.type) - literal_args = [arg for arg in args if get_origin(arg) is Literal] - if len(literal_args) > 1: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must use a single " - "Literal type. Please use 'Literal[Literal1, Literal2]' " - "instead of 'Union[Literal1, Literal2]'.") return cls @@ -1798,7 +1780,7 @@ class ParallelConfig: eplb_step_interval: int = 3000 """ Interval for rearranging experts in expert parallelism. - + Note that if this is greater than the EPLB window size, only the metrics of the last `eplb_window_size` steps will be used for rearranging experts. """ From 9025a9a7050253678431b2c20e6dd0be55f0dcc2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 1 Jul 2025 06:20:34 -0400 Subject: [PATCH 092/175] [Quant] [Bugfix] Fix quantization config matching with `hf_to_vllm_mapper` (#20046) --- .../test_register_quantization_config.py | 1 + vllm/lora/models.py | 2 +- vllm/lora/worker_manager.py | 5 +--- .../layers/quantization/base_config.py | 13 +++++++++++ .../layers/quantization/bitblas.py | 1 + .../compressed_tensors/compressed_tensors.py | 17 +++++++++++++- .../model_executor/layers/quantization/fp8.py | 10 +++++++- .../layers/quantization/gptq_bitblas.py | 1 + .../layers/quantization/marlin.py | 2 ++ .../layers/quantization/modelopt.py | 1 + .../layers/quantization/torchao.py | 1 + vllm/model_executor/model_loader/utils.py | 22 ++++++++++-------- vllm/model_executor/models/interfaces.py | 23 ++++++++++++++++--- vllm/model_executor/models/qwen2_5_vl.py | 14 +++++------ vllm/model_executor/models/transformers.py | 1 + vllm/model_executor/models/utils.py | 15 +++++++++++- vllm/model_executor/utils.py | 7 ++++-- 17 files changed, 107 insertions(+), 29 deletions(-) diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index 42081a8c68cdc..6c541fdbeeae2 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig): def __init__(self, num_bits: int = 8) -> None: """Initialize the quantization config.""" + super().__init__() self.num_bits = num_bits def get_name(self) -> QuantizationMethods: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 262e6799583ae..9e1ed3a771798 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -805,7 +805,7 @@ def create_lora_manager( lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" - if not hasattr(model, "packed_modules_mapping"): + if not isinstance(model, SupportsLoRA): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 7da44569f4086..7a4af74cbeb12 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -111,10 +111,7 @@ class WorkerLoRAManager(AbstractWorkerManager): # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. model = self._adapter_manager.model - hf_to_vllm_mapper = None - if (hasattr(model, "hf_to_vllm_mapper") - and model.hf_to_vllm_mapper is not None): - hf_to_vllm_mapper = model.hf_to_vllm_mapper + hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) lora = self._lora_model_cls.from_local_checkpoint( lora_path, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 78c5c75c06515..4a43351260e9f 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -10,6 +10,7 @@ from torch import nn if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper else: QuantizationMethods = str @@ -149,3 +150,15 @@ class QuantizationConfig(ABC): def get_cache_scale(self, name: str) -> Optional[str]: return None + + def apply_vllm_mapper( # noqa: B027 + self, hf_to_vllm_mapper: "WeightsMapper"): + """ + Interface for models to update module names referenced in + quantization configs in order to reflect the vllm model structure + + :param hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure + """ + # TODO (@kylesayrs): add implementations for all subclasses + pass diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 9e5ce39ec8f2e..aa8eee88a9f9e 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -63,6 +63,7 @@ class BitBLASConfig(QuantizationConfig): # (since we have only one group per output channel) desc_act = False + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4f87b2a44f0ac..e7f65d13181d8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import suppress -from typing import Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast import torch from compressed_tensors.config import (CompressionFormat, @@ -37,6 +37,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import cutlass_fp4_supported) from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + logger = init_logger(__name__) __all__ = ["CompressedTensorsLinearMethod"] @@ -80,6 +83,18 @@ class CompressedTensorsConfig(QuantizationConfig): def get_name(self) -> QuantizationMethods: return "compressed-tensors" + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + self.target_scheme_map = hf_to_vllm_mapper.apply_dict( + self.target_scheme_map) + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) + self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( + self.sparsity_scheme_map) + self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( + self.sparsity_ignore_list) + if self.kv_cache_scheme is not None: + self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict( + self.kv_cache_scheme) + def get_quant_method( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 93472207fbb86..60df679a74bda 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -39,6 +39,9 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) @@ -100,6 +103,11 @@ class Fp8Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return [] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.ignored_layers is not None: + self.ignored_layers = hf_to_vllm_mapper.apply_list( + self.ignored_layers) + @classmethod def from_config(cls, config: dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 78e0f59fa4bee..caeb266d0b933 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -81,6 +81,7 @@ class GPTQBitBLASConfig(QuantizationConfig): # (since we have only one group per output channel) desc_act = False + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 62667db26b669..18d1c13373df9 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -32,6 +32,8 @@ class MarlinConfig(QuantizationConfig): group_size: int, lm_head_quantized: bool, ) -> None: + super().__init__() + # Group size for the quantization. self.group_size = group_size self.lm_head_quantized = lm_head_quantized diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e35db5b31dba7..a10911b84afc4 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -181,6 +181,7 @@ class ModelOptNvFp4Config(QuantizationConfig): exclude_modules: list[str], group_size: int = 16, ) -> None: + super().__init__() self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index a4e0356c02689..63b2ab6bab063 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -55,6 +55,7 @@ class TorchAOConfig(QuantizationConfig): os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") """ + super().__init__() self.torchao_config = torchao_config self.skip_modules = skip_modules or [] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 79e6fa7b16dc7..159e7b1e6b01a 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,6 +24,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +from vllm.model_executor.models.interfaces import SupportsQuant from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig, Note that model attributes are passed by reference to quant_config, enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + + Once the `SupportsQuant` mixin has been added to all models, this + function can be removed """ - packed_mapping = getattr(model_class, "packed_modules_mapping", None) - if packed_mapping is not None: - # pass packed_modules_mapping by reference to quant_config - quant_config.packed_modules_mapping = packed_mapping - else: - logger.warning( - "The model class %s has not defined `packed_modules_mapping`, " - "this may lead to incorrect mapping of quantized or ignored " - "modules", model_class.__name__) + if not issubclass(model_class, SupportsQuant): + hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None) + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + + # pass mappings by reference to quant_config + if hf_to_vllm_mapper is not None: + quant_config.apply_vllm_mapper(hf_to_vllm_mapper) + if packed_mapping is not None: + quant_config.packed_modules_mapping = packed_mapping diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index ad59fe79edcb1..d234632ef1b75 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -18,6 +18,7 @@ from .interfaces_base import is_pooling_model if TYPE_CHECKING: from vllm.attention import AttentionMetadata + from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors logger = init_logger(__name__) @@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool: class SupportsQuant: """The interface required for all models that support quantization.""" - packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} + hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None + packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None quant_config: Optional[QuantizationConfig] = None def __new__(cls, *args, **kwargs) -> Self: instance = super().__new__(cls) + + # find config passed in arguments quant_config = cls._find_quant_config(*args, **kwargs) if quant_config is not None: + + # attach config to model for general use instance.quant_config = quant_config - instance.quant_config.packed_modules_mapping.update( - cls.packed_modules_mapping) + + # apply model mappings to config for proper config-model matching + # NOTE: `TransformersForCausalLM` is not supported due to how this + # class defines `hf_to_vllm_mapper` as a post-init `@property`. + # After this is fixed, get `instance.hf_to_vllm_mapper` directly + if getattr(instance, "hf_to_vllm_mapper", None) is not None: + instance.quant_config.apply_vllm_mapper( + instance.hf_to_vllm_mapper) + if getattr(instance, "packed_modules_mapping", None) is not None: + instance.quant_config.packed_modules_mapping.update( + instance.packed_modules_mapping) + return instance @staticmethod def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + """Find quant config passed through model constructor args""" from vllm.config import VllmConfig # avoid circular import args_values = list(args) + list(kwargs.values()) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ff53a2775e3d4..1b64b61a1e5cf 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -61,7 +61,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) + SupportsMultiModal, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) @@ -821,7 +821,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): info=Qwen2_5_VLProcessingInfo, dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): + SupportsLoRA, SupportsPP, + SupportsQuant): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config @@ -846,7 +846,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=self._maybe_ignore_quant_config(self.quant_config), prefix=maybe_prefix(prefix, "visual"), ) @@ -859,12 +859,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + if isinstance(config, (GPTQConfig, GPTQMarlinConfig)): return None - return quant_config + return config def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 2f78d9d4cc065..04ee3a454f9d8 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -467,6 +467,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend, # this makes thing complicated. We need to remove this mapper after refactor # `TransformersModel` in the future. + # NOTE: `SupportsQuant` can be updated after property decorator is removed @property def hf_to_vllm_mapper(self): prefix_mapper = { diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index aa88f42101605..62deb68035b92 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Callable, Literal, Optional, Protocol, Union, overload +from typing import Any, Callable, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn @@ -64,6 +64,19 @@ class WeightsMapper: return ((out_name, data) for name, data in weights if (out_name := self._map_name(name)) is not None) + def apply_list(self, values: list[str]) -> list[str]: + return [ + out_name for name in values + if (out_name := self._map_name(name)) is not None + ] + + def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]: + return { + out_name: value + for name, value in values.items() + if (out_name := self._map_name(name)) is not None + } + class AutoWeightsLoader: """ diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index cbaa34bfc30b2..2b20ca2a3ba3f 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -58,7 +58,8 @@ def _make_synced_weight_loader(original_weight_loader): def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: - parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {})) + parent_map = getattr(model, "packed_modules_mapping", None) + parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} # don't infer mapping if the model has defined it explicitly. if parent_map: @@ -66,7 +67,9 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: # We only check main components instead of whole model submodules for child in model.children(): - child_map = getattr(child, "packed_modules_mapping", {}) + child_map = getattr(child, "packed_modules_mapping", None) + child_map = copy.deepcopy(child_map) if child_map is not None else {} + if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()): raise ValueError( From 650d5dbd04e92f5043a11e4a4d86d4f39ee1b694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 1 Jul 2025 13:40:14 +0200 Subject: [PATCH 093/175] [Misc] Minor refactor of NIXL background handshake (#20068) Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 60 ++++++++++--------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7a077dce7706c..56ae1acf8571f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -515,6 +515,33 @@ class NixlConnectorWorker: # Remote rank -> agent name. return {p_remote_rank: handshake(path, p_remote_rank)} + def _background_nixl_handshake(self, req_id: str, + remote_engine_id: EngineId, meta: ReqMeta): + # Do NIXL handshake in background and add to _ready_requests when done. + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + fut = self._handshake_initiation_executor.submit( + self._nixl_handshake, meta.remote_host, meta.remote_port, + meta.tp_size) + self._handshake_futures[remote_engine_id] = fut + + def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): + with self._handshake_lock: + del self._handshake_futures[eid] + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception("Handshake with %s failed", eid) + + fut.add_done_callback(done_callback) + + # TODO: handle failure state of future in the + # callback, we want to fail the request in this case. + def request_ready(_f: Future[Any], entry=(req_id, meta)): + self._ready_requests.put(entry) + + fut.add_done_callback(request_ready) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -902,37 +929,14 @@ class NixlConnectorWorker: remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) if remote_engine_id not in self._remote_agents: - # Being optimistic to assume engine is usually ready, apply - # lock only when the optimistic check fails. + # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - fut = self._handshake_futures.get(remote_engine_id) - if fut is None: - fut = self._handshake_initiation_executor.submit( - self._nixl_handshake, meta.remote_host, - meta.remote_port, meta.tp_size) - self._handshake_futures[remote_engine_id] = fut - - def done_callback(f: Future[dict[int, str]], - eid=remote_engine_id): - with self._handshake_lock: - del self._handshake_futures[eid] - try: - self._remote_agents[eid] = f.result() - except Exception: - logger.exception( - "Handshake with %s failed", eid) - - fut.add_done_callback(done_callback) - - # TODO: handle failure state of future in the - # callback, we want to fail the request in this case. - def request_ready(_f: Future[Any], - entry=(req_id, meta)): - self._ready_requests.put(entry) - - fut.add_done_callback(request_ready) + self._background_nixl_handshake( + req_id, remote_engine_id, meta) continue + + # Handshake already completed, start async read xfer. self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. From ed70f3c64f684750edea087e286cbf264e7cc3f3 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 1 Jul 2025 20:48:26 +0800 Subject: [PATCH 094/175] Add GLM4.1V model (Draft) (#19331) Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- docs/models/supported_models.md | 3 +- examples/offline_inference/vision_language.py | 40 +- tests/entrypoints/openai/test_video.py | 2 +- .../multimodal/generation/test_common.py | 28 + .../generation/vlm_utils/custom_inputs.py | 20 + .../generation/vlm_utils/model_utils.py | 24 + .../multimodal/processing/test_common.py | 24 + tests/models/registry.py | 1 + tests/multimodal/test_utils.py | 4 +- vllm/assets/video.py | 26 +- vllm/entrypoints/chat_utils.py | 4 + .../model_executor/layers/rotary_embedding.py | 119 ++ vllm/model_executor/models/glm4_1v.py | 1589 +++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/multimodal/inputs.py | 8 +- vllm/multimodal/parse.py | 42 +- vllm/multimodal/video.py | 27 +- 17 files changed, 1946 insertions(+), 16 deletions(-) create mode 100644 vllm/model_executor/models/glm4_1v.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0248700292ae2..db650b37a38db 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -553,6 +553,7 @@ Specified using `--task generate`. | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ | @@ -583,7 +584,7 @@ Specified using `--task generate`. | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | -| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | +| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 57b042ed013b1..b9e8bef26eb24 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -248,6 +248,42 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.1V +def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "THUDM/GLM-4.1V-9B-Thinking" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1114,6 +1150,7 @@ model_example_map = { "fuyu": run_fuyu, "gemma3": run_gemma3, "glm4v": run_glm4v, + "glm4_1v": run_glm4_1v, "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, "internvl_chat": run_internvl, @@ -1172,10 +1209,11 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays + metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata vid_questions = ["Why is this video funny?"] return { - "data": video, + "data": [(video, metadata)] if args.model_type == "glm4_1v" else video, "questions": vid_questions, } diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 990ea3579291d..b68e08556ee96 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -50,7 +50,7 @@ async def client(server): @pytest.fixture(scope="session") def base64_encoded_video() -> dict[str, str]: return { - video_url: encode_video_base64(fetch_video(video_url)) + video_url: encode_video_base64(fetch_video(video_url)[0]) for video_url in TEST_VIDEO_URLS } diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 9d63339737ce6..6ecf6db56cb39 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -309,6 +309,34 @@ VLM_TEST_SETTINGS = { num_logprobs=10, marks=[large_gpu_mark(min_gb=32)], ), + "glm4_1v": VLMTestInfo( + models=["THUDM/GLM-4.1V-9B-Thinking"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501 + video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", # noqa: E501 + max_model_len=2048, + max_num_seqs=2, + get_stop_token_ids=lambda tok: [151329, 151336, 151338], + num_logprobs=10, + image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + auto_cls=AutoModelForImageTextToText, + ), + "glm4_1v-video": VLMTestInfo( + models=["THUDM/GLM-4.1V-9B-Thinking"], + # GLM4.1V require include video metadata for input + test_type=VLMTestType.CUSTOM_INPUTS, + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + patch_hf_runner=model_utils.glm4_1v_patch_hf_runner, + custom_test_opts=[CustomTestOptions( + inputs=custom_inputs.video_with_metadata_glm4_1v(), + limit_mm_per_prompt={"video": 1}, + )], + # This is needed to run on machine with 24GB VRAM + vllm_runner_kwargs={"gpu_memory_utilization": 0.95}, + ), "h2ovl": VLMTestInfo( models = [ "h2oai/h2ovl-mississippi-800m", diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index aa5835243e042..c53243b42e384 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -129,3 +129,23 @@ def windows_attention_image_qwen2_5_vl(): wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5]) return build_single_image_inputs([image], [prompt], wrapped_sf) + + +def video_with_metadata_glm4_1v(): + video_array = VIDEO_ASSETS[0].np_ndarrays + metadata = VIDEO_ASSETS[0].metadata + question = "Describe the video." + video_prompt = "<|begin_of_video|><|video|><|end_of_video|>" + formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" + + scales = [0.1, 0.2, 0.25] + video_input = [[(rescale_video_size(video_array, scale), metadata)] + for scale in scales] + prompts = [formatted_prompt] * len(video_input) + + return [ + PromptWithMultiModalInput( + prompts=prompts, + video_data=video_input, + ) + ] diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index af4c72f44b676..c1a2aa0dcafbb 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -16,9 +16,11 @@ import torch from PIL.Image import Image from transformers import (AutoConfig, AutoTokenizer, BatchFeature, GenerationConfig, GenerationMixin) +from transformers.video_utils import VideoMetadata from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side +from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .types import RunnerOutput @@ -373,6 +375,28 @@ def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for GLM4.1V.""" + hf_processor = hf_model.processor + + def processor(*args, videos=None, **kwargs): + if videos is not None and is_list_of(videos, tuple): + # If videos is a list of tuples, we assume each tuple contains + # (video_array, metadata) as in the case of GLM4.1V. + video_metadata = [[VideoMetadata(**video[1])] for video in videos] + videos = [[video[0]] for video in videos] + else: + video_metadata = None + + return hf_processor(*args, + videos=videos, + video_metadata=video_metadata, + **kwargs) + + hf_model.processor = processor + return hf_model + + def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for H2OVL.""" diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 1ba60178c13db..0f33225eda2da 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -24,6 +24,22 @@ from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS +def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + """ + Patch the multimodal data for GLM4.1V model. + """ + # Ensure video metadata is included + if "video" in mm_data: + video = mm_data["video"] + mm_data["video"] = (video, { + "total_num_frames": len(video), + "fps": len(video), + "duration": 1, + "video_backend": "opencv" + }) + return mm_data + + def _test_processing_correctness( model_id: str, hit_rate: float, @@ -154,6 +170,11 @@ _IGNORE_MM_KEYS = { "ultravox": {"audio_features"}, } +MM_DATA_PATCHES = { + # GLM4.1V requires video metadata to be included in the input + "glm4v": glm4_1v_patch_mm_data, +} + def _test_processing_correctness_one( model_config: ModelConfig, @@ -166,6 +187,8 @@ def _test_processing_correctness_one( ): model_type = model_config.hf_config.model_type ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) + if model_type in MM_DATA_PATCHES: + mm_data = MM_DATA_PATCHES[model_type](mm_data) if isinstance(prompt, str): text_prompt = prompt @@ -245,6 +268,7 @@ def _test_processing_correctness_one( "adept/fuyu-8b", "google/gemma-3-4b-it", "THUDM/glm-4v-9b", + "THUDM/GLM-4.1V-9B-Thinking", "ibm-granite/granite-speech-3.3-2b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index e56dd19bec670..affe2e88b2b94 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -338,6 +338,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 + "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 5ac0a90f50473..a48542cec3f87 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -172,7 +172,9 @@ async def test_fetch_video_http(video_url: str, num_frames: int): video_sync = connector.fetch_video(video_url, num_frames=num_frames) video_async = await connector.fetch_video_async(video_url, num_frames=num_frames) - assert np.array_equal(video_sync, video_async) + # Check that the video frames are equal and metadata are same + assert np.array_equal(video_sync[0], video_async[0]) + assert video_sync[1] == video_async[1] # Used for the next two tests related to `merge_and_sort_multimodal_metadata`. diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 01834aeeb6c12..16412121cf0a8 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from functools import lru_cache -from typing import ClassVar, Literal, Optional +from typing import Any, ClassVar, Literal, Optional import cv2 import numpy as np @@ -77,6 +77,24 @@ def video_to_pil_images_list(path: str, ] +def video_get_metadata(path: str) -> dict[str, Any]: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames / fps if fps > 0 else 0 + + metadata = { + "total_num_frames": total_frames, + "fps": fps, + "duration": duration, + "video_backend": "opencv" + } + return metadata + + VideoAssetName = Literal["baby_reading"] @@ -105,6 +123,12 @@ class VideoAsset: ret = video_to_ndarrays(video_path, self.num_frames) return ret + @property + def metadata(self) -> dict[str, Any]: + video_path = download_video_asset(self.filename) + ret = video_get_metadata(video_path) + return ret + def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: """ Read audio data from the video asset, used in Qwen2.5-Omni examples. diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 35ee52ab4601d..45f1894d022b3 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -515,6 +515,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): if modality in ("image", "image_embeds"): if model_type == "chatglm": return "<|begin_of_image|><|endoftext|><|end_of_image|>" + if model_type == "glm4v": + return "<|begin_of_image|><|image|><|end_of_image|>" if model_type in ("phi3_v", "phi4mm"): return f"<|image_{current_count}|>" if model_type in ("minicpmo", "minicpmv"): @@ -563,6 +565,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): elif modality == "video": if model_type == "internvl_chat": return "