rattus128 e42682b24e
Reduce Peak WAN inference VRAM usage (#9898)
* flux: Do the xq and xk ropes one at a time

This was doing independendent interleaved tensor math on the q and k
tensors, leading to the holding of more than the minimum intermediates
in VRAM. On a bad day, it would VRAM OOM on xk intermediates.

Do everything q and then everything k, so torch can garbage collect
all of qs intermediates before k allocates its intermediates.

This reduces peak VRAM usage for some WAN2.2 inferences (at least).

* wan: Optimize qkv intermediates on attention

As commented. The former logic computed independent pieces of QKV in
parallel which help more inference intermediates in VRAM spiking
VRAM usage. Fully roping Q and garbage collecting the intermediates
before touching K reduces the peak inference VRAM usage.
2025-09-16 19:21:14 -04:00
..
2024-12-20 16:24:55 -05:00
2024-06-27 18:43:11 -04:00
2025-08-27 12:45:02 -04:00
2025-09-02 15:36:22 -04:00
2025-01-24 06:15:54 -05:00
2025-07-06 07:07:39 -04:00
2025-09-15 18:10:55 -04:00
2025-09-02 15:36:22 -04:00
2025-09-01 20:33:50 -04:00
2025-07-24 20:59:19 -04:00