[Bugfix] fix IntermediateTensors equal method (#23027)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie 2025-08-18 17:58:11 +08:00 committed by GitHub
parent 27e8d1ea3e
commit 5a30bd10d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 3 deletions

View File

@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt
@ -98,3 +99,38 @@ def test_sequence_group_stage():
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
def test_sequence_intermediate_tensors_equal():
class AnotherIntermediateTensors(IntermediateTensors):
pass
intermediate_tensors = IntermediateTensors({})
another_intermediate_tensors = AnotherIntermediateTensors({})
assert intermediate_tensors != another_intermediate_tensors
empty_intermediate_tensors_1 = IntermediateTensors({})
empty_intermediate_tensors_2 = IntermediateTensors({})
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
different_key_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
difference_key_intermediate_tensors_2 = IntermediateTensors(
{"2": torch.zeros([2, 4], dtype=torch.int32)})
assert (different_key_intermediate_tensors_1
!= difference_key_intermediate_tensors_2)
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 5], dtype=torch.int32)})
assert (same_key_different_value_intermediate_tensors_1
!= same_key_different_value_intermediate_tensors_2)
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
assert (same_key_same_value_intermediate_tensors_1 ==
same_key_same_value_intermediate_tensors_2)

View File

@ -1163,7 +1163,13 @@ class IntermediateTensors:
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
if not isinstance(other, self.__class__):
return False
if self.tensors.keys() != other.tensors.keys():
return False
return all(
torch.equal(self.tensors[k], other.tensors[k])
for k in self.tensors)
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"