[Misc] Various simplifications and typing fixes (#5368)

This commit is contained in:
Nick Hill 2024-06-10 19:29:02 -07:00 committed by GitHub
parent 76477a93b7
commit a008629807
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 62 additions and 89 deletions

View File

@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group, we can take the # Since there's only one sequence per sequence group, we can take the
# first sample. # first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))] samples = [output.samples[0] for output in outputs]
# -1 means the output token is not valid (eg. due to spec decode # -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens). # rejecting tokens).

View File

@ -306,8 +306,10 @@ class RejectionSampler(nn.Module):
# Fill in the first k columns of the output tensor using masks and data # Fill in the first k columns of the output tensor using masks and data
# tensors. # tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids, torch.where(accepted_mask,
-torch.ones_like(draft_token_ids)) draft_token_ids,
-torch.ones_like(draft_token_ids),
out=output)
# Fill the last column. # Fill the last column.
# We check output directly as accepted may have True values inconsistent # We check output directly as accepted may have True values inconsistent

View File

@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output = self._scorer_worker.execute_model( target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone( execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, )) seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output" assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens) num_scoring_tokens)
def _contract_batch( def _contract_batch(
self, contracted_bs: int, self, contracted_bs: int, target_sampler_output: SamplerOutput,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals, num_scoring_tokens: int, proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape( target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
spec_expanded_bs, k + 1) target_probs = target_probs.reshape(*target_token_ids.shape,
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size)
self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)
all_tokens = torch.full(size=(contracted_bs, k + 1), all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1, fill_value=-1)
device=self._device, all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
dtype=torch.long) all_logprobs = target_logprobs.new_full(size=all_probs.shape,
all_probs = torch.zeros(contracted_bs, fill_value=-float("inf"))
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
if non_spec_indices: if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_tokens[non_spec_indices, :1] = non_spec_target_token_ids

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
""" """
assert "speculative_config" in kwargs assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config") speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None assert speculative_config is not None
target_worker = Worker(*args, **kwargs) target_worker = Worker(*args, **kwargs)
@ -109,12 +110,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with proposer=%s", logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker)) type(proposer_worker))
return SpecDecodeWorker( return SpecDecodeWorker(proposer_worker,
proposer_worker, scorer_worker,
scorer_worker, disable_by_batch_size=disable_by_batch_size,
disable_by_batch_size=disable_by_batch_size, rejection_sampler=RejectionSampler(
rejection_sampler=RejectionSampler( disable_bonus_tokens=disable_bonus_tokens))
disable_bonus_tokens=disable_bonus_tokens, ))
def __init__( def __init__(
self, self,

View File

@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices, nonzero_proposal_len_indices,
) )
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, @staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed): nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset """Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal their proposal_len to 0 the draft worker does not provide a proposal
@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
self, self,
batch_size: int, batch_size: int,
proposal_len: int, proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput], maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int], proposal_lens: List[int],
nonzero_proposal_len_indices: List[int], nonzero_proposal_len_indices: List[int],
sampler_transposed: bool, sampler_transposed: bool,
@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
if maybe_sampler_output is None: if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None. # If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals. # In this case we return empty proposals.
proposal_tokens = torch.full( proposal_tokens = torch.tensor(-1,
size=( dtype=torch.long,
batch_size, device=self._device).expand(
proposal_len, batch_size, proposal_len)
), proposal_probs = torch.tensor(0,
fill_value=-1, dtype=torch.float32,
dtype=torch.long, device=self._device).expand(
device=self._device, batch_size, proposal_len,
) self._vocab_size)
proposal_probs = torch.zeros( proposal_lens_tensor = torch.tensor(0,
batch_size, dtype=torch.long,
proposal_len, device=self._device).expand(
self._vocab_size, len(proposal_lens))
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output sampler_output = maybe_sampler_output
@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
# Now, reformat the output GPU tensors such that each sequence has # Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1] # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full( entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]), size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1, fill_value=-1,
dtype=torch.long,
device=self._device,
) )
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros( entire_proposal_probs = proposal_probs.new_zeros(
batch_size, batch_size,
*proposal_probs.shape[1:], *proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
) )
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs

View File

@ -1,12 +1,11 @@
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput) SequenceOutput)
SeqId = int SeqId = int
@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all """Given a list of SequenceGroupMetadata, create a list of all
sequence ids. sequence ids.
""" """
return list( return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
def get_all_num_logprobs( def get_all_num_logprobs(
@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id: SeqId, seq_id: SeqId,
topk_token_ids: List[int], topk_token_ids: List[int],
topk_logprobs: List[float], topk_logprobs: List[float],
) -> SequenceGroupOutput: ) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results. """Create a SequenceGroupOutput given the sampling results.
Args: Args:

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional from typing import Dict, Optional, Type
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -9,7 +9,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
logger = init_logger(__name__) logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads") assert hasattr(config.text_config, "num_attention_heads")
return config.text_config return config.text_config
else: else:
return config return config

View File

@ -527,16 +527,6 @@ class ModelRunner:
) )
assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
@ -544,11 +534,6 @@ class ModelRunner:
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
@ -601,6 +586,21 @@ class ModelRunner:
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype) data_type=kv_cache_dtype)
else: else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,