mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 04:27:03 +08:00
Hacky hacky
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
2ca830dbaa
commit
0c7e6c1e36
@ -2305,8 +2305,8 @@ class CompilationConfig(BaseModel):
|
||||
backend: str = ""
|
||||
custom_ops: List[str] = Field(default_factory=list)
|
||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
# "vllm.unified_attention",
|
||||
# "vllm.unified_attention_with_output",
|
||||
])
|
||||
|
||||
use_inductor: bool = True
|
||||
|
||||
@ -118,6 +118,20 @@ class GPUModelRunner:
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
|
||||
# Attention metadata related persistent buffers
|
||||
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.seq_start_loc = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping = torch.zeros(
|
||||
self.max_num_tokens,
|
||||
# CPU slot_mapping is int32, but
|
||||
# this one must be int64
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
self.arange_np = np.arange(max(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32)
|
||||
@ -337,27 +351,30 @@ class GPUModelRunner:
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True).long()
|
||||
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_start_loc[:num_reqs + 1].copy_(
|
||||
self.seq_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc=self.query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_start_loc=seq_start_loc,
|
||||
seq_start_loc=self.seq_start_loc,
|
||||
block_table=self.input_batch.block_table[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
slot_mapping=self.slot_mapping,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
# partial request, we do so for simplicity. We will ignore the sampled
|
||||
# token from the partial request.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
logits_indices = self.query_start_loc[1:num_reqs + 1] - 1
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _prepare_sampling(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user