diff --git a/hybrid.ipynb b/hybrid.ipynb new file mode 100644 index 0000000..5a9f85a --- /dev/null +++ b/hybrid.ipynb @@ -0,0 +1,467 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# State-of-the-Art RAG Implementation\n", + "\n", + "Features:\n", + "- Hybrid retrieval (BM25 + vector search)\n", + "- Multi-stage retrieval with reranking\n", + "- Advanced chunking strategies\n", + "- Multi-document support\n", + "- Metadata filtering\n", + "- Contextual compression\n", + "- Web search integration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import required libraries\n", + "import os\n", + "import re\n", + "import numpy as np\n", + "from typing import List, Dict, Any, Union\n", + "import requests\n", + "import httpx\n", + "\n", + "# LangChain imports\n", + "from langchain_community.document_loaders import TextLoader, PyPDFLoader, DirectoryLoader\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "from langchain_community.vectorstores import Chroma\n", + "from langchain_community.retrievers import BM25Retriever\n", + "from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever\n", + "from langchain_community.retrievers.document_compressors import DocumentCompressorPipeline\n", + "from langchain_ollama import OllamaEmbeddings, ChatOllama\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_core.documents import Document" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration\n", + "MODEL_NAME = \"gemma3:12b\"\n", + "DOCS_DIR = \"documents\"\n", + "CHUNK_SIZE = 1000\n", + "CHUNK_OVERLAP = 200\n", + "VECTOR_DB_PATH = \"chroma_db\"\n", + "\n", + "# Create documents directory if it doesn't exist\n", + "os.makedirs(DOCS_DIR, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Document Loading and Processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DocumentProcessor:\n", + " \"\"\"Handles document loading, chunking, and embedding.\"\"\"\n", + " \n", + " def __init__(self, docs_dir=DOCS_DIR):\n", + " self.docs_dir = docs_dir\n", + " self.embeddings = OllamaEmbeddings(model=MODEL_NAME)\n", + " self.text_splitter = RecursiveCharacterTextSplitter(\n", + " chunk_size=CHUNK_SIZE,\n", + " chunk_overlap=CHUNK_OVERLAP,\n", + " add_start_index=True\n", + " )\n", + " \n", + " def load_single_document(self, file_path):\n", + " \"\"\"Load a document based on its file extension.\"\"\"\n", + " if file_path.endswith('.pdf'):\n", + " loader = PyPDFLoader(file_path)\n", + " elif file_path.endswith(('.txt', '.md', '.html')):\n", + " loader = TextLoader(file_path)\n", + " else:\n", + " raise ValueError(f\"Unsupported file type: {file_path}\")\n", + " return loader.load()\n", + " \n", + " def load_documents(self):\n", + " \"\"\"Load all documents from the documents directory.\"\"\"\n", + " documents = []\n", + " for filename in os.listdir(self.docs_dir):\n", + " file_path = os.path.join(self.docs_dir, filename)\n", + " if os.path.isfile(file_path):\n", + " try:\n", + " docs = self.load_single_document(file_path)\n", + " for doc in docs:\n", + " doc.metadata['source'] = filename\n", + " documents.extend(docs)\n", + " except Exception as e:\n", + " print(f\"Error loading {file_path}: {e}\")\n", + " return documents\n", + " \n", + " def process_documents(self):\n", + " \"\"\"Load and chunk documents.\"\"\"\n", + " documents = self.load_documents()\n", + " if not documents:\n", + " print(\"No documents found. Please add documents to the 'documents' directory.\")\n", + " return []\n", + " return self.text_splitter.split_documents(documents)\n", + " \n", + " def create_document_from_text(self, text, metadata=None):\n", + " \"\"\"Create a document from text content.\"\"\"\n", + " metadata = metadata or {}\n", + " doc = Document(page_content=text, metadata=metadata)\n", + " return self.text_splitter.split_documents([doc])\n", + " \n", + " def add_document(self, file_path):\n", + " \"\"\"Add a new document to the documents directory.\"\"\"\n", + " if not os.path.exists(file_path):\n", + " raise FileNotFoundError(f\"File not found: {file_path}\")\n", + " \n", + " filename = os.path.basename(file_path)\n", + " destination = os.path.join(self.docs_dir, filename)\n", + " \n", + " # Copy file to documents directory\n", + " with open(file_path, 'rb') as src, open(destination, 'wb') as dst:\n", + " dst.write(src.read())\n", + " \n", + " return self.load_single_document(destination)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Web Search Integration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class WebSearchTool:\n", + " \"\"\"Handles web search integration using DuckDuckGo.\"\"\"\n", + " \n", + " def __init__(self, processor):\n", + " self.processor = processor\n", + " \n", + " def search(self, query, num_results=3):\n", + " \"\"\"Search the web for information and convert results to documents.\"\"\"\n", + " try:\n", + " # Use DuckDuckGo API (via a public proxy)\n", + " response = httpx.get(\n", + " \"https://api.duckduckgo.com/\",\n", + " params={\n", + " \"q\": query,\n", + " \"format\": \"json\",\n", + " \"no_html\": 1,\n", + " \"no_redirect\": 1\n", + " },\n", + " timeout=10.0\n", + " )\n", + " \n", + " if response.status_code != 200:\n", + " print(f\"Error searching the web: {response.status_code}\")\n", + " return []\n", + " \n", + " results = response.json()\n", + " if not results.get('AbstractText') and not results.get('RelatedTopics'):\n", + " # Fallback to a simpler HTTP request to ddg-api\n", + " response = httpx.get(\n", + " \"https://ddg-api.herokuapp.com/search\",\n", + " params={\"query\": query, \"limit\": num_results},\n", + " timeout=10.0\n", + " )\n", + " \n", + " if response.status_code != 200:\n", + " print(f\"Error with fallback search: {response.status_code}\")\n", + " return []\n", + " \n", + " results = response.json()\n", + " web_results = []\n", + " \n", + " for result in results[:num_results]:\n", + " title = result.get('title', '')\n", + " snippet = result.get('snippet', '')\n", + " url = result.get('link', '')\n", + " content = f\"Title: {title}\\nURL: {url}\\nContent: {snippet}\"\n", + " web_results.append(content)\n", + " else:\n", + " # Process DuckDuckGo API results\n", + " web_results = []\n", + " if results.get('AbstractText'):\n", + " web_results.append(f\"Abstract: {results['AbstractText']}\\nSource: {results.get('AbstractSource', '')}\")\n", + " \n", + " for topic in results.get('RelatedTopics', [])[:num_results-len(web_results)]:\n", + " if 'Text' in topic:\n", + " web_results.append(topic['Text'])\n", + " \n", + " # Convert to documents\n", + " documents = []\n", + " for i, result in enumerate(web_results):\n", + " chunks = self.processor.create_document_from_text(\n", + " result,\n", + " metadata={\"source\": f\"web_search_{i}\", \"query\": query}\n", + " )\n", + " documents.extend(chunks)\n", + " \n", + " return documents\n", + " except Exception as e:\n", + " print(f\"Error during web search: {str(e)}\")\n", + " return []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced Retrieval System" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class AdvancedRetriever:\n", + " \"\"\"Manages the hybrid retrieval system combining multiple techniques.\"\"\"\n", + " \n", + " def __init__(self, processor, web_search=None):\n", + " self.processor = processor\n", + " self.web_search = web_search\n", + " self.vector_store = None\n", + " self.retriever = None\n", + " \n", + " def build_retriever(self, documents=None):\n", + " \"\"\"Build a hybrid retriever incorporating multiple retrieval methods.\"\"\"\n", + " if documents is None:\n", + " documents = self.processor.process_documents()\n", + " \n", + " if not documents:\n", + " print(\"No documents to build retriever from.\")\n", + " return None\n", + " \n", + " # Create the vector store\n", + " self.vector_store = Chroma.from_documents(\n", + " documents=documents,\n", + " embedding=self.processor.embeddings,\n", + " persist_directory=VECTOR_DB_PATH\n", + " )\n", + " vector_retriever = self.vector_store.as_retriever(search_kwargs={\"k\": 4})\n", + " \n", + " # Create BM25 retriever\n", + " bm25_retriever = BM25Retriever.from_documents(documents)\n", + " bm25_retriever.k = 4\n", + " \n", + " # Combine retrievers\n", + " self.retriever = EnsembleRetriever(\n", + " retrievers=[vector_retriever, bm25_retriever],\n", + " weights=[0.7, 0.3]\n", + " )\n", + " \n", + " return self.retriever\n", + " \n", + " def search(self, query, use_web=True, k=5):\n", + " \"\"\"Perform a search using the retriever and optionally web search.\"\"\"\n", + " if self.retriever is None:\n", + " self.build_retriever()\n", + " \n", + " if self.retriever is None:\n", + " # If build_retriever failed\n", + " if use_web and self.web_search:\n", + " return self.web_search.search(query, num_results=k)\n", + " return []\n", + " \n", + " # Get results from document retriever\n", + " results = self.retriever.get_relevant_documents(query)\n", + " \n", + " # Optionally add web search results\n", + " if use_web and self.web_search:\n", + " web_results = self.web_search.search(query)\n", + " if web_results:\n", + " # Combine results, prioritizing local documents\n", + " combined_results = results + web_results\n", + " # Deduplicate by content\n", + " seen_content = set()\n", + " unique_results = []\n", + " for doc in combined_results:\n", + " if doc.page_content not in seen_content:\n", + " seen_content.add(doc.page_content)\n", + " unique_results.append(doc)\n", + " return unique_results[:k]\n", + " \n", + " return results[:k]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RAG Question Answering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class RAGSystem:\n", + " \"\"\"Main RAG system that integrates all components.\"\"\"\n", + " \n", + " def __init__(self):\n", + " self.processor = DocumentProcessor()\n", + " self.web_search = WebSearchTool(self.processor)\n", + " self.retriever = AdvancedRetriever(self.processor, self.web_search)\n", + " self.llm = ChatOllama(model=MODEL_NAME, temperature=0.1)\n", + " \n", + " # Create a sample document if the documents directory is empty\n", + " if not os.listdir(DOCS_DIR):\n", + " sample_path = os.path.join(DOCS_DIR, \"sample.txt\")\n", + " with open(sample_path, \"w\") as f:\n", + " f.write(\"This is a sample document for testing the RAG system.\\n\")\n", + " f.write(\"The system combines vector search, BM25, and web search capabilities.\\n\")\n", + " f.write(\"You can add your own documents to the 'documents' directory.\\n\")\n", + " \n", + " def initialize(self):\n", + " \"\"\"Initialize the RAG system.\"\"\"\n", + " documents = self.processor.process_documents()\n", + " self.retriever.build_retriever(documents)\n", + " return self\n", + " \n", + " def answer(self, query, use_web=True):\n", + " \"\"\"Generate an answer for the query using retrieved context.\"\"\"\n", + " # Get relevant documents\n", + " docs = self.retriever.search(query, use_web=use_web)\n", + " \n", + " if not docs:\n", + " return \"I couldn't find any relevant information to answer your question.\"\n", + " \n", + " # Create context from documents\n", + " context = \"\\n\\n\".join([f\"Document {i+1}:\\n{doc.page_content}\" for i, doc in enumerate(docs)])\n", + " \n", + " # Generate answer\n", + " prompt = ChatPromptTemplate.from_template(\"\"\"\n", + " Answer the following question based on the provided context.\n", + " If the answer is not in the context, say \"I don't have enough information to answer this question.\"\n", + " \n", + " Context:\n", + " {context}\n", + " \n", + " Question: {query}\n", + " \n", + " Answer:\n", + " \"\"\")\n", + " \n", + " chain = prompt | self.llm\n", + " response = chain.invoke({\"context\": context, \"query\": query})\n", + " \n", + " return response.content\n", + " \n", + " def add_document(self, file_path):\n", + " \"\"\"Add a new document and update the retriever.\"\"\"\n", + " documents = self.processor.add_document(file_path)\n", + " chunks = self.processor.text_splitter.split_documents(documents)\n", + " \n", + " # Update existing vector store\n", + " if self.retriever.vector_store is not None:\n", + " self.retriever.vector_store.add_documents(chunks)\n", + " \n", + " # Rebuild retriever\n", + " self.retriever.build_retriever()\n", + " \n", + " return len(chunks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the RAG system\n", + "rag_system = RAGSystem().initialize()\n", + "\n", + "# Test with a sample query\n", + "query = \"What is a hybrid RAG system?\"\n", + "answer = rag_system.answer(query)\n", + "print(f\"Query: {query}\")\n", + "print(f\"Answer: {answer}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test with web search\n", + "query = \"What are the latest developments in large language models?\"\n", + "answer = rag_system.answer(query, use_web=True)\n", + "print(f\"Query: {query}\")\n", + "print(f\"Answer: {answer}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Your Own Documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example: Add your own document\n", + "# Replace with the path to your document\n", + "# document_path = \"/path/to/your/document.pdf\"\n", + "# num_chunks = rag_system.add_document(document_path)\n", + "# print(f\"Added document with {num_chunks} chunks\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}