[fix] lora benchmarks pass no_lora_flag_cpu (#23774)

Signed-off-by: Dylan Maloy <34420038+dolpm@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
dolpm 2025-09-17 06:22:25 -07:00 committed by GitHub
parent bfe9380161
commit 1b962e2457
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -464,7 +464,11 @@ class BenchmarkTensors:
for field_name in LoRAKernelMeta.__dataclass_fields__: for field_name in LoRAKernelMeta.__dataclass_fields__:
field = getattr(self.lora_kernel_meta, field_name) field = getattr(self.lora_kernel_meta, field_name)
assert isinstance(field, torch.Tensor) assert isinstance(field, torch.Tensor)
setattr(self.lora_kernel_meta, field_name, to_device(field)) setattr(
self.lora_kernel_meta,
field_name,
to_device(field) if field_name != "no_lora_flag_cpu" else field,
)
def metadata(self) -> tuple[int, int, int]: def metadata(self) -> tuple[int, int, int]:
""" """
@ -512,6 +516,7 @@ class BenchmarkTensors:
"lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
"lora_ids": self.lora_kernel_meta.active_lora_ids, "lora_ids": self.lora_kernel_meta.active_lora_ids,
"scaling": 1.0, "scaling": 1.0,
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
} }
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
@ -552,6 +557,7 @@ class BenchmarkTensors:
"lora_ids": self.lora_kernel_meta.active_lora_ids, "lora_ids": self.lora_kernel_meta.active_lora_ids,
"offset_start": 0, "offset_start": 0,
"add_inputs": add_inputs, "add_inputs": add_inputs,
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
} }
def bench_fn_kwargs( def bench_fn_kwargs(