mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:04:57 +08:00
[BugFix][Minor] Fix full cuda graph bug when max_num_seqs < 512 (#19171)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
0678b52251
commit
af7fc84fd2
@ -1737,7 +1737,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# has num_tokens in total.
|
# has num_tokens in total.
|
||||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
|
num_reqs = min(num_tokens, max_num_reqs)
|
||||||
min_tokens_per_req = num_tokens // num_reqs
|
min_tokens_per_req = num_tokens // num_reqs
|
||||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||||
@ -1765,7 +1765,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
attn_metadata_i = (
|
attn_metadata_i = (
|
||||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||||
num_reqs=num_tokens,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=num_tokens,
|
max_query_len=num_tokens,
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user