diff --git a/Hybrid RAG.ipynb b/Hybrid RAG.ipynb new file mode 100644 index 0000000..3006a80 --- /dev/null +++ b/Hybrid RAG.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "213b722c-b0d3-489c-a736-521e0d34dade", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import nltk\n", + "from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "from langchain_core.vectorstores import InMemoryVectorStore\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_ollama import OllamaEmbeddings, ChatOllama\n", + "from langchain_community.retrievers import BM25Retriever\n", + "from langchain.retrievers import EnsembleRetriever\n", + "from typing_extensions import TypedDict\n", + "from langgraph.graph import START, END, StateGraph\n", + "\n", + "# Ensure NLTK tokenizer is available\n", + "try:\n", + " nltk.data.find('tokenizers/punkt')\n", + "except LookupError:\n", + " nltk.download('punkt')\n", + "\n", + "# Initialize model and embeddings\n", + "model = ChatOllama(model=\"gemma3:12b\", temperature=0.2)\n", + "embeddings = OllamaEmbeddings(model=\"gemma3:12b\")\n", + "\n", + "# Vector store\n", + "vector_store = InMemoryVectorStore(embeddings)\n", + "\n", + "# Templates\n", + "qa_template = \"\"\"\n", + "You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. \n", + "If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n", + "Question: {question} \n", + "Context: {context} \n", + "Answer:\n", + "\"\"\"\n", + "\n", + "# Text splitter\n", + "def split_text(documents):\n", + " text_splitter = RecursiveCharacterTextSplitter(\n", + " chunk_size=1000,\n", + " chunk_overlap=200,\n", + " add_start_index=True\n", + " )\n", + " return text_splitter.split_documents(documents)\n", + "\n", + "# PDF handling\n", + "def load_pdf(file_path):\n", + " if not os.path.exists(file_path):\n", + " raise FileNotFoundError(f\"File not found: {file_path}\")\n", + " loader = PDFPlumberLoader(file_path)\n", + " documents = loader.load()\n", + " return documents\n", + "\n", + "# Web page handling (using WebBaseLoader)\n", + "def load_webpage(url):\n", + " loader = WebBaseLoader(url)\n", + " documents = loader.load()\n", + " return documents\n", + "\n", + "# Hybrid retriever\n", + "def build_hybrid_retriever(documents):\n", + " vector_store.clear()\n", + " vector_store.add_documents(documents)\n", + " semantic_retriever = vector_store.as_retriever(search_kwargs={\"k\": 3})\n", + " bm25_retriever = BM25Retriever.from_documents(documents)\n", + " bm25_retriever.k = 3\n", + " hybrid_retriever = EnsembleRetriever(\n", + " retrievers=[semantic_retriever, bm25_retriever],\n", + " weights=[0.7, 0.3]\n", + " )\n", + " return hybrid_retriever\n", + "\n", + "# DuckDuckGo search implementation\n", + "def search_ddg(query, num_results=3):\n", + " from langchain_community.utilities import DuckDuckGoSearchAPIWrapper\n", + " search = DuckDuckGoSearchAPIWrapper()\n", + " results = search.results(query, num_results)\n", + " return results\n", + "\n", + "# Answer question with error handling\n", + "def answer_question(question, documents):\n", + " try:\n", + " context = \"\\n\\n\".join([doc.page_content for doc in documents])\n", + " prompt = ChatPromptTemplate.from_template(qa_template)\n", + " chain = prompt | model\n", + " return chain.invoke({\"question\": question, \"context\": context}).content\n", + " except Exception as e:\n", + " return f\"Error generating answer: {e}\"\n", + "\n", + "# Simple RAG node for web search\n", + "class WebSearchState(TypedDict):\n", + " query: str\n", + " results: list\n", + " response: str\n", + "\n", + "def web_search(state):\n", + " results = search_ddg(state[\"query\"])\n", + " return {\"results\": results}\n", + "\n", + "def generate_search_response(state):\n", + " try:\n", + " context = \"\\n\\n\".join([f\"{r['title']}: {r['snippet']}\" for r in state[\"results\"]])\n", + " prompt = ChatPromptTemplate.from_template(qa_template)\n", + " chain = prompt | model\n", + " response = chain.invoke({\"question\": state[\"query\"], \"context\": context})\n", + " return {\"response\": response.content}\n", + " except Exception as e:\n", + " return {\"response\": f\"Error generating response: {e}\"}\n", + "\n", + "# Build search graph\n", + "search_graph = StateGraph(WebSearchState)\n", + "search_graph.add_node(\"search\", web_search)\n", + "search_graph.add_node(\"generate\", generate_search_response)\n", + "search_graph.add_edge(START, \"search\")\n", + "search_graph.add_edge(\"search\", \"generate\")\n", + "search_graph.add_edge(\"generate\", END)\n", + "search_workflow = search_graph.compile()\n", + "\n", + "# Main command-line interface\n", + "if __name__ == \"__main__\":\n", + " print(\"Welcome to the Advanced RAG System\")\n", + " print(\"Choose an option:\")\n", + " print(\"1. Analyze PDF\")\n", + " print(\"2. Crawl URL\")\n", + " print(\"3. Search Internet\")\n", + " choice = input(\"Enter your choice (1/2/3): \")\n", + "\n", + " if choice == \"1\":\n", + " pdf_path = input(\"Enter the path to the PDF file: \").strip()\n", + " if not pdf_path:\n", + " print(\"Please enter a valid file path.\")\n", + " else:\n", + " try:\n", + " print(\"Processing PDF...\")\n", + " documents = load_pdf(pdf_path)\n", + " if not documents:\n", + " print(\"No documents were loaded from the PDF. The file might be empty or not contain extractable text.\")\n", + " else:\n", + " chunked_documents = split_text(documents)\n", + " if not chunked_documents:\n", + " print(\"No text chunks were created. The PDF might not contain any text.\")\n", + " else:\n", + " retriever = build_hybrid_retriever(chunked_documents)\n", + " print(f\"Processed {len(chunked_documents)} chunks\")\n", + " question = input(\"Ask a question about the PDF: \").strip()\n", + " if not question:\n", + " print(\"Please enter a valid question.\")\n", + " else:\n", + " print(\"Searching document...\")\n", + " related_documents = retriever.get_relevant_documents(question)\n", + " if not related_documents:\n", + " print(\"No relevant documents found for the question.\")\n", + " else:\n", + " answer = answer_question(question, related_documents)\n", + " print(\"Answer:\", answer)\n", + " except Exception as e:\n", + " print(f\"Error: {e}\")\n", + "\n", + " elif choice == \"2\":\n", + " url = input(\"Enter the URL to analyze: \").strip()\n", + " if not url:\n", + " print(\"Please enter a valid URL.\")\n", + " else:\n", + " try:\n", + " print(\"Loading webpage...\")\n", + " web_documents = load_webpage(url)\n", + " if not web_documents:\n", + " print(\"No documents were loaded from the webpage. The page might be empty or not contain extractable text.\")\n", + " else:\n", + " web_chunks = split_text(web_documents)\n", + " if not web_chunks:\n", + " print(\"No text chunks were created. The webpage might not contain any text.\")\n", + " else:\n", + " web_retriever = build_hybrid_retriever(web_chunks)\n", + " print(f\"Processed {len(web_chunks)} chunks from webpage\")\n", + " question = input(\"Ask a question about the webpage: \").strip()\n", + " if not question:\n", + " print(\"Please enter a valid question.\")\n", + " else:\n", + " print(\"Analyzing content...\")\n", + " web_results = web_retriever.get_relevant_documents(question)\n", + " if not web_results:\n", + " print(\"No relevant documents found for the question.\")\n", + " else:\n", + " answer = answer_question(question, web_results)\n", + " print(\"Answer:\", answer)\n", + " except Exception as e:\n", + " print(f\"Error loading webpage: {e}\")\n", + "\n", + " elif choice == \"3\":\n", + " query = input(\"Enter your search query: \").strip()\n", + " if not query:\n", + " print(\"Please enter a valid search query.\")\n", + " else:\n", + " try:\n", + " print(\"Searching the web...\")\n", + " search_result = search_workflow.invoke({\"query\": query})\n", + " print(\"Response:\", search_result[\"response\"])\n", + " print(\"Sources:\")\n", + " for result in search_result[\"results\"]:\n", + " print(f\"- {result['title']}: {result['link']}\")\n", + " except Exception as e:\n", + " print(f\"Error during search: {e}\")\n", + "\n", + " else:\n", + " print(\"Invalid choice\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:base] *", + "language": "python", + "name": "conda-base-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}