mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 23:57:13 +08:00
[Bugfix] fix IntermediateTensors equal method (#23027)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
parent
27e8d1ea3e
commit
5a30bd10d8
@ -2,10 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
|
||||
SequenceOutput)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
SequenceData, SequenceOutput)
|
||||
|
||||
from .core.utils import create_dummy_prompt
|
||||
|
||||
@ -98,3 +99,38 @@ def test_sequence_group_stage():
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
assert seq_group.is_prefill() is False
|
||||
|
||||
|
||||
def test_sequence_intermediate_tensors_equal():
|
||||
|
||||
class AnotherIntermediateTensors(IntermediateTensors):
|
||||
pass
|
||||
|
||||
intermediate_tensors = IntermediateTensors({})
|
||||
another_intermediate_tensors = AnotherIntermediateTensors({})
|
||||
assert intermediate_tensors != another_intermediate_tensors
|
||||
|
||||
empty_intermediate_tensors_1 = IntermediateTensors({})
|
||||
empty_intermediate_tensors_2 = IntermediateTensors({})
|
||||
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
|
||||
|
||||
different_key_intermediate_tensors_1 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)})
|
||||
difference_key_intermediate_tensors_2 = IntermediateTensors(
|
||||
{"2": torch.zeros([2, 4], dtype=torch.int32)})
|
||||
assert (different_key_intermediate_tensors_1
|
||||
!= difference_key_intermediate_tensors_2)
|
||||
|
||||
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)})
|
||||
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 5], dtype=torch.int32)})
|
||||
assert (same_key_different_value_intermediate_tensors_1
|
||||
!= same_key_different_value_intermediate_tensors_2)
|
||||
|
||||
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)})
|
||||
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)})
|
||||
assert (same_key_same_value_intermediate_tensors_1 ==
|
||||
same_key_same_value_intermediate_tensors_2)
|
||||
|
||||
@ -1163,7 +1163,13 @@ class IntermediateTensors:
|
||||
return len(self.tensors)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other, self.__class__) and self
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if self.tensors.keys() != other.tensors.keys():
|
||||
return False
|
||||
return all(
|
||||
torch.equal(self.tensors[k], other.tensors[k])
|
||||
for k in self.tensors)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user