[Core] Zero-copy asdict for InputMetadata (#3475)

This commit is contained in:
Antoni Baum 2024-03-18 15:56:40 -07:00 committed by GitHub
parent 9fdf3de346
commit 49eedea373
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass, fields
from typing import Optional from typing import Optional, Any, Dict
import torch import torch
@ -31,3 +31,12 @@ class InputMetadata:
def __post_init__(self): def __post_init__(self):
# will not appear in the __repr__ and __init__ # will not appear in the __repr__ and __init__
self.attn_bias = None self.attn_bias = None
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
}

View File

@ -1,5 +1,4 @@
import contextlib import contextlib
import dataclasses
import time import time
from typing import Dict, List, Optional, Tuple, Set, Union from typing import Dict, List, Optional, Tuple, Set, Union
@ -527,7 +526,7 @@ class ModelRunner:
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
} }
metadata_dict.update(dataclasses.asdict(input_metadata)) metadata_dict.update(input_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)