From 0f8cafe2d1550a33803fb64b2224e6bf3f913449 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 13 Jan 2025 19:28:53 +0800 Subject: [PATCH] [Kernel] unified_attention for Attention.forward (#11967) Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 26 ++++++++++++++------------ vllm/utils.py | 1 - vllm/worker/hpu_model_runner.py | 13 +++++++++++-- vllm/worker/hpu_worker.py | 3 +++ vllm/worker/neuron_model_runner.py | 17 ++++++++++------- vllm/worker/openvino_model_runner.py | 4 +++- vllm/worker/openvino_worker.py | 13 +++++++++++-- vllm/worker/tpu_model_runner.py | 28 ++++++++++++++++++---------- vllm/worker/tpu_worker.py | 6 +++++- vllm/worker/xpu_model_runner.py | 21 ++++++++++++--------- 10 files changed, 87 insertions(+), 45 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b8afd428f2cc0..c7e7a4d52e5a7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -134,15 +134,10 @@ class Attention(nn.Module): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + _kv_cache: torch.Tensor, + _attn_metadata: AttentionMetadata, ) -> torch.Tensor: - - if self.use_direct_call: - return self.impl.forward(query, key, value, kv_cache, - attn_metadata, self._k_scale, - self._v_scale) - elif self.use_output: + if self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) # Reshape the query, key, and value tensors. @@ -154,12 +149,19 @@ class Attention(nn.Module): key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) - torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name) + if self.use_direct_call: + unified_attention_with_output(query, key, value, output, + self.layer_name) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name) return output.view(-1, hidden_size) else: - return torch.ops.vllm.unified_attention(query, key, value, - self.layer_name) + if self.use_direct_call: + return unified_attention(query, key, value, self.layer_name) + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/utils.py b/vllm/utils.py index 217ccb25cef6d..9a509da3c1ef1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2171,5 +2171,4 @@ def bind_kv_cache( forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): - assert forward_ctx.kv_cache[ve].numel() == 0 forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9d479f412af46..3e5105f3b62e3 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -28,6 +28,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import DeviceConfig, VllmConfig from vllm.distributed.parallel_state import get_world_group +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -40,7 +41,8 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import (bind_kv_cache, is_pin_memory_available, + make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -1286,6 +1288,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers + bind_kv_cache( + self.vllm_config.compilation_config.static_forward_context, + [kv_caches]) max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] max_batch_size = min(self.max_num_batched_tokens // max_seq_len, self.scheduler_config.max_num_seqs) @@ -1943,7 +1948,11 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - with self.profiler.record_event('internal', model_event_name): + with set_forward_context( + model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine), \ + self.profiler.record_event( + 'internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index cca7cd50bfc7b..8b2d8aaed2803 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -20,6 +20,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest +from vllm.utils import bind_kv_cache from vllm.worker.cache_engine import CacheEngine from vllm.worker.hpu_model_runner import HPUModelRunner from vllm.worker.model_runner_base import ModelRunnerBase @@ -215,6 +216,8 @@ class HPUWorker(LocalOrDistributedWorkerBase): self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.hpu_cache) def _warm_up_model(self) -> None: # NOTE(kzawora): We should use virtual engine index here diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index ae4eb6ba6eaec..a35f5467e1a1f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -8,6 +8,7 @@ from torch import nn from transformers_neuronx.config import GenerationConfig from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -314,13 +315,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): raise ValueError( "NeuronModelRunner does not support multi-step execution.") - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - ) + with set_forward_context(None, self.vllm_config, 0): + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device), + ) # Compute the logits only if the on-device sampling is turned off as # on-device sampling outputs the token ids. diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 6000e5dfe4e30..a38b5a4e6e8d5 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -8,6 +8,7 @@ from torch import nn from vllm.attention import get_attn_backend from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -350,7 +351,8 @@ class OpenVINOModelRunner(ModelRunnerBase): device=self.device), } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(attn_metadata, self.vllm_config, 0): + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 0bf522d5333ed..3482073566215 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -20,6 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.utils import bind_kv_cache from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -339,6 +340,8 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ov_device, ) self.kv_cache = self.cache_engine.kv_cache + bind_kv_cache(self.compilation_config.static_forward_context, + [self.kv_cache]) self.model_runner.block_size = self.cache_engine.block_size assert self.kv_cache is not None @@ -507,12 +510,18 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): self.model_runner.block_size = tmp_cache_config.block_size + bind_kv_cache(self.compilation_config.static_forward_context, + profiling_cache_engine.kv_cache) # Run the model with the dummy inputs. self.model_runner.execute_model(seqs, profiling_cache_engine.kv_cache) - # explicitly delete temporary KV cache manager to free KV cache - # when real inputs will be passed to OV + # Explicitly revert bind_kv_cache and delete temporary KV cache + # manager to free KV cache when real inputs will be passed to OV + bind_kv_cache(self.compilation_config.static_forward_context, [[ + torch.tensor([]) + for _ in range(len(profiling_cache_engine.kv_cache)) + ]]) del profiling_cache_engine logger.info( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7bdb7f0e2d6a9..52c577bccab9c 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -13,6 +13,7 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model @@ -265,8 +266,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) + with set_forward_context(attn_metadata, self.vllm_config, 0): + self.model(token_ids, position_ids, attn_metadata, input_lens, t, + p, num_samples, kv_caches) def warmup_model( self, @@ -663,10 +665,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, p, - model_input.num_samples, - kv_caches) + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, + p, model_input.num_samples, + kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -711,10 +716,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, p, - model_input.num_samples, - kv_caches) + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, + p, model_input.num_samples, + kv_caches) self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 8754f7538f251..ea0e700545b16 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -12,7 +12,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, @@ -108,6 +108,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): torch.tensor([], dtype=torch.float32, device=self.device)) for _ in range(num_layers)] + bind_kv_cache(self.compilation_config.static_forward_context, + [kv_caches]) self.model_runner._dummy_run( batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, @@ -170,6 +172,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): device="cpu") cpu_v_cache = torch.zeros_like(cpu_k_cache) self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) + bind_kv_cache(self.compilation_config.static_forward_context, + [self.tpu_cache]) self._warmup_model() def _warmup_model(self) -> None: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9cf25387560da..82b8f22a5af33 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,6 +12,7 @@ import torch.nn as nn from vllm.attention import get_attn_backend from vllm.config import VllmConfig from vllm.distributed import get_pp_group +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadataCache @@ -562,15 +563,17 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start_time = time.time() - - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device)) + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device)) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states