diff --git a/tests/test_sequence.py b/tests/test_sequence.py index c734c8514a6da..1b019be9e56dc 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -98,3 +99,38 @@ def test_sequence_group_stage(): assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(1) assert seq_group.is_prefill() is False + + +def test_sequence_intermediate_tensors_equal(): + + class AnotherIntermediateTensors(IntermediateTensors): + pass + + intermediate_tensors = IntermediateTensors({}) + another_intermediate_tensors = AnotherIntermediateTensors({}) + assert intermediate_tensors != another_intermediate_tensors + + empty_intermediate_tensors_1 = IntermediateTensors({}) + empty_intermediate_tensors_2 = IntermediateTensors({}) + assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 + + different_key_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + difference_key_intermediate_tensors_2 = IntermediateTensors( + {"2": torch.zeros([2, 4], dtype=torch.int32)}) + assert (different_key_intermediate_tensors_1 + != difference_key_intermediate_tensors_2) + + same_key_different_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_different_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 5], dtype=torch.int32)}) + assert (same_key_different_value_intermediate_tensors_1 + != same_key_different_value_intermediate_tensors_2) + + same_key_same_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_same_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + assert (same_key_same_value_intermediate_tensors_1 == + same_key_same_value_intermediate_tensors_2) diff --git a/vllm/sequence.py b/vllm/sequence.py index 347015c7ef3d1..43d5c8beef270 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1163,7 +1163,13 @@ class IntermediateTensors: return len(self.tensors) def __eq__(self, other: object): - return isinstance(other, self.__class__) and self + if not isinstance(other, self.__class__): + return False + if self.tensors.keys() != other.tensors.keys(): + return False + return all( + torch.equal(self.tensors[k], other.tensors[k]) + for k in self.tensors) def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})"