Add option to restrict media domains (#25783)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Russell Bryant 2025-09-26 21:23:52 -04:00 committed by simon-mo
parent 04c2b26972
commit 32335c8b34
11 changed files with 80 additions and 1 deletions

View File

@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
!!! tip
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
## Offline Inference ## Offline Inference
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:

View File

@ -60,6 +60,12 @@ Key points from the PyTorch security guide:
- Implement proper authentication and authorization for management interfaces - Implement proper authentication and authorization for management interfaces
- Follow the principle of least privilege for all system components - Follow the principle of least privilege for all system components
### 4. **Restrict Domains Access for Media URLs:**
Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
## Security and Firewalls: Protecting Exposed vLLM Systems ## Security and Firewalls: Protecting Exposed vLLM Systems
While vLLM is designed to allow unsafe network services to be isolated to While vLLM is designed to allow unsafe network services to be isolated to

View File

@ -45,6 +45,7 @@ class MockModelConfig:
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False

View File

@ -240,6 +240,7 @@ class MockModelConfig:
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)

View File

@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes()) @pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: dict[str, Image.Image], async def test_fetch_image_base64(url_images: dict[str, Image.Image],
raw_image_url: str, suffix: str): raw_image_url: str, suffix: str):
connector = MediaConnector() connector = MediaConnector(
# Domain restriction should not apply to data URLs.
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
url_image = url_images[raw_image_url] url_image = url_images[raw_image_url]
try: try:
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
modality_idxs = argsort_mm_positions(mm_positions) modality_idxs = argsort_mm_positions(mm_positions)
assert modality_idxs == expected_modality_idxs assert modality_idxs == expected_modality_idxs
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_allowed_media_domains(video_url: str, num_frames: int):
connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}},
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
with pytest.raises(ValueError):
_, _ = connector.fetch_video(disallowed_url)
with pytest.raises(ValueError):
_, _ = await connector.fetch_video_async(disallowed_url)

View File

@ -137,6 +137,9 @@ class ModelConfig:
"""Allowing API requests to read local images or videos from directories """Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only specified by the server file system. This is a security risk. Should only
be enabled in trusted environments.""" be enabled in trusted environments."""
allowed_media_domains: Optional[list[str]] = None
"""If set, only media URLs that belong to this domain can be used for
multi-modal inputs. """
revision: Optional[str] = None revision: Optional[str] = None
"""The specific model version to use. It can be a branch name, a tag name, """The specific model version to use. It can be a branch name, a tag name,
or a commit id. If unspecified, will use the default version.""" or a commit id. If unspecified, will use the default version."""

View File

@ -281,6 +281,8 @@ class SpeculativeConfig:
trust_remote_code, trust_remote_code,
allowed_local_media_path=self.target_model_config. allowed_local_media_path=self.target_model_config.
allowed_local_media_path, allowed_local_media_path,
allowed_media_domains=self.target_model_config.
allowed_media_domains,
dtype=self.target_model_config.dtype, dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed, seed=self.target_model_config.seed,
revision=self.revision, revision=self.revision,

View File

@ -297,6 +297,8 @@ class EngineArgs:
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
allowed_media_domains: Optional[
list[str]] = ModelConfig.allowed_media_domains
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
@ -531,6 +533,8 @@ class EngineArgs:
**model_kwargs["hf_config_path"]) **model_kwargs["hf_config_path"])
model_group.add_argument("--allowed-local-media-path", model_group.add_argument("--allowed-local-media-path",
**model_kwargs["allowed_local_media_path"]) **model_kwargs["allowed_local_media_path"])
model_group.add_argument("--allowed-media-domains",
**model_kwargs["allowed_media_domains"])
model_group.add_argument("--revision", **model_kwargs["revision"]) model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision", model_group.add_argument("--code-revision",
**model_kwargs["code_revision"]) **model_kwargs["code_revision"])
@ -997,6 +1001,7 @@ class EngineArgs:
tokenizer_mode=self.tokenizer_mode, tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path, allowed_local_media_path=self.allowed_local_media_path,
allowed_media_domains=self.allowed_media_domains,
dtype=self.dtype, dtype=self.dtype,
seed=self.seed, seed=self.seed,
revision=self.revision, revision=self.revision,

View File

@ -637,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self): def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path return self._model_config.allowed_local_media_path
@property
def allowed_media_domains(self):
return self._model_config.allowed_media_domains
@property @property
def mm_registry(self): def mm_registry(self):
return MULTIMODAL_REGISTRY return MULTIMODAL_REGISTRY
@ -837,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
) )
def parse_image( def parse_image(
@ -921,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
) )
def parse_image( def parse_image(

View File

@ -86,6 +86,8 @@ class LLM:
or videos from directories specified by the server file system. or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted This is a security risk. Should only be enabled in trusted
environments. environments.
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently, dtype: The data type for the model weights and activations. Currently,
@ -169,6 +171,7 @@ class LLM:
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: ModelDType = "auto", dtype: ModelDType = "auto",
quantization: Optional[QuantizationMethods] = None, quantization: Optional[QuantizationMethods] = None,
@ -264,6 +267,7 @@ class LLM:
skip_tokenizer_init=skip_tokenizer_init, skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path, allowed_local_media_path=allowed_local_media_path,
allowed_media_domains=allowed_media_domains,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
quantization=quantization, quantization=quantization,

View File

@ -50,6 +50,7 @@ class MediaConnector:
connection: HTTPConnection = global_http_connection, connection: HTTPConnection = global_http_connection,
*, *,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
) -> None: ) -> None:
""" """
Args: Args:
@ -82,6 +83,9 @@ class MediaConnector:
allowed_local_media_path_ = None allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_ self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
def _load_data_url( def _load_data_url(
self, self,
@ -115,6 +119,14 @@ class MediaConnector:
return media_io.load_file(filepath) return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
if self.allowed_media_domains and url_spec.hostname not in \
self.allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}")
def load_from_url( def load_from_url(
self, self,
url: str, url: str,
@ -125,6 +137,8 @@ class MediaConnector:
url_spec = urlparse(url) url_spec = urlparse(url)
if url_spec.scheme.startswith("http"): if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout) data = connection.get_bytes(url, timeout=fetch_timeout)
@ -150,6 +164,8 @@ class MediaConnector:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if url_spec.scheme.startswith("http"): if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout) data = await connection.async_get_bytes(url, timeout=fetch_timeout)
future = loop.run_in_executor(global_thread_pool, future = loop.run_in_executor(global_thread_pool,