Fix IntermediateTensors initialization and add type hints (#28743)

Signed-off-by: Mohammad Othman <Mo@MohammadOthman.com>
Co-authored-by: Mohammad Othman <Mo@MohammadOthman.com>
This commit is contained in:
Mohammad Othman 2025-11-15 06:31:36 +02:00 committed by GitHub
parent ac86bff8cb
commit 363aaeef0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,12 +60,17 @@ class IntermediateTensors:
tensors: dict[str, torch.Tensor]
kv_connector_output: KVConnectorOutput | None
def __init__(self, tensors):
def __init__(
self,
tensors: dict[str, torch.Tensor],
kv_connector_output: KVConnectorOutput | None = None,
) -> None:
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
self.kv_connector_output = kv_connector_output
def __getitem__(self, key: str | slice):
if isinstance(key, str):