Defensively copy sampling_params (#2881)

If the SamplingParams object passed to LLMEngine.add_request() is mutated after it returns, it could affect the async sampling process for that request.

Suggested by @Yard1 https://github.com/vllm-project/vllm/pull/2514#discussion_r1490106059
This commit is contained in:
Nick Hill 2024-02-17 11:18:04 -08:00 committed by GitHub
parent 5f08050d8d
commit 185b2c29e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -464,6 +464,9 @@ class LLMEngine:
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler
sampling_params = copy.deepcopy(sampling_params)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, prefix)