diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md index 683fa8217a80..d1c058eafa4c 100644 --- a/docs/source/deployment/frameworks/index.md +++ b/docs/source/deployment/frameworks/index.md @@ -11,6 +11,7 @@ helm lws modal open-webui +retrieval_augmented_generation skypilot streamlit triton diff --git a/docs/source/deployment/frameworks/retrieval_augmented_generation.md b/docs/source/deployment/frameworks/retrieval_augmented_generation.md new file mode 100644 index 000000000000..f84451fafe91 --- /dev/null +++ b/docs/source/deployment/frameworks/retrieval_augmented_generation.md @@ -0,0 +1,84 @@ +(deployment-retrieval-augmented-generation)= + +# Retrieval-Augmented Generation + +[Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources. + +Here are the integrations: +- vLLM + [langchain](https://github.com/langchain-ai/langchain) + [milvus](https://github.com/milvus-io/milvus) +- vLLM + [llamaindex](https://github.com/run-llama/llama_index) + [milvus](https://github.com/milvus-io/milvus) + +## vLLM + langchain + +### Prerequisites + +- Setup vLLM and langchain environment + +```console +pip install -U vllm \ + langchain_milvus langchain_openai \ + langchain_community beautifulsoup4 \ + langchain-text-splitters +``` + +### Deploy + +- Start the vLLM server with the supported embedding model, e.g. + +```console +# Start embedding service (port 8000) +vllm serve ssmits/Qwen2-7B-Instruct-embed-base +``` + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +# Start chat service (port 8001) +vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 +``` + +- Use the script: + +- Run the script + +```python +python retrieval_augmented_generation_with_langchain.py +``` + +## vLLM + llamaindex + +### Prerequisites + +- Setup vLLM and llamaindex environment + +```console +pip install vllm \ + llama-index llama-index-readers-web \ + llama-index-llms-openai-like \ + llama-index-embeddings-openai-like \ + llama-index-vector-stores-milvus \ +``` + +### Deploy + +- Start the vLLM server with the supported embedding model, e.g. + +```console +# Start embedding service (port 8000) +vllm serve ssmits/Qwen2-7B-Instruct-embed-base +``` + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +# Start chat service (port 8001) +vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 +``` + +- Use the script: + +- Run the script + +```python +python retrieval_augmented_generation_with_llamaindex.py +``` diff --git a/examples/online_serving/retrieval_augmented_generation_with_langchain.py b/examples/online_serving/retrieval_augmented_generation_with_langchain.py new file mode 100644 index 000000000000..73063065cb36 --- /dev/null +++ b/examples/online_serving/retrieval_augmented_generation_with_langchain.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Retrieval Augmented Generation (RAG) Implementation with Langchain +================================================================== + +This script demonstrates a RAG implementation using LangChain, Milvus +and vLLM. RAG enhances LLM responses by retrieving relevant context +from a document collection. + +Features: +- Web content loading and chunking +- Vector storage with Milvus +- Embedding generation with vLLM +- Question answering with context + +Prerequisites: +1. Install dependencies: + pip install -U vllm \ + langchain_milvus langchain_openai \ + langchain_community beautifulsoup4 \ + langchain-text-splitters + +2. Start services: + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + +Usage: + python retrieval_augmented_generation_with_langchain.py + +Notes: + - Ensure both vLLM services are running before executing + - Default ports: 8000 (embedding), 8001 (chat) + - First run may take time to download models +""" + +import argparse +from argparse import Namespace +from typing import Any + +from langchain_community.document_loaders import WebBaseLoader +from langchain_core.documents import Document +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_milvus import Milvus +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_text_splitters import RecursiveCharacterTextSplitter + + +def load_and_split_documents(config: dict[str, Any]): + """ + Load and split documents from web URL + """ + try: + loader = WebBaseLoader(web_paths=(config["url"], )) + docs = loader.load() + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"], + ) + return text_splitter.split_documents(docs) + except Exception as e: + print(f"Error loading document from {config['url']}: {str(e)}") + raise + + +def init_vectorstore(config: dict[str, Any], documents: list[Document]): + """ + Initialize vector store with documents + """ + return Milvus.from_documents( + documents=documents, + embedding=OpenAIEmbeddings( + model=config["embedding_model"], + openai_api_key=config["vllm_api_key"], + openai_api_base=config["vllm_embedding_endpoint"], + ), + connection_args={"uri": config["uri"]}, + drop_old=True, + ) + + +def init_llm(config: dict[str, Any]): + """ + Initialize llm + """ + return ChatOpenAI( + model=config["chat_model"], + openai_api_key=config["vllm_api_key"], + openai_api_base=config["vllm_chat_endpoint"], + ) + + +def get_qa_prompt(): + """ + Get question answering prompt template + """ + template = """You are an assistant for question-answering tasks. +Use the following pieces of retrieved context to answer the question. +If you don't know the answer, just say that you don't know. +Use three sentences maximum and keep the answer concise. +Question: {question} +Context: {context} +Answer: +""" + return PromptTemplate.from_template(template) + + +def format_docs(docs: list[Document]): + """ + Format documents for prompt + """ + return "\n\n".join(doc.page_content for doc in docs) + + +def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate): + """ + Set up question answering chain + """ + return ({ + "context": retriever | format_docs, + "question": RunnablePassthrough(), + } + | prompt + | llm + | StrOutputParser()) + + +def get_parser() -> argparse.ArgumentParser: + """ + Parse command line arguments + """ + parser = argparse.ArgumentParser(description='RAG with vLLM and langchain') + + # Add command line arguments + parser.add_argument('--vllm-api-key', + default="EMPTY", + help='API key for vLLM compatible services') + parser.add_argument('--vllm-embedding-endpoint', + default="http://localhost:8000/v1", + help='Base URL for embedding service') + parser.add_argument('--vllm-chat-endpoint', + default="http://localhost:8001/v1", + help='Base URL for chat service') + parser.add_argument('--uri', + default="./milvus.db", + help='URI for Milvus database') + parser.add_argument( + '--url', + default=("https://docs.vllm.ai/en/latest/getting_started/" + "quickstart.html"), + help='URL of the document to process') + parser.add_argument('--embedding-model', + default="ssmits/Qwen2-7B-Instruct-embed-base", + help='Model name for embeddings') + parser.add_argument('--chat-model', + default="qwen/Qwen1.5-0.5B-Chat", + help='Model name for chat') + parser.add_argument('-i', + '--interactive', + action='store_true', + help='Enable interactive Q&A mode') + parser.add_argument('-k', + '--top-k', + type=int, + default=3, + help='Number of top results to retrieve') + parser.add_argument('-c', + '--chunk-size', + type=int, + default=1000, + help='Chunk size for document splitting') + parser.add_argument('-o', + '--chunk-overlap', + type=int, + default=200, + help='Chunk overlap for document splitting') + + return parser + + +def init_config(args: Namespace): + """ + Initialize configuration settings from command line arguments + """ + + return { + "vllm_api_key": args.vllm_api_key, + "vllm_embedding_endpoint": args.vllm_embedding_endpoint, + "vllm_chat_endpoint": args.vllm_chat_endpoint, + "uri": args.uri, + "embedding_model": args.embedding_model, + "chat_model": args.chat_model, + "url": args.url, + "chunk_size": args.chunk_size, + "chunk_overlap": args.chunk_overlap, + "top_k": args.top_k + } + + +def main(): + # Parse command line arguments + args = get_parser().parse_args() + + # Initialize configuration + config = init_config(args) + + # Load and split documents + documents = load_and_split_documents(config) + + # Initialize vector store and retriever + vectorstore = init_vectorstore(config, documents) + retriever = vectorstore.as_retriever(search_kwargs={"k": config["top_k"]}) + + # Initialize llm and prompt + llm = init_llm(config) + prompt = get_qa_prompt() + + # Set up QA chain + qa_chain = create_qa_chain(retriever, llm, prompt) + + # Interactive mode + if args.interactive: + print("\nWelcome to Interactive Q&A System!") + print("Enter 'q' or 'quit' to exit.") + + while True: + question = input("\nPlease enter your question: ") + if question.lower() in ['q', 'quit']: + print("\nThank you for using! Goodbye!") + break + + output = qa_chain.invoke(question) + print(output) + else: + # Default single question mode + question = ("How to install vLLM?") + output = qa_chain.invoke(question) + print("-" * 50) + print(output) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py new file mode 100644 index 000000000000..a8f76dfe4c69 --- /dev/null +++ b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +RAG (Retrieval Augmented Generation) Implementation with LlamaIndex +================================================================ + +This script demonstrates a RAG system using: +- LlamaIndex: For document indexing and retrieval +- Milvus: As vector store backend +- vLLM: For embedding and text generation + +Features: +1. Document Loading & Processing +2. Embedding & Storage +3. Query Processing + +Requirements: +1. Install dependencies: +pip install llama-index llama-index-readers-web \ + llama-index-llms-openai-like \ + llama-index-embeddings-openai-like \ + llama-index-vector-stores-milvus \ + +2. Start services: + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + +Usage: + python retrieval_augmented_generation_with_llamaindex.py + +Notes: + - Ensure both vLLM services are running before executing + - Default ports: 8000 (embedding), 8001 (chat) + - First run may take time to download models +""" +import argparse +from argparse import Namespace +from typing import Any + +from llama_index.core import Settings, StorageContext, VectorStoreIndex +from llama_index.core.node_parser import SentenceSplitter +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from llama_index.llms.openai_like import OpenAILike +from llama_index.readers.web import SimpleWebPageReader +from llama_index.vector_stores.milvus import MilvusVectorStore + + +def init_config(args: Namespace): + """Initialize configuration with command line arguments""" + return { + "url": args.url, + "embedding_model": args.embedding_model, + "chat_model": args.chat_model, + "vllm_api_key": args.vllm_api_key, + "embedding_endpoint": args.embedding_endpoint, + "chat_endpoint": args.chat_endpoint, + "db_path": args.db_path, + "chunk_size": args.chunk_size, + "chunk_overlap": args.chunk_overlap, + "top_k": args.top_k + } + + +def load_documents(url: str) -> list: + """Load and process web documents""" + return SimpleWebPageReader(html_to_text=True).load_data([url]) + + +def setup_models(config: dict[str, Any]): + """Configure embedding and chat models""" + Settings.embed_model = OpenAILikeEmbedding( + api_base=config["embedding_endpoint"], + api_key=config["vllm_api_key"], + model_name=config["embedding_model"], + ) + + Settings.llm = OpenAILike( + model=config["chat_model"], + api_key=config["vllm_api_key"], + api_base=config["chat_endpoint"], + context_window=128000, + is_chat_model=True, + is_function_calling_model=False, + ) + + Settings.transformations = [ + SentenceSplitter( + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"], + ) + ] + + +def setup_vector_store(db_path: str) -> MilvusVectorStore: + """Initialize vector store""" + sample_emb = Settings.embed_model.get_text_embedding("test") + print(f"Embedding dimension: {len(sample_emb)}") + return MilvusVectorStore(uri=db_path, dim=len(sample_emb), overwrite=True) + + +def create_index(documents: list, vector_store: MilvusVectorStore): + """Create document index""" + storage_context = StorageContext.from_defaults(vector_store=vector_store) + return VectorStoreIndex.from_documents( + documents, + storage_context=storage_context, + ) + + +def query_document(index: VectorStoreIndex, question: str, top_k: int): + """Query document with given question""" + query_engine = index.as_query_engine(similarity_top_k=top_k) + return query_engine.query(question) + + +def get_parser() -> argparse.ArgumentParser: + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description='RAG with vLLM and LlamaIndex') + + # Add command line arguments + parser.add_argument( + '--url', + default=("https://docs.vllm.ai/en/latest/getting_started/" + "quickstart.html"), + help='URL of the document to process') + parser.add_argument('--embedding-model', + default="ssmits/Qwen2-7B-Instruct-embed-base", + help='Model name for embeddings') + parser.add_argument('--chat-model', + default="qwen/Qwen1.5-0.5B-Chat", + help='Model name for chat') + parser.add_argument('--vllm-api-key', + default="EMPTY", + help='API key for vLLM compatible services') + parser.add_argument('--embedding-endpoint', + default="http://localhost:8000/v1", + help='Base URL for embedding service') + parser.add_argument('--chat-endpoint', + default="http://localhost:8001/v1", + help='Base URL for chat service') + parser.add_argument('--db-path', + default="./milvus_demo.db", + help='Path to Milvus database') + parser.add_argument('-i', + '--interactive', + action='store_true', + help='Enable interactive Q&A mode') + parser.add_argument('-c', + '--chunk-size', + type=int, + default=1000, + help='Chunk size for document splitting') + parser.add_argument('-o', + '--chunk-overlap', + type=int, + default=200, + help='Chunk overlap for document splitting') + parser.add_argument('-k', + '--top-k', + type=int, + default=3, + help='Number of top results to retrieve') + + return parser + + +def main(): + # Parse command line arguments + args = get_parser().parse_args() + + # Initialize configuration + config = init_config(args) + + # Load documents + documents = load_documents(config["url"]) + + # Setup models + setup_models(config) + + # Setup vector store + vector_store = setup_vector_store(config["db_path"]) + + # Create index + index = create_index(documents, vector_store) + + if args.interactive: + print("\nEntering interactive mode. Type 'quit' to exit.") + while True: + # Get user question + question = input("\nEnter your question: ") + + # Check for exit command + if question.lower() in ['quit', 'exit', 'q']: + print("Exiting interactive mode...") + break + + # Get and print response + print("\n" + "-" * 50) + print("Response:\n") + response = query_document(index, question, config["top_k"]) + print(response) + print("-" * 50) + else: + # Single query mode + question = "How to install vLLM?" + response = query_document(index, question, config["top_k"]) + print("-" * 50) + print("Response:\n") + print(response) + print("-" * 50) + + +if __name__ == "__main__": + main()