# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.sequence import IntermediateTensors def test_sequence_intermediate_tensors_equal(): class AnotherIntermediateTensors(IntermediateTensors): pass intermediate_tensors = IntermediateTensors({}) another_intermediate_tensors = AnotherIntermediateTensors({}) assert intermediate_tensors != another_intermediate_tensors empty_intermediate_tensors_1 = IntermediateTensors({}) empty_intermediate_tensors_2 = IntermediateTensors({}) assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 different_key_intermediate_tensors_1 = IntermediateTensors( {"1": torch.zeros([2, 4], dtype=torch.int32)}) difference_key_intermediate_tensors_2 = IntermediateTensors( {"2": torch.zeros([2, 4], dtype=torch.int32)}) assert (different_key_intermediate_tensors_1 != difference_key_intermediate_tensors_2) same_key_different_value_intermediate_tensors_1 = IntermediateTensors( {"1": torch.zeros([2, 4], dtype=torch.int32)}) same_key_different_value_intermediate_tensors_2 = IntermediateTensors( {"1": torch.zeros([2, 5], dtype=torch.int32)}) assert (same_key_different_value_intermediate_tensors_1 != same_key_different_value_intermediate_tensors_2) same_key_same_value_intermediate_tensors_1 = IntermediateTensors( {"1": torch.zeros([2, 4], dtype=torch.int32)}) same_key_same_value_intermediate_tensors_2 = IntermediateTensors( {"1": torch.zeros([2, 4], dtype=torch.int32)}) assert (same_key_same_value_intermediate_tensors_1 == same_key_same_value_intermediate_tensors_2)