LLMs/3b. Hybrid RAG.ipynb

239 lines
10 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "213b722c-b0d3-489c-a736-521e0d34dade",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import nltk\n",
"from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"from langchain_core.vectorstores import InMemoryVectorStore\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_ollama import OllamaEmbeddings, ChatOllama\n",
"from langchain_community.retrievers import BM25Retriever\n",
"from langchain.retrievers import EnsembleRetriever\n",
"from typing_extensions import TypedDict\n",
"from langgraph.graph import START, END, StateGraph\n",
"\n",
"# Ensure NLTK tokenizer is available\n",
"try:\n",
" nltk.data.find('tokenizers/punkt')\n",
"except LookupError:\n",
" nltk.download('punkt')\n",
"\n",
"# Initialize model and embeddings\n",
"model = ChatOllama(model=\"gemma3:12b\", temperature=0.2)\n",
"embeddings = OllamaEmbeddings(model=\"gemma3:12b\")\n",
"\n",
"# Vector store\n",
"vector_store = InMemoryVectorStore(embeddings)\n",
"\n",
"# Templates\n",
"qa_template = \"\"\"\n",
"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. \n",
"If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n",
"Question: {question} \n",
"Context: {context} \n",
"Answer:\n",
"\"\"\"\n",
"\n",
"# Text splitter\n",
"def split_text(documents):\n",
" text_splitter = RecursiveCharacterTextSplitter(\n",
" chunk_size=1000,\n",
" chunk_overlap=200,\n",
" add_start_index=True\n",
" )\n",
" return text_splitter.split_documents(documents)\n",
"\n",
"# PDF handling\n",
"def load_pdf(file_path):\n",
" if not os.path.exists(file_path):\n",
" raise FileNotFoundError(f\"File not found: {file_path}\")\n",
" loader = PDFPlumberLoader(file_path)\n",
" documents = loader.load()\n",
" return documents\n",
"\n",
"# Web page handling (using WebBaseLoader)\n",
"def load_webpage(url):\n",
" loader = WebBaseLoader(url)\n",
" documents = loader.load()\n",
" return documents\n",
"\n",
"# Hybrid retriever\n",
"def build_hybrid_retriever(documents):\n",
" vector_store.clear()\n",
" vector_store.add_documents(documents)\n",
" semantic_retriever = vector_store.as_retriever(search_kwargs={\"k\": 3})\n",
" bm25_retriever = BM25Retriever.from_documents(documents)\n",
" bm25_retriever.k = 3\n",
" hybrid_retriever = EnsembleRetriever(\n",
" retrievers=[semantic_retriever, bm25_retriever],\n",
" weights=[0.7, 0.3]\n",
" )\n",
" return hybrid_retriever\n",
"\n",
"# DuckDuckGo search implementation\n",
"def search_ddg(query, num_results=3):\n",
" from langchain_community.utilities import DuckDuckGoSearchAPIWrapper\n",
" search = DuckDuckGoSearchAPIWrapper()\n",
" results = search.results(query, num_results)\n",
" return results\n",
"\n",
"# Answer question with error handling\n",
"def answer_question(question, documents):\n",
" try:\n",
" context = \"\\n\\n\".join([doc.page_content for doc in documents])\n",
" prompt = ChatPromptTemplate.from_template(qa_template)\n",
" chain = prompt | model\n",
" return chain.invoke({\"question\": question, \"context\": context}).content\n",
" except Exception as e:\n",
" return f\"Error generating answer: {e}\"\n",
"\n",
"# Simple RAG node for web search\n",
"class WebSearchState(TypedDict):\n",
" query: str\n",
" results: list\n",
" response: str\n",
"\n",
"def web_search(state):\n",
" results = search_ddg(state[\"query\"])\n",
" return {\"results\": results}\n",
"\n",
"def generate_search_response(state):\n",
" try:\n",
" context = \"\\n\\n\".join([f\"{r['title']}: {r['snippet']}\" for r in state[\"results\"]])\n",
" prompt = ChatPromptTemplate.from_template(qa_template)\n",
" chain = prompt | model\n",
" response = chain.invoke({\"question\": state[\"query\"], \"context\": context})\n",
" return {\"response\": response.content}\n",
" except Exception as e:\n",
" return {\"response\": f\"Error generating response: {e}\"}\n",
"\n",
"# Build search graph\n",
"search_graph = StateGraph(WebSearchState)\n",
"search_graph.add_node(\"search\", web_search)\n",
"search_graph.add_node(\"generate\", generate_search_response)\n",
"search_graph.add_edge(START, \"search\")\n",
"search_graph.add_edge(\"search\", \"generate\")\n",
"search_graph.add_edge(\"generate\", END)\n",
"search_workflow = search_graph.compile()\n",
"\n",
"# Main command-line interface\n",
"if __name__ == \"__main__\":\n",
" print(\"Welcome to the Advanced RAG System\")\n",
" print(\"Choose an option:\")\n",
" print(\"1. Analyze PDF\")\n",
" print(\"2. Crawl URL\")\n",
" print(\"3. Search Internet\")\n",
" choice = input(\"Enter your choice (1/2/3): \")\n",
"\n",
" if choice == \"1\":\n",
" pdf_path = input(\"Enter the path to the PDF file: \").strip()\n",
" if not pdf_path:\n",
" print(\"Please enter a valid file path.\")\n",
" else:\n",
" try:\n",
" print(\"Processing PDF...\")\n",
" documents = load_pdf(pdf_path)\n",
" if not documents:\n",
" print(\"No documents were loaded from the PDF. The file might be empty or not contain extractable text.\")\n",
" else:\n",
" chunked_documents = split_text(documents)\n",
" if not chunked_documents:\n",
" print(\"No text chunks were created. The PDF might not contain any text.\")\n",
" else:\n",
" retriever = build_hybrid_retriever(chunked_documents)\n",
" print(f\"Processed {len(chunked_documents)} chunks\")\n",
" question = input(\"Ask a question about the PDF: \").strip()\n",
" if not question:\n",
" print(\"Please enter a valid question.\")\n",
" else:\n",
" print(\"Searching document...\")\n",
" related_documents = retriever.get_relevant_documents(question)\n",
" if not related_documents:\n",
" print(\"No relevant documents found for the question.\")\n",
" else:\n",
" answer = answer_question(question, related_documents)\n",
" print(\"Answer:\", answer)\n",
" except Exception as e:\n",
" print(f\"Error: {e}\")\n",
"\n",
" elif choice == \"2\":\n",
" url = input(\"Enter the URL to analyze: \").strip()\n",
" if not url:\n",
" print(\"Please enter a valid URL.\")\n",
" else:\n",
" try:\n",
" print(\"Loading webpage...\")\n",
" web_documents = load_webpage(url)\n",
" if not web_documents:\n",
" print(\"No documents were loaded from the webpage. The page might be empty or not contain extractable text.\")\n",
" else:\n",
" web_chunks = split_text(web_documents)\n",
" if not web_chunks:\n",
" print(\"No text chunks were created. The webpage might not contain any text.\")\n",
" else:\n",
" web_retriever = build_hybrid_retriever(web_chunks)\n",
" print(f\"Processed {len(web_chunks)} chunks from webpage\")\n",
" question = input(\"Ask a question about the webpage: \").strip()\n",
" if not question:\n",
" print(\"Please enter a valid question.\")\n",
" else:\n",
" print(\"Analyzing content...\")\n",
" web_results = web_retriever.get_relevant_documents(question)\n",
" if not web_results:\n",
" print(\"No relevant documents found for the question.\")\n",
" else:\n",
" answer = answer_question(question, web_results)\n",
" print(\"Answer:\", answer)\n",
" except Exception as e:\n",
" print(f\"Error loading webpage: {e}\")\n",
"\n",
" elif choice == \"3\":\n",
" query = input(\"Enter your search query: \").strip()\n",
" if not query:\n",
" print(\"Please enter a valid search query.\")\n",
" else:\n",
" try:\n",
" print(\"Searching the web...\")\n",
" search_result = search_workflow.invoke({\"query\": query})\n",
" print(\"Response:\", search_result[\"response\"])\n",
" print(\"Sources:\")\n",
" for result in search_result[\"results\"]:\n",
" print(f\"- {result['title']}: {result['link']}\")\n",
" except Exception as e:\n",
" print(f\"Error during search: {e}\")\n",
"\n",
" else:\n",
" print(\"Invalid choice\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:base] *",
"language": "python",
"name": "conda-base-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}