[Misc] Various simplifications and typing fixes (#5368)

This commit is contained in:
Nick Hill 2024-06-10 19:29:02 -07:00 committed by GitHub
parent 76477a93b7
commit a008629807
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 62 additions and 89 deletions

View File

@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))]
samples = [output.samples[0] for output in outputs]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).

View File

@ -306,8 +306,10 @@ class RejectionSampler(nn.Module):
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
torch.where(accepted_mask,
draft_token_ids,
-torch.ones_like(draft_token_ids),
out=output)
# Fill the last column.
# We check output directly as accepted may have True values inconsistent

View File

@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, ))
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens)
def _contract_batch(
self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape(
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)
all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(contracted_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config")
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None
target_worker = Worker(*args, **kwargs)
@ -109,12 +110,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens, ))
return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens))
def __init__(
self,

View File

@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices,
)
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
@staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(
size=(
batch_size,
proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
proposal_probs = torch.zeros(
batch_size,
proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(
entire_proposal_probs = proposal_probs.new_zeros(
batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs

View File

@ -1,12 +1,11 @@
from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Tuple
import torch
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SequenceOutput)
SeqId = int
@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return list(
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
def get_all_num_logprobs(
@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
) -> SequenceGroupOutput:
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Type
from transformers import PretrainedConfig
@ -9,7 +9,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
else:
return config
return config

View File

@ -527,16 +527,6 @@ class ModelRunner:
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
@ -544,11 +534,6 @@ class ModelRunner:
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
@ -601,6 +586,21 @@ class ModelRunner:
seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype)
else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,