mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 20:57:18 +08:00
feat: support for penalty and badwords for async + spec
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
This commit is contained in:
parent
6f1f6a916a
commit
461e7b3e32
@ -147,22 +147,15 @@ class InputProcessor:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"vLLM V1 does not support per request user provided logits processors."
|
"vLLM V1 does not support per request user provided logits processors."
|
||||||
)
|
)
|
||||||
# Async scheduling + spec decode currently incompatible with some
|
# Async scheduling + spec decode currently incompatible with structured outputs
|
||||||
# sampling parameters.
|
|
||||||
if (
|
if (
|
||||||
self.vllm_config.speculative_config is not None
|
self.vllm_config.speculative_config is not None
|
||||||
and self.vllm_config.scheduler_config.async_scheduling
|
and self.vllm_config.scheduler_config.async_scheduling
|
||||||
and (
|
and params.structured_outputs
|
||||||
params.frequency_penalty != 0.0
|
|
||||||
or params.presence_penalty != 0.0
|
|
||||||
or params.repetition_penalty != 1.0
|
|
||||||
or params.bad_words_token_ids
|
|
||||||
or params.structured_outputs
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"async scheduling with spec decoding doesn't yet support "
|
"async scheduling with spec decoding doesn't yet support "
|
||||||
"penalties, bad words or structured outputs in sampling parameters."
|
"structured outputs in sampling parameters."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_params(
|
def _validate_params(
|
||||||
|
|||||||
@ -954,6 +954,31 @@ class InputBatch:
|
|||||||
if sampled_ids := sampled_token_ids[prev_index]:
|
if sampled_ids := sampled_token_ids[prev_index]:
|
||||||
req_output_token_ids[-len(sampled_ids) :] = sampled_ids
|
req_output_token_ids[-len(sampled_ids) :] = sampled_ids
|
||||||
|
|
||||||
|
def update_async_spec_token_ids(
|
||||||
|
self, draft_token_ids_cpu: list[list[int]] | None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
In async scheduling case, update spec_token_ids in sampling metadata
|
||||||
|
with real draft token ids from prior step.
|
||||||
|
This is called right before they are needed by the rejection sampler
|
||||||
|
for penalty/bad_words computation.
|
||||||
|
"""
|
||||||
|
if draft_token_ids_cpu is None or self.prev_req_id_to_index is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
spec_token_ids = self.sampling_metadata.spec_token_ids
|
||||||
|
if not spec_token_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
for index, req_id in enumerate(self.req_ids):
|
||||||
|
prev_index = self.prev_req_id_to_index.get(req_id, default=None)
|
||||||
|
if prev_index is None:
|
||||||
|
continue
|
||||||
|
assert prev_index < len(draft_token_ids_cpu)
|
||||||
|
draft_ids = draft_token_ids_cpu[prev_index]
|
||||||
|
assert index < len(spec_token_ids)
|
||||||
|
spec_token_ids[index] = draft_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_reqs(self) -> int:
|
def num_reqs(self) -> int:
|
||||||
return len(self.req_id_to_index)
|
return len(self.req_id_to_index)
|
||||||
|
|||||||
@ -591,15 +591,32 @@ class GPUModelRunner(
|
|||||||
# with dedicated stream for overlapping and event for coordination.
|
# with dedicated stream for overlapping and event for coordination.
|
||||||
self.valid_sampled_token_count_event: torch.Event | None = None
|
self.valid_sampled_token_count_event: torch.Event | None = None
|
||||||
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
|
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
|
||||||
|
self.valid_sampled_token_count_cpu: torch.Tensor | None = None
|
||||||
|
# Pre-allocated tensor for copying draft token ids to CPU,
|
||||||
|
# with dedicated stream for overlapping and event for coordination.
|
||||||
|
self.draft_token_ids_copy_event: torch.Event | None = None
|
||||||
|
self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None
|
||||||
|
self.draft_token_ids_cpu: torch.Tensor | None = None
|
||||||
if self.use_async_scheduling and self.num_spec_tokens:
|
if self.use_async_scheduling and self.num_spec_tokens:
|
||||||
self.valid_sampled_token_count_event = torch.Event()
|
self.valid_sampled_token_count_event = torch.Event()
|
||||||
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
|
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
|
||||||
self.valid_sampled_token_count_cpu = torch.empty(
|
self.valid_sampled_token_count_cpu = torch.empty(
|
||||||
self.max_num_reqs,
|
self.max_num_reqs,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
self.draft_token_ids_copy_event = torch.Event()
|
||||||
|
self.draft_token_ids_copy_stream = torch.cuda.Stream()
|
||||||
|
self.draft_token_ids_cpu = torch.empty(
|
||||||
|
(self.max_num_reqs, self.num_spec_tokens),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
)
|
||||||
|
# Flag to track if valid draft tokens were copied this step.
|
||||||
|
# Reset to False at step start, set True in _copy_draft_token_ids.
|
||||||
|
self._has_draft_tokens: bool = False
|
||||||
|
|
||||||
# Ephemeral state transferred between execute_model() and sample_tokens().
|
# Ephemeral state transferred between execute_model() and sample_tokens().
|
||||||
self.execute_model_state: ExecuteModelState | None = None
|
self.execute_model_state: ExecuteModelState | None = None
|
||||||
@ -2555,15 +2572,19 @@ class GPUModelRunner(
|
|||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
|
# Update output token ids with tokens sampled in last step
|
||||||
|
# if async scheduling and required by current sampling params.
|
||||||
|
self.input_batch.update_async_output_token_ids()
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
# Update output token ids with tokens sampled in last step
|
|
||||||
# if async scheduling and required by current sampling params.
|
|
||||||
self.input_batch.update_async_output_token_ids()
|
|
||||||
return self.sampler(
|
return self.sampler(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update spec_token_ids with real draft tokens from previous step
|
||||||
|
draft_token_ids_cpu = self._get_draft_token_ids_cpu()
|
||||||
|
self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)
|
||||||
|
|
||||||
sampler_output = self.rejection_sampler(
|
sampler_output = self.rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
None, # draft_probs
|
None, # draft_probs
|
||||||
@ -3370,6 +3391,55 @@ class GPUModelRunner(
|
|||||||
self.valid_sampled_token_count_event.synchronize()
|
self.valid_sampled_token_count_event.synchronize()
|
||||||
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
|
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
|
||||||
|
|
||||||
|
def _copy_draft_token_ids(
|
||||||
|
self, draft_token_ids: torch.Tensor, num_reqs: int
|
||||||
|
) -> None:
|
||||||
|
"""Copy draft token ids to CPU asynchronously.
|
||||||
|
|
||||||
|
This is used for async scheduling with spec decode + penalty/bad_words.
|
||||||
|
The draft_token_ids will be used in the next step to update
|
||||||
|
input_batch.spec_token_ids for correct penalty/bad_words computation.
|
||||||
|
"""
|
||||||
|
if self.draft_token_ids_copy_event is None or not isinstance(
|
||||||
|
draft_token_ids, torch.Tensor
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
default_stream = torch.cuda.current_stream()
|
||||||
|
with torch.cuda.stream(self.draft_token_ids_copy_stream):
|
||||||
|
self.draft_token_ids_copy_stream.wait_stream(default_stream) # type: ignore
|
||||||
|
# Copy draft_token_ids [num_reqs, num_spec_tokens] to pinned CPU
|
||||||
|
assert self.draft_token_ids_cpu is not None
|
||||||
|
self.draft_token_ids_cpu[:num_reqs].copy_(
|
||||||
|
draft_token_ids[:num_reqs], non_blocking=True
|
||||||
|
)
|
||||||
|
self.draft_token_ids_copy_event.record()
|
||||||
|
self._has_draft_tokens = True
|
||||||
|
|
||||||
|
def _get_draft_token_ids_cpu(self) -> list[list[int]] | None:
|
||||||
|
"""Get previously copied draft token ids from CPU.
|
||||||
|
|
||||||
|
Called at the beginning of the next step to update spec_token_ids
|
||||||
|
for async scheduling with spec decode + penalty/bad_words.
|
||||||
|
Returns None if no draft tokens were copied in previous step.
|
||||||
|
"""
|
||||||
|
if isinstance(self._draft_token_ids, list):
|
||||||
|
return self._draft_token_ids
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.draft_token_ids_copy_event is None
|
||||||
|
or self.draft_token_ids_cpu is None
|
||||||
|
or not self._has_draft_tokens
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._has_draft_tokens = False
|
||||||
|
self.draft_token_ids_copy_event.synchronize()
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
if num_reqs == 0:
|
||||||
|
return None
|
||||||
|
return self.draft_token_ids_cpu[:num_reqs].tolist()
|
||||||
|
|
||||||
def propose_draft_token_ids(
|
def propose_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@ -3530,6 +3600,9 @@ class GPUModelRunner(
|
|||||||
mm_embed_inputs=mm_embed_inputs,
|
mm_embed_inputs=mm_embed_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._copy_draft_token_ids(
|
||||||
|
self._draft_token_ids, self.input_batch.num_reqs
|
||||||
|
)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user