ops: Put weight cast on the offload stream (#10697)

This needs to be on the offload stream. This reproduced a black screen
with low resolution images on a slow bus when using FP8.
This commit is contained in:
rattus 2025-11-10 13:52:11 +10:00 committed by GitHub
parent dea899f221
commit c350009236
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -110,9 +110,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.bias_function:
bias = f(bias)
weight = weight.to(dtype=dtype)
if weight_has_function:
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
for f in s.weight_function:
weight = f(weight)