LLMs/4. Doc Search Agent.ipynb
2025-02-27 02:21:19 -08:00

156 lines
6.4 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0d92bf90-2548-4a24-87f7-2d87a7dbbd4c",
"metadata": {},
"outputs": [],
"source": [
"from langchain.vectorstores import Chroma\n",
"from langchain.embeddings import SentenceTransformerEmbeddings\n",
"from langchain.schema import Document\n",
"from ollama import chat\n",
"import os\n",
"import re\n",
"# CHANGED THE ORDER OF PATHS AND IT RETRIEVED THE RESPONSES CORRECTLY. \n",
"DOCUMENT_PATHS = [\n",
" r'/home/masih/rag_data/vape.txt',\n",
" r'/home/masih/rag_data/Hamrah.txt',\n",
" r'/home/masih/rag_data/Shah.txt',\n",
" r'/home/masih/rag_data/Khalife.txt',\n",
" r'/home/masih/rag_data/takapoo.txt',\n",
" r'/home/masih/rag_data/carbon.txt',\n",
"\n",
"]\n",
"\n",
"EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'\n",
"LLM_MODEL = 'gemma2:9b'\n",
"CHUNK_SIZE = 1000\n",
"OVERLAP = 200\n",
"CHROMA_PERSIST_DIR = r'\\home\\Masih\\chroma_db\\chroma_db'\n",
"\n",
"class ChromaRAGSystem:\n",
" def __init__(self):\n",
" # Init embedding model\n",
" self.embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)\n",
" # Vector store instance\n",
" self.vector_db = None\n",
" \n",
" def build_vector_store(self):\n",
" \"\"\"Process documents and create Chroma vector store\"\"\"\n",
" all_docs = []\n",
"\n",
" for doc_idx, path in enumerate(DOCUMENT_PATHS):\n",
" with open(path, 'r', encoding='utf-8') as f:\n",
" text = re.sub(r'\\s+', ' ', f.read()).strip()\n",
" # sliding window chunking\n",
" chunks = [\n",
" text[i:i+CHUNK_SIZE] \n",
" for i in range(0, len(text), CHUNK_SIZE - OVERLAP)\n",
" ]\n",
" # LangChain documents with metadata\n",
" for chunk in chunks:\n",
" all_docs.append(Document(\n",
" page_content=chunk,\n",
" metadata={\"source_doc\": doc_idx}\n",
" ))\n",
"\n",
" # Chroma vector store\n",
" self.vector_db = Chroma.from_documents(\n",
" documents=all_docs,\n",
" embedding=self.embeddings,\n",
" persist_directory=CHROMA_PERSIST_DIR\n",
" )\n",
" self.vector_db.persist()\n",
" \n",
" def load_vector_store(self):\n",
" \"\"\"Load existing Chroma vector store\"\"\"\n",
" self.vector_db = Chroma(\n",
" persist_directory=CHROMA_PERSIST_DIR,\n",
" embedding_function=self.embeddings\n",
" )\n",
" \n",
" def document_query(self, query, top_k=5):\n",
" \"\"\"Retrieve context from all documents based on query\"\"\"\n",
" # Perform similarity search across all documents\n",
" results = self.vector_db.similarity_search(query=query, k=top_k)\n",
" return [doc.page_content for doc in results]\n",
"\n",
"class AnswerGenerator:\n",
" def __init__(self, rag_system):\n",
" self.rag = rag_system\n",
" \n",
" def generate_response(self, question):\n",
" \"\"\"Generate context-aware answer using LLM\"\"\"\n",
" # Retrieve relevant context from the best matching documents\n",
" context_chunks = self.rag.document_query(question)\n",
" context = \"\\n\".join(context_chunks)\n",
" \n",
" prompt = f\"\"\"با استفاده از متن زیر به سوال پاسخ دهید:\n",
"{context}\n",
"\n",
"اگر پاسخ در متن وجود ندارد عبارت 'پاسخی یافت نشد' را برگردانید\n",
"\n",
"سوال: {question}\n",
"پاسخ:\"\"\"\n",
" \n",
" response = chat(model=LLM_MODEL, messages=[{'role': 'user', 'content': prompt}])\n",
" return response['message']['content']\n",
"\n",
"if __name__ == \"__main__\":\n",
" rag_system = ChromaRAGSystem()\n",
" \n",
" # Init vector store\n",
" if not os.path.exists(CHROMA_PERSIST_DIR):\n",
" print(\"Creating new vector store...\")\n",
" rag_system.build_vector_store()\n",
" else:\n",
" print(\"Loading existing vector store...\")\n",
" rag_system.load_vector_store()\n",
" \n",
" # Init answer generator\n",
" answer_engine = AnswerGenerator(rag_system)\n",
"\n",
" queries = [\n",
" \"چرا اینترنت همراه اول گوشی وصل نمیشود؟\",\n",
" \"چطوری ویپ مورد نظرمو پیدا کنم؟\",\n",
" \"شاه عباس که بود؟\",\n",
" \"خلیفه سلطان که بود و چه کرد؟\",\n",
" \"کربن اکتیو و کربن بلک چه هستند و چه تفاوتی دارند و برای چه استفاده میشن؟\",\n",
" \"شرکت تکاپو صنعت نامی چه محصولاتی ارایه میدهد؟ چه چیزی این شرکت را منحصر به فرد میسازد؟ سهام این شرکت صعودی است یا نزولی؟\"\n",
" ]\n",
" \n",
" with open( r'/home/masih/rag_data/response2.txt', 'w', encoding='utf-8') as output_file: \n",
" for q_num, query in enumerate(queries):\n",
" answer = answer_engine.generate_response(query)\n",
" output_file.write(f\"سوال {q_num+1}:\\n{query}\\n\\nپاسخ:\\n{answer}\\n\\n{'='*50}\\n\\n\")\n",
" print(f\"پردازش سوال {q_num+1}/{len(queries)} تکمیل شد\")\n",
"\n",
" print(\"تمامی سوالات با موفقیت پردازش شدند!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}