Hacky hacky

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2024-12-18 15:13:26 -05:00
parent 2ca830dbaa
commit 0c7e6c1e36
2 changed files with 29 additions and 12 deletions

View File

@ -2305,8 +2305,8 @@ class CompilationConfig(BaseModel):
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
# "vllm.unified_attention",
# "vllm.unified_attention_with_output",
])
use_inductor: bool = True

View File

@ -118,6 +118,20 @@ class GPUModelRunner:
dtype=self.dtype,
device=self.device)
# Attention metadata related persistent buffers
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(
self.max_num_tokens,
# CPU slot_mapping is int32, but
# this one must be int64
dtype=torch.int64,
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np = np.arange(max(self.max_num_reqs, self.max_model_len),
dtype=np.int32)
@ -337,27 +351,30 @@ class GPUModelRunner:
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long()
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_start_loc[:num_reqs + 1].copy_(
self.seq_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
query_start_loc=self.query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
seq_start_loc=self.seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
slot_mapping=slot_mapping,
slot_mapping=self.slot_mapping,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
logits_indices = self.query_start_loc[1:num_reqs + 1] - 1
return attn_metadata, logits_indices
def _prepare_sampling(