[Bugfix][Misc]: fix graph capture for decoder (#9549)

This commit is contained in:
yudian0504 2024-10-22 01:33:30 +08:00 committed by GitHub
parent f6b97293aa
commit 8ca8954841
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -828,7 +828,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
num_seqs=len(seq_lens),
max_decode_seq_len=max_encoder_seq_len,
max_decode_seq_len=max_decode_seq_len,
max_encoder_seq_len=max_encoder_seq_len)
batch_size = len(input_tokens)