[doc] explain common errors around torch.compile (#12340)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-23 14:56:02 +08:00 committed by GitHub
parent f0ef37233e
commit 511627445e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -197,6 +197,27 @@ if __name__ == '__main__':
llm = vllm.LLM(...)
```
## `torch.compile` Error
vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](https://github.com/vllm-project/vllm/pull/10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script:
```python
import torch
@torch.compile
def f(x):
# a simple function to test torch.compile
x = x + 1
x = x * 2
x = x.sin()
return x
x = torch.randn(4, 4).cuda()
print(f(x))
```
If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See [this issue](https://github.com/vllm-project/vllm/issues/12219) for example.
## Known Issues
- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759).