mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:45:45 +08:00
[Misc] Various simplifications and typing fixes (#5368)
This commit is contained in:
parent
76477a93b7
commit
a008629807
@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
|
||||
# Since there's only one sequence per sequence group, we can take the
|
||||
# first sample.
|
||||
samples = [outputs[step].samples[0] for step in range(len(outputs))]
|
||||
samples = [output.samples[0] for output in outputs]
|
||||
|
||||
# -1 means the output token is not valid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
|
||||
@ -306,8 +306,10 @@ class RejectionSampler(nn.Module):
|
||||
|
||||
# Fill in the first k columns of the output tensor using masks and data
|
||||
# tensors.
|
||||
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
|
||||
-torch.ones_like(draft_token_ids))
|
||||
torch.where(accepted_mask,
|
||||
draft_token_ids,
|
||||
-torch.ones_like(draft_token_ids),
|
||||
out=output)
|
||||
|
||||
# Fill the last column.
|
||||
# We check output directly as accepted may have True values inconsistent
|
||||
|
||||
@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list, ))
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(
|
||||
self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
self, contracted_bs: int, target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1, self._vocab_size)
|
||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.reshape(*target_token_ids.shape,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.reshape(target_probs.shape)
|
||||
|
||||
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1,
|
||||
device=self._device,
|
||||
dtype=torch.long)
|
||||
all_probs = torch.zeros(contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
all_logprobs = torch.full(size=(
|
||||
contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
),
|
||||
fill_value=-float("inf"),
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
|
||||
"""
|
||||
assert "speculative_config" in kwargs
|
||||
speculative_config = kwargs.get("speculative_config")
|
||||
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
|
||||
assert speculative_config is not None
|
||||
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
@ -109,12 +110,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
rejection_sampler=RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens, ))
|
||||
return SpecDecodeWorker(proposer_worker,
|
||||
scorer_worker,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
rejection_sampler=RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
|
||||
@staticmethod
|
||||
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices, transposed):
|
||||
"""Remove sequences from nonzero_proposal_len_indices and reset
|
||||
their proposal_len to 0 the draft worker does not provide a proposal
|
||||
@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
maybe_sampler_output: Optional[List[SamplerOutput]],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(
|
||||
size=(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_tokens = torch.tensor(-1,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len)
|
||||
proposal_probs = torch.tensor(0,
|
||||
dtype=torch.float32,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len,
|
||||
self._vocab_size)
|
||||
proposal_lens_tensor = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
len(proposal_lens))
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(
|
||||
entire_proposal_tokens = proposal_tokens.new_full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(
|
||||
entire_proposal_probs = proposal_probs.new_zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
@ -16,11 +15,7 @@ def get_all_seq_ids(
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
return list(
|
||||
chain.from_iterable([
|
||||
seq_group_metadata.seq_data.keys()
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]))
|
||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
@ -68,7 +63,7 @@ def create_sequence_group_output(
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[int],
|
||||
topk_logprobs: List[float],
|
||||
) -> SequenceGroupOutput:
|
||||
) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"chatglm": ChatGLMConfig,
|
||||
"dbrx": DbrxConfig,
|
||||
"mpt": MPTConfig,
|
||||
@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
assert hasattr(config.text_config, "num_attention_heads")
|
||||
return config.text_config
|
||||
else:
|
||||
return config
|
||||
return config
|
||||
|
||||
@ -527,16 +527,6 @@ class ModelRunner:
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
@ -544,11 +534,6 @@ class ModelRunner:
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
@ -601,6 +586,21 @@ class ModelRunner:
|
||||
seq_start_loc=seq_start_loc,
|
||||
data_type=kv_cache_dtype)
|
||||
else:
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user