[Kernel] unified_attention for Attention.forward (#11967)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-01-13 19:28:53 +08:00 committed by GitHub
parent 5340a30d01
commit 0f8cafe2d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 87 additions and 45 deletions

View File

@ -134,15 +134,10 @@ class Attention(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
_kv_cache: torch.Tensor,
_attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.use_direct_call:
return self.impl.forward(query, key, value, kv_cache,
attn_metadata, self._k_scale,
self._v_scale)
elif self.use_output:
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
# Reshape the query, key, and value tensors.
@ -154,12 +149,19 @@ class Attention(nn.Module):
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
if self.use_direct_call:
unified_attention_with_output(query, key, value, output,
self.layer_name)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
self.layer_name)
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore

View File

@ -2171,5 +2171,4 @@ def bind_kv_cache(
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]

View File

@ -28,6 +28,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
@ -40,7 +41,8 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
@ -1286,6 +1288,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[kv_caches])
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
@ -1943,7 +1948,11 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
with set_forward_context(
model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine), \
self.profiler.record_event(
'internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices

View File

@ -20,6 +20,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
@ -215,6 +216,8 @@ class HPUWorker(LocalOrDistributedWorkerBase):
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.hpu_cache)
def _warm_up_model(self) -> None:
# NOTE(kzawora): We should use virtual engine index here

View File

@ -8,6 +8,7 @@ from torch import nn
from transformers_neuronx.config import GenerationConfig
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
@ -314,13 +315,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
)
with set_forward_context(None, self.vllm_config, 0):
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
)
# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.

View File

@ -8,6 +8,7 @@ from torch import nn
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
@ -350,7 +351,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)
with set_forward_context(attn_metadata, self.vllm_config, 0):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)

View File

@ -20,6 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.utils import bind_kv_cache
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
@ -339,6 +340,8 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
ov_device,
)
self.kv_cache = self.cache_engine.kv_cache
bind_kv_cache(self.compilation_config.static_forward_context,
[self.kv_cache])
self.model_runner.block_size = self.cache_engine.block_size
assert self.kv_cache is not None
@ -507,12 +510,18 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
self.model_runner.block_size = tmp_cache_config.block_size
bind_kv_cache(self.compilation_config.static_forward_context,
profiling_cache_engine.kv_cache)
# Run the model with the dummy inputs.
self.model_runner.execute_model(seqs,
profiling_cache_engine.kv_cache)
# explicitly delete temporary KV cache manager to free KV cache
# when real inputs will be passed to OV
# Explicitly revert bind_kv_cache and delete temporary KV cache
# manager to free KV cache when real inputs will be passed to OV
bind_kv_cache(self.compilation_config.static_forward_context, [[
torch.tensor([])
for _ in range(len(profiling_cache_engine.kv_cache))
]])
del profiling_cache_engine
logger.info(

View File

@ -13,6 +13,7 @@ import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
@ -265,8 +266,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
p, num_samples, kv_caches)
def warmup_model(
self,
@ -663,10 +665,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
@ -711,10 +716,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)
if i < num_steps - 1:

View File

@ -12,7 +12,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
@ -108,6 +108,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
bind_kv_cache(self.compilation_config.static_forward_context,
[kv_caches])
self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
@ -170,6 +172,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
bind_kv_cache(self.compilation_config.static_forward_context,
[self.tpu_cache])
self._warmup_model()
def _warmup_model(self) -> None:

View File

@ -12,6 +12,7 @@ import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
@ -562,15 +563,17 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device))
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device))
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states