mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 17:51:51 +08:00
[Kernel] unified_attention for Attention.forward (#11967)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
5340a30d01
commit
0f8cafe2d1
@ -134,15 +134,10 @@ class Attention(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
_kv_cache: torch.Tensor,
|
||||
_attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_direct_call:
|
||||
return self.impl.forward(query, key, value, kv_cache,
|
||||
attn_metadata, self._k_scale,
|
||||
self._v_scale)
|
||||
elif self.use_output:
|
||||
if self.use_output:
|
||||
output = torch.empty_like(query)
|
||||
hidden_size = query.size(-1)
|
||||
# Reshape the query, key, and value tensors.
|
||||
@ -154,12 +149,19 @@ class Attention(nn.Module):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
if self.use_direct_call:
|
||||
unified_attention_with_output(query, key, value, output,
|
||||
self.layer_name)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(query, key, value,
|
||||
self.layer_name)
|
||||
if self.use_direct_call:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
|
||||
@ -2171,5 +2171,4 @@ def bind_kv_cache(
|
||||
forward_ctx = ctx[layer_name]
|
||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
||||
assert forward_ctx.kv_cache[ve].numel() == 0
|
||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import DeviceConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -40,7 +41,8 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
|
||||
make_tensor_with_pad)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@ -1286,6 +1288,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
def profile_run(self) -> None:
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
bind_kv_cache(
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
[kv_caches])
|
||||
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
|
||||
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
|
||||
self.scheduler_config.max_num_seqs)
|
||||
@ -1943,7 +1948,11 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
f"graphs{'T' if use_graphs else 'F'}")
|
||||
else:
|
||||
model_event_name = 'model_executable'
|
||||
with self.profiler.record_event('internal', model_event_name):
|
||||
with set_forward_context(
|
||||
model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine), \
|
||||
self.profiler.record_event(
|
||||
'internal', model_event_name):
|
||||
hidden_states = self.model.forward(
|
||||
**execute_model_kwargs,
|
||||
selected_token_indices=sampling_metadata.selected_token_indices
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.hpu_model_runner import HPUModelRunner
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
@ -215,6 +216,8 @@ class HPUWorker(LocalOrDistributedWorkerBase):
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.hpu_cache)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# NOTE(kzawora): We should use virtual engine index here
|
||||
|
||||
@ -8,6 +8,7 @@ from torch import nn
|
||||
from transformers_neuronx.config import GenerationConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -314,13 +315,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
raise ValueError(
|
||||
"NeuronModelRunner does not support multi-step execution.")
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
)
|
||||
with set_forward_context(None, self.vllm_config, 0):
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
||||
or {},
|
||||
device=self.device),
|
||||
)
|
||||
|
||||
# Compute the logits only if the on-device sampling is turned off as
|
||||
# on-device sampling outputs the token ids.
|
||||
|
||||
@ -8,6 +8,7 @@ from torch import nn
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -350,7 +351,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
||||
device=self.device),
|
||||
}
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
@ -339,6 +340,8 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
ov_device,
|
||||
)
|
||||
self.kv_cache = self.cache_engine.kv_cache
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
[self.kv_cache])
|
||||
self.model_runner.block_size = self.cache_engine.block_size
|
||||
|
||||
assert self.kv_cache is not None
|
||||
@ -507,12 +510,18 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
self.model_runner.block_size = tmp_cache_config.block_size
|
||||
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
profiling_cache_engine.kv_cache)
|
||||
# Run the model with the dummy inputs.
|
||||
self.model_runner.execute_model(seqs,
|
||||
profiling_cache_engine.kv_cache)
|
||||
|
||||
# explicitly delete temporary KV cache manager to free KV cache
|
||||
# when real inputs will be passed to OV
|
||||
# Explicitly revert bind_kv_cache and delete temporary KV cache
|
||||
# manager to free KV cache when real inputs will be passed to OV
|
||||
bind_kv_cache(self.compilation_config.static_forward_context, [[
|
||||
torch.tensor([])
|
||||
for _ in range(len(profiling_cache_engine.kv_cache))
|
||||
]])
|
||||
del profiling_cache_engine
|
||||
|
||||
logger.info(
|
||||
|
||||
@ -13,6 +13,7 @@ import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -265,8 +266,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(p, 0)
|
||||
# Dummy run.
|
||||
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
||||
num_samples, kv_caches)
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
|
||||
p, num_samples, kv_caches)
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
@ -663,10 +665,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
||||
t = model_input.t[i:i + 1].to(self.device)
|
||||
p = model_input.p[i:i + 1].to(self.device)
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
attn_metadata, input_lens, t, p,
|
||||
model_input.num_samples,
|
||||
kv_caches)
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
attn_metadata, input_lens, t,
|
||||
p, model_input.num_samples,
|
||||
kv_caches)
|
||||
next_token_ids.append(output_token_ids[0])
|
||||
start_idx = end_idx
|
||||
|
||||
@ -711,10 +716,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
input_lens = model_input.input_lens.to(self.device)
|
||||
for i in range(num_steps):
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
attn_metadata, input_lens, t, p,
|
||||
model_input.num_samples,
|
||||
kv_caches)
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
attn_metadata, input_lens, t,
|
||||
p, model_input.num_samples,
|
||||
kv_caches)
|
||||
self.cached_step_outputs.append(output_token_ids)
|
||||
|
||||
if i < num_steps - 1:
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
|
||||
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerBase,
|
||||
@ -108,6 +108,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
torch.tensor([], dtype=torch.float32,
|
||||
device=self.device))
|
||||
for _ in range(num_layers)]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
[kv_caches])
|
||||
self.model_runner._dummy_run(
|
||||
batch_size=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
@ -170,6 +172,8 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
device="cpu")
|
||||
cpu_v_cache = torch.zeros_like(cpu_k_cache)
|
||||
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
[self.tpu_cache])
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
|
||||
@ -12,6 +12,7 @@ import torch.nn as nn
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadataCache
|
||||
@ -562,15 +563,17 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start_time = time.time()
|
||||
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device))
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
||||
or {},
|
||||
device=self.device))
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user